In [None]:
from srwgan_model import Generator, Discriminator
from losses import build_vgg, perceptual_loss, wasserstein_loss, mse_loss
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import array_to_img
import matplotlib.pyplot as plt

# Create samples directory
os.makedirs("samples", exist_ok=True)

generator = Generator(12)
discriminator = Discriminator()
vgg_model = build_vgg()
dummy_input = tf.random.normal((1, 128, 128, 3))
vgg_model(dummy_input) 
g_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
d_optimizer = tf.keras.optimizers.RMSprop(5e-5)
#Dummy forward pass (ensure variable creation)
generator(tf.zeros((1, 64, 64, 3)))
discriminator(tf.zeros((1, 128, 128, 3)))

# === Training Step Function ===
@tf.function
def train_step(lr_batch, hr_batch, critic_iter):
    batch_size = tf.shape(lr_batch)[0]
    real_labels = -tf.ones((batch_size, 1))
    fake_labels = tf.ones((batch_size, 1))

    # === Train Discriminator multiple times ===
    for _ in range(critic_iter):
        with tf.GradientTape() as disc_tape:
            super_res = generator(lr_batch, training=True)  # Forward pass only
            real_output = discriminator(hr_batch, training=True)
            fake_output = discriminator(super_res, training=True)

            d_loss_real = wasserstein_loss(real_labels, real_output)
            d_loss_fake = wasserstein_loss(fake_labels, fake_output)
            d_loss = d_loss_real + d_loss_fake

        disc_grads = disc_tape.gradient(d_loss, discriminator.trainable_variables)
        d_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_variables))
        # === Weight Clipping ===
        for var in discriminator.trainable_variables:
            var.assign(tf.clip_by_value(var, -0.05, 0.05))

    # === Train Generator once ===
    with tf.GradientTape() as gen_tape:
        super_res = generator(lr_batch, training=True)
        fake_output = discriminator(super_res, training=True)

        adv_loss = wasserstein_loss(real_labels, fake_output)
        cont_loss = mse_loss(hr_batch, super_res)
        perc_loss = perceptual_loss(vgg_model, hr_batch, super_res)

        g_loss = 1e-3 * adv_loss + 0.01 * perc_loss + cont_loss

    gen_grads = gen_tape.gradient(g_loss, generator.trainable_variables)
    g_optimizer.apply_gradients(zip(gen_grads, generator.trainable_variables))

    return g_loss, d_loss

# === Sample Image Saving ===
def save_sample_output(epoch, generator, val_dataset):
    for lr_img, hr_img in val_dataset.take(1):
        sr_img = generator(lr_img, training=False)
        sr = array_to_img(sr_img[5])
        hr = array_to_img(hr_img[5])
        lr = array_to_img(tf.image.resize(lr_img[5], (128, 128)))  # for visual comparison

        stacked = np.hstack([np.array(lr), np.array(sr), np.array(hr)])
        plt.imsave(f"samples/epoch_{epoch:03}.png", stacked.astype("uint8"))
        break

def train(train_dataset, val_dataset, epochs, patience=10, start_epoch=0, critic_iter=5):
    for epoch in range(start_epoch, epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        g_loss_metric = tf.keras.metrics.Mean()
        d_loss_metric = tf.keras.metrics.Mean()
        
        for step, (lr_batch, hr_batch) in enumerate(train_dataset, start=1):
            g_loss, d_loss = train_step(lr_batch, hr_batch, critic_iter)
            g_loss_metric.update_state(g_loss)
            d_loss_metric.update_state(d_loss)

        avg_gen_loss = g_loss_metric.result()
        avg_disc_loss = d_loss_metric.result()
        print(f"Epoch {epoch + 1} Summary → Gen Loss: {avg_gen_loss:.4f} Disc Loss: {avg_disc_loss:.4f}")

        # === Save samples and checkpoints ===
        if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
            save_sample_output(epoch + 1, generator, val_dataset)
            manager.save(checkpoint_number=epoch + 1)
            generator.save(f"/kaggle/working/checkpoints/srw_generator_epoch_{epoch+1}.h5")
            discriminator.save(f"/kaggle/working/checkpoints/srw_discriminator_epoch_{epoch+1}.h5")

# Checkpointing
ckpt = tf.train.Checkpoint(generator=generator,
                            discriminator=discriminator,
                            generator_optimizer=g_optimizer,
                            discriminator_optimizer=d_optimizer)
manager = tf.train.CheckpointManager(ckpt, './srgan_ckpts', max_to_keep=5)

# Restore if checkpoint exists
if manager.latest_checkpoint:
    ckpt.restore(manager.latest_checkpoint)
    print(f"✅ Restored from checkpoint: {manager.latest_checkpoint}")
else:
    print("🆕 Training from scratch")

# === Resume Support ===
def get_latest_epoch(manager):
    if manager.latest_checkpoint:
        ckpt_name = os.path.basename(manager.latest_checkpoint)
        try:
            return int(ckpt_name.split('-')[-1])
        except ValueError:
            return 0
    return 0

latest_epoch = get_latest_epoch(manager)
# === Start Training ===
train(train_dataset, val_dataset, epochs=50, patience=10, start_epoch=latest_epoch)