# Reproducing ESPCN in the Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network paper

In this blog, we reproduce the table 1 in paper https://arxiv.org/abs/1609.05158. Based on the already available code https://github.com/yjn870/ESPCN-pytorch, we made several improvements and also reproduced the experiments for 4K and video.

## How and Why do we need ESPCN?

In deep learning research, many researchers have been focusing on the super-resolution problem, which is to upscale a low resolution image to high resolution space. This technique could be used to restore image quality and also could be used in general image processing. In fields like face recognition, medical imaging and also satellite imaging, super resolution has been widely applied. It has also become one of the most popular topics in deep learning area. 

However, in previous research, the super-resolution operation is carried out in the high resolution space, which, according to the author of the ESPCN paper, is unnecessary and increases the overal computational complexity. Increasing the resolution of the low-resolution images before the image enhancement step increases the computational complexity. And in CNN, the complexity could severly influence the speed of the implementation. Evem more, some traditional interpolation methods used in super-resolution methods can not bring additional information to solve ill-posed problem.

While in ESPCN, the upscaling is only performed in the final layer of the network, which greatly increase the efficiency of the model.ALso the ESPCN could obtain additional gains in certain cases. What's more, in ESPCN, no explicit interpolation filter is used. Therefore, the network is able to learn a better mapping from low-resolution image to high-resolution image compared to using a single fixed filter.

In ESPCN, an additional deconvolution layer is added. The deconvolutinal layer is a more generic form of the interpolation filter. Thus, more information could be provided when using additional deconvolution layer.

An effective way to implement sub-pixel convolution layer is also proposed in the EPSCN. *TBC*

In order to verify that ESPCN could actually outperform the previous super-resolution algorithm, we reproduce this paper by using experiments as following.

## Experiment Setup

The two image datasets used for evaluation are public available benchmark datasets. The first one is the Timofte dataset, which contains 91 training images and two test dataset. The second one is 50,000 randomly selected images from ImageNet for the training.

As for video experiments, in the paper, the author uses publicly available Xiph database. *INPUT OUR VIDEO DATASET HERE*

According to the paper, the author ran the experiment on a K2 GPU while in our cases, we ran our experiment on our local computer, which is *INPUT YOUR COMPUTER GPU HERE* 

## Running the experiment
### Network framework

In [1]:
'''
This is the design of our ESPCN networks, including the intialization 
weights and forward methods
'''
import math
from torch import nn


class ESPCN(nn.Module):
    def __init__(self, scale_factor, num_channels=1):
        super(ESPCN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, 64, kernel_size=5, padding=5//2),
            nn.Tanh(),
            nn.Conv2d(64, 32, kernel_size=3, padding=3//2),
            nn.Tanh(),
        )
        self.last_part = nn.Sequential(
            nn.Conv2d(32, num_channels * (scale_factor ** 2), kernel_size=3, padding=3 // 2),
            nn.PixelShuffle(scale_factor)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.in_channels == 32:
                    nn.init.normal_(m.weight.data, mean=0.0, std=0.001)
                    nn.init.zeros_(m.bias.data)
                else:
                    nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                    nn.init.zeros_(m.bias.data)

    def forward(self, x):
        x = self.first_part(x)
        x = self.last_part(x)
        return x

NameError: name 'nn' is not defined

### Hyperparameter setting

Our chosen hyperparameters are as following:

scale: 3       

learning rate: 1e-3 
                
batch-size 16 

number of epochs: 200 

number of workers: 8 

### Image super resolution
Training

In [None]:
import argparse
import os
import copy

import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

from models import ESPCN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnr


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--train-file', type=str, required=True)
    parser.add_argument('--eval-file', type=str, required=True)
    parser.add_argument('--outputs-dir', type=str, required=True)
    parser.add_argument('--weights-file', type=str)
    parser.add_argument('--scale', type=int, default=3)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--num-epochs', type=int, default=200)
    parser.add_argument('--num-workers', type=int, default=8)
    parser.add_argument('--seed', type=int, default=123)
    args = parser.parse_args()

    args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))

    if not os.path.exists(args.outputs_dir):
        os.makedirs(args.outputs_dir)

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    torch.manual_seed(args.seed)

    model = ESPCN(scale_factor=args.scale).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam([
        {'params': model.first_part.parameters()},
        {'params': model.last_part.parameters(), 'lr': args.lr * 0.1}
    ], lr=args.lr)

    train_dataset = TrainDataset(args.train_file)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    eval_dataset = EvalDataset(args.eval_file)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = 0
    best_psnr = 0.0

    for epoch in range(args.num_epochs):
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr * (0.1 ** (epoch // int(args.num_epochs * 0.8)))

        model.train()
        epoch_losses = AverageMeter()

        with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size), ncols=80) as t:
            t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))

            for data in train_dataloader:
                inputs, labels = data

                inputs = inputs.to(device)
                labels = labels.to(device)

                preds = model(inputs)

                loss = criterion(preds, labels)

                epoch_losses.update(loss.item(), len(inputs))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
                t.update(len(inputs))

        torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))

        model.eval()
        epoch_psnr = AverageMeter()

        for data in eval_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                preds = model(inputs).clamp(0.0, 1.0)

            epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

        print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

        if epoch_psnr.avg > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_psnr.avg
            best_weights = copy.deepcopy(model.state_dict())

    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
    torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

Based on the open source code, we made some improvements on the network ourselves. And by changing the network structure, we could achieve better results compared to before.

### Video super resolution