# SRGAN

In [None]:
from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers.schedules import PiecewiseConstantDecay
from keras.optimizers import Adam
from keras.models import Model
from keras.losses import MeanSquaredError, BinaryCrossentropy, MeanAbsoluteError
from keras.layers import layers, Input, Convolution2D, MaxPool2D, Dense, Flatten, Dropout, BatchNormalization, Add, Lambda, LeakyReLU
from keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping
from tensorflow.python.keras.layers import PReLU
from keras.applications.vgg19 import VGG19, preprocess_input
from keras.utils import plot_model
from keras.metrics import Mean
from PIL import Image
import time

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from pathlib import Path
import pandas as pd
import numpy as np
import cv2

from datetime import datetime
import os
import tensorflow as tf

from datasets.div2k.parameters import Div2kParameters 
from datasets.div2k.loader import create_training_and_validation_datasets
from utils.normalization import normalize_m11, normalize_01, denormalize_m11
from utils.dataset_mappings import random_crop, random_flip, random_rotate, random_lr_jpeg_noise
from utils.metrics import psnr_metric
from utils.config import config
from utils.callbacks import SaveCustomCheckpoint

from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

In [None]:
# Configuration - align with your assignment requirements
LR_SHAPE = (32, 32, 3)    # Low-resolution input
HR_SHAPE = (128, 128, 3)  # High-resolution target 
SCALING_FACTOR = 4         # 4x upscaling (32→128)

In [None]:
# Residual Block (basic building block)
class ResidualBlock(layers.Layer):
    def __init__(self, filters=64):
        super(ResidualBlock, self).__init__()
        self.conv1 = layers.Conv2D(filters, 3, padding='same')
        self.bn1 = layers.BatchNormalization()
        self.prelu = layers.PReLU()
        self.conv2 = layers.Conv2D(filters, 3, padding='same')
        self.bn2 = layers.BatchNormalization()
    
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.prelu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        return layers.add([inputs, x])

# Generator Network
def build_generator():
    inputs = Input(shape=LR_SHAPE)
    
    # Initial feature extraction
    x = layers.Conv2D(64, 9, padding='same', activation='relu')(inputs)
    initial = x
    
    # Residual blocks (use 16 as in reference)
    for _ in range(16):
        x = ResidualBlock(64)(x)
    
    # Skip connection
    x = layers.Conv2D(64, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.add([initial, x])
    
    # Upsampling blocks (2x each, total 4x)
    x = layers.Conv2D(256, 3, padding='same')(x)
    x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(x)  # PixelShuffle
    x = layers.PReLU()(x)
    
    x = layers.Conv2D(256, 3, padding='same')(x)
    x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(x)  # PixelShuffle
    x = layers.PReLU()(x)
    
    # Final output
    x = layers.Conv2D(3, 9, padding='same', activation='tanh')(x)
    
    return Model(inputs, x)

# Discriminator Network  
def build_discriminator():
    inputs = Input(shape=HR_SHAPE)
    
    # Feature extraction with increasing filters
    x = layers.Conv2D(64, 3, padding='same')(inputs)
    x = layers.LeakyReLU(0.2)(x)
    
    for filters in [128, 256, 512]:
        x = layers.Conv2D(filters, 3, strides=2, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(0.2)(x)
    
    # Classification head
    x = layers.Flatten()(x)
    x = layers.Dense(1024, activation='relu')(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Dense(1, activation='sigmoid')(x)
    
    return Model(inputs, x)

# VGG-based Perceptual Loss
def build_vgg_loss():
    vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet', 
                                      input_shape=HR_SHAPE)
    vgg.trainable = False
    # Use block5_conv4 features for perceptual loss as referenced
    return Model(vgg.input, vgg.get_layer('block5_conv4').output)

# Combined SRGAN Model
class SRGAN(keras.Model):
    def __init__(self, generator, discriminator, vgg):
        super(SRGAN, self).__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.vgg = vgg
        
    def compile(self, g_optimizer, d_optimizer, **kwargs):
        super(SRGAN, self).compile(**kwargs)
        self.g_optimizer = g_optimizer
        self.d_optimizer = d_optimizer
        
    def train_step(self, batch_data):
        hr_imgs, lr_imgs = batch_data
        
        # Train Discriminator
        with tf.GradientTape() as d_tape:
            generated_imgs = self.generator(lr_imgs, training=False)
            
            real_output = self.discriminator(hr_imgs, training=True)
            fake_output = self.discriminator(generated_imgs, training=True)
            
            d_real_loss = keras.losses.binary_crossentropy(tf.ones_like(real_output), real_output)
            d_fake_loss = keras.losses.binary_crossentropy(tf.zeros_like(fake_output), fake_output)
            d_loss = tf.reduce_mean(d_real_loss + d_fake_loss) / 2
            
        d_grads = d_tape.gradient(d_loss, self.discriminator.trainable_variables)
        self.d_optimizer.apply_gradients(zip(d_grads, self.discriminator.trainable_variables))
        
        # Train Generator
        with tf.GradientTape() as g_tape:
            generated_imgs = self.generator(lr_imgs, training=True)
            fake_output = self.discriminator(generated_imgs, training=False)
            
            # Adversarial loss
            g_adv_loss = tf.reduce_mean(
                keras.losses.binary_crossentropy(tf.ones_like(fake_output), fake_output))
            
            # Perceptual loss (VGG-based)
            real_features = self.vgg(hr_imgs)
            fake_features = self.vgg(generated_imgs)
            g_perceptual_loss = tf.reduce_mean(
                keras.losses.mean_squared_error(real_features, fake_features))
            
            # Total generator loss
            g_loss = 1e-3 * g_adv_loss + g_perceptual_loss
            
        g_grads = g_tape.gradient(g_loss, self.generator.trainable_variables)
        self.g_optimizer.apply_gradients(zip(g_grads, self.generator.trainable_variables))
        
        return {"d_loss": d_loss, "g_loss": g_loss, "g_adv_loss": g_adv_loss, "g_perceptual_loss": g_perceptual_loss}

In [None]:
# Data preparation function
def prepare_srgan_data(dataset_path, split_ratio=0.7):
    """Prepare HR/LR pairs for SRGAN training"""
    # Load your dataset similar to your current approach
    full_ds = tf.keras.preprocessing.image_dataset_from_directory(
        dataset_path,
        label_mode=None,
        color_mode='rgb',
        image_size=HR_SHAPE[:2],
        shuffle=True,
        seed=123
    )
    
    # Normalize and split
    full_ds = full_ds.map(lambda x: tf.cast(x, tf.float32) / 127.5 - 1.0)
    
    # Create LR versions (32×32) and keep HR (128×128)
    def create_hr_lr_pairs(hr_img):
        lr_img = tf.image.resize(hr_img, LR_SHAPE[:2], method='area')  # Downsample
        return (hr_img, lr_img)
    
    paired_ds = full_ds.map(create_hr_lr_pairs)
    
    # Split for training (70%) and testing (30%) as required
    dataset_size = len(list(paired_ds))
    train_size = int(split_ratio * dataset_size)
    
    train_ds = paired_ds.take(train_size).batch(32).prefetch(tf.data.AUTOTUNE)
    test_ds = paired_ds.skip(train_size).batch(32).prefetch(tf.data.AUTOTUNE)
    
    return train_ds, test_ds

# Training execution
def train_srgan():
    # Build models
    generator = build_generator()
    discriminator = build_discriminator() 
    vgg = build_vgg_loss()
    
    # Create SRGAN
    srgan = SRGAN(generator, discriminator, vgg)
    
    # Compile with optimizers
    srgan.compile(
        g_optimizer=keras.optimizers.Adam(1e-4),
        d_optimizer=keras.optimizers.Adam(1e-4),
    )
    
    # Load your data
    train_ds, test_ds = prepare_srgan_data('C:/Users/lolze/Documents/Github/Midterm_AppliedAI/data/processed_128')
    
    # Callbacks for saving models (critical for Colab)
    checkpoint_cb = keras.callbacks.ModelCheckpoint(
        'srgan_weights_epoch_{epoch:02d}.h5',
        save_weights_only=True,
        save_freq='epoch'
    )
    
    # Train for 150+ epochs as required
    history = srgan.fit(
        train_ds,
        epochs=150,
        callbacks=[checkpoint_cb],
        verbose=1
    )
    
    return srgan, history