# 16x Image Enlargement using a Super Resolution Generative Adversarial Network

In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
from PIL import Image
from tensorflow import keras
from tensorflow.keras.layers import Input, Dense, Conv2D, Conv2DTranspose, Flatten, Reshape, LeakyReLU, Dropout, BatchNormalization, GlobalAveragePooling2D, UpSampling2D, Add, PReLU, Lambda
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.applications import VGG19
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
# If using colab
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
def residual_block(x, filters=64):
    shortcut = x
    x = Conv2D(filters, (3,3), padding='same')(x)
    x = BatchNormalization()(x)
    x = PReLU(shared_axes=[1,2])(x)
    x = Conv2D(filters, (3,3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Add()([shortcut, x])
    return x

def upsample_pixel_shuffle(x, filters=64, scale=2):
    x = Conv2D(filters * (scale**2), (3,3), padding='same')(x)

    x = tf.keras.layers.UpSampling2D(size=scale, interpolation='nearest')(x)
    x = Conv2D(filters, (3,3), padding='same')(x)
    x = PReLU(shared_axes=[1,2])(x)
    return x

def build_generator(input_shape=(96,96,3), num_res_blocks=8):
    inputs = Input(shape=input_shape)

    x = Conv2D(64, (9,9), padding='same')(inputs)
    x = PReLU(shared_axes=[1,2])(x)
    skip_connection = x

    for _ in range(num_res_blocks):
        x = residual_block(x, 64)

    x = Conv2D(64, (3,3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Add()([x, skip_connection])

    x = upsample_pixel_shuffle(x, 256, scale=2)  # 96x96 -> 192x192
    x = upsample_pixel_shuffle(x, 256, scale=2)  # 192x192 -> 384x384

    outputs = Conv2D(3, (9,9), padding='same', activation='tanh')(x)

    return Model(inputs, outputs, name="Generator")

def build_discriminator(input_shape=(384, 384, 3)):
    inputs = Input(shape=input_shape)

    x = Conv2D(64, (3,3), strides=1, padding='same')(inputs)
    x = LeakyReLU(negative_slope=0.2)(x)

    x = Conv2D(64, (1,1), strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(negative_slope=0.2)(x)

    x = Conv2D(128, (3,3), strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(negative_slope=0.2)(x)

    x = Conv2D(128, (1,1), strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(negative_slope=0.2)(x)

    x = Conv2D(256, (3,3), strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(negative_slope=0.2)(x)

    x = Conv2D(512, (3,3), strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(negative_slope=0.2)(x)

    x = GlobalAveragePooling2D()(x)
    x = Dense(128)(x)
    x = LeakyReLU(0.2)(x)
    outputs = Dense(1)(x)

    return Model(inputs, outputs, name="Discriminator")

disc_model = build_discriminator()
gen_model = build_generator()
print(gen_model.count_params(), disc_model.count_params())

In [None]:
def build_vgg_feature_extractor(layer_name='block5_conv4'):
    vgg = VGG19(weights='imagenet', include_top=False)
    vgg.trainable = False
    out = vgg.get_layer(layer_name).output
    model = Model(vgg.input, out)
    model.trainable = False
    return model

vgg_feat = build_vgg_feature_extractor('block5_conv4')

mse = tf.keras.losses.MeanSquaredError()

def content_loss_fn(hr, sr):
    def preprocess_for_vgg(x):
        x = (x + 1.0) * 127.5  # [-1,1] -> [0,255]
        return tf.keras.applications.vgg19.preprocess_input(x)
    hr_proc = preprocess_for_vgg(hr)
    sr_proc = preprocess_for_vgg(sr)
    hr_feat = vgg_feat(hr_proc)
    sr_feat = vgg_feat(sr_proc)
    return mse(hr_feat, sr_feat)

print("D trainable weights:", len(disc_model.trainable_weights))

content_weight = 1e-1
adv_weight = 5e-3
n_critic = 2 # Discriminator updates per Generator update

d_optimizer = tf.keras.optimizers.Adam(1e-5, 0.5)
g_optimizer = tf.keras.optimizers.Adam(1e-4, 0.5)

@tf.function(reduce_retracing=True)
def train_step(lr_batch, hr_batch):
    batch_size = tf.shape(lr_batch)[0]
    real_labels = tf.ones((batch_size, 1)) * 0.9
    fake_labels = tf.zeros((batch_size, 1)) + 0.1

    for _ in range(n_critic):
        with tf.GradientTape() as tape_d:
            sr = gen_model(lr_batch, training=True)
            real_out = disc_model(hr_batch, training=True)
            fake_out = disc_model(sr, training=True)

            # LSGAN discriminator loss: 0.5 * ( (D(x)-1)^2 + (D(G)-0)^2 )
            d_loss_real = mse(real_labels, real_out)
            d_loss_fake = mse(fake_labels, fake_out)
            d_loss = 0.5 * (d_loss_real + d_loss_fake)

        grads = tape_d.gradient(d_loss, disc_model.trainable_variables)
        d_optimizer.apply_gradients(zip(grads, disc_model.trainable_variables))

    with tf.GradientTape() as tape_g:
        sr = gen_model(lr_batch, training=True)
        fake_out_for_g = disc_model(sr, training=False)
        adv_loss = mse(real_labels, fake_out_for_g)
        cont_loss = content_loss_fn(hr_batch, sr)
        g_loss = content_weight * cont_loss + adv_weight * adv_loss

    grads = tape_g.gradient(g_loss, gen_model.trainable_variables)
    g_optimizer.apply_gradients(zip(grads, gen_model.trainable_variables))

    return d_loss, g_loss, cont_loss, adv_loss


In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint

# checkpoint_dir = '/content/drive/MyDrive/models/checkpoints' If using colab
checkpoint_dir = '/models/checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

checkpoint = tf.train.Checkpoint(
    generator_optimizer=g_optimizer,
    discriminator_optimizer=d_optimizer,
    generator=gen_model,
    discriminator=disc_model
)

checkpoint_manager = tf.train.CheckpointManager(
    checkpoint,
    directory=checkpoint_dir,
    max_to_keep=5
)

def train_gan_full(lr_images, hr_images, epochs=1000, batch_size=2,
                   save_interval=50, start_epoch=0):
    dataset = tf.data.Dataset.from_tensor_slices((lr_images, hr_images)).shuffle(len(hr_images), reshuffle_each_iteration=True).take(100).batch(batch_size).prefetch(4)
    for epoch in range(start_epoch, epochs):
        d_losses = []
        g_losses = []

        for lr_batch, hr_batch in dataset:
            d_loss, g_loss, cont_loss, adv_loss = train_step(lr_batch, hr_batch)
            d_losses.append(d_loss.numpy())
            g_losses.append(g_loss.numpy())
        print(f"Epoch {epoch+1}/{epochs} D_loss={np.mean(d_losses):.4f} G_loss={np.mean(g_losses):.4f} content={cont_loss.numpy():.4f} adv={adv_loss.numpy():.4f}")
        if (epoch + 1) % save_interval == 0:
            save_path = checkpoint_manager.save()
            print(f"Checkpoint saved at epoch {epoch+1}: {save_path}")
            gen_model.save(f'{checkpoint_dir}/gen_epoch_{epoch+1}.keras')
            disc_model.save(f'{checkpoint_dir}/disc_epoch_{epoch+1}.keras')

    checkpoint_manager.save()
    gen_model.save(f'{checkpoint_dir}/gen_final.keras')
    disc_model.save(f'{checkpoint_dir}/disc_final.keras')
    print("Training complete. Final model saved.")

def resume_training(lr_images, hr_images, epochs=2000, batch_size=2):
    """Load latest checkpoint and resume training"""

    latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)

    if latest_checkpoint:
        status = checkpoint.restore(latest_checkpoint)
        print(f"Restored from checkpoint: {latest_checkpoint}")

        try:
            start_epoch = int(latest_checkpoint.split('-')[-1])
            #print(f"Resuming from epoch {start_epoch}")
        except:
            start_epoch = 0
            print("Could not determine epoch number, starting from 0")
    else:
        print("No checkpoint found. Starting fresh training.")
        start_epoch = 0

    train_gan_full(lr_images, hr_images, epochs, batch_size,
                   save_interval=50, start_epoch=0)

In [None]:
def _load_single_pair(args):
    name, lr_dir, hr_dir, lr_size, hr_size = args
    lr_path = os.path.join(lr_dir, f"{name}.png")
    hr_path = os.path.join(hr_dir, f"{name}.png")
    lr = Image.open(lr_path).convert("RGB")
    lr = lr.resize(lr_size, Image.LANCZOS)
    lr_arr = np.array(lr, dtype=np.float32) / 127.5 - 1.0
    hr = Image.open(hr_path).convert("RGB")
    hr = hr.resize(hr_size, Image.LANCZOS)
    hr_arr = np.array(hr, dtype=np.float32) / 127.5 - 1.0

    return lr_arr, hr_arr

def load_image_pairs_fast(lr_dir, hr_dir, pair_filenames, lr_size=(96,96), hr_size=(384,384), save_dir=None, max_workers=8):
    lr_images = []
    hr_images = []

    args_list = [(name, lr_dir, hr_dir, lr_size, hr_size) for name in pair_filenames]

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        results = list(tqdm(executor.map(_load_single_pair, args_list), total=len(args_list), desc="Loading images"))

    for lr_arr, hr_arr in results:
        lr_images.append(lr_arr)
        hr_images.append(hr_arr)

    lr_images = np.stack(lr_images)
    hr_images = np.stack(hr_images)

    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        lr_npz_path = os.path.join(save_dir, "lr_images.npz")
        hr_npz_path = os.path.join(save_dir, "hr_images.npz")

        np.savez_compressed(lr_npz_path, lr_images)
        np.savez_compressed(hr_npz_path, hr_images)

        print(f"Saved LR to {lr_npz_path}")
        print(f"Saved HR to {hr_npz_path}")

    return lr_images, hr_images

lr_dir = "processed/lr"
hr_dir = "processed/hr"
save_dir = "processed"

lr_files = [f for f in os.listdir(lr_dir) if f.endswith(".png")]
pair_filenames = [os.path.splitext(f)[0] for f in lr_files]

lr_images, hr_images = load_image_pairs_fast(lr_dir, hr_dir, pair_filenames, save_dir=save_dir)

print("LR shape:", lr_images.shape)
print("HR shape:", hr_images.shape)


In [None]:
def load_npz_images(lr_npz_path, hr_npz_path):
    if not os.path.exists(lr_npz_path):
        raise FileNotFoundError(f"LR file not found: {lr_npz_path}")
    if not os.path.exists(hr_npz_path):
        raise FileNotFoundError(f"HR file not found: {hr_npz_path}")

    lr_data = np.load(lr_npz_path)
    hr_data = np.load(hr_npz_path)

    lr_images = lr_data['arr_0']
    hr_images = hr_data['arr_0']

    print(f"Loaded LR images: {lr_images.shape}")
    print(f"Loaded HR images: {hr_images.shape}")

    return lr_images, hr_images

# Colab
# lr_npz = "/content/drive/MyDrive/processed/lr_images.npz"
# hr_npz = "/content/drive/MyDrive/processed/hr_images.npz"

lr_npz = "processed/lr_images.npz"
hr_npz = "processed/hr_images.npz"

lr_images, hr_images = load_npz_images(lr_npz, hr_npz)

print("Subset LR shape:", lr_images.shape)
print("Subset HR shape:", hr_images.shape)


In [None]:
x = input("T for training, R for resume: ")
if x.lower()=="t":
    train_gan_full(lr_images, hr_images, epochs=100, batch_size=4, save_interval=50)
if x.lower()=="r":
    resume_training(lr_images, hr_images, epochs=100, batch_size=4)


In [None]:
# Colab
# gen_model.save('/content/drive/MyDrive/models/srgan_generator.keras')
# disc_model.save('/content/drive/MyDrive/models/srgan_discriminator.keras')
gen_model.save('models/srgan_generator.keras')
disc_model.save('models/srgan_discriminator.keras')

In [None]:
import numpy as np
from PIL import Image
from tensorflow.keras.models import load_model
import time

# gen = load_model('/content/drive/MyDrive/models/srgan_generator.keras', compile=False, safe_mode=False)
gen = load_model('/models/srgan_generator.keras', compile=False, safe_mode=False)
timestamp = int(time.time())
def preprocess_img(path):
    img = Image.open(path)
    arr = (np.array(img, dtype=np.float32) / 127.5) - 1.0
    return arr[np.newaxis, ...]
test_file_name = '/content/1750465599751830_p000'
lr_input = preprocess_img(f'{test_file_name}.png')
sr_output = gen.predict(lr_input)

sr_img = ((sr_output[0] + 1.0) * 127.5).clip(0,255).astype(np.uint8)
Image.fromarray(sr_img).save(f'/{test_file_name}super_resolved{timestamp}.png')


In [None]:
timestamp = int(time.time())
# gen = load_model('/content/drive/MyDrive/models/srgan_generator.keras', compile=False, safe_mode=False)
gen = load_model('/models/srgan_generator.keras', compile=False, safe_mode=False)
def comprehensive_evaluation(lr_test, hr_test, gen_model):
    sr_test = gen_model.predict(lr_test, verbose=0)

    hr_01 = tf.cast((hr_test + 1.0) / 2.0, tf.float32)
    sr_01 = tf.cast((sr_test + 1.0) / 2.0, tf.float32)
    lr_01 = tf.cast((lr_test + 1.0) / 2.0, tf.float32)

    psnr_vals = tf.image.psnr(sr_01, hr_01, max_val=1.0).numpy()
    ssim_vals = tf.image.ssim(sr_01, hr_01, max_val=1.0).numpy()

    print("=" * 50)
    print("SRGAN Evaluation Results")
    print("=" * 50)
    print(f"Mean PSNR: {np.mean(psnr_vals):.2f} dB")
    print(f"Mean SSIM: {np.mean(ssim_vals):.4f}")
    print(f"Median PSNR: {np.median(psnr_vals):.2f} dB")
    print(f"Median SSIM: {np.median(ssim_vals):.4f}")
    print(f"Best PSNR: {np.max(psnr_vals):.2f} dB")
    print(f"Worst PSNR: {np.min(psnr_vals):.2f} dB")
    print("=" * 50)
    n_images = len(lr_test)
    n_cols = 5
    n_rows = (n_images + n_cols - 1) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 4 * n_rows))
    axes = axes.flatten()
    for i in range(n_images):
        lr_upscaled = tf.image.resize(
            lr_01[i:i+1],
            [384, 384],
            method='nearest'
        )[0].numpy()
        comparison = np.hstack([
            lr_upscaled, # Low res
            sr_01[i].numpy(), # Super resolved
            hr_01[i].numpy() # Ground truth image
        ])
        axes[i].imshow(comparison)
        axes[i].set_title(
            f'Image {i+1}\nPSNR: {psnr_vals[i]:.2f} dB | SSIM: {ssim_vals[i]:.3f}',
            fontsize=10
        )
        axes[i].axis('off')
    for i in range(n_images, len(axes)):
        axes[i].axis('off')

    plt.tight_layout()
    plt.savefig(f'/content/drive/MyDrive/colab_data/all_test_results{timestamp}.png', dpi=200, bbox_inches='tight')
    plt.show()

    return psnr_vals, ssim_vals
# The ones which I've checked often for past versions are 400:420
psnr_scores, ssim_scores = comprehensive_evaluation(lr_images[400:420], hr_images[400:420], gen_model=gen)

print("\nPSNR scores:", psnr_scores)
print("SSIM scores:", ssim_scores)