# PSTAT 197A Final Project Report - ESPCNN (ISR)

## Introduction 

Our goal is to construct a model which that can reconstruct and de-blur images. The ESPCNN model, a type of ISR model, designed for image restoration, aiming to recover a high-resolution (HR) image from its corresponding low-resolution (LR) counterpart. Our image dataset consists of around 4,000 low-resolution images of wild animals from Kaggle(https://www.kaggle.com/datasets/dimensi0n/afhq-512?select=wild), each with dimensions of 3x512x512 (3 RGB color channel represents colored images instead of grey-scaled images, with 512x512 pixel length and width). To simulate the process of recovering high-resolution images from low-resolution inputs, we downscaled the high-resolution images to create new low-resolution images. Then, we set up the ESPCNN model with appropriate hyperparameters. Lastly, training the model on our prepared dataset. This involves feeding in the low-resolution images and training the model to output high-resolution images that match the upscaled versions.

## Dataset and Pre-processing

To prepare the images for input into our model, we need to load the dataset into pairs of input and target that the model can read. First we make a copy of the image, and downsample one image by the scale factor. The scale factor is a hyperparameter and we choose 4, thus the resulting input image has the size 128x128. Notice that the convolution and pixel shuffle layers do not restrict the input image shape, so images of any size can be feed into the model besides the square images we use. Then, we convert the images into tensors and put them into the Pytorch dataloader object.


## Model

![alt text](https://raw.githubusercontent.com/PSTAT197-F23/vignette-super-resolution/main/image/model_explanation.png)

The model we use is called an efficient sub-pixel convolutional neural network(ESPCN). As we can see on the diagram, we first feed the model with a low resolution image, and we apply CNN layer to get feature maps. With those feature maps we can get number of channels that is greater than the original 3. With these channels we applied the pixel shuffle method which is the most important step for the sub-pixel convolution network. The concept of sub-pixel is that we believe that we have tiny pixels between the two physical pixels in the microscopic world. The reason for this sub-pixel convolution network is to reveal this sub-pixel relationship. As we can see on the diagram, each pixel on the feature maps represents each sub-pixel on the high resolution image. Our model contains three convolution layers which increase the channel size to 48, $3 \times \text{scale factor}^2$. With these channels we can use the SPC layer in order to get the high resolution picture. In addition, the activation function we used was ReLU, because it can turn the linear relationships into non-linear ones, and it is not computationally heavy which means we can train more epochs with less time and get better results. The loss function we used was MSE. The original 512 x 512 image would be the label and the down-sampled 128 x 128 picture was our input. The loss function is shown below:<br> 
$\ell(W_{1:L, b_{1:L}}) = \frac{1}{r^2HW}\sum\limits_{x=1}^{rH}\sum\limits_{x=1}^{rW}\left( I^{HR}_{x,y} - f^L_{x,y}(I^{LR}) \right)^2 $

## Result

The efficacy of our model is assessed through a comparative analysis of the input low-resolution images and the reconstructed high-resolution outputs, alongside the monitoring of the model's learning progress through a loss graph.

![alt text](https://raw.githubusercontent.com/PSTAT197-F23/vignette-super-resolution/main/image/reference.jpg)

The first image above provides a visual representation of the model's performance. The top row is the original image that was input into the model for training. As we can see, they are downsized. The second row shows work of our model. The enhanced clarity and detail in the reconstructed images are evident when compared to their low-resolution counterparts, showcasing the model's capability to recover fine details and sharpness that were not discernible in the original images. The bottom row is their corresponding high-resolution outputs at a size of 512x512 pixels for comparison. Thus, it's easy to tell that our model performed quite good in recovring low-resolution images. 

![alt text](https://raw.githubusercontent.com/PSTAT197-F23/vignette-super-resolution/main/image/loss.jpg)

The second image depicts training and testing loss curves, which are essential indicator of the model's learning process over time. The blue line represents the training loss while the orange line indicates the test loss. The training loss is the measure of how well the model is fitting the training data. It is expected to decrease over time as the model learns from the training data. The training loss is captured in the variable loss_history, which is populated within the train function each time the loss is calculated for a batch during training. The test loss measures how well the model generalizes to new, unseen data. The test loss is recorded in the variable test_loss_history, which is updated every 5 iterations within the training loop when the tloss is calculated using the test_dataloader. A steady decline in two loss values as training progresses is a positive sign, indicating that our model is effectively learning the task of super-resolution, and has no signs of overfitting at the current stage.

These results are promising, suggesting that our model is capable of significantly improving quality of low-resolution images. The training and test loss trends further affirm that the model is learning as expected and is on the right path to achieving image reconstruction.

However, besides MSE, we still need another interpretable metric to quantify the distance between the generated images and the original images and evaluate the model performance. In addition, no overfitting suggests that increasing the model size may potentially yield better results, which we haven't explore due to hardware constraints.

## Code Appendix 

### Run the scripts to reproduce the result. Do not run this notebook.

### dataset

In [None]:
from os import listdir
from os.path import join

from PIL import Image

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.io import read_image


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

def load_img(filepath):
    img = Image.open(filepath)
    return img

class DatasetFromFolder(Dataset):
    def __init__(self, image_dir, scale_factor):
        super(DatasetFromFolder, self).__init__()
        self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]
        self.tensor = transforms.ToTensor()
        self.scale_factor = scale_factor

    def __getitem__(self, index):
        input = load_img(self.image_filenames[index])
        target = input.copy()
    
        input = self.tensor(input)
        target = self.tensor(target)
        
        height, width = transforms.functional.get_image_size(input)
        resize = transforms.Resize((int(height/self.scale_factor), int(width/self.scale_factor)), 
                                  transforms.InterpolationMode.BICUBIC, 
                                  antialias=True
                                 )
        input = resize(input)
        del(resize)
        
        return input, target
    
    def __len__(self):
        return len(self.image_filenames)

### model

In [None]:
import torch
from torch import nn


class Model(nn.Module):
    def __init__(self, scale_factor):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, padding=2)
        self.conv3 = nn.Conv2d(64, 3 * scale_factor ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.clamp(x, 0.0, 1.0)
        x = self.pixel_shuffle(x)
        return x

### main

In [None]:
from PIL import Image

import numpy as np

import pandas as pd 
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader


# import dataset.py and model.py
from dataset import *
from model import *


# main helper functions
def swap(img):
    img = img.swapaxes(0, 1)
    img = img.swapaxes(1, 2)
    return img

def train(epoch, model):
    epoch_loss = 0
    epoch_loss_history = []
    epoch_test_loss_history = []
    
    for iteration, batch in enumerate(train_dataloader, 1):
        img, target = batch[0].to(device), batch[1].to(device)

        optimizer.zero_grad()
        loss = criterion(model(img), target)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()

        if iteration % 5 == 0:
            tbatch = next(iter(test_dataloader))
            timg, ttarget = tbatch[0].to(device), tbatch[1].to(device)
            tloss = criterion(model(timg), ttarget)
            
            epoch_loss_history.append(loss.item())
            epoch_test_loss_history.append(tloss.item())
            
            print("===> Epoch[{}]({}/{}): Loss: {:.6f}, Test Loss: {:.6f}".format(
                epoch+1, iteration, len(train_dataloader), loss.item(), tloss.item()))

    print("===> Epoch {} Complete: Avg. Loss: {:.6f}".format(epoch+1, epoch_loss / len(train_dataloader)))
    
    return epoch_loss_history, epoch_test_loss_history


# set device
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

    
# set hyperparameter
scale_factor    = 4
batch_size      = 32
epoch           = 5
learning_rate   = 0.0003
criterion       = nn.MSELoss()


# load data
train_data = DatasetFromFolder("../data/train/wild", scale_factor=scale_factor)
test_data = DatasetFromFolder("../data/train/wild", scale_factor=scale_factor)
ref_data = DatasetFromFolder("../data/loss", scale_factor=scale_factor)

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
ref_dataloader = DataLoader(ref_data, batch_size=1, shuffle=False)


# create model
model = Model(scale_factor=scale_factor).to(device)


# train
figure, ax = plt.subplots(3, epoch)
figure.set_size_inches(20, 15)

loss_history = []
test_loss_history = []

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for i in range(epoch):
    epoch_loss_history, epoch_test_loss_history = train(i, model)
    loss_history = loss_history + epoch_loss_history
    test_loss_history = test_loss_history + epoch_test_loss_history
    del(epoch_loss_history)
    del(epoch_test_loss_history)
    
    ref = next(iter(ref_dataloader))[0]
    ref_tgt = next(iter(ref_dataloader))[1]
    ref_fit = model(ref.to(device)).cpu()
    
    ref = swap(ref.squeeze())
    ref_tgt = swap(ref_tgt.squeeze())
    ref_fit = swap(ref_fit.detach().numpy().squeeze())

    ax[0, i].imshow(ref)
    ax[1, i].imshow(ref_fit)
    ax[2, i].imshow(ref_tgt)

figure.savefig('../image/reference.jpg')


# plot loss
plt.yscale("log")
plt.plot(loss_history)
plt.plot(test_loss_history)
plt.savefig('../image/loss.jpg')


# save model
torch.save(model.state_dict(), '../model/model.pt')