In [1]:
from Wgan import build_generator,build_critic

# Define loss functions
def critic_loss(real_score, fake_score):
    return tf.reduce_mean(fake_score) - tf.reduce_mean(real_score)

def generator_loss(fake_score):
    return -tf.reduce_mean(fake_score)

generator = build_generator()
critic = build_critic()

CLIP_VALUE = 0.01
g_opt = tf.keras.optimizers.RMSprop(1e-4)
c_opt = tf.keras.optimizers.RMSprop(5e-5)

#checkpoint
ckpt=tf.train.Checkpoint(generator=generator,discriminator=critic,
                         g_optimizer=g_opt,d_optimizer=c_opt)
manager=tf.train.CheckpointManager(ckpt,'./wgan_ckpts',max_to_keep=3)

if manager.latest_checkpoint:
  ckpt.restore(manager.latest_checkpoint)
  print('Restored from',manager.latest_checkpoint)
else:
  print('training from scratch')

from keras.preprocessing.image import array_to_img
import matplotlib.pyplot as plt
import os

os.makedirs("samples", exist_ok=True)

def save_sample_output(epoch, generator, val_dataset):
    for lr_img, hr_img in val_dataset.take(1):
        sr_img = generator(lr_img, training=False)
        sr = array_to_img(sr_img[0])
        hr = array_to_img(hr_img[0])
        lr = array_to_img(tf.image.resize(lr_img[0], (128, 128)))

        canvas = np.hstack([np.array(lr), np.array(sr), np.array(hr)])
        plt.imsave(f"samples/epoch_{epoch:03}.png", canvas.astype("uint8"))
        break

# Dummy forward pass to create variables before using @tf.function
generator(tf.zeros((1, 128, 128, 3)))
critic([tf.zeros((1, 128, 128, 3)),tf.zeros((1, 128, 128, 3))])

@tf.function
def train_critic(corrupted_batch, clean_batch):
    fake_images = generator(corrupted_batch, training=True)
    with tf.GradientTape() as tape:
        real_score = critic([corrupted_batch, clean_batch], training=True)
        fake_score = critic([corrupted_batch, fake_images], training=True)
        loss = critic_loss(real_score, fake_score)
    grads = tape.gradient(loss, critic.trainable_variables)
    c_opt.apply_gradients(zip(grads, critic.trainable_variables))
    for var in critic.trainable_variables:
        var.assign(tf.clip_by_value(var, -CLIP_VALUE, CLIP_VALUE))
    return loss

@tf.function
def train_generator(corrupted_batch):
    with tf.GradientTape() as tape:
        fake_images = generator(corrupted_batch, training=True)
        fake_score = critic([corrupted_batch, fake_images], training=True)
        loss = generator_loss(fake_score)
    grads = tape.gradient(loss, generator.trainable_variables)
    g_opt.apply_gradients(zip(grads, generator.trainable_variables))
    return loss

def train(train_dataset, val_dataset, epochs, start_epoch=0,CRTIC_ITER=5):

    for epoch in range(start_epoch, epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        g_loss_metric = tf.keras.metrics.Mean()
        c_loss_metric = tf.keras.metrics.Mean()

        for lr_batch, hr_batch in train_dataset:
            for _ in range(CRITIC_ITER):
                critic_loss_val = train_critic(lr_batch, hr_batch)
                c_loss_metric.update_state(critic_loss_val)

            gen_loss_val = train_generator(lr_batch)
            g_loss_metric.update_state(gen_loss_val)

        avg_gen_loss = g_loss_metric.result()
        avg_critic_loss = c_loss_metric.result()

        if epoch % 10 == 0 or epoch == epochs - 1:
            save_sample_output(epoch, generator, val_dataset)
            manager.save(checkpoint_number=epoch + 1)
            generator.save(f"/kaggle/working/checkpoints/wgan_generator_epoch_{epoch+1}.h5")
            critic.save(f"/kaggle/working/checkpoints/wgan_critic_epoch_{epoch+1}.h5")

        print(f"Epoch {epoch+1}: Gen Loss: {avg_gen_loss:.4f}, Disc Loss: {avg_critic_loss:.4f}")

def get_latest_epoch(manager):
    if manager.latest_checkpoint:
        ckpt_name = os.path.basename(manager.latest_checkpoint)
        try:
            return int(ckpt_name.split('-')[-1])
        except ValueError:
            return 0
    return 0

latest_epoch=get_latest_epoch(manager)
#Start training
train(train_dataset,val_dataset, epochs=200, start_epoch=latest_epoch)

ModuleNotFoundError: No module named 'Wgan'