In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# imports
import torch 
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision
from torch.utils.tensorboard import SummaryWriter

import PIL
from PIL import Image

import h5py
import sys, os, math,random,time
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
!pip install icecream
from icecream import ic 

# !pip install ipyplot
# import ipyplot

from IPython.display import clear_output 

In [None]:
# tensorboard writer instantiation 
writer = SummaryWriter()
# global variables declaration 
device="cuda:0" if torch.cuda.is_available() else "cpu"
dataset_dir="<PATH TO DIV2k Data>"

In [None]:
def get_patch(lr_img, hr_img, patch_size=96, scale = 4 , ix=-1, iy=-1):
    (ih, iw) = lr_img.shape[0],lr_img.shape[1]
    patch_size=96
    scale=4
    ip = patch_size
    tp = patch_size *scale

    ix = random.randrange( 0 , iw - patch_size - 1 )
    iy = random.randrange( 0 , ih - patch_size - 1 )

    (tx, ty) = (scale * ix, scale * iy)

    lr_patch = lr_img[iy : iy + ip, ix : ix + ip,:]
    hr_patch = hr_img[ty : ty + tp, tx : tx + tp,:]

    return lr_patch,hr_patch 

In [None]:
class DIV2KDataset(Dataset):
    def __init__(self,dataset_dir,dataset_length,scale=4,patch_size=96):
        self.lr_h5_file = h5py.File(f'{dataset_dir}/DIV2K_CAR_LR_X4_train.h5' , 'r')['X4']
        self.hr_h5_file = h5py.File(f'{dataset_dir}/DIV2K_HR_train.h5' , 'r')['HR']

        self.dataset_dir=dataset_dir
        self.dataset_length=dataset_length
        self.patch_size=patch_size
        self.scale=scale

    def __len__(self):
        return self.dataset_length
    
    def __getitem__(self,idx):
        img_name =str(idx)

        lr_img = self.lr_h5_file[img_name][()]
        hr_img = self.hr_h5_file[img_name][()]

        lr_patch,hr_patch=get_patch(lr_img,hr_img,self.scale,self.patch_size)
        
        return {
            'lr_img':torch.from_numpy(lr_patch),
            'hr_img':torch.from_numpy(hr_patch),
        }

In [None]:
div2k_dataset = DIV2KDataset(dataset_dir=dataset_dir,dataset_length=800,scale=4)
dataloader = DataLoader(div2k_dataset, batch_size=8,shuffle=True,pin_memory=True)

In [None]:
# check if the loader is working properly 
patches=div2k_dataset[1]
display(Image.fromarray(patches['lr_img'].numpy(),mode="RGB"))
display(Image.fromarray(patches['hr_img'].numpy(),mode="RGB"))

In [None]:
class VGGLoss(torch.nn.Module):
    def __init__(self, feature_layer= 35):
        super(VGGLoss, self).__init__()
        model = torchvision.models.vgg19(pretrained=True)
        self.features = torch.nn.Sequential(*list(model.features.children())[:feature_layer]).eval()
        # Freezing parameters, not to train.
        for name, param in self.features.named_parameters():
            param.requires_grad = False

    def forward(self, source, target) -> torch.Tensor:
        vgg_loss = torch.nn.functional.l1_loss(self.features(source), self.features(target))
        return vgg_loss

In [None]:
def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size//2), bias=bias)

In [None]:
# Normalization Module
class MeanShift(nn.Conv2d):
    def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1)
        self.weight.data.div_(std.view(3, 1, 1, 1))
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
        self.bias.data.div_(std)
        self.requires_grad = False

In [None]:
# Residual Block Module
class ResBlock(nn.Module):
    def __init__(
        self, conv, n_feat, kernel_size,
        bias=True, bn=False, act=nn.ReLU(True)):

        super(ResBlock, self).__init__()
        m = []
        for i in range(2):
            m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
            if bn: m.append(nn.BatchNorm2d(n_feat))
            if i == 0: m.append(act)

        self.body = nn.Sequential(*m)

    def forward(self, x):
        res = self.body(x).mul(1)
        res += x

        return res

In [None]:
# Self Attention Module
class SelfAttentionBlock(nn.Module):
    def __init__(self,n_feats):
        super(SelfAttentionBlock, self).__init__()
        self.X_copy_1 = None
        self.X_copy_2 = None
        self.conv_layer1 = default_conv(n_feats,n_feats,1)
        self.conv_layer2 = default_conv(n_feats,n_feats,1)
        self.conv_layer3 = default_conv(n_feats,n_feats,1)
        self.sig = nn.Sigmoid()

    def forward(self, X):
        self.X_copy_1 = torch.clone(X)
        self.X_copy_2 = torch.clone(X)

        layer1 = self.conv_layer1(self.X_copy_1)
        layer2 = self.conv_layer2(self.X_copy_2)
        projection_map = torch.matmul(layer1,layer2.transpose(2,3))
        attention_map = self.sig(projection_map)
        sa_output = torch.matmul(attention_map,self.conv_layer3(X))
        return sa_output

In [None]:
# Channel Attention 
class CALayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CALayer, self).__init__()
        self.channel=channel
        self.reduction=reduction
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # self.conv_du = nn.Sequential(
        #         nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
        #         nn.ReLU(inplace=True),
        #         nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
        #         nn.Sigmoid()
        # )
        self.conv_du = nn.TransformerEncoderLayer(d_model=channel, nhead = reduction, dim_feedforward = channel//reduction , dropout=0, activation='gelu')
    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y.reshape(1,-1,self.channel))
        return x * y.reshape(-1,self.channel,1,1)

In [None]:
# Residual Channel Attention Block (RCAB)
class RCAB(nn.Module):
    def __init__(self, conv, n_feat, kernel_size, reduction,bias=True, act=nn.ReLU(True)):
        super(RCAB, self).__init__()
        modules_body = []
        for i in range(2):
            modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
            if i == 0: modules_body.append(act)
        modules_body.append(CALayer(n_feat, reduction))
        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

In [None]:
# Residual Group (RG)
class ResidualGroup(nn.Module):
    def __init__(self, conv, n_feat, kernel_size, reduction, act, n_resblocks):
        super(ResidualGroup, self).__init__()
        modules_body = []
        modules_body = [RCAB(conv, n_feat, kernel_size, reduction, bias=True, act=nn.ReLU(True)) for _ in range(n_resblocks)]
        modules_body.append(conv(n_feat, n_feat, kernel_size))
        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res   

In [None]:
# Upsampling Module
class Upsampler(nn.Sequential):
    def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
        m = []
        if (scale & (scale - 1)) == 0:    # Is scale = 2^n?
            for _ in range(int(math.log(scale, 2))):
                m.append(conv(n_feat, 4 * n_feat, 3, bias))
                m.append(nn.PixelShuffle(2))
                if bn: m.append(nn.BatchNorm2d(n_feat))
                if act: m.append(act())
        elif scale == 3:
            m.append(conv(n_feat, 9 * n_feat, 3, bias))
            m.append(nn.PixelShuffle(3))
            if bn: m.append(nn.BatchNorm2d(n_feat))
            if act: m.append(act())
        else:
            print("Scale not implemented")
        super(Upsampler, self).__init__(*m)

In [None]:
# GAN Generator
class SRGenerator(nn.Module):
    def __init__(self, conv=default_conv):
        super(SRGenerator, self).__init__()
        
        n_resgroups = 2
        n_resblocks = 4
        n_feats = 128
        kernel_size = 3
        reduction = 16
        scale = 4
        act = nn.ReLU(True)
        n_rec_blocks = 4
        
        self.small_kernel_size=3
        self.large_kernel_size=5
        self.kernel_sizes=2  

        # RGB mean for DIV2K
        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)
        self.sub_mean = MeanShift(1, rgb_mean, rgb_std)

        self.feature_small = default_conv(3,n_feats,self.small_kernel_size)
        self.sa_small_1 = SelfAttentionBlock(n_feats=n_feats)
        self.sa_small_2 = SelfAttentionBlock(n_feats=n_feats)

        self.feature_large = default_conv(3,n_feats,self.large_kernel_size)
        self.sa_large_1 = SelfAttentionBlock(n_feats=n_feats)
        self.sa_large_2 = SelfAttentionBlock(n_feats=n_feats)


        # define body module
        modules_body_1 = [ResidualGroup(conv, n_feats, kernel_size, reduction, act=act, n_resblocks=n_resblocks) for _ in range(n_resgroups)]
        modules_body_2 = [ResidualGroup(conv, n_feats, kernel_size, reduction, act=act, n_resblocks=n_resblocks) for _ in range(n_resgroups)]

        modules_body_1.append(conv(n_feats, n_feats, kernel_size))
        modules_body_2.append(conv(n_feats, n_feats, kernel_size))

        conv_1x1 = default_conv(n_feats*self.kernel_sizes,n_feats*self.kernel_sizes,kernel_size=1)
        ca = CALayer(n_feats*self.kernel_sizes)
        self.rec_input=nn.Sequential(*[conv_1x1,ca])

        # define tail module
        modules_tail = [Upsampler(conv, scale, n_feats*self.kernel_sizes, act=False),conv(n_feats*self.kernel_sizes, 3, kernel_size)]

        self.add_mean = MeanShift(1, rgb_mean, rgb_std, 1)

        # self.head = nn.Sequential(*modules_head)
        self.res_body_1 = nn.Sequential(*modules_body_1)
        self.res_body_2 = nn.Sequential(*modules_body_2)
        self.tail = nn.Sequential(*modules_tail)

    # for kernel_size 3        
    def small_proc(self, X):

        features_ext = self.feature_small(X)
        # sa11_features=self.sa_small_1(features_ext)
        resblocks_out = self.res_body_1(features_ext)
        resblocks_out+=features_ext

        return resblocks_out#self.sa_small_2() 
    
    # for kernel_size 9
    def large_proc(self, X):
        
        features_ext = self.feature_large(X)
        # sa21_features=self.sa_large_1(features_ext)
        resblocks_out = self.res_body_2(features_ext)
        resblocks_out+=features_ext
        
        return resblocks_out#self.sa_large_2() 

    def forward(self, x):
        x = self.sub_mean(x)
        x_copy = torch.clone(x)
        
        #get each kernel's output
        small_Y = self.small_proc(x)
        large_Y = self.large_proc(x)  

        concat_output=torch.cat((small_Y,large_Y),1) 
        rec_input=self.rec_input(concat_output)

        x = self.tail(rec_input)
        x = self.add_mean(x)

        return x

In [None]:
# PatchGAN discriminator
# Snippet borrowed from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/f13aab8148bd5f15b9eb47b690496df8dadbab0c/models/networks.py#L538
class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        super(NLayerDiscriminator, self).__init__()
        use_bias = True
        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, x):
        return self.model(x)

In [None]:
gen_model = SRGenerator().to(device)
disc_model = NLayerDiscriminator().to(device)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print('gen_model',f'{count_parameters(gen_model):,}')
print('disc_model',f'{count_parameters(disc_model):,}')

In [None]:
vgg_loss_criterion=VGGLoss().to(device)
l1_loss_criterion=nn.L1Loss()
adv_criterion=nn.BCEWithLogitsLoss()

In [None]:
gen_optimizer = torch.optim.Adam(gen_model.parameters(), lr=0.00015) 
gen_scheduler = torch.optim.lr_scheduler.StepLR(gen_optimizer, step_size=125, gamma=0.9)

In [None]:
disc_optimizer = torch.optim.Adam(disc_model.parameters(), lr=0.0001) #lr=0.00002
disc_scheduler = torch.optim.lr_scheduler.StepLR(disc_optimizer, step_size=125, gamma=0.9) 

In [None]:
GEN_MODEL_PATH=f'/content/drive/MyDrive/trained_models/test_model_GEN_{datetime.now()}.pt'
DISC_MODEL_PATH=f'/content/drive/MyDrive/trained_models/test_model_DISC_{datetime.now()}.pt'
# GEN_MODEL_PATH=f'/content/drive/MyDrive/trained_models/test_model_GEN_2021-04-24 19_26_37.116762.pt'
# DISC_MODEL_PATH=f'/content/drive/MyDrive/trained_models/test_model_DISC_2021-04-24 19_26_37.116821.pt'
load_model=False
save_model=True
if load_model:
    gen_model.load_state_dict(torch.load(GEN_MODEL_PATH,map_location=device))
    disc_model.load_state_dict(torch.load(DISC_MODEL_PATH,map_location=device))
print(GEN_MODEL_PATH)
print(DISC_MODEL_PATH)

In [None]:
epochs=200

In [None]:
if save_model: 
    for epoch in tqdm(range(epochs)):
        for k,data in enumerate(dataloader):
            
            # Batching
            lr_batch = (data['lr_img']).permute(0, 3,  1, 2).to(device).float()
            hr_batch = (data['hr_img']).permute(0, 3,  1, 2).to(device).float()
            batch_size = lr_batch.shape

            # The real sample label is 1, and the generated sample label is 0.
            tensor_shape = (batch_size[0], 1, 46, 46 )     # for 96 X 96 patch size; generalize it later
            real_label = torch.full(tensor_shape, 1, dtype=lr_batch.dtype).to(device)
            fake_label = torch.full(tensor_shape, 0, dtype=lr_batch.dtype).to(device)

            # Set discriminator gradients to zero.
            disc_model.zero_grad()

            # Feed forward in Discriminator
            real_output = disc_model(hr_batch)
            sr = gen_model(lr_batch)
            fake_output = disc_model(sr.detach())

            # Adversarial loss for real and fake images (relativistic average GAN)
            d_loss_real = adv_criterion(real_output - torch.mean(fake_output), real_label)
            d_loss_fake = adv_criterion(fake_output - torch.mean(real_output), fake_label)
            disc_loss =(d_loss_real + d_loss_fake) / 2 

            # Discriminator backward pass
            disc_loss.backward()

            # Update discriminator params
            disc_optimizer.step()

            print("[EPOCH]: %i, [BATCH]: %i [DISC_LOSS]: %.6f" % (epoch,k, disc_loss.item()))
            
            # Set generator gradients to zero.
            gen_optimizer.zero_grad()

            #  VGG19_loss
            content_loss = vgg_loss_criterion(sr, hr_batch)

            # pixel-wise L1 loss
            pixel_loss = l1_loss_criterion(sr, hr_batch)

            # The accuracy probability of high resolution image and super-resolution image is calculated without calculating high-resolution gradient.
            real_output = disc_model(hr_batch)  # No train real fake image.
            fake_output = disc_model(sr)  # Train fake image.
            
            # Adversarial loss (relativistic average GAN)
            adversarial_loss = adv_criterion(fake_output - torch.mean(real_output), real_label)

            g_loss = content_loss + 0.005* adversarial_loss + 0.01*pixel_loss # in next test change to: content_loss + 0.01*adversarial_loss + 0.1*pixel_loss
        
            print("[EPOCH]: %i, [BATCH]: %i [GEN_LOSS]: %.6f" % (epoch,k, g_loss.item()))
            print("content_loss:",content_loss.item(),"pixel_loss",pixel_loss.item(),"adversarial_loss",adversarial_loss.item())

            # Generator backward pass 
            g_loss.backward()


            # Update generator params
            gen_optimizer.step()

            # write values to tensorboard
            writer.add_scalar('Loss/g_loss', g_loss.item(), 100*epoch + k )
            writer.add_scalar('Loss/d_loss', disc_loss.item(), 100*epoch + k )
            writer.add_scalar('Loss/content_loss', content_loss.item(), 100*epoch + k )
            writer.add_scalar('Loss/pixel_loss', pixel_loss.item(), 100*epoch + k )
            writer.add_scalar('Loss/adversarial_loss', adversarial_loss.item(), 100*epoch + k )

            if not k%10:
                a=(hr_batch).permute(0, 2, 3, 1).cpu().numpy().astype(np.int8)[0]

                display(PIL.Image.fromarray( data['hr_img'].numpy().astype(np.int8)[0],mode="RGB"))

                show_sr=(sr).permute(0, 2, 3, 1).cpu().detach().numpy().astype(np.int8)[0]
                print(show_sr.shape)
                display(PIL.Image.fromarray( show_sr,mode="RGB"))

                # writer.add_images(str(100*epoch + k),(sr).permute(0, 2, 3, 1),dataformats='NHWC')
                
                torch.save(gen_model.state_dict(),GEN_MODEL_PATH)
                torch.save(disc_model.state_dict(),DISC_MODEL_PATH)
                print("saved!")

            gen_scheduler.step()    
            disc_scheduler.step()    
            print("")
        if epoch!=epochs-1:
            clear_output()

In [None]:
# Test image dataloader
test_dataset = DIV2KDataset(dataset_dir=dataset_dir,dataset_length=800,scale=4)
test_loader = DataLoader(test_dataset, batch_size=1,shuffle=True,pin_memory=True)

In [None]:
def test_model():
    test_dataset=next(iter(test_loader))
    test_lr_dataset_tensor=test_dataset['lr_img']
    display(PIL.Image.fromarray(test_lr_dataset_tensor[0].numpy(), mode="RGB"))

    test_lr_dataset_tensor=test_lr_dataset_tensor.permute(0, 3,  1, 2).to(device).float()
    test_hr_dataset_tensor=test_dataset['hr_img']

    sr_imgs=gen_model(test_lr_dataset_tensor.float())
    show_sr_img=(sr_imgs).permute(0, 2, 3, 1).cpu().detach().numpy().astype(np.int8)[0]
    print(show_sr_img.shape)

    display(PIL.Image.fromarray( show_sr_img,mode="RGB"))    
    display(PIL.Image.fromarray(test_hr_dataset_tensor[0].numpy(), mode="RGB"))
    
    return (test_lr_dataset_tensor,test_hr_dataset_tensor,sr_imgs)

In [None]:
psnrs=[]
for i in range(10):
    lr, hr, sr = test_model()
    psnrs.append(cv2.PSNR(sr.permute(0, 2,  3, 1)[0].cpu().detach().numpy().astype(np.int8)[0],hr[0].cpu().detach().numpy().astype(np.int8)[0]))

In [None]:
psnrs

In [None]:
# writer.close()
%load_ext tensorboard
# %reload_ext tensorboard
%tensorboard --logdir "/content/runs"