In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import array_to_img

from srgan_model import Generator, Discriminator
from losses import build_vgg, bce_loss


# === Loss Functions ===
bce = bce_loss(from_logits=True)
vgg = build_vgg()

def generator_loss(high_res, super_res, fake_output):
    content_loss = tf.reduce_mean(tf.square(vgg(high_res) - vgg(super_res)))
    adversarial_loss = bce(tf.ones_like(fake_output), fake_output)
    return content_loss + (1e-3 * adversarial_loss)

def discriminator_loss(real_output, fake_output):
    real_loss = bce(tf.ones_like(real_output), real_output)
    fake_loss = bce(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss


# === Setup ===
tf.random.set_seed(42)

os.makedirs("samples", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)

generator = Generator(10)
discriminator = Discriminator()

gen_optimizer = keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.9)
disc_optimizer = keras.optimizers.Adam(learning_rate=1e-5, beta_1=0.9)

# Dummy forward pass to initialize variables
generator(tf.zeros((1, 64, 64, 3)))
discriminator(tf.zeros((1, 128, 128, 3)))


# === Training Step Function ===
@tf.function
def train_step(low_res, high_res, disc_steps=5):
    for _ in range(disc_steps):
        with tf.GradientTape() as disc_tape:
            super_res = generator(low_res, training=True)
            real_output = discriminator(high_res, training=True)
            fake_output = discriminator(super_res, training=True)
            disc_loss = discriminator_loss(real_output, fake_output)

        grads_disc = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
        disc_optimizer.apply_gradients(zip(grads_disc, discriminator.trainable_variables))

    with tf.GradientTape() as gen_tape:
        super_res = generator(low_res, training=True)
        fake_output = discriminator(super_res, training=True)
        gen_loss = generator_loss(high_res, super_res, fake_output)

    grads_gen = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gen_optimizer.apply_gradients(zip(grads_gen, generator.trainable_variables))

    return gen_loss, disc_loss


# === Checkpointing ===
ckpt = tf.train.Checkpoint(generator=generator,
                           discriminator=discriminator,
                           generator_optimizer=gen_optimizer,
                           discriminator_optimizer=disc_optimizer)
manager = tf.train.CheckpointManager(ckpt, './srgan_ckpts', max_to_keep=5)

if manager.latest_checkpoint:
    ckpt.restore(manager.latest_checkpoint)
    print(f"✅ Restored from checkpoint: {manager.latest_checkpoint}")
else:
    print("🆕 Training from scratch")


# === Save Sample Output ===
def save_sample_output(epoch, generator, val_dataset):
    for lr_img, hr_img in val_dataset.take(1):
        sr_img = generator(lr_img, training=False)
        sr = array_to_img(sr_img[0])
        hr = array_to_img(hr_img[0])
        lr = array_to_img(tf.image.resize(lr_img[0], (128, 128)))  # upscale LR for comparison

        stacked = np.hstack([np.array(lr), np.array(sr), np.array(hr)])
        plt.imsave(f"samples/epoch_{epoch:03}.png", stacked.astype("uint8"))
        break


# === Resume Helper ===
def get_latest_epoch(manager):
    if manager.latest_checkpoint:
        ckpt_name = os.path.basename(manager.latest_checkpoint)
        try:
            return int(ckpt_name.split('-')[-1])
        except ValueError:
            return 0
    return 0


# === Training Loop ===
def train(train_dataset, val_dataset, epochs, patience=10, start_epoch=0):
    best_loss = float('inf')
    wait = 0

    for epoch in range(start_epoch, epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        g_loss_metric = tf.keras.metrics.Mean()
        d_loss_metric = tf.keras.metrics.Mean()

        for step, (lr_batch, hr_batch) in enumerate(train_dataset, start=1):
            gen_loss, disc_loss = train_step(lr_batch, hr_batch)
            g_loss_metric.update_state(gen_loss)
            d_loss_metric.update_state(disc_loss)

        avg_gen_loss = g_loss_metric.result()
        avg_disc_loss = d_loss_metric.result()
        print(f"Epoch {epoch + 1} Summary → Gen Loss: {avg_gen_loss:.4f}, Disc Loss: {avg_disc_loss:.4f}")

        if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
            save_sample_output(epoch + 1, generator, val_dataset)
            manager.save(checkpoint_number=epoch + 1)
            generator.save(f"checkpoints/generator_epoch_{epoch + 1}.h5")
            discriminator.save(f"checkpoints/discriminator_epoch_{epoch + 1}.h5")


# === Start Training ===
latest_epoch = get_latest_epoch(manager)

# Define your datasets before running the training:
# train_dataset = ...
# val_dataset = ...
# train(train_dataset, val_dataset, epochs=50, start_epoch=latest_epoch)

