In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# cd drive/My \Drive/.....

In [None]:
'''
Pranath Reddy
Benchmark Notebook for Superresolution
Model: RCAN with VGG content loss
'''

# Import required libraries
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from tqdm.notebook import tqdm
from torch import autograd
from torchvision import models
import torch.utils.model_zoo as model_zoo
import math
from skimage.metrics import structural_similarity as ssim
from sklearn.utils import shuffle
from torch.utils.data import TensorDataset, DataLoader

# Load training data
x_trainHR = np.load('./Data/train_HR.npy').astype(np.float32).reshape(-1,1,150,150) 
x_trainLR = np.load('./Data/train_LR.npy').astype(np.float32).reshape(-1,1,75,75)
x_trainHR = torch.Tensor(x_trainHR)
x_trainLR = torch.Tensor(x_trainLR)
print(x_trainHR.shape)
print(x_trainLR.shape)

dataset = TensorDataset(x_trainLR, x_trainHR)
dataloader = DataLoader(dataset, batch_size=8)

# Channel Attention (CA) Layer
class CALayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.ca = nn.Sequential(
                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.ca(y)
        return x * y

# Residual Channel Attention Block (RCAB)
class RCAB(nn.Module):
    def __init__(self, channel, reduction=16):
        super(RCAB, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(channel, channel, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel, channel, 3, 1, 1),
            CALayer(channel, reduction)
        )

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

# Residual Group (RG)
class RG(nn.Module):
    def __init__(self, channel, num_rcab, reduction):
        super(RG, self).__init__()
        modules_body = [RCAB(channel, reduction) for _ in range(num_rcab)]
        modules_body.append(nn.Conv2d(channel, channel, 3, padding=1))
        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

# Residual Channel Attention Network (RCAN)
class RCAN(nn.Module):
    def __init__(self, num_rg=2, num_rcab=4, num_channels=1, num_features=64, scale_factor=2, reduction=16):
        super(RCAN, self).__init__()
        self.sf = scale_factor
        self.num_features = num_features

        self.head = nn.Conv2d(num_channels, num_features, 3, padding=1)

        self.body = nn.Sequential(*[RG(num_features, num_rcab, reduction) for _ in range(num_rg)])

        self.tail = nn.Sequential(
            nn.Conv2d(num_features, num_features, 3, padding=1),
            nn.Conv2d(num_features, num_channels * (scale_factor ** 2), 3, padding=1),
            nn.PixelShuffle(scale_factor)
        )

    def forward(self, x):
        x = self.head(x)
        res = self.body(x)
        res += x
        x = self.tail(res)
        return x

device = torch.device("cuda")
model = RCAN().to(device)

# Load the pre-trained VGG model and extract the features layer
vgg = models.vgg19(pretrained=True).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad = False

class CombinedLoss(nn.Module):
    def __init__(self, vgg, alpha=0.5):
        super(CombinedLoss, self).__init__()
        self.vgg = vgg
        self.mse_loss = nn.MSELoss()
        self.alpha = alpha

    def forward(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        content_loss = self.mse_loss(x_vgg, y_vgg)
        mse_loss = self.mse_loss(x, y)
        loss = self.alpha * content_loss + (1 - self.alpha) * mse_loss
        return loss

#criteria = VGGLoss(vgg)
criteria = CombinedLoss(vgg)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

n_epochs = 100

loss_array = []
for epoch in tqdm(range(1, n_epochs+1)):
    train_loss = 0.0
    for data in dataloader:
        datalr = data[0]
        datahr = data[1]
        datalr = datalr.to(device)
        datahr = datahr.to(device)
        datalr_rgb = datalr.repeat(1,3,1,1)
        datahr_rgb = datahr.repeat(1,3,1,1)

        outputs = model(datalr)
        outputs_rgb = outputs.repeat(1,3,1,1)
        loss = criteria(outputs_rgb, datahr_rgb)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += (loss.item()*datahr.size(0))

    train_loss = train_loss/x_trainHR.shape[0]
    loss_array.append(train_loss)

    torch.save(model, './Weights/RCAN_VGG.pth')