# Task
Train an SRGAN model using the image data located in the "midterm data" folder on Google Drive.

## Mount google drive

### Subtask:
Mount Google Drive to access the data.


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


**Reasoning**:
Verify that the "midterm data" folder is accessible after mounting.



In [3]:
import os

folder_path = '/content/drive/MyDrive/midterm/data'

if os.path.exists(folder_path):
    print(f"The folder '{folder_path}' is accessible.")
    print("Contents of the folder:")
    for item in os.listdir(folder_path):
        print(item)
else:
    print(f"The folder '{folder_path}' was not found.")

The folder '/content/drive/MyDrive/midterm/data' is accessible.
Contents of the folder:
lowres_32
processed_128


## Load and preprocess data

### Subtask:
Load the images from the specified folder and preprocess them for training the SRGAN model. This may include resizing, normalization, and splitting into training and validation sets.


In [4]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import numpy as np
from sklearn.model_selection import train_test_split

lowres_dir = '/content/drive/MyDrive/midterm/data/lowres_32'
highres_dir = '/content/drive/MyDrive/midterm/data/processed_128'

from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 10707702549266160311
xla_global_id: -1
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 40419328000
locality {
  bus_id: 1
  links {
  }
}
incarnation: 11050755551203360114
physical_device_desc: "device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0"
xla_global_id: 416903419
]


# Task
Train an SRGAN model using the data in the folder "/content/drive/MyDrive/midterm data".

## Load and preprocess data

### Subtask:
Load the images from the specified folder and preprocess them for training the SRGAN model. This may include resizing, normalization, and splitting into training and validation sets.


**Reasoning**:
Define a function to load and preprocess images, then use it to load low and high-resolution images, split them into training and validation sets, and print the shapes of the resulting sets.



In [5]:
# Configuration - align with your assignment requirements
LR_SHAPE = (32, 32, 3)    # Low-resolution input
HR_SHAPE = (128, 128, 3)  # High-resolution target
SCALING_FACTOR = 4         # 4x upscaling (32→128)
BATCH_SIZE = 32
EPOCHS = 150

In [6]:
def load_and_preprocess_images():

    # Load without batching first
    highres_ds = tf.keras.preprocessing.image_dataset_from_directory(
        '/content/drive/MyDrive/midterm/data/processed_128',
        labels=None,
        image_size=(128, 128),
        batch_size=None,  # No batching initially
        shuffle=True,
        seed=123
    )

    def create_hr_lr_pair(hr_img):
        # Normalize to [-1, 1]
        hr_img = tf.cast(hr_img, tf.float32)
        hr_img = (hr_img / 127.5) - 1.0

        # Create low-res version
        lr_img = tf.image.resize(hr_img, [32, 32], method='area')

        return lr_img, hr_img

    # Create paired dataset
    paired_ds = highres_ds.map(create_hr_lr_pair, num_parallel_calls=tf.data.AUTOTUNE)

    # Apply batching ONLY ONCE
    paired_ds = paired_ds.batch(BATCH_SIZE)  # Batch size 32

    # Get dataset info
    dataset_size = 25000  # Count unbatched elements
    train_size = 17500

    print(f"Total samples: {dataset_size}")
    print(f"Training samples: {train_size}")
    print(f"Test samples: {dataset_size - train_size}")

    # Split dataset
    train_ds = paired_ds.take(train_size)
    test_ds = paired_ds.skip(train_size)

    # Prefetch for performance
    train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
    test_ds = test_ds.prefetch(tf.data.AUTOTUNE)

    return train_ds, test_ds

# Test the fix
print("Creating dataset...")
train_ds, test_ds = load_and_preprocess_images()

def verify_fixed_shapes(dataset, name="Dataset"):
    """Verify the dataset shapes"""
    print(f"\n{name} Shape Verification:")
    for hr_batch, lr_batch in dataset.take(1):
        print(f"HR batch shape: {hr_batch.shape}")
        print(f"LR batch shape: {lr_batch.shape}")

        # Check individual sample shapes
        print(f"HR sample shape: {hr_batch[0].shape}")
        print(f"LR sample shape: {lr_batch[0].shape}")

        # Check value ranges
        print(f"HR value range: [{tf.reduce_min(hr_batch[0]):.3f}, {tf.reduce_max(hr_batch[0]):.3f}]")
        print(f"LR value range: [{tf.reduce_min(lr_batch[0]):.3f}, {tf.reduce_max(lr_batch[0]):.3f}]")
        break

verify_fixed_shapes(train_ds, "Training")

Creating dataset...
Found 25000 files.
Total samples: 25000
Training samples: 17500
Test samples: 7500

Training Shape Verification:
HR batch shape: (32, 32, 32, 3)
LR batch shape: (32, 128, 128, 3)
HR sample shape: (32, 32, 3)
LR sample shape: (128, 128, 3)
HR value range: [-1.000, 0.913]
LR value range: [-1.000, 1.000]


# Task
Build an SRGAN model to generate 128x128 images from 32x32 images using the dataset in the "midterm data" folder in Google Drive, train a binary classifier on the original 128x128 images, train another binary classifier on the SRGAN-generated 128x128 images, and compare their performance. The SRGAN should be trained for at least 150 epochs. The dataset should be split into 70% training and 30% testing, and normalization and image transformations should be applied. The data is located in the "midterm data" folder in Google Drive, with subfolders "lowres_32" for 32x32 images and "processed_128" for 128x128 images.

In [7]:
# Verify your normalization
def verify_normalization(dataset):
    for lr_batch, hr_batch in dataset.take(1):
        print(f"LR range: [{tf.reduce_min(lr_batch):.3f}, {tf.reduce_max(lr_batch):.3f}]")
        print(f"HR range: [{tf.reduce_min(hr_batch):.3f}, {tf.reduce_max(hr_batch):.3f}]")

verify_normalization(train_ds)

LR range: [-1.000, 1.000]
HR range: [-1.000, 1.000]


In [8]:
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv2D, BatchNormalization, PReLU, Add, UpSampling2D, Dense, Flatten, LeakyReLU

# Define the generator model (based on the SRResNet architecture)
def build_generator(input_shape):
    def residual_block(x):
        filters = 64
        kernel_size = 3
        strides = 1
        padding = 'same'
        gamma_init = tf.random_normal_initializer(1., 0.02)

        y = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding)(x)
        y = BatchNormalization(gamma_initializer=gamma_init)(y)
        y = PReLU(shared_axes=[1, 2])(y)
        y = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding)(y)
        y = BatchNormalization(gamma_initializer=gamma_init)(y)
        out = Add()([x, y])
        return out

    def upscale_block(x):
        filters = 256
        kernel_size = 3
        strides = 1
        padding = 'same'

        y = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding)(x)
        y = UpSampling2D(size=2)(y)
        y = PReLU(shared_axes=[1, 2])(y)
        return y

    inputs = Input(shape=input_shape)

    # Initial convolutional layer
    x = Conv2D(filters=64, kernel_size=9, strides=1, padding='same')(inputs)
    x = PReLU(shared_axes=[1, 2])(x)

    # Residual blocks
    for _ in range(16): # As in the original SRResNet paper
        x = residual_block(x)

    # Second convolutional layer with skip connection
    x = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = Add()([x, Conv2D(filters=64, kernel_size=9, strides=1, padding='same')(inputs)]) # Add skip connection from initial layer

    # Upscaling blocks
    x = upscale_block(x)
    x = upscale_block(x)

    # Output convolutional layer
    outputs = Conv2D(filters=3, kernel_size=9, strides=1, padding='same', activation='tanh')(x)

    generator = Model(inputs=inputs, outputs=outputs)
    print(f"Generator built: input {generator.input_shape}, output {generator.output_shape}")
    return generator

# Define the discriminator model
def build_discriminator(input_shape):
    def discriminator_block(x, filters, strides, batchnorm=True):
        x = Conv2D(filters=filters, kernel_size=3, strides=strides, padding='same')(x)
        if batchnorm:
            x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.2)(x)
        return x

    inputs = Input(shape=input_shape)

    x = discriminator_block(inputs, 64, 2, batchnorm=False)
    x = discriminator_block(x, 128, 2)
    x = discriminator_block(x, 256, 2)
    x = discriminator_block(x, 512, 2)
    x = discriminator_block(x, 512, 2)
    x = discriminator_block(x, 512, 2)
    x = discriminator_block(x, 1024, 1)
    x = discriminator_block(x, 1024, 1)


    x = Flatten()(x)
    x = Dense(1024)(x)
    x = LeakyReLU(alpha=0.2)(x)
    outputs = Dense(1, activation='sigmoid')(x)

    discriminator = Model(inputs=inputs, outputs=outputs)
    print(f"Discriminator built: input {discriminator.input_shape}, output {discriminator.output_shape}")
    return discriminator

# Define the input shapes
lowres_shape = (32, 32, 3)
highres_shape = (128, 128, 3)

# Build the generator and discriminator models
generator = build_generator(lowres_shape)
discriminator = build_discriminator(highres_shape)

generator.summary()
discriminator.summary()

Generator built: input (None, 32, 32, 3), output (None, 128, 128, 3)
Discriminator built: input (None, 128, 128, 3), output (None, 1)




In [9]:
import tensorflow as tf
from tensorflow.keras.applications import VGG19
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
import keras
from tensorflow.keras.applications.vgg19 import preprocess_input # Import preprocess_input

# Define the combined SRGAN model as a Keras Model subclass
class SRGAN(keras.Model):
    def __init__(self, generator, discriminator, vgg):
        super().__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.vgg = vgg # VGG model for perceptual loss

    def compile(self, generator_optimizer, discriminator_optimizer):
        super().compile()
        self.generator_optimizer = generator_optimizer
        self.discriminator_optimizer = discriminator_optimizer
        # Define loss functions here as well, matching the train_step
        self.adversarial_loss = keras.losses.BinaryCrossentropy(from_logits=False) # Use from_logits=False since discriminator output has sigmoid
        self.pixel_loss = keras.losses.MeanSquaredError()
        # perceptual_loss function is defined outside the class, using self.vgg

    def perceptual_loss(self, highres_true, highres_gen):
        # Cast to float32
        highres_true = tf.cast(highres_true, tf.float32)
        highres_gen = tf.cast(highres_gen, tf.float32)

        # Resize to VGG input (128x128) if necessary
        if highres_true.shape[1:3] != (128, 128):
            highres_true = tf.image.resize(highres_true, (128, 128))
        if highres_gen.shape[1:3] != (128, 128):
            highres_gen = tf.image.resize(highres_gen, (128, 128))

        # Scale from [-1,1] -> [0,255]
        highres_true = tf.clip_by_value((highres_true + 1.0) * 127.5, 0.0, 255.0)
        highres_gen = tf.clip_by_value((highres_gen + 1.0) * 127.5, 0.0, 255.0)

        # Preprocess for VGG
        highres_true_vgg = preprocess_input(highres_true)
        highres_gen_vgg = preprocess_input(highres_gen)

        # Compute MSE in feature space
        return tf.reduce_mean(tf.square(self.vgg(highres_true_vgg) - self.vgg(highres_gen_vgg)))


    @tf.function
    def train_step(self, data):
        lowres_images, highres_images = data

        # Train Discriminator (alternating updates - train discriminator half as often as generator)
        # This logic is handled outside this function in the training loop or dataset
        with tf.GradientTape() as tape:
            generated_highres_images = self.generator(lowres_images, training=False)
            real_output = self.discriminator(highres_images, training=True)
            fake_output = self.discriminator(generated_highres_images, training=True)

            # Use label smoothing for real and fake labels
            real_labels = tf.ones_like(real_output) * 0.9 # Smooth real labels to 0.9
            fake_labels = tf.zeros_like(fake_output) + 0.1 # Smooth fake labels to 0.1

            real_loss = tf.cast(self.adversarial_loss(real_labels, real_output), tf.float32)
            fake_loss = tf.cast(self.adversarial_loss(fake_labels, fake_output), tf.float32)
            discriminator_loss = (tf.reduce_mean(real_loss) + tf.reduce_mean(fake_loss))


            disc_gradients = tape.gradient(discriminator_loss, discriminator.trainable_variables)
            disc_gradients, _ = tf.clip_by_global_norm(disc_gradients, 5.0)
            self.discriminator_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))

        # Train Generator
        with tf.GradientTape() as tape:
            generated_highres = self.generator(lowres_images, training=True)
            fake_output = self.discriminator(generated_highres, training=False)

            perc_loss = tf.cast(self.perceptual_loss(highres_images, generated_highres), tf.float32)
            pix_loss = tf.cast(self.pixel_loss(highres_images, generated_highres), tf.float32)
            gan_loss = tf.cast(self.adversarial_loss(tf.ones_like(fake_output), fake_output), tf.float32)

            perc_loss = tf.reduce_mean(perc_loss)
            pix_loss = tf.reduce_mean(pix_loss)

            generator_loss = 0.01 * perc_loss + 0.005 * gan_loss + 0.01 * pix_loss

            gen_gradients = tape.gradient(generator_loss, generator.trainable_variables)
            gen_gradients, _ = tf.clip_by_global_norm(gen_gradients, 5.0)
            self.generator_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))

        return {
            "discriminator_loss": discriminator_loss,
            "generator_loss": generator_loss,
            "perceptual_loss": perc_loss,
            "pixel_loss": pix_loss,
            "gan_loss": gan_loss
        }

    @tf.function
    def test_step(self, data):
        lowres_images, highres_images = data
        generated_highres = self.generator(lowres_images, training=False)

        real_output = self.discriminator(highres_images, training=False)
        fake_output = self.discriminator(generated_highres, training=False)

        real_loss = tf.cast(self.adversarial_loss(tf.ones_like(real_output), real_output), tf.float32)
        fake_loss = tf.cast(self.adversarial_loss(tf.zeros_like(fake_output), fake_output), tf.float32)
        discriminator_loss = real_loss + fake_loss

        perc_loss = tf.cast(self.perceptual_loss(highres_images, generated_highres), tf.float32)
        pix_loss = tf.cast(self.pixel_loss(highres_images, generated_highres), tf.float32)
        gan_loss = tf.cast(self.adversarial_loss(tf.ones_like(fake_output), fake_output), tf.float32)

        generator_loss = perc_loss + 1e-2 * gan_loss + 1e-2 * pix_loss

        return {
            "discriminator_loss": discriminator_loss,
            "generator_loss": generator_loss,
            "perceptual_loss": perc_loss,
            "pixel_loss": pix_loss,
            "gan_loss": gan_loss
        }


# Assuming generator and discriminator models are built in a previous cell
# Define the input shapes
lowres_shape = (32, 32, 3)
highres_shape = (128, 128, 3)

# Build the generator and discriminator models
# These lines should be in a previous cell, but are included here for completeness
# generator = build_generator(lowres_shape)
# discriminator = build_discriminator(highres_shape)

# Perceptual loss using VGG19
vgg = VGG19(weights='imagenet', include_top=False, input_shape=(128, 128, 3))
vgg.trainable = False
vgg_output_layer = vgg.get_layer('block5_conv4').output # Using the output of a specific VGG layer
vgg_model_for_loss = Model(inputs=vgg.input, outputs=vgg_output_layer)

# Instantiate the SRGAN model
srgan = SRGAN(generator=generator, discriminator=discriminator, vgg=vgg_model_for_loss)

# Define optimizers (as in article)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5, beta_1=0.5)
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.5)

# Compile the model
srgan.compile(
     generator_optimizer=generator_optimizer,
     discriminator_optimizer=discriminator_optimizer
     # Losses are defined within the SRGAN class's compile method
 )

In [10]:
import os
import tensorflow as tf
import keras

# Enable mixed precision training globally for speedup on compatible hardware
tf.keras.mixed_precision.set_global_policy('mixed_float16')

# Custom callback to save generator and discriminator weights
class CustomGANCheckpoint(keras.callbacks.Callback):
    def __init__(self, checkpoint_dir, gen_filename='generator_epoch_{epoch:03d}.weights.h5', disc_filename='discriminator_epoch_{epoch:03d}.weights.h5'):
        super().__init__()
        self.checkpoint_dir = checkpoint_dir
        self.gen_filename = gen_filename
        self.disc_filename = disc_filename

    def on_epoch_end(self, epoch, logs=None):
        # self.model here refers to the SRGAN instance being trained
        gen_path = os.path.join(self.checkpoint_dir, self.gen_filename.format(epoch=epoch + 1))
        disc_path = os.path.join(self.checkpoint_dir, self.disc_filename.format(epoch=epoch + 1))

        self.model.generator.save_weights(gen_path)
        self.model.discriminator.save_weights(disc_path)


def train_srgan_enhanced(load_from_checkpoint_path=None, initial_epoch=0):
    # Load weights from checkpoint if specified and initial_epoch > 0
    if load_from_checkpoint_path and initial_epoch > 0:
        print(f"Attempting to load weights from checkpoint for epoch {initial_epoch:03d}")
        try:
            # Expecting .h5 weights files now
            generator_weights_path = os.path.join(load_from_checkpoint_path, f'generator_epoch_{initial_epoch:03d}.weights.h5')
            discriminator_weights_path = os.path.join(load_from_checkpoint_path, f'discriminator_epoch_{initial_epoch:03d}.weights.h5')

            if os.path.exists(generator_weights_path) and os.path.exists(discriminator_weights_path):
                generator.load_weights(generator_weights_path)
                discriminator.load_weights(discriminator_weights_path)
                print("✓ Model weights loaded successfully from checkpoint!")
            else:
                print("✗ Checkpoint files not found. Starting training from scratch.")
                initial_epoch = 0 # Reset initial epoch if files don't exist
        except Exception as e:
            print(f"✗ Error loading weights: {e}. Starting training from scratch.")
            initial_epoch = 0 # Reset initial epoch if loading fails

    # Create SRGAN instance with the (potentially loaded) generator and discriminator
    srgan = SRGAN(generator=generator, discriminator=discriminator, vgg=vgg)

    # Optimizers (as in article)
    generator_optimizer = keras.optimizers.Adam(2e-5, beta_1=0.5)
    discriminator_optimizer = keras.optimizers.Adam(1e-6, beta_1=0.5)

    # Compile the model
    srgan.compile(
        generator_optimizer=generator_optimizer,
        discriminator_optimizer=discriminator_optimizer
    )

    print("✓ SRGAN compiled successfully!")

    # Enhanced callbacks (like the article)
    checkpoint_dir = "/content/drive/MyDrive/midterm/checkpoints/"
    os.makedirs(checkpoint_dir, exist_ok=True)

    callbacks = [
        # Custom checkpointing for generator and discriminator weights
        CustomGANCheckpoint(checkpoint_dir),

        # Learning rate scheduler
        keras.callbacks.ReduceLROnPlateau(
            monitor='generator_loss',
            factor=0.1,
            patience=10,
            min_lr=1e-7,
            verbose=1
        ),

        # TensorBoard
        keras.callbacks.TensorBoard(
            log_dir=os.path.join(checkpoint_dir, "logs"),
            histogram_freq=1,
            update_freq='epoch'
        ),
    ]

    # Test forward pass before training
    print("Testing forward pass...")
    try:
        # Ensure we take the correct (HR, LR) pair from the dataset
        for lr_batch, hr_batch in train_ds.take(1):
            # Test generator: takes LR input
            fake_hr = generator(lr_batch)
            print(f"✓ Generator test: {lr_batch.shape} -> {fake_hr.shape}")

            # Test discriminator: takes HR input
            disc_out = discriminator(hr_batch)
            print(f"✓ Discriminator test: {hr_batch.shape} -> {disc_out.shape}")

            # Test VGG: takes HR input
            vgg_out = vgg(hr_batch)
            print(f"✓ VGG test: {hr_batch.shape} -> {vgg_out.shape}")
            break
    except Exception as e:
        print(f"✗ Forward pass test failed: {e}")
        return None, None

    # Train the model
    print(f"\n Starting SRGAN training for {EPOCHS} epochs...")

    history = srgan.fit(
        train_ds,
        epochs=EPOCHS,
        validation_data=test_ds,
        callbacks=callbacks,
        verbose=1,
        initial_epoch=initial_epoch
    )

    print("✅ Training completed successfully!")
    return srgan, history

# Start enhanced training
# If loading from checkpoint, set initial_epoch to the epoch number to resume from
# Example: if checkpoint is from epoch 28, set initial_epoch=28
srgan_model, history = train_srgan_enhanced(load_from_checkpoint_path='/content/drive/MyDrive/midterm/checkpoints',
                                            initial_epoch=0
)

✓ SRGAN compiled successfully!
Testing forward pass...
✓ Generator test: (32, 32, 32, 3) -> (32, 128, 128, 3)
✓ Discriminator test: (32, 128, 128, 3) -> (32, 1)
✓ VGG test: (32, 128, 128, 3) -> (32, 4, 4, 512)

 Starting SRGAN training for 150 epochs...
Epoch 1/150
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 308ms/step - discriminator_loss: 0.7841 - gan_loss: 0.6436 - generator_loss: 2.0137 - perceptual_loss: 200.8529 - pixel_loss: 0.1944



[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m362s[0m 360ms/step - discriminator_loss: 0.7840 - gan_loss: 0.6441 - generator_loss: 2.0127 - perceptual_loss: 200.7584 - pixel_loss: 0.1943 - learning_rate: 0.0010
Epoch 2/150
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m117s[0m 149ms/step - discriminator_loss: 0.6956 - gan_loss: 1.0887 - generator_loss: 1.6447 - perceptual_loss: 163.8401 - pixel_loss: 0.0887 - learning_rate: 0.0010
Epoch 3/150
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m117s[0m 150ms/step - discriminator_loss: 0.6926 - gan_loss: 0.9832 - generator_loss: 1.3528 - perceptual_loss: 134.7136 - pixel_loss: 0.0725 - learning_rate: 0.0010
Epoch 4/150
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m118s[0m 150ms/step - discriminator_loss: 0.6893 - gan_loss: 0.9376 - generator_loss: 1.1777 - perceptual_loss: 117.2359 - pixel_loss: 0.0650 - learning_rate: 0.0010
Epoch 5/150
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[