# ESRGAN
Ali Ghanbari - 40110524



- exe_env: Set the enviorment manually. if the enviorment is not local the script needs to download datasets.
- hr_shape: the resolution of target images. 2040x2040 in the original paper. I set it to 800x800 because thats the heighest resolution that fits in 16gigs of VRAM.
- network_depth: The number of "Basic Blocks" as described in the paper. 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

exc_env = 'colab'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hr_shape = (800, 800) # (2040, 2040)
network_depth = 16
device

device(type='cuda')

----

## Dataset
1. download:

I uploaded the dataset to my google drive because it takes forever to download all the datasets and unzip them.

In [None]:
if exc_env != 'local':
    from google.colab import drive
    drive.mount('/content/drive')
    pass

Mounted at /content/drive


3. Create dataset access object:

- mean,std: The use of this mean and std are recommended. [see this link](https://stackoverflow.com/questions/58151507/why-pytorch-officially-use-mean-0-485-0-456-0-406-and-std-0-229-0-224-0-2)
- transforms: because we intend to do 4x super resolution. we add 2 transforms to get HR and LR image from ground truth image.

The rest is pretty self explanatory. it's just a dataset object, returning HR and LR from a set of image files from my google drive.

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
import glob
import random

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

def denormalize(tensors):
    """ Denormalizes image tensors using mean and std """
    for c in range(3):
        tensors[:, c].mul_(std[c]).add_(mean[c])
    return torch.clamp(tensors, 0, 255)

class ImageDataset(Dataset):
    def __init__(self, root, hr_shape):
        hr_height, hr_width = hr_shape
        # Transforms for low resolution images and high resolution images
        self.lr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
        self.hr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height, hr_height), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )

        self.files = sorted(glob.glob(root + "/*_HR/*.*"))

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)

        return {"lr": img_lr, "hr": img_hr}

    def __len__(self):
        return len(self.files)

---

# Architectural Components

### 1. FeatureExtractor

A model based on the pre-activation layers of The VGG-19 pre-trained model.

before feature-activation in VGG19 there are 54 feature maps, thats why it's called `vgg19_54`.

The VGG-19 is pre-trained on more than a million images.

We use the output of this network as the basis for "Preceptual Loss".

In [None]:
from torchvision.models import vgg19, VGG19_Weights

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(weights=VGG19_Weights.DEFAULT)
        self.vgg19_54 = nn.Sequential(*list(vgg19_model.features.children())[:35])

    def forward(self, img):
        return self.vgg19_54(img)

### 2.DenseResidualBlock

![RRDB](https://drive.google.com/file/d/1w88_O0QS-EVtKWscmCQqxoU8a0JNiNST/view?usp=share_link)

As seen in the figure above, a dense block in RRDB consists of 5 convolution layers where the first 4 are followed by a leaky ReLU activation function.

There is also a β at the end of each block which is the residual scaling factor. it is one of the technics used to reduce artifacts.

In [None]:
class DenseResidualBlock(nn.Module):
    def __init__(self, filters, res_scale=0.2):
        super(DenseResidualBlock, self).__init__()
        self.res_scale = res_scale

        def block(in_features, non_linearity=True):
            layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)]
            if non_linearity:
                layers += [nn.LeakyReLU()]
            return nn.Sequential(*layers)

        self.b1 = block(in_features=1 * filters)
        self.b2 = block(in_features=2 * filters)
        self.b3 = block(in_features=3 * filters)
        self.b4 = block(in_features=4 * filters)
        self.b5 = block(in_features=5 * filters, non_linearity=False)
        self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]

    def forward(self, x):
        inputs = x
        for block in self.blocks:
            out = block(inputs)
            inputs = torch.cat([inputs, out], 1)
        return out.mul(self.res_scale) + x

### 2.ResidualInResidualDenseBlock

![RRDB](https://aliqanbari.s3.ir-thr-at1.arvanstorage.ir/RRDB.jpg)

We put 3 DenseBlocks one after another in a residual way(`+ x`) and apply residual scaling(`mul(self.res_scale)`)

In [None]:
class ResidualInResidualDenseBlock(nn.Module):
    def __init__(self, filters, res_scale=0.2):
        super(ResidualInResidualDenseBlock, self).__init__()
        self.res_scale = res_scale
        self.dense_blocks = nn.Sequential(
            DenseResidualBlock(filters), DenseResidualBlock(filters), DenseResidualBlock(filters)
        )

    def forward(self, x):
        return self.dense_blocks(x).mul(self.res_scale) + x

----

## Model

1. Generator:

![SRResNet](https://aliqanbari.s3.ir-thr-at1.arvanstorage.ir/assets/SRResNet.jpg)

The Generator uses SRResNet Structure but with `RRDB` as the basic block.

In [None]:
class Generator(nn.Module):
    def __init__(self, channels, filters=64, num_res_blocks=16, num_upsample=2):
        super(Generator, self).__init__()

        # First layer
        self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
        # Residual blocks
        self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(filters) for _ in range(num_res_blocks)])
        # Second conv layer post residual blocks
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
        # Upsampling layers
        upsample_layers = []
        for _ in range(num_upsample):
            upsample_layers += [
                nn.Conv2d(filters, filters * 4, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(),
                nn.PixelShuffle(upscale_factor=2),
            ]
        self.upsampling = nn.Sequential(*upsample_layers)
        # Final output block
        self.conv3 = nn.Sequential(
            nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.conv3(out)
        return out

2. Discriminator:

![Discriminator](https://aliqanbari.s3.ir-thr-at1.arvanstorage.ir/assets/srgan-discrimenator.jpg)

The discriminator is very similar to the one in SRGAN but instead of outputing a single value for real & fakeness, we output an activation map to indicate which parts are more fake or more real.

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
        self.output_shape = (1, patch_h, patch_w)

        def discriminator_block(in_filters, out_filters, first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]):
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

----

## Training

### 1. Setup:

In [None]:
from torch.utils.data import DataLoader
import torch

epochs = 500
epoch = 4
checkpoint_interval = 1
sample_interval = 50
learning_rate = 0.0002
dataset_name = 'drive/MyDrive/Datasets/Images'

# Initialize generator and discriminator
generator = Generator(3, filters=64, num_res_blocks=network_depth).to(device)
discriminator = Discriminator(input_shape=(3, *hr_shape)).to(device)
feature_extractor = FeatureExtractor().to(device)

# Set feature extractor to inference mode
feature_extractor.eval()

# Losses
criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
criterion_content = torch.nn.L1Loss().to(device)
criterion_pixel = torch.nn.L1Loss().to(device)

if epoch != 0:
    # Load pretrained models
    generator.load_state_dict(torch.load(f"/content/drive/MyDrive/Models/generator_{hr_shape[0]}_{network_depth}_{epoch}.pth"))
    discriminator.load_state_dict(torch.load(f"/content/drive/MyDrive/Models/discriminator_{hr_shape[0]}_{network_depth}_{epoch}.pth"))

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor

dataloader = DataLoader(
    ImageDataset(dataset_name, hr_shape=hr_shape),
    shuffle=True,
    batch_size=1,
)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]



### 2. Train:

In the first epoch, we only train the generator using L1-loss. after the generator has warmed up we then add the Discriminator for adversarial loss and the FeatureExtractor for perceptual loss.

The total loss of the generator is then calculated from all three as such:
```
loss_G = loss_content + 5e-3 * loss_GAN + 1e-2 * loss_pixel
```

In [None]:
from torch.autograd import Variable
from torchvision.utils import save_image

for epoch in range(epoch, epochs):
    for i, imgs in enumerate(dataloader):

        batches_done = epoch * len(dataloader) + i

        # Configure model input
        imgs_lr = Variable(imgs["lr"].type(Tensor))
        imgs_hr = Variable(imgs["hr"].type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        optimizer_G.zero_grad()

        # Generate a high resolution image from low resolution input
        gen_hr = generator(imgs_lr)

        # Measure pixel-wise loss against ground truth
        loss_pixel = criterion_pixel(gen_hr, imgs_hr)

        warmup_batches = len(dataloader) * 1
        if batches_done < warmup_batches:
            # Warm-up (pixel-wise loss only)
            loss_pixel.backward()
            optimizer_G.step()
            print(
                "[Epoch %d/%d] [Batch %d/%d] [G pixel: %f]"
                % (epoch, epochs, i, len(dataloader), loss_pixel.item())
            )
            continue

        # Extract validity predictions from discriminator
        pred_real = discriminator(imgs_hr).detach()
        pred_fake = discriminator(gen_hr)

        # Adversarial loss (relativistic average GAN)
        loss_GAN = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid)

        # Content loss
        gen_features = feature_extractor(gen_hr)
        real_features = feature_extractor(imgs_hr).detach()
        loss_content = criterion_content(gen_features, real_features)

        # Total generator loss
        loss_G = loss_content + 5e-3 * loss_GAN + 1e-2 * loss_pixel

        loss_G.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        pred_real = discriminator(imgs_hr)
        pred_fake = discriminator(gen_hr.detach())

        # Adversarial loss for real and fake images (relativistic average GAN)
        loss_real = criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)
        loss_fake = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)

        # Total loss
        loss_D = (loss_real + loss_fake) / 2

        loss_D.backward()
        optimizer_D.step()

        # --------------
        #  Log Progress
        # --------------

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, content: %f, adv: %f, pixel: %f]"
            % (
                epoch,
                epochs,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_content.item(),
                loss_GAN.item(),
                loss_pixel.item(),
            )
        )

        if batches_done % sample_interval == 0:
            # Save image grid with upsampled inputs and ESRGAN outputs
            imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
            img_grid = denormalize(torch.cat((imgs_lr, gen_hr), -1))
            save_image(img_grid, f"/content/drive/MyDrive/Training/{hr_shape[0]}_{epoch}_{network_depth}_{i * epoch}.png", nrow=1, normalize=False)

        if batches_done % checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(generator.state_dict(), f"/content/drive/MyDrive/Models/generator_{hr_shape[0]}_{network_depth}_{epoch}.pth")
            torch.save(discriminator.state_dict(), f"/content/drive/MyDrive/Models/discriminator_{hr_shape[0]}_{network_depth}_{epoch}.pth")

[Epoch 3/500] [Batch 0/1300] [D loss: 0.000909] [G loss: 1.321213, content: 1.278195, adv: 8.160275, pixel: 0.221699]
[Epoch 3/500] [Batch 1/1300] [D loss: 0.000250] [G loss: 0.897165, content: 0.844847, adv: 10.150295, pixel: 0.156680]
[Epoch 3/500] [Batch 2/1300] [D loss: 0.185972] [G loss: 1.054695, content: 1.043352, adv: 1.919887, pixel: 0.174269]
[Epoch 3/500] [Batch 3/1300] [D loss: 0.014866] [G loss: 0.769045, content: 0.743570, adv: 4.838839, pixel: 0.128063]
[Epoch 3/500] [Batch 4/1300] [D loss: 0.000094] [G loss: 0.936002, content: 0.872000, adv: 12.475327, pixel: 0.162478]
[Epoch 3/500] [Batch 5/1300] [D loss: 0.000051] [G loss: 0.871796, content: 0.812017, adv: 11.789556, pixel: 0.083095]
[Epoch 3/500] [Batch 6/1300] [D loss: 0.000082] [G loss: 1.367742, content: 1.305396, adv: 12.043975, pixel: 0.212588]
[Epoch 3/500] [Batch 7/1300] [D loss: 0.001087] [G loss: 1.585505, content: 1.538128, adv: 9.099475, pixel: 0.187948]
[Epoch 3/500] [Batch 8/1300] [D loss: 0.017901] [G l

KeyboardInterrupt: ignored

---

## Testing

In [None]:
import os
import glob

input_test_images = glob.glob("/content/test/*.png")


input_test_images

['/content/test/cat.png',
 '/content/test/birds.png',
 '/content/test/building.png',
 '/content/test/baboon.png',
 '/content/test/baby.png']

In [None]:
from PIL import Image
from torch.autograd import Variable
from torchvision.utils import save_image

os.makedirs("output", exist_ok=True)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

for test_img in input_test_images:
  image_tensor = Variable(transform(Image.open(test_img))).to(device).unsqueeze(0)
  with torch.no_grad():
    lr_img = nn.functional.interpolate(image_tensor, scale_factor=4)
    sr_image = generator(image_tensor)
    img_grid = denormalize(torch.cat((lr_img, sr_image), -1))
    file_name = test_img.split("/")[-1]
    save_image(img_grid, f"output/sr-{file_name}")