# AIMasters Computer Vision and Video Processing &mdash; Image deblurring task

<img src="images/blur.png">

In the first task of the course you will have to implement an image deblurring method.  
We will start with the simplest possible network and by the end of the task you will implement SOTA deblurring architecture.

## Preliminaries

In [2]:
!pip install utils

Collecting utils
  Downloading utils-1.0.2.tar.gz (13 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: utils
  Building wheel for utils (setup.py) ... [?25l[?25hdone
  Created wheel for utils: filename=utils-1.0.2-py2.py3-none-any.whl size=13906 sha256=89db1c1c1548861dab18d83953a76e421df6de8a69c091490b64f74cb06688c4
  Stored in directory: /root/.cache/pip/wheels/b8/39/f5/9d0ca31dba85773ececf0a7f5469f18810e1c8a8ed9da28ca7
Successfully built utils
Installing collected packages: utils
Successfully installed utils-1.0.2


In [3]:
from utils import download_from_yadisk, test_model
import os

TARGET_DIR = '.'
FILENAME = "GoPro.zip"

if not os.path.exists(os.path.join(TARGET_DIR, FILENAME)):
    # we are going to download 5.2 gb file, downloading will take some time
    download_from_yadisk(
        short_url='https://disk.yandex.ru/d/v7O5TOsstnsscw',
        filename=FILENAME,
        target_dir=TARGET_DIR
    )

ImportError: cannot import name 'download_from_yadisk' from 'utils' (/usr/local/lib/python3.10/dist-packages/utils/__init__.py)

In [None]:
!unzip -qq GoPro.zip

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
plt.rcParams.update({'axes.titlesize': 'small'})

import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.utils import make_grid

from torchvision import transforms
from PIL import Image
import os

print(f"GPU: {torch.cuda.is_available()}")
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### Prepare Dataset

In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from pathlib import Path

class GoProDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        blur_path = Path(f"{root_dir}/blur")
        sharp_path = Path(f"{root_dir}/sharp")
        self.blurry_images = [p.name for p in list(blur_path.glob("*.png"))]

    def __len__(self):
        return len(self.blurry_images)

    def __getitem__(self, idx):
        blurry_path = Path(self.root_dir, 'blur', self.blurry_images[idx])
        sharp_path = Path(self.root_dir, 'sharp', self.blurry_images[idx])

        blurry_image = np.array(Image.open(blurry_path).convert('RGB'))
        sharp_image = np.array(Image.open(sharp_path).convert('RGB'))

        if self.transform:
            blurry_image = blurry_image.astype(np.float32)/255.0
            sharp_image = sharp_image.astype(np.float32)/255.0
            augmented = self.transform(image=blurry_image, image1=sharp_image)
            blurry_image = augmented['image']
            blurry_image = blurry_image.to(torch.float32)
            sharp_image = augmented['image1']
            sharp_image = sharp_image.to(torch.float32)

        return blurry_image, sharp_image

transform = A.Compose([
    A.RandomCrop(256, 256, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    ToTensorV2(),
], additional_targets={
        'image1': 'image'}, p=1)

transform_test = A.Compose([
    ToTensorV2(),
], additional_targets={
        'image1': 'image'}, p=1)

train_dataset = GoProDataset(root_dir='GoPro/train', transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=3, shuffle=True, num_workers=16)

test_dataset = GoProDataset(root_dir='GoPro/test', transform=transform_test)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
samples = torch.stack([train_dataset[i][0] for i in range(32, 48)], dim=0)

plt.figure(figsize=(10, 10))
plt.imshow(make_grid(samples, nrow=4).permute(1, 2, 0))
plt.show()

## Baseline

### Basic network (1 point)
We will start off by implementing the simplest possible convolutional neural network.  
A few convolutional layers, without pooling.

Remenicent of early works in [Super-Resolution](https://arxiv.org/abs/1501.00092) and [Denoising](https://arxiv.org/abs/1608.03981)

**Important to note that our network will learn the residual for debluring**  
$I_{deblur}=I_{input} + f_{\theta}(I_{input})$

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBnBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        """Simple convolutional block

        Your task is to fill in the following modules:

            conv + bn + relu

        """
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        # Your code vvv
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        # Your code ^^^


class Baseline(nn.Module):
    def __init__(self, block, n_blocks=5, n_filters=64):
        """Basic convolutional model

        Your task is to implement the following architecture:

            input_conv(3, n_filters)
            n_blocks * ConvBnBlock
            output_conv(n_filters, 3)

        """

        super().__init__()
        # Your code vvv
        # Your code ^^^

    def forward(self, x):
        inp = x
        # Your code vvv
        # Your code ^^^
        return x + inp

In [None]:
baseline = Baseline(ConvBnBlock)
test_input = torch.rand(1, 3, 256, 256)
test_output = baseline(test_input)
assert test_input.shape == test_output.shape

### PSNR Loss (0.5 points)

We will directly optimize the PSNR quality metric instead of plain MSE  
$PSNR = 10 * \log_{10}(\frac{MAX_I^2}{MSE})$  
In out case the images are normed to [0, 1], so the $MAX_I=1$

Be careful when averaging, first calculate the MSE for each pair of images, only then apply log and after that average along batch axis.

In [None]:
class PSNRLoss(nn.Module):
    def __init__(self):
        """Peak signal-to-noise ratio loss function

        NOTE: during training we minimize the loss but greater PSNR mean better
        So you may carry the -1 into the log when implementing
        thus elimating the need for division
        """
        super().__init__()
        self.eps = 1e-8  # use eps to prevent 0 in log

    def forward(self, pred, target):
        """
            Implement the following calculation:

                10 * mean(log10(mse(pred, target)))

        """
        # Your code vvv
        # Your code ^^^

In [None]:
criterion = PSNRLoss()
a = torch.tensor([[[[0.1632, 0.0024, 0.9913, 0.8892],
          [0.5655, 0.4472, 0.4592, 0.2013],
          [0.7722, 0.9089, 0.1708, 0.3654],
          [0.6147, 0.9567, 0.7018, 0.2376]]]])
b = torch.tensor([[[[0.8498, 0.1168, 0.3987, 0.6781],
          [0.7864, 0.9762, 0.3694, 0.9926],
          [0.9000, 0.0293, 0.0454, 0.0984],
          [0.9478, 0.3730, 0.9617, 0.5052]]]])
assert torch.isclose(criterion(a, b), torch.tensor(-6.8417))

### Training (0.5 points)
Basic training pipeline.

Note the use of gradient clipping, while not required to train the simplest model it will greately help for later parts of the task!

Hint: use [`torch.nn.utils.clip_grad_norm_`](https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html), `max_norm=0.05` worked fine

In [None]:
def train_model(model, train_dataloader, optimizer, criterion, scheduler, num_epochs=200, checkpoints_path="./checkpoints", use_grad_clip=True):
    # Create checkpoints folder
    PATH = checkpoints_path
    os.makedirs(PATH, exist_ok = True)

    # Training loop
    for epoch in range(num_epochs):
        for inputs, targets in train_dataloader:
            # Your code vvv

            if use_grad_clip:
                pass

            # Your code ^^^

        torch.save({
                  'epoch': epoch,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict()
                  }, os.path.join(PATH, f"epoch_{epoch+1}.tar"))

        # Update the learning rate
        scheduler.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}, LR: {scheduler.get_last_lr()[0]}')

In [None]:
from utils import get_scheduler

torch.manual_seed(11)

model = Baseline(ConvBnBlock)
model = model.to(device)

criterion = PSNRLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
use_grad_clip = True
scheduler = get_scheduler(optimizer)

train_model(model, train_dataloader, optimizer, criterion, scheduler, num_epochs=50)

In [None]:
result = test_model(model, device, test_dataloader)
assert result <= 0.00252
print("Congrats!")

In [None]:
pics = test_dataset[2]
blurred, gt = pics[0], pics[1]
plt.figure(figsize=(20, 4))
plt.suptitle("Visual inspection")
plt.subplot(131)
plt.title("Blurred")
plt.imshow(blurred.permute(1, 2, 0))
plt.subplot(132)
plt.title("Model result")
with torch.no_grad():
    output = model(blurred.unsqueeze(0).to(device)).cpu().squeeze(0).permute(1, 2, 0).numpy()
output = np.clip(output, 0, 1)
plt.imshow(output)
plt.subplot(133)
plt.imshow(gt.permute(1, 2, 0))
plt.title("GT")
plt.show()

## UNet

The next step is to process our image in multiple scales, like the multi-scale methods of old school CV.

The most common model that uses multiple scales is the [UNet](https://arxiv.org/abs/1505.04597), first proposed for medical segmentation and containing many strange design decisions.  
Nonetheless, it has proven to be useful for many CV tasks and is videly used in image and video restoration.

U-Nets are named as such because they have this U-like shape, where the input image is first reduced in dimensionality in the downsizing portion, then increased in dimensionality back to its original size in the upsizing portion.

<img src="images/unet.png">


UNet may be implemented with either concatenation of skip-connections or with the summation.  
The results usually do not differ that much.

### Double convolution (0.5 points)
In UNet each convolution block is created with two convolutions.  
You have to implement a simple stack of two layers of convolution with two ReLU activations

In [None]:
class UNetBlock(nn.Module):
    def __init__(self, in_channels):
        """Basic building block of UNet architecture

        Your task is to fill in the following modules:

            conv + bn + relu + conv + bn + relu

        """
        super().__init__()
        # Your code vvv
        # Your code ^^^

    def forward(self, x):
        # Your code vvv
        # Your code ^^^

In [None]:
block = UNetBlock(64)
vec = torch.rand(16, 64, 256, 256)
assert block(vec).shape == torch.Size([16, 64, 256, 256])

### UNet Down block (0.5 points)

In [None]:
class UNetDownBlock(nn.Module):
    def __init__(self, chan):
        """Downsampling block in the encoder of UNet

            Your task is to fill in the following modules:

                AvgPool + Conv 1x1

            Input spatial dimension is **reduced** by a factor of 2
            The number of channels is **increased** by a factor of 2
        """
        super().__init__()
        # Your code vvv
        # Your code ^^^

    def forward(self, x):
        # Your code vvv
        # Your code ^^^

In [None]:
block = UNetDownBlock(64)
vec = torch.rand(16, 64, 256, 256)
assert block(vec).shape == torch.Size([16, 128, 128, 128])

### UNet Up block (0.5 points)

In [None]:
class UNetUpBlock(nn.Module):
    def __init__(self, chan):
        """Upsampling block in the encoder of UNet

            Your task is to fill in the following modules:

                Upsample + Conv2d

            Input spatial dimension is **increased** by a factor of 2
            The number of channels is **reduced** by a factor of 2
        """
        super(UNetUpBlock, self).__init__()
        # Your code vvv
        # Your code ^^^

    def forward(self, x):
        # Your code vvv
        # Your code ^^^

In [None]:
block = UNetUpBlock(64)
vec = torch.rand(16, 64, 256, 256)
assert block(vec).shape == torch.Size([16, 32, 512, 512])

### Generalized Unet arch (1.5 points)
Over the years the UNet architecture has proven itself useful.  
Many follow up papers retain macro level architecture and change individual blocks, i.e. only changing the downscaling operation, changing number of convs in block, adding attention to bottleneck, etc.  

Therefore we will also implement a generalized architecture.

**Your model must pass the quality assert to get points for this part**

In [None]:
class GeneralizedUNet(nn.Module):
    def __init__(self, block, downBlock, upBlock, img_channel=3, width=16,
                 middle_blk_num=1, enc_blk_nums=[1,1,1,1], dec_blk_nums=[1,1,1,1]):
        """GeneralizedUNet architecture

            This part is implemented for you
            But feel free to change it

        """
        super().__init__()

        self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, bias=True)
        self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1, bias=True)

        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        self.middle_blks = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()

        chan = width

        for num in enc_blk_nums:
            self.encoders.append(
                nn.Sequential(
                    *[block(chan) for _ in range(num)]
                )
            )
            self.downs.append(
                downBlock(chan)
            )
            chan = chan * 2

        self.middle_blks = \
            nn.Sequential(
                *[block(chan) for _ in range(middle_blk_num)]
            )

        for num in dec_blk_nums:
            self.ups.append(
                upBlock(chan)
            )
            chan = chan // 2
            self.decoders.append(
                nn.Sequential(
                    *[block(chan) for _ in range(num)]
                )
            )

        self.padder_size = 2 ** len(self.encoders)

    def forward(self, inp):
        """Performs forward pass of a UNet-like model

        Your task is to implement the following steps:

            1. Intro convolution
            2. Encoder blocks + save skip connections + downsample
            3. Middle blocks (also known as bottleneck)
            4. Decoder blocks + add skip connection
            5. Ending
            6. Inp + ending out (we're still using residual lerning)

        """
        activations = []
        # Your code vvv
        # Your code ^^^

In [None]:
model = GeneralizedUNet(UNetBlock, UNetDownBlock, UNetUpBlock)
vec = torch.rand(16, 3, 256, 256)
assert model(vec).shape == torch.Size([16, 3, 256, 256])

### Train UNet

In [None]:
torch.manual_seed(11)

model = GeneralizedUNet(UNetBlock, UNetDownBlock, UNetUpBlock)
model = model.to(device)

criterion = PSNRLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
use_grad_clip = True
scheduler = get_scheduler(optimizer)

train_model(model, train_dataloader, optimizer, criterion, scheduler, num_epochs=100, use_grad_clip=use_grad_clip)

In [None]:
result = test_model(model, device, test_dataloader)
assert result <= 0.00235

In [None]:
pics = test_dataset[2]
blurred, gt = pics[0], pics[1]
plt.figure(figsize=(20, 4))
plt.suptitle("Visual inspection")
plt.subplot(131)
plt.title("Blurred")
plt.imshow(blurred.permute(1, 2, 0))
plt.subplot(132)
plt.title("Model result")
with torch.no_grad():
    output = model(blurred.unsqueeze(0).to(device)).cpu().squeeze(0).permute(1, 2, 0).numpy()
output = np.clip(output, 0, 1)
plt.imshow(output)
plt.subplot(133)
plt.imshow(gt.permute(1, 2, 0))
plt.title("GT")
plt.show()

## NAFNet
As discussed in the previous section, many networks build upon UNet and change individual blocks.  
[NAFNet](https://arxiv.org/abs/2204.04676) was for a long time a SOTA deblurring approach, even thought it was presented as a baseline.

The main differences are:
* No regular nonlinearities (no ReLU/GELU/ELU/etc.)
* Simplified channel attention
* LayerNorm instead of BatchNorm
* and many more training tricks taken from transformer papers

In a way this work is similar to [ConvNeXt](https://arxiv.org/abs/2201.03545), where authors also shook off the dust of ResNet and trained a SOTA classification CNN.


**And now you are tasked with implementing NAFNet :)**

### Simple gate block (0.5 points)

<img src="images/gate.png">

Instead of using regular non-linearities NAFNet proposes the use of "Simple Gates" which perform pointwise multiplication of feature maps.

Simple gate layer splits input feature map in 2 part along the channel axis and multiplies them.

You may find [`torch.chunk`](https://pytorch.org/docs/stable/generated/torch.chunk.html) to be useful

In [None]:
class SimpleGate(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        """

            Split the input in 2, along the channel axis
            Return multiplication of these 2 parts

        """
        # Your code vvv
        # Your code ^^^

In [None]:
block = SimpleGate()
vec = torch.rand(16, 64, 256, 256)
assert block(vec).shape == torch.Size([16, 32, 256, 256])

### NAFNet up block (0.5 points)

In [None]:
class NAFNetUpBlock(nn.Module):
    def __init__(self, channels):
        """NAFNet upsampling block

        Implement and use the following modules:

            conv 1x1 (chan, 2 * chan)
            pixelshuffle(2)

        """
        super().__init__()
        # Your code vvv
        # Your code ^^^
    def forward(self, x):
        # Your code vvv
        # Your code ^^^

In [None]:
block = NAFNetUpBlock(64)
vec = torch.rand(16, 64, 256, 256)
assert block(vec).shape == torch.Size([16, 32, 512, 512])

### NAFNet down block (0.5 points)

In [None]:
class NAFNetDownBlock(nn.Module):
    def __init__(self, channels):
        """NAFNet downsampling block

        Implement and use the following modules:

            conv with stride 2, **mind the padding**

        """
        super().__init__()
        # Your code vvv
        # Your code ^^^

    def forward(self, x):
        # Your code vvv
        # Your code ^^^

In [None]:
block = NAFNetDownBlock(64)
vec = torch.rand(16, 64, 256, 256)
assert block(vec).shape == torch.Size([16, 128, 128, 128])

### Simplified channel attention (0.5 points)

Regular channel attention produces weights for each channel of input feature map using 2-layer MLP.  
NAFNet authors propose to remove the small MLP on channels by just pooling and linearly projecting the feature map to get channel weights.

<img src="images/sca.png">

In [None]:
class SCA(nn.Module):
    def __init__(self, in_channels, out_channels):
        """Simplified channel attention module

        Implement and use the following modules:

            adaptiveavgpool to get 1x1 feature map
            conv 1x1 projection layer

        """
        super().__init__()
        # Your code vvv
        # Your code ^^^
    def forward(self, x):
        """

            Return only the attention weights

        """
        # Your code vvv
        # Your code ^^^

In [None]:
block = SCA(64, 32)
vec = torch.rand(16, 64, 256, 256)
assert block(vec).shape == torch.Size([16, 32, 256, 256])

### NAFNet block (3 points)

**The final boss of the task**.

The diagram shows the intra-block structure:

<img src="images/nafnet_block.png" width=620>

Note the use of learnable skip-connection scales `beta` and `gamma`, use [`nn.Parameter`](https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html)

**Your model must pass the quality assert to get points for this part**

In [None]:
from utils import LayerNorm2d  # use this layernorm

class NAFNetBlock(nn.Module):
    def __init__(self, c, DW_Expand=2, FFN_Expand=2):
        super().__init__()

        dw_channel = c * DW_Expand
        ffn_channel = FFN_Expand * c

        self.beta = nn.Parameter(...)
        self.gamma = nn.Parameter(...)

        # Your code vvv
        # Your code ^^^

    def forward(self, inp):
        # Your code vvv
        # Your code ^^^

In [None]:
block = NAFNetBlock(64)
vec = torch.rand(16, 64, 256, 256)
assert block(vec).shape == torch.Size([16, 64, 256, 256])

### Train NAFNet

In [None]:
torch.manual_seed(11)

model = GeneralizedUNet(NAFNetBlock, NAFNetDownBlock, NAFNetUpBlock, enc_blk_nums=[1,2,2,28], dec_blk_nums=[2,2,2,1])
model = model.to(device)

criterion = PSNRLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
use_grad_clip = True
scheduler = get_scheduler(optimizer)

train_model(model, train_dataloader, optimizer, criterion, scheduler)

In [None]:
from utils import test_model
result = test_model(model, device, test_dataloader)
assert result <= 0.002

## Check model results

In [None]:
pics = test_dataset[2]
blurred, gt = pics[0], pics[1]
plt.figure(figsize=(20, 4))
plt.suptitle("Visual inspection")
plt.subplot(131)
plt.title("Blurred")
plt.imshow(blurred.permute(1, 2, 0))
plt.subplot(132)
plt.title("Model result")
with torch.no_grad():
    output = model(blurred.unsqueeze(0).to(device)).cpu().squeeze(0).permute(1, 2, 0).numpy()
output = np.clip(output, 0, 1)
plt.imshow(output)
plt.subplot(133)
plt.imshow(gt.permute(1, 2, 0))
plt.title("GT")
plt.show()

# Your thoughts

Please share your thoughts on the task.  
Did you like it, did you hate it?

If so, why?

**your text goes here**