The actual training was done in Google Colab

### preparing the dataset

In [18]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [21]:
def downsample_image(image, scale_factor):
    return tf.image.resize(image, [image.shape[0] // scale_factor, image.shape[1] // scale_factor], method=tf.image.ResizeMethod.BICUBIC)

def load_dataset(directory, image_size, scale_factor, batch_size):
    datagen = ImageDataGenerator(rescale=1./255)
    dataset = datagen.flow_from_directory(directory, target_size=image_size, batch_size=batch_size, class_mode=None)
    
    lr_images = []
    hr_images = []
    
    for img_batch in dataset:
        hr_images.extend(img_batch)
        lr_images.extend([downsample_image(img, scale_factor) for img in img_batch])
        if len(hr_images) >= len(dataset) * batch_size:
            break
    
    return tf.convert_to_tensor(lr_images), tf.convert_to_tensor(hr_images)

# Example usage
lr_images, hr_images = load_dataset('Div2K/DIV2K_train_HR', image_size=(256, 256), scale_factor=4, batch_size=16)


Found 800 images belonging to 1 classes.


In [24]:
type(lr_images)

tensorflow.python.framework.ops.EagerTensor

In [16]:
from tensorflow.keras.layers import GlobalAveragePooling2D, Input, Conv2D, LeakyReLU, PReLU, BatchNormalization, UpSampling2D, Dense, Flatten, Add, Lambda
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.optimizers import Adam

### Generator 

In [11]:
from tensorflow.keras.layers import Input, Conv2D, PReLU, BatchNormalization, Add, Lambda
from tensorflow.keras.models import Model
import tensorflow as tf

def pixel_shuffle(scale):
    def _pixel_shuffle(x):
        return tf.nn.depth_to_space(x, scale)
    
    def _output_shape(input_shape):
        batch_size, height, width, channels = input_shape
        if height is None:
            height = None
        else:
            height = height * scale
        
        if width is None:
            width = None
        else:
            width = width * scale
        
        channels = channels // (scale ** 2)
        return (batch_size, height, width, channels)
    
    return Lambda(_pixel_shuffle, output_shape=_output_shape)


##  The output img of the generator model is 4x the size of input img
def build_main_generator(input_shape=(None, None, 3)):
    
    input_layer = Input(shape=input_shape)
    x = Conv2D(64, (9, 9), padding='same')(input_layer)
    x = PReLU(shared_axes=[1, 2])(x)
    
    residual = x
    for _ in range(16):
        res = Conv2D(64, (3, 3), padding='same')(residual)
        res = BatchNormalization(momentum=0.8)(res)
        res = PReLU(shared_axes=[1, 2])(res)
        res = Conv2D(64, (3, 3), padding='same')(res)
        res = BatchNormalization(momentum=0.8)(res)
        residual = Add()([residual, res])
    
    x = Conv2D(64, (3, 3), padding='same')(residual)
    x = BatchNormalization(momentum=0.8)(x)
    x = Add()([x, residual])
    
    for _ in range(2):
        x = Conv2D(256, (3, 3), padding='same')(x)
        x = pixel_shuffle(scale=2)(x)
        x = PReLU(shared_axes=[1, 2])(x)
    
    output_layer = Conv2D(3, (9, 9), padding='same')(x)
    
    return Model(inputs=input_layer, outputs=output_layer)

generator = build_main_generator()
generator.summary()

We are using the functional api call and not the Sequential layer because sequential layer requires fixed input shape
and using this method allows flexibility

### Discriminator

In [17]:
def build_discriminator_2(input_shape=(None, None, 3)):
    input_layer = Input(shape=input_shape)
    
    x = Conv2D(64, (3, 3), strides=1, padding='same')(input_layer)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(64, (3, 3), strides=2, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(128, (3, 3), strides=1, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(128, (3, 3), strides=2, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(256, (3, 3), strides=1, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(256, (3, 3), strides=2, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(512, (3, 3), strides=1, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(512, (3, 3), strides=2, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024)(x)
    x = LeakyReLU(alpha=0.2)(x)
    output_layer = Dense(1, activation='sigmoid')(x)
    
    return Model(inputs=input_layer, outputs=output_layer)

discriminator = build_discriminator_2()
discriminator.summary()

### Pre-Training Generator 

In [None]:
# now let's train the generator like a normal neural network so 
# it can give reasonably high res iamges before we start adversarial
# training using the discriminator

In [26]:
mse = MeanSquaredError()
optimizer = Adam(0.0002, 0.5)

In [27]:
generator.compile(loss=mse, optimizer=optimizer)

In [28]:
epochs = 10
batch_size = 16

for epoch in range(epochs):
    for i in range(0, len(lr_images), batch_size):
        lr_batch = lr_images[i:i + batch_size]
        hr_batch = hr_images[i:i + batch_size]
        
        generated_images = generator.predict(lr_batch)
        mse_loss  = mse(hr_batch, generated_images)
        
        generator.train_on_batch(lr_batch, hr_batch)
        
    print(f"Epoch {epoch + 1}/{epochs}, MSE Loss: {mse_loss}")

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step


KeyboardInterrupt: 

In [29]:
@tf.function
def train_step(lr_batch, hr_batch):
    with tf.GradientTape() as tape:
        generated_images = generator(lr_batch, training=True)
        loss = mse(hr_batch, generated_images)
    gradients = tape.gradient(loss, generator.trainable_variables)
    optimizer.apply_gradients(zip(gradients, generator.trainable_variables))
    return loss

# Pre-train generator
epochs = 10
batch_size = 16
for epoch in range(epochs):
    for i in range(0, len(lr_images), batch_size):
        lr_batch = lr_images[i:i + batch_size]
        hr_batch = hr_images[i:i + batch_size]
        
        mse_loss = train_step(lr_batch, hr_batch)
    
    print(f"Epoch {epoch + 1}/{epochs}, MSE Loss: {mse_loss.numpy()}")

KeyboardInterrupt: 

Stopped the above training due to computational intensity

Currently, we have a generator that can produce reasonable high-res images,(we did that by doing some mse-training on it).
Now, it's time to train the Actual GAN using adversarial loss+ VGG loss and simulataneously train the discriminator.

### Importing VGG-19

In [31]:
import tensorflow as tf
from tensorflow.keras.applications import VGG19
from tensorflow.keras.models import Model
from tensorflow.keras.losses import MeanSquaredError, BinaryCrossentropy
from tensorflow.keras.optimizers import Adam

In [35]:
# Loading pre-trained VGG-19 model and extracting feature maps to 
# get VGG and add it to adversarial loss to get final Generator Loss

def build_vgg_model():
    vgg = VGG19(weights='imagenet', include_top=False)
    vgg.trainable = False
    chosen_layers = ['block5_conv4']
    outputs  =  [vgg.get_layer(name).output for name in chosen_layers]
    model = Model(inputs=vgg.inputs, outputs = outputs)
    return model


vgg_feature_extractor = build_vgg_model()

# Computing VGG Loss
def compute_vgg_loss(hr_images, generated_images):
    hr_features = vgg_feature_extractor(hr_images)
    generated_features = vgg_feature_extractor(generated_images)
    mse = MeanSquaredError()
    return mse(hr_features, generated_features)

# Final Generator Loss
def generator_loss(disc_output, hr_images, generated_images, lambda_vgg=1):
    adversarial_loss = BinaryCrossentropy(from_logits=True)(tf.ones_like(disc_output), disc_output)
    vgg_loss = compute_vgg_loss(hr_images, generated_images)
    total_loss = adversarial_loss + lambda_vgg * vgg_loss
    return total_loss

# Discriminator Loss
def discriminator_loss(real_output, fake_output):
    real_loss = BinaryCrossentropy(from_logits=True)(tf.ones_like(real_output), real_output)
    fake_loss = BinaryCrossentropy(from_logits=True)(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m80134624/80134624[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 0us/step


In [None]:
discriminator_2 = build_discriminator()
optimizer_gen = Adam(0.0001, 0.5)
optimizer_disc = Adam(0.0001, 0.5)

@tf.function
def train_step(generator, discriminator, hr_images, lr_images, optimizer_gen, optimizer_disc):
    with tf.GradientTape() as tape_gen, tf.GradientTape() as tape_disc:
        generated_images = generator(lr_images, training=True)
        disc_real_output = discriminator(hr_images, training=True)
        disc_fake_output = discriminator(generated_images, training=True)

        # Calculate losses
        gen_loss = generator_loss(disc_fake_output, hr_images, generated_images)
        disc_loss = discriminator_loss(disc_real_output, disc_fake_output)

    # Apply gradients using the gradient tape
    # it is used to do automatic differentiation and since we are implementing kind of 
    grads_gen = tape_gen.gradient(gen_loss, generator.trainable_variables)
    grads_disc = tape_disc.gradient(disc_loss, discriminator.trainable_variables)
    optimizer_gen.apply_gradients(zip(grads_gen, generator.trainable_variables))
    optimizer_disc.apply_gradients(zip(grads_disc, discriminator.trainable_variables))

    return gen_loss, disc_loss

# Example training loop
epochs = 100
batch_size = 16
for epoch in range(epochs):
    for i in range(0, len(lr_images), batch_size):
        lr_batch = lr_images[i:i + batch_size]
        hr_batch = hr_images[i:i + batch_size]
        gen_loss, disc_loss = train_step(generator, discriminator_2, hr_batch, lr_batch, optimizer_gen, optimizer_disc)
    print(f'Epoch {epoch+1}, Gen Loss: {gen_loss}, Disc Loss: {disc_loss}')
    
## couldnt train on my laptop due to high computational intensity
## did it on google Colab

### Note
#### we never had to worry about gradient tape i.e calculating the gradients in a simple ANN, but since the GAN network is 
#### complex and backpropagation takes place AFTER we get the classification from the discriminator, so we have store
#### the gradients in gradient tape and calculate the gradients and optimise the function After doing the whole
#### forward propagation from generator to the end of discriminator.

### Note
#### this was handled by the model.fit() in a normal ANN, it abstracts away these low level details

# Final code

In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

#######

def downsample_image(image, scale_factor):
    return tf.image.resize(image, [image.shape[0] // scale_factor, image.shape[1] // scale_factor], method=tf.image.ResizeMethod.BICUBIC)

def load_dataset(directory, image_size, scale_factor, batch_size):
    datagen = ImageDataGenerator(rescale=1./255)
    dataset = datagen.flow_from_directory(directory, target_size=image_size, batch_size=batch_size, class_mode=None)
    
    lr_images = []
    hr_images = []
    
    for img_batch in dataset:
        hr_images.extend(img_batch)
        lr_images.extend([downsample_image(img, scale_factor) for img in img_batch])
        if len(hr_images) >= len(dataset) * batch_size:
            break
    
    return tf.convert_to_tensor(lr_images), tf.convert_to_tensor(hr_images)

# Example usage
lr_images, hr_images = load_dataset('Div2K/DIV2K_train_HR', image_size=(256, 256), scale_factor=4, batch_size=16)

######

type(lr_images)

######

from tensorflow.keras.layers import GlobalAveragePooling2D, Input, Conv2D, LeakyReLU, PReLU, BatchNormalization, UpSampling2D, Dense, Flatten, Add, Lambda
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.optimizers import Adam

#####

from tensorflow.keras.layers import Input, Conv2D, PReLU, BatchNormalization, Add, Lambda
from tensorflow.keras.models import Model
import tensorflow as tf

def pixel_shuffle(scale):
    def _pixel_shuffle(x):
        return tf.nn.depth_to_space(x, scale)
    
    def _output_shape(input_shape):
        batch_size, height, width, channels = input_shape
        if height is None:
            height = None
        else:
            height = height * scale
        
        if width is None:
            width = None
        else:
            width = width * scale
        
        channels = channels // (scale ** 2)
        return (batch_size, height, width, channels)
    
    return Lambda(_pixel_shuffle, output_shape=_output_shape)


##  The output img of the generator model is 4x the size of input img
def build_main_generator(input_shape=(None, None, 3)):
    
    input_layer = Input(shape=input_shape)
    x = Conv2D(64, (9, 9), padding='same')(input_layer)
    x = PReLU(shared_axes=[1, 2])(x)
    
    residual = x
    for _ in range(16):
        res = Conv2D(64, (3, 3), padding='same')(residual)
        res = BatchNormalization(momentum=0.8)(res)
        res = PReLU(shared_axes=[1, 2])(res)
        res = Conv2D(64, (3, 3), padding='same')(res)
        res = BatchNormalization(momentum=0.8)(res)
        residual = Add()([residual, res])
    
    x = Conv2D(64, (3, 3), padding='same')(residual)
    x = BatchNormalization(momentum=0.8)(x)
    x = Add()([x, residual])
    
    for _ in range(2):
        x = Conv2D(256, (3, 3), padding='same')(x)
        x = pixel_shuffle(scale=2)(x)
        x = PReLU(shared_axes=[1, 2])(x)
    
    output_layer = Conv2D(3, (9, 9), padding='same')(x)
    
    return Model(inputs=input_layer, outputs=output_layer)

generator = build_main_generator()
generator.summary()

#####

def build_discriminator_2(input_shape=(None, None, 3)):
    input_layer = Input(shape=input_shape)
    
    x = Conv2D(64, (3, 3), strides=1, padding='same')(input_layer)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(64, (3, 3), strides=2, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(128, (3, 3), strides=1, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(128, (3, 3), strides=2, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(256, (3, 3), strides=1, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(256, (3, 3), strides=2, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(512, (3, 3), strides=1, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(512, (3, 3), strides=2, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024)(x)
    x = LeakyReLU(alpha=0.2)(x)
    output_layer = Dense(1, activation='sigmoid')(x)
    
    return Model(inputs=input_layer, outputs=output_layer)

discriminator = build_discriminator_2()
discriminator.summary()

#####

mse = MeanSquaredError()
optimizer = Adam(0.0002, 0.5)

#####

generator.compile(loss=mse, optimizer=optimizer)

#####

@tf.function
def train_step(lr_batch, hr_batch):
    with tf.GradientTape() as tape:
        generated_images = generator(lr_batch, training=True)
        loss = mse(hr_batch, generated_images)
    gradients = tape.gradient(loss, generator.trainable_variables)
    optimizer.apply_gradients(zip(gradients, generator.trainable_variables))
    return loss

# Pre-train generator
epochs = 10
batch_size = 16
for epoch in range(epochs):
    for i in range(0, len(lr_images), batch_size):
        lr_batch = lr_images[i:i + batch_size]
        hr_batch = hr_images[i:i + batch_size]
        
        mse_loss = train_step(lr_batch, hr_batch)
    
    print(f"Epoch {epoch + 1}/{epochs}, MSE Loss: {mse_loss.numpy()}")

#####

import tensorflow as tf
from tensorflow.keras.applications import VGG19
from tensorflow.keras.models import Model
from tensorflow.keras.losses import MeanSquaredError, BinaryCrossentropy
from tensorflow.keras.optimizers import Adam

####

# Loading pre-trained VGG-19 model and extracting feature maps to 
# get VGG Loss

def build_vgg_model():
    vgg = VGG19(weights='imagenet', include_top=False)
    vgg.trainable = False
    chosen_layers = ['block5_conv4']
    outputs  =  [vgg.get_layer(name).output for name in chosen_layers]
    model = Model(inputs=vgg.inputs, outputs = outputs)
    return model


vgg_feature_extractor = build_vgg_model()

# Compute VGG Loss
def compute_vgg_loss(hr_images, generated_images):
    hr_features = vgg_feature_extractor(hr_images)
    generated_features = vgg_feature_extractor(generated_images)
    mse = MeanSquaredError()
    return mse(hr_features, generated_features)

# Generator Loss
def generator_loss(disc_output, hr_images, generated_images, lambda_vgg=1):
    adversarial_loss = BinaryCrossentropy(from_logits=True)(tf.ones_like(disc_output), disc_output)
    vgg_loss = compute_vgg_loss(hr_images, generated_images)
    total_loss = adversarial_loss + lambda_vgg * vgg_loss
    return total_loss

# Discriminator Loss
def discriminator_loss(real_output, fake_output):
    real_loss = BinaryCrossentropy(from_logits=True)(tf.ones_like(real_output), real_output)
    fake_loss = BinaryCrossentropy(from_logits=True)(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss


####

discriminator_2 = build_discriminator()
optimizer_gen = Adam(0.0001, 0.5)
optimizer_disc = Adam(0.0001, 0.5)

@tf.function
def train_step(generator, discriminator, hr_images, lr_images, optimizer_gen, optimizer_disc):
    with tf.GradientTape() as tape_gen, tf.GradientTape() as tape_disc:
        generated_images = generator(lr_images, training=True)
        disc_real_output = discriminator(hr_images, training=True)
        disc_fake_output = discriminator(generated_images, training=True)

        # Calculate losses
        gen_loss = generator_loss(disc_fake_output, hr_images, generated_images)
        disc_loss = discriminator_loss(disc_real_output, disc_fake_output)

    # Apply gradients
    grads_gen = tape_gen.gradient(gen_loss, generator.trainable_variables)
    grads_disc = tape_disc.gradient(disc_loss, discriminator.trainable_variables)
    optimizer_gen.apply_gradients(zip(grads_gen, generator.trainable_variables))
    optimizer_disc.apply_gradients(zip(grads_disc, discriminator.trainable_variables))

    return gen_loss, disc_loss

# Example training loop
epochs = 100
batch_size = 16
for epoch in range(epochs):
    for i in range(0, len(lr_images), batch_size):
        lr_batch = lr_images[i:i + batch_size]
        hr_batch = hr_images[i:i + batch_size]
        gen_loss, disc_loss = train_step(generator, discriminator_2, hr_batch, lr_batch, optimizer_gen, optimizer_disc)
    print(f'Epoch {epoch+1}, Gen Loss: {gen_loss}, Disc Loss: {disc_loss}')

####