# PSTAT 197A Final Project Report - ESPCNN (ISR)

## Introduction 

Our objective is to develop a model specialized in image reconstruction and de-blurring. Specifically, we employ the ESPCN model, a form of Image Super-Resolution (ISR) model designed for image restoration. This model is tailored to recover high-resolution (HR) images from their corresponding low-resolution (LR) counterparts. Our dataset comprises approximately 4,000 low-resolution images of wild animals, sourced from Kaggle (https://www.kaggle.com/datasets/dimensi0n/afhq-512?select=wild), each possessing dimensions of 3x512x512 (with three RGB color channels representing colored images and 512 pixel length x512 pixel width)). To simulate the process of recuperating high-resolution images from low-resolution inputs, we systematically downscaled the high-resolution images, thereby generating new low-resolution counterparts. Subsequently, we configured the ESPCN model with appropriately chosen hyperparameters and initiated its training phase using our prepared dataset. This entails inputting the low-resolution images into the model and iteratively refining its parameters to yield high-resolution outputs congruent with the upscaled versions.

## Dataset and Pre-processing

To prepare the images for input into our model, the dataset is formatted into pairs of input and target images, which the model can interpret. Initially, each image is duplicated, with one copy downsampled by a specified scale factor. In our specific project, the scale factor is set as a hyperparameter of 4, resulting in an input image size of 128x128. Note that the convolution and pixel shuffle layers employed in our model do not impose constraints on the input image shape, allowing for the utilization of images of varying sizes, not just the square format we presently employ. Subsequently, the images are converted into tensors and incorporated into the PyTorch DataLoader object for seamless integration into the training process.

## Model

![alt text](https://raw.githubusercontent.com/PSTAT197-F23/vignette-super-resolution/main/image/model_explanation.png)

We employ an Efficient Sub-Pixel Convolutional Neural Network (ESPCN) as our chosen model. As illustrated in the diagram, the model is initially fed a low-resolution image, and a convolutional neural network (CNN) layer is applied to generate feature maps. The resulting number of channels surpasses the original three channels. With these channels, we applied the pixel shuffle method, a pivotal component of the sub-pixel convolution network, to transform these feature maps. The underlying concept assumes the existence of minute pixels between two physical pixels in the microscopic realm, motivating the use of a sub-pixel convolution network to unveil this relationship. As portrayed in the diagram, each pixel on the feature maps corresponds to a sub-pixel on the high-resolution image. The model encompasses three convolution layers, elevating the channel size to 48 ($3 \times \text{scale factor}^2$). Subsequently, the Spatial Pixel Shuffle (SPC) layer is employed with these channels to obtain the high-resolution image. The ReLU activation function is employed for its ability to introduce non-linearity to linear relationships without incurring heavy computational costs, facilitating longer training epochs and better results in less time. The Mean Squared Error (MSE) serves as our chosen loss function, with the original 512x512 image as the label and the down-sampled 128x128 picture as our input. The loss function is shown below:<br> 
$$\ell(W_{1:L, b_{1:L}}) = \frac{1}{r^2HW}\sum\limits_{x=1}^{rH}\sum\limits_{x=1}^{rW}\left( I^{HR}_{x,y} - f^L_{x,y}(I^{LR}) \right)^2 $$

## Result

The efficacy of our model is assessed through a comparative analysis of the input low-resolution images and the corresponding reconstructed high-resolution outputs, alongside the monitoring of the model's learning progress through the visualization of a loss graph.

![alt text](https://raw.githubusercontent.com/PSTAT197-F23/vignette-super-resolution/main/image/reference.jpg)

The visual representation presented above offers a comprehensive insight into the model's performance. The top row displays the original images used as input during the training phase, which we can see are downsized. In the second row, the output of our model is displayed, revealing a remarkable improvement in clarity and detail compared to the initial low-resolution images. This underscores the model's proficiency in recovering intricate details and sharpness that were not discernible in the original images. The bottom row displays the corresponding high-resolution outputs, each at a size of 512x512 pixels, facilitating a straightforward comparison. The discernible enhancement in image quality underscores the effectiveness of our model in reconstructing low-resolution images.

![alt text](https://raw.githubusercontent.com/PSTAT197-F23/vignette-super-resolution/main/image/loss.jpg)

The second image depicts training and testing loss curves, which are essential indicators of the model's learning process over time. The blue line graphically represents the training loss, a metric reflecting how adeptly the model fits the training data. A gradual decrease in the training loss over time is expected as the model learns from the training dataset. The training loss is tracked through the variable `loss_history`, which is systematically updated within the train function each time the loss is computed for a batch during training. The orange line denotes the test loss, assessing the model's generalization to new, unseen data. The test loss is recorded in the variable `test_loss_history`, which is updated every 5 iterations within the training loop when the test loss is calculated using the `test_dataloader`. A consistent downward trend in both loss values as training advances is indicative of positive learning dynamics, signifying the model's effective learning of the super-resolution task without exhibiting signs of overfitting at the present stage.

## Conclusion

These results are promising, indicating that our model holds the potential to substantially enhance the quality of low-resolution images. The observed trends in both training and test loss reinforce the notion that the model is learning effectively and progressing toward successful image reconstruction.

However, there is a need for an additional interpretable metric, besides MSE, to precisely quantify the dissimilarity between the generated images and their original counterparts, and evaluate the model performance. 

Moreover, the absence of overfitting suggests that increasing the model size may potentially yield better results, which we haven't explore due to hardware constraints, prompting consideration for future investigations into model scalability and its impact on overall performance.

## Code Appendix 

### Run the scripts to reproduce the result. Do not run this notebook.

### dataset

In [None]:
from os import listdir
from os.path import join

from PIL import Image

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.io import read_image


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

def load_img(filepath):
    img = Image.open(filepath)
    return img

class DatasetFromFolder(Dataset):
    def __init__(self, image_dir, scale_factor):
        super(DatasetFromFolder, self).__init__()
        self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]
        self.tensor = transforms.ToTensor()
        self.scale_factor = scale_factor

    def __getitem__(self, index):
        input = load_img(self.image_filenames[index])
        target = input.copy()
    
        input = self.tensor(input)
        target = self.tensor(target)
        
        height, width = transforms.functional.get_image_size(input)
        resize = transforms.Resize((int(height/self.scale_factor), int(width/self.scale_factor)), 
                                  transforms.InterpolationMode.BICUBIC, 
                                  antialias=True
                                 )
        input = resize(input)
        del(resize)
        
        return input, target
    
    def __len__(self):
        return len(self.image_filenames)

### model

In [None]:
import torch
from torch import nn


class Model(nn.Module):
    def __init__(self, scale_factor):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, padding=2)
        self.conv3 = nn.Conv2d(64, 3 * scale_factor ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.clamp(x, 0.0, 1.0)
        x = self.pixel_shuffle(x)
        return x

### main

In [None]:
from PIL import Image

import numpy as np

import pandas as pd 
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader


# import dataset.py and model.py
from dataset import *
from model import *


# main helper functions
def swap(img):
    img = img.swapaxes(0, 1)
    img = img.swapaxes(1, 2)
    return img

def train(epoch, model):
    epoch_loss = 0
    epoch_loss_history = []
    epoch_test_loss_history = []
    
    for iteration, batch in enumerate(train_dataloader, 1):
        img, target = batch[0].to(device), batch[1].to(device)

        optimizer.zero_grad()
        loss = criterion(model(img), target)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()

        if iteration % 5 == 0:
            tbatch = next(iter(test_dataloader))
            timg, ttarget = tbatch[0].to(device), tbatch[1].to(device)
            tloss = criterion(model(timg), ttarget)
            
            epoch_loss_history.append(loss.item())
            epoch_test_loss_history.append(tloss.item())
            
            print("===> Epoch[{}]({}/{}): Loss: {:.6f}, Test Loss: {:.6f}".format(
                epoch+1, iteration, len(train_dataloader), loss.item(), tloss.item()))

    print("===> Epoch {} Complete: Avg. Loss: {:.6f}".format(epoch+1, epoch_loss / len(train_dataloader)))
    
    return epoch_loss_history, epoch_test_loss_history


# set device
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

    
# set hyperparameter
scale_factor    = 4
batch_size      = 32
epoch           = 5
learning_rate   = 0.0003
criterion       = nn.MSELoss()


# load data
train_data = DatasetFromFolder("../data/train/wild", scale_factor=scale_factor)
test_data = DatasetFromFolder("../data/train/wild", scale_factor=scale_factor)
ref_data = DatasetFromFolder("../data/loss", scale_factor=scale_factor)

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
ref_dataloader = DataLoader(ref_data, batch_size=1, shuffle=False)


# create model
model = Model(scale_factor=scale_factor).to(device)


# train
figure, ax = plt.subplots(3, epoch)
figure.set_size_inches(20, 15)

loss_history = []
test_loss_history = []

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for i in range(epoch):
    epoch_loss_history, epoch_test_loss_history = train(i, model)
    loss_history = loss_history + epoch_loss_history
    test_loss_history = test_loss_history + epoch_test_loss_history
    del(epoch_loss_history)
    del(epoch_test_loss_history)
    
    ref = next(iter(ref_dataloader))[0]
    ref_tgt = next(iter(ref_dataloader))[1]
    ref_fit = model(ref.to(device)).cpu()
    
    ref = swap(ref.squeeze())
    ref_tgt = swap(ref_tgt.squeeze())
    ref_fit = swap(ref_fit.detach().numpy().squeeze())

    ax[0, i].imshow(ref)
    ax[1, i].imshow(ref_fit)
    ax[2, i].imshow(ref_tgt)

figure.savefig('../image/reference.jpg')


# plot loss
plt.yscale("log")
plt.plot(loss_history)
plt.plot(test_loss_history)
plt.savefig('../image/loss.jpg')


# save model
torch.save(model.state_dict(), '../model/model.pt')