<a href="https://colab.research.google.com/github/PrakritiShetty/DS303_Project_Paper_Implementation/blob/main/Paper_Implementation_DS303_RMSProp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Dataset Downloading and Preprocessing

In [None]:
import os
import numpy as np
import random

from PIL import Image
import tensorflow as tf
from tensorflow import keras
import albumentations as A
import sys

import argparse
import yaml
import tensorflow as tf
from tensorflow import keras


In [None]:
from PIL import Image

IMAGE_FORMAT = ".png"
# generating LR images from HR by bicubic downsampling. Average is simple avg, bicubic is weighted avg and subsampling is subsitution.
DOWNSAMPLE_MODE = Image.BICUBIC
COLOR_CHANNELS = 3
HR_IMG_SIZE = (648, 648) 
UPSCALING_FACTOR = 4
LR_IMG_SIZE = (HR_IMG_SIZE[0] // UPSCALING_FACTOR , HR_IMG_SIZE[1] // UPSCALING_FACTOR) # used // for integer division

In [None]:
class DIV2K_Dataset(keras.utils.Sequence):
   
   # keras.utils.sequence is a data generator - used in situations like when we need advanced control over sample generation or when simple data does not fit into memory and must be loaded dynamically
    
    def __init__(self, hr_image_folder: str, batch_size: int, set_type: str):
        self.batch_size = batch_size
        self.hr_image_folder = hr_image_folder
        self.images = np.sort([
            x for x in os.listdir(hr_image_folder) if x.endswith(IMAGE_FORMAT)
        ])

        if set_type == "train":
          self.images = self.images[:-200] # 700 images for training
        elif set_type == "val":
          self.images = self.images[-200:-100] # 100 images for validation
        else:
          self.images = self.images[-100:] # 100 images for testing

        # data augmentation
        # done on HR images only, then LR will be made from that.
        # for training and validation sets, data augmentation includes scaling and rotation
        if set_type in ["train", "val"]:
            self.transform = A.Compose(
                [
                    A.RandomCrop(width=HR_IMG_SIZE[0], height=HR_IMG_SIZE[1], p=1.0),
                    A.RandomRotate90(),
                    # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.90,rotate_limit=45, p=.75),
                    # A.OneOf([A.augmentations.geometric.resize.RandomScale (scale_limit=0.6, interpolation=1, always_apply=False, p=0.5)],[A.augmentations.geometric.resize.RandomScale (scale_limit=0.7, interpolation=1, always_apply=False, p=0.5)][A.augmentations.geometric.resize.RandomScale (scale_limit=0.8, interpolation=1, always_apply=False, p=0.5)][A.augmentations.geometric.resize.RandomScale (scale_limit=0.9, interpolation=1, always_apply=False, p=0.5)]),
                    # A.augmentations.geometric.resize.RandomScale (scale_limit=0.6, interpolation=1, always_apply=False, p=0.5),
                    # A.OneOf([A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5)])
                    A.HorizontalFlip(p=0.5)
                    
                ]
            )
        else: 
            self.transform = A.Compose(
                [
                    A.RandomCrop(width=HR_IMG_SIZE[0], height=HR_IMG_SIZE[1], p=1.0),
                ]
            )

        self.to_float = A.ToFloat(max_value=255)

    def __len__(self):
        return len(self.images)//self.batch_size

    def on_epoch_end(self):
        random.shuffle(self.images)

    def __getitem__(self, idx):
        # batch of samples
        i = idx * self.batch_size
        batch_images = self.images[i : i + self.batch_size] # all images in that particular batch
        batch_hr_images = np.zeros((self.batch_size,) + HR_IMG_SIZE + (COLOR_CHANNELS,))
        batch_lr_images = np.zeros((self.batch_size,) + LR_IMG_SIZE + (COLOR_CHANNELS,))


        for i, image in enumerate(batch_images):
           
            hr_image = Image.open(os.path.join(self.hr_image_folder, image))
            hr_image = np.array(hr_image) 

            # because the augmentations are all applied on hr images only, we need to apply the transformations on the hr images and then downsample them to lr images  
            hr_image_transform = self.transform(image=hr_image)["image"] # converts and saves hr image as lr image
            hr_image_transform_1 = Image.fromarray(hr_image_transform)
            lr_image_transform_1 = hr_image_transform_1.resize(LR_IMG_SIZE, resample=DOWNSAMPLE_MODE)
            lr_image_transform = np.array(lr_image_transform_1)

            batch_hr_images[i] = self.to_float(image=hr_image_transform)["image"]
            batch_lr_images[i] = self.to_float(image=lr_image_transform)["image"]

        return (batch_lr_images, batch_hr_images)

    

Model Building

In [None]:
from keras import Sequential, initializers
from keras.layers import Conv2D, Conv2DTranspose, InputLayer, PReLU, Activation


In [None]:
def create_model( d: int, s: int, m: int, input_size: tuple = LR_IMG_SIZE, upscaling_factor: int = UPSCALING_FACTOR, color_channels: int = COLOR_CHANNELS):
    model = Sequential()
    model.add( InputLayer( input_shape=(input_size[0], input_size[1], color_channels)))

    # feature extraction
    model.add(
        Conv2D(
            kernel_size = 5, # f1
            filters = d, # n1
            padding="same",
            kernel_initializer=initializers.HeNormal(),
        )
    )

    # activation func after every conv layer
    model.add( PReLU( alpha_initializer="zeros", shared_axes=[1, 2]))

    # shrinking
    model.add(
        Conv2D(
            kernel_size = 1,
            filters = s,
            padding="same",
            kernel_initializer=initializers.HeNormal(),
        )
    )

    model.add( PReLU( alpha_initializer="zeros", shared_axes=[1, 2]))

    # non linear mapping
    for _ in range(m):
        model.add(
            Conv2D(
                kernel_size = 3,
                filters = s,
                padding="same",
                kernel_initializer=initializers.HeNormal(),
            )
        )
    
    model.add(PReLU(alpha_initializer="zeros", shared_axes=[1, 2]))

    # expanding
    model.add(
        Conv2D(
            kernel_size=1, 
            filters=d, 
            padding="same"
          )
      )
    
    model.add(PReLU(alpha_initializer="zeros", shared_axes=[1, 2]))

    # deconvolution
    model.add(
        Conv2DTranspose(
            kernel_size=9,
            filters= color_channels,
            strides= upscaling_factor,
            padding="same",
            kernel_initializer=initializers.RandomNormal(mean=0, stddev=0.001),
        )
    )

    return model


Model Training

In [None]:
from google.colab import drive

drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
data_path = "/content/gdrive/MyDrive/DIV2K_train_HR/"



In [None]:
# data_path = "/content/data/DIV2K_train_valid_HR/"

model_d =  56
model_s = 12
model_m = 4

lr_init =  0.001
epochs =  500
batch_size =  30
steps_per_epoch =  20
val_batch_size = 20
validation_steps =  4

weights_fn= "/content/model_{epoch:05d}.h5"

In [None]:
def train() -> None:

    train_dataset = DIV2K_Dataset(
        hr_image_folder = data_path,
        batch_size= batch_size,
        set_type="train",
    )
    val_dataset = DIV2K_Dataset(
        hr_image_folder= data_path,
        batch_size= val_batch_size,
        set_type="val",
    )

    model = create_model(d=model_d, s=model_s, m=model_m)
    model.compile(
        optimizer=keras.optimizers.RMSprop(learning_rate=lr_init),
        loss="mean_squared_error",
    )
    reduce_lr = keras.callbacks.ReduceLROnPlateau(
        monitor="loss", factor=0.5, patience=20, min_lr=10e-6, verbose=1
    )
    early_stopping = keras.callbacks.EarlyStopping(
        monitor="val_loss",
        min_delta=10e-6, 
        patience=40, 
        verbose=0,
        restore_best_weights=True,
    )
    save = keras.callbacks.ModelCheckpoint(
        filepath=weights_fn,
        monitor="loss",
        save_best_only=True,
        save_weights_only=False,
        save_freq="epoch",
    )

    history = model.fit(
        train_dataset,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        callbacks=[reduce_lr, early_stopping, save], 
        validation_data=val_dataset,
        validation_steps=validation_steps,
    )


if __name__ == '__main__':
    train()

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
 2/20 [==>...........................] - ETA: 3:56 - loss: 0.0107

Model Evaluation

In [None]:
import matplotlib.pyplot as plt

In [None]:
model = keras.models.load_model("/content/model.h5")

NameError: ignored

In [None]:
test_dataset = DIV2K_Dataset(
    hr_image_folder="/content/data/DIV2K_train_valid_HR/",
    batch_size="val_batch_size",
    set_type="test",
)

In [None]:
n_runs = 5
psnrs = []

for _ in range(n_runs):
    for batch in test_dataset:
        preds = model.predict(batch[0])
        psnr = tf.image.psnr(batch[1], preds, max_val=1.0)
        psnr = psnr.numpy().tolist()
        psnrs.extend(psnr)

print("Mean PSNR: {:.3f}".format(np.mean(psnrs)))

Visualisations

In [None]:
batch_id = 0
batch = test_dataset.__getitem__(batch_id)
preds = model.predict(batch[0])

In [None]:
img_id = 1

plt.figure(figsize=[15, 15])
plt.subplot(2, 2, 1)
plt.imshow(batch[0][img_id])
plt.axis("off")
plt.title("LR Image")

plt.subplot(2, 2, 2)
plt.imshow(batch[1][img_id])
plt.axis("off")
plt.title("HR Image")

plt.subplot(2, 2, 3)
plt.imshow(preds[img_id])
plt.axis("off")
plt.title("Restored Image")


plt.subplot(2, 2, 4)
lr_image = Image.fromarray(np.array(batch[0][img_id] * 255, dtype="uint8"))
lr_image_resized = lr_image.resize(HR_IMG_SIZE, resample=DOWNSAMPLE_MODE)
plt.imshow(lr_image_resized)
plt.axis("off")
plt.title("Bilinear Upsampling")

plt.tight_layout()
plt.show()