In [None]:
import os
import glob
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from google.colab import drive
drive.mount('/content/drive')
DATASET_ROOT = "/content/drive/MyDrive/data/HQ-50K"
processed_images = np.load("/content/drive/MyDrive/data/data_processed/images_64.npy")

BASE = "/content/drive/MyDrive/GAN_results"

psnr_values = []
train_gen_losses, train_disc_losses = [], []
val_gen_losses, val_disc_losses = [], []

#Separate 64x64 images into 90% training and 10% testing with a random seed.
train_images, test_images = train_test_split(processed_images, test_size=0.1, random_state=42)

#Grab a sample for comparisons
sample_noisy = train_images[0:16]
sample_clean = test_images[0:16]

BATCH_SIZE = 32
EPOCHS = 30


#Normalize the dataset
dataset = tf.data.Dataset.from_tensor_slices(train_images)
dataset = dataset.map(lambda x: (tf.clip_by_value(x + tf.random.normal(tf.shape(x), stddev=0.1),
                                                 -1.0, 1.0), x))
dataset = dataset.shuffle(1000).batch(BATCH_SIZE)


# Modify generator to accept 64x64 images instead of working up from random noise
def make_generator_model():
    inputs = tf.keras.Input(shape=(64, 64, 3))

    x = layers.Conv2D(64, 5, strides=2, padding="same")(inputs)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2D(128, 5, strides=2, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2DTranspose(128, 5, strides=2, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2DTranspose(64, 5, strides=2, padding="same")(x)
    x = layers.LeakyReLU()(x)

    outputs = layers.Conv2D(3, 5, activation="tanh", padding="same")(x)
    return tf.keras.Model(inputs, outputs)


#Accept a noisy image and its clean counterpart to compare against.
def make_discriminator_model():
    noisy_input = tf.keras.Input(shape=(64, 64, 3))
    clean_input = tf.keras.Input(shape=(64, 64, 3))
    x = layers.Concatenate()([noisy_input, clean_input])

    x = layers.Conv2D(64, 5, strides=2, padding="same")(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2D(128, 5, strides=2, padding="same")(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2D(256, 5, strides=2, padding="same")(x)
    x = layers.LeakyReLU()(x)

    x = layers.Flatten()(x)
    x = layers.Dense(1)(x)

    return tf.keras.Model([noisy_input, clean_input], x)


#DCGAN recommends binary cross entropy to punish false positives and false negatives.
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    return (cross_entropy(tf.ones_like(real_output), real_output) +
            cross_entropy(tf.zeros_like(fake_output), fake_output))

def make_generator_loss(LAMBDA):
    def generator_loss(fake_output, real_image, generated_image):
        adv = cross_entropy(tf.ones_like(fake_output), fake_output)
        l1 = tf.reduce_mean(tf.abs(real_image - generated_image))
        return adv + LAMBDA * l1
    return generator_loss


def evaluate_psnr(generator, clean_images):
    noisy_images = tf.clip_by_value(clean_images +
                                    tf.random.normal(tf.shape(clean_images), stddev=0.1),
                                    -1.0, 1.0)

    generated = generator(noisy_images, training=False).numpy()

    clean_rescaled = (clean_images + 1) / 2
    gen_rescaled = (generated + 1) / 2

    scores = []
    for clean, gen in zip(clean_rescaled, gen_rescaled):
        scores.append(compare_psnr(clean, gen, data_range=1.0))

    return float(np.mean(scores))


@tf.function
def train_step(noisy_batch, clean_batch,
               generator, discriminator,
               generator_optimizer, discriminator_optimizer,
               generator_loss_fn):

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noisy_batch, training=True)

        real_output = discriminator([noisy_batch, clean_batch], training=True)
        fake_output = discriminator([noisy_batch, generated_images], training=True)

        gen_loss = generator_loss_fn(fake_output, clean_batch, generated_images)
        disc_loss = discriminator_loss(real_output, fake_output)

    #Generate gradients reflecting the losses of each model, use to teach the model
    gen_grads = gen_tape.gradient(gen_loss, generator.trainable_variables)
    disc_grads = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    #Apply the gradients to the models
    generator_optimizer.apply_gradients(zip(gen_grads, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_variables))

#Sample the loss by using the model at each epoch and saving it to track progress
def compute_losses(dataset, generator, discriminator, generator_loss_fn):
    gen_losses = []
    disc_losses = []

    for noisy_batch, clean_batch in dataset:
        generated_images = generator(noisy_batch, training=False)
        fake_output = discriminator([noisy_batch, generated_images], training=False)
        real_output = discriminator([noisy_batch, clean_batch], training=False)

        gen_loss = generator_loss_fn(fake_output, clean_batch, generated_images)
        disc_loss = discriminator_loss(real_output, fake_output)

        gen_losses.append(gen_loss.numpy())
        disc_losses.append(disc_loss.numpy())

    avg_gen_loss = sum(gen_losses) / len(gen_losses)
    avg_disc_loss = sum(disc_losses) / len(disc_losses)
    return avg_gen_loss, avg_disc_loss

def save_loss_curve(train_gen_losses, train_disc_losses,
                    val_gen_losses, val_disc_losses,
                    epoch, outdir):
    os.makedirs(outdir, exist_ok=True)
    plt.figure(figsize=(8,5))
    plt.plot(range(1, epoch+1), train_gen_losses, label="Train Gen Loss", marker='o')
    plt.plot(range(1, epoch+1), train_disc_losses, label="Train Disc Loss", marker='o')
    plt.plot(range(1, epoch+1), val_gen_losses, label="Val Gen Loss", marker='x')
    plt.plot(range(1, epoch+1), val_disc_losses, label="Val Disc Loss", marker='x')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    path = os.path.join(outdir, f"loss_curve_epoch_{epoch:03d}.png")
    plt.savefig(path)
    plt.close()

def save_epoch_image(model, epoch, sample, outdir):
    os.makedirs(outdir, exist_ok=True)

    preds = model(sample, training=False)
    plt.figure(figsize=(4,4))

    for i in range(preds.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow((preds[i] * 0.5 + 0.5))
        plt.axis("off")

    path = os.path.join(outdir, f"epoch_{epoch:04d}.png")
    plt.savefig(path)
    plt.close()

def update_psnr_log(generator, clean_samples, epoch, psnr_values, outdir):
    score = evaluate_psnr(generator, clean_samples)
    psnr_values.append(score)

    plt.figure()
    plt.plot(range(1, len(psnr_values) + 1), psnr_values, marker="o")
    plt.xlabel("Epoch")
    plt.ylabel("PSNR (dB)")
    plt.title("PSNR per Epoch")
    plt.grid(True)

    os.makedirs(outdir, exist_ok=True)
    path = os.path.join(outdir, "psnr_curve.png")
    plt.savefig(path)
    plt.close()

    return psnr_values

#I tried to be fancy and write a loop and automate the model tuning
#It failed and I can't figure out why, something wrong with Tensorflow
#I just put the values I want and it will still run a single iteration for me
#before the program crashes

lambda_values = [80]
adam_values = [1e-4]

results = {}

for LAMBDA in lambda_values:
    for adam_lr in adam_values:
        print(f"\n\n=== Training with LAMBDA={LAMBDA}, Adam LR={adam_lr} ===")

        img_dir = f"{BASE}/lambda{LAMBDA}_adam{adam_lr}"
        generator = make_generator_model()
        discriminator = make_discriminator_model()

        generator_optimizer = tf.keras.optimizers.Adam(adam_lr)
        discriminator_optimizer = tf.keras.optimizers.Adam(adam_lr)

        generator_loss_fn = make_generator_loss(LAMBDA)

        #Training loop
        for epoch in range(1, EPOCHS + 1):
            for noisy_batch, clean_batch in dataset:
                train_step(noisy_batch, clean_batch,
                           generator, discriminator,
                           generator_optimizer, discriminator_optimizer,
                           generator_loss_fn)

            avg_train_gen, avg_train_disc = compute_losses(dataset, generator, discriminator, generator_loss_fn)
            train_gen_losses.append(avg_train_gen)
            train_disc_losses.append(avg_train_disc)

            #Log all outputs and scores
            val_dataset = tf.data.Dataset.from_tensor_slices((sample_noisy, sample_clean)).batch(BATCH_SIZE)
            avg_val_gen, avg_val_disc = compute_losses(val_dataset, generator, discriminator, generator_loss_fn)
            val_gen_losses.append(avg_val_gen)
            val_disc_losses.append(avg_val_disc)

            save_loss_curve(train_gen_losses, train_disc_losses,
                    val_gen_losses, val_disc_losses,
                    epoch, img_dir)

            print(f"Epoch {epoch}/{EPOCHS} complete")

            img_dir = f"{BASE}/lambda{LAMBDA}_adam{adam_lr}"
            save_epoch_image(generator, epoch+1, sample_noisy, img_dir)
            psnr_values = update_psnr_log(generator,
                                  sample_clean,
                                  epoch,
                                  psnr_values,
                                  img_dir)


        #Print directly to console for immediate viewing of results
        psnr_score = evaluate_psnr(generator, sample_clean)
        print(f"[RESULT] LAMBDA={LAMBDA}, Adam={adam_lr} â†’ PSNR={psnr_score:.2f}")


        results[(LAMBDA, adam_lr)] = psnr_score

        gen_path = f"{img_dir}/generator.keras"
        disc_path = f"{img_dir}/discriminator.keras"

        #Save the final model
        generator.save(gen_path)
        discriminator.save(disc_path)

        print(f"Saved models to {gen_path} and {disc_path}")
