# ELSR-torch
Implementation of the paper ["ELSR: Extreme Low-Power Super Resolution Network For Mobile Devices"](https://arxiv.org/abs/2208.14600) using PyTorch. The code replicates the method proposed by the paper, but it is meant to be trained on limited devices. For that purpose the dataset is drastically smaller, and the training is way simpler.

### Requirements
 - pytorch=1.13.1
 - opencv=4.7.0
 - pillow=9.4.0
 - matplotlib

## Model
The ELSR model is a small sub-pixel convolutional neural network with 6 layers. Only 5 of them have learnable parameters. The architecture is shown in the image below: (code in the following cell)

![elsr](./plots/elsr.png "Model architecture")

In [1]:
import math
from torch import nn


class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.prelu(out)
        out = self.conv2(out)
        return x + out


class ELSR(nn.Module):
    def __init__(self, upscale_factor):
        super(ELSR, self).__init__()
        self.layer1 = nn.Conv2d(3, 6, kernel_size=3, padding=1)
        self.layer2_4 = ResBlock(6, 6)
        self.layer5 = nn.Conv2d(6, 3 * (upscale_factor ** 2), kernel_size=3, padding=1)     # 6 -> 48
        self.layer6 = nn.PixelShuffle(upscale_factor)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.out_channels == 48:
                    nn.init.normal_(m.weight.data, mean=0.0, std=0.001)
                    nn.init.zeros_(m.bias.data)
                else:
                    nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                    nn.init.zeros_(m.bias.data)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2_4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        return x


### PixelShuffle
The PixelShuffle block (also known as depth2space) that performs computationally efficient upsampling by rearranging pixels in an image to increase its spatial resolution. Formally, let **x** be a tensor of size (**batch_size**, **C_in**, **H_in**, **W_in**), where **C_in** is the number of input channels, **H_in** and **W_in** are the height and width of the input, respectively. The goal of PixelShuffle is to upsample the spatial resolution of **x** by a factor of **r**, meaning that the output should be a tensor of size (**batch_size**, **C_out**, **H_in** * **r**, **W_in** * **r**), where **C_out** = **C_in** // **r^2**.

In [5]:
model = ELSR(upscale_factor=4)
print(type(model.layer6))

<class 'torch.nn.modules.pixelshuffle.PixelShuffle'>


## Dataset
ELSR is trained on the REDS dataset, composed of sets of 300 videos, each set has a different degradation. My model is trained on a drastically reduced version of the dataset, containing only 30 videos with lower resolution (the original dataset was too big for me to train). The dataset (h5 files) is available at the following link: [https://drive.google.com/drive/folders/158bbeXr6EtCiuLI5wSh3SYRWMaWxK0Mq?usp=sharing](https://drive.google.com/drive/folders/158bbeXr6EtCiuLI5wSh3SYRWMaWxK0Mq?usp=sharing). The Dataset classes were defined in this way:

In [6]:
import h5py
from torch.utils.data import Dataset

class TrainDataset(Dataset):
    def __init__(self, h5_file):
        super(TrainDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return f['lr'][idx] / 255., f['hr'][idx] / 255.

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])

class ValDataset(Dataset):
    def __init__(self, h5_file):
        super(ValDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return f['lr'][idx] / 255., f['hr'][idx] / 255.

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])

To generate my dataset I downscaled some videos from the REDS dataset using the following functions:

In [8]:
import os
import cv2
import matplotlib.pyplot as plt

def generate_training_data(data_path, out_path, scale):

    if(not os.path.exists(out_path)):
        os.makedirs(out_path)

    c = 0
    if len(os.listdir(out_path)) > 0: c += len(os.listdir(out_path))
    for folder, j in zip(sorted(os.listdir(data_path)), range(10)):
        for image in sorted(os.listdir(os.path.join(data_path, folder))):
            image_path = os.path.join(data_path, folder, image)
            resized_image = resize_image(image_path, scale)
            plt.imsave(f'{out_path}{c}.png', resized_image)
            c = c+1

def generate_validation_set(data_path_X, out_path_X, data_path_Y, out_path_Y, scale):

    if(not os.path.exists(out_path_X)):
        os.makedirs(out_path_X)

    c = 0
    for folder, j in zip(sorted(os.listdir(data_path_X)), range(10)):
        for image in sorted(os.listdir(os.path.join(data_path_X, folder))):
            image_path = os.path.join(data_path_X, folder, image)
            resized_image = resize_image(image_path, scale)
            plt.imsave(f'{out_path_X}{c}.png', resized_image)
            c = c+1

    if(not os.path.exists(out_path_Y)):
        os.makedirs(out_path_Y)

    c = 0
    for folder, j in zip(sorted(os.listdir(data_path_Y)), range(10)):
        for image in sorted(os.listdir(os.path.join(data_path_Y, folder))):
            image_path = os.path.join(data_path_Y, folder, image)
            resized_image = resize_image(image_path, scale)
            plt.imsave(f'{out_path_Y}{c}.png', resized_image)
            c = c+1

def resize_image(img_path, scale):
    image = cv2.imread(img_path)
    resized_image = cv2.resize(image, dsize=(image.shape[1]//scale, image.shape[0]//scale), interpolation=cv2.INTER_CUBIC)
    resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB)
    return resized_image

### Data augmentation
To prevent overfitting and achieve better training results, I've augmented the video frames in my dataset using random augmentation between flipping, rotation and zoom, code below.

**Notice the augmentation is the same for the (low_res, high_res) pairs.**

In [11]:
import os
import random
import cv2
import matplotlib.pyplot as plt

def augment_data(low_res, high_res):
    # Read images
    low_res = cv2.cvtColor(cv2.imread(low_res), cv2.COLOR_BGR2RGB)
    high_res = cv2.cvtColor(cv2.imread(high_res), cv2.COLOR_BGR2RGB)

    # Randomly choose a type of augmentation
    aug_type = random.choice(["flip", "rotate", "zoom", "none"])

    # Perform the chosen type of augmentation
    if aug_type == "flip":
        low_res = cv2.flip(low_res, 1)
        high_res = cv2.flip(high_res, 1)
    elif aug_type == "rotate":
        angle = random.uniform(-30, 30)
        rowsLR, colsLR = low_res.shape[:2]
        MLR = cv2.getRotationMatrix2D((colsLR/2, rowsLR/2), angle, 1)
        low_res = cv2.warpAffine(low_res, MLR, (colsLR, rowsLR))
        rowsHR, colsHR = high_res.shape[:2]
        MHR = cv2.getRotationMatrix2D((colsHR/2, rowsHR/2), angle, 1)
        high_res = cv2.warpAffine(high_res, MHR, (colsHR, rowsHR))
    elif aug_type == "zoom":
        zoom_scale = random.uniform(0.8, 1.2)
        rowsLR, colsLR = low_res.shape[:2]
        MLR = cv2.getRotationMatrix2D((colsLR/2, rowsLR/2), 0, zoom_scale)
        low_res = cv2.warpAffine(low_res, MLR, (colsLR, rowsLR))
        rowsHR, colsHR = high_res.shape[:2]
        MHR = cv2.getRotationMatrix2D((colsHR/2, rowsHR/2), 0, zoom_scale)
        high_res = cv2.warpAffine(high_res, MHR, (colsHR, rowsHR))
    
    return low_res, high_res

## Training
The training of the ELSR model is split in 6 steps in the paper, using different loss functions and different frame patch sizes. Nonetheless, for this implementation the images in the dataset are much smaller, hence only 3 steps are needed since we can use full-size images. Notice the number of epochs is reduced and the learning rate scheduler of the first training step is used even in the others. PSNR is used as a validation metric.

In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader

def psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))

def train(model, dataloader, loss_fn, optimizer, device, scheduler):
    model.train()
    train_loss = 0
    for i, data in enumerate(dataloader):
        lr, hr = data
        lr, hr = lr.to(device), hr.to(device)
        optimizer.zero_grad()
        sr = model(lr)
        np.save("plot_data/lr.npy", lr[0].cpu().numpy().transpose(1,2,0))
        np.save("plot_data/hr.npy", hr[0].cpu().numpy().transpose(1,2,0))
        np.save("plot_data/sr.npy", sr[0].detach().cpu().numpy().transpose(1,2,0))
        loss = loss_fn(sr, hr)
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()

    avg_train_loss = train_loss / len(dataloader)
    return avg_train_loss

def validate(model, dataloader, device):
    model.eval()
    psnr_sum = 0
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            lr, hr = data
            lr, hr = lr.to(device), hr.to(device)
            sr = model(lr)
            psnr_sum += psnr(sr, hr)

    avg_psnr = psnr_sum / len(dataloader)
    return avg_psnr


# [...] See training.py for full implementation
import argparse

parser = argparse.ArgumentParser()
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ELSR(upscale_factor=args.scale).to(device)
criterion = nn.MSELoss() if args.loss == 'mse' else nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

# Learning Rate Scheduler
lambda1 = lambda epoch: args.lr*0.5 if epoch > args.epochs // 5 * 2 else args.lr
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda1)

train_dataset = TrainDataset(args.train)
train_dataloader = DataLoader(dataset=train_dataset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                pin_memory=True)

val_dataset = ValDataset(args.val)
val_dataloader = DataLoader(dataset=val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False)

best_psnr = 0.0
train_losses = []
psnrs = []

for epoch in range(1, args.epochs+1):
    train_loss = train(model=model, dataloader=train_dataloader, loss_fn=criterion, optimizer=optimizer, device=device, scheduler=scheduler)
    val_psnr = validate(model=model, dataloader=val_dataloader, device=device)

    train_losses.append(train_loss)
    psnrs.append(val_psnr)

    if val_psnr > best_psnr:
        best_psnr = val_psnr
        torch.save(model.state_dict(), os.path.join(args.out,f'best_X{args.scale}_model.pth'))


### Training step 1
Train the model on the x2 dataset using the L1 loss:
```bash
python training.py \
	--train "datasets/h5/train_X2.h5" \
	--val "datasets/h5/val_X2.h5" \
	--out "checkpoints/" \
	--scale 2 \
	--epochs 300 \
	--loss "mae" \
	--lr 0.01
```

### Training step 2
Fine-tune the pre-trained model from step 1 using the x4 dataset. Use L1 loss and use a higher learning rate. In the paper this is done in 2 steps, using different patch-sizes.
```bash
python training.py \
	--train "datasets/h5/train_X4.h5" \
	--val "datasets/h5/val_X4.h5" \
	--out "checkpoints/" \
	--scale 4 \
	--epochs 50 \
	--loss "mae" \
	--lr 0.05 \
	--weights "best_X2_model.pth"
```

### Training step 3
Fine-tune the pre-trained model from step 2 using the x4 dataset. Use MSE loss and use a lower learning rate. In the paper this is done in 3 steps, using different patch-sizes.
```bash
python training.py \
	--train "datasets/h5/train_X4.h5" \
	--val "datasets/h5/val_X4.h5" \
	--out "checkpoints/" \
	--scale 4 \
	--epochs 250 \
	--loss "mse" \
	--lr 5e-3 --weights "best_X4_model.pth"
```

## Results
Due to the used dataset I wasn't able to replicate the papers results, but indeed there are interesting results proving that video-super-resolution can be done with such a small model. The graphs below are the training losses through each training step:

![](./plots/training_losses.png)

### Tests

The testing of single frame super-resolution is done in this way (video-sr is achieved by iterating sr on every frame):
 1. Resize the input image to (image.height // upscale_factor, image.width // upscale_factor) using Bicubic interpolation
 2. Calculate the bicubic_upsampled image of the previously produced low resolution image by the same upscaling factor using Bicubic interpolation
 3. Use the low resolution image to predict the sr_image
 4. Calculate PSNR between sr_image and bicubic_upsampled
The results are shown below:

![](./plots/sanremo_upscaled.png)

The PSNR of the generated image has shown to be lower, but the resulting images are smoother, making bigger images better-looking:

![](./plots/sonic_upscaled.png)

Blurring stands out in pixelated images:

![](./plots/pika_upscaled.png)

### Real-time video super-resolution
Of course tests on videos have been done. To achieve "real-time" video-sr the model should be able to preduct at least 30 FPS. Results are shown below.

In [18]:
from test_video import test_video
from preprocessing import psnr
import torch
import torch.backends.cudnn as cudnn

SCALE = 4
WEIGHTS = "./checkpoints/best_X4_model.pth"
INPUT = "test/video/"

cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ELSR(upscale_factor=SCALE).to(device)

state_dict = torch.load(WEIGHTS)
model.load_state_dict(state_dict=state_dict)
model.eval()

video = []
for frame_path in os.listdir(INPUT):
    frame = cv2.cvtColor(cv2.imread(os.path.join(INPUT, frame_path)), cv2.COLOR_BGR2RGB)
    video.append(frame)

sr_video, bicubic_video, video, t = test_video(model, device, video, upscale_factor=SCALE)

avg_psnr = 0.0
for sr_img, image in zip(sr_video, video):
    avg_psnr += psnr(sr_img, image)
avg_psnr /= len(sr_video)

bicubic_psnr = 0.0
for bicubic_image, image in zip(bicubic_video, video): 
    bicubic_psnr += psnr(bicubic_image, image)
bicubic_psnr /= len(bicubic_video)

print(f"PSNR of Bicubic upscaled: {bicubic_psnr} dB")
print(f"PSNR of Super-resoluted video: {avg_psnr} dB")
print(f'FPS: {1/(t/len(sr_video)):.1f}')

PSNR of Bicubic upscaled: 28.802486419677734 dB
PSNR of Super-resoluted video: 28.447572708129883 dB
FPS: 2632.0


| Bicubic GIF: 28.80 dB  | ELSR GIF: 28.45 dB    |
| ------------- | ------------- |
| ![](./out/bicubic_video.gif)  | ![](./out/sr_video.gif)  |

## Conclusions
To me it's incredible that such a small model (17 KB) is able to outperform some bigger models just by using a CNN. Xiaomi researchers proved that GAN models, which are way too computationally expensive for mobile devices, can be easily replaced by something like this. I'm not disappointed in terms of results, because I believe that with a better dataset they would've been better. In particular, my "ground truth" images in the training set were replaced by video frames downscaled using Bicubuc interpolation, so I didn't expect the sr-output to get a higher PSNR than the Bicubic upsampled one's. Hence, I think it learnt to reproduce Bicubic interpolation with much less computational power, combined with some smoothing that comes from the convolutions.