In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# cd drive/My \Drive/.....

In [None]:
'''
Pranath Reddy
Benchmark Notebook for Superresolution
Model: RCAN
'''

# Import required libraries
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from tqdm.notebook import tqdm
from torch import autograd
from torchvision import models
import torch.utils.model_zoo as model_zoo
import math
from skimage.metrics import structural_similarity as ssim
from sklearn.utils import shuffle
from torch.utils.data import TensorDataset, DataLoader

# Load training data
x_trainHR = np.load('./Data/train_HR.npy').astype(np.float32).reshape(-1,1,150,150) 
x_trainLR = np.load('./Data/train_LR.npy').astype(np.float32).reshape(-1,1,75,75)
x_trainHR = torch.Tensor(x_trainHR)
x_trainLR = torch.Tensor(x_trainLR)
print(x_trainHR.shape)
print(x_trainLR.shape)

dataset = TensorDataset(x_trainLR, x_trainHR)
dataloader = DataLoader(dataset, batch_size=8)

# Channel Attention (CA) Layer
class CALayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.ca = nn.Sequential(
                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.ca(y)
        return x * y

# Residual Channel Attention Block (RCAB)
class RCAB(nn.Module):
    def __init__(self, channel, reduction=16):
        super(RCAB, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(channel, channel, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel, channel, 3, 1, 1),
            CALayer(channel, reduction)
        )

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

# Residual Group (RG)
class RG(nn.Module):
    def __init__(self, channel, num_rcab, reduction):
        super(RG, self).__init__()
        modules_body = [RCAB(channel, reduction) for _ in range(num_rcab)]
        modules_body.append(nn.Conv2d(channel, channel, 3, padding=1))
        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

# Residual Channel Attention Network (RCAN)
class RCAN(nn.Module):
    def __init__(self, num_rg=2, num_rcab=4, num_channels=1, num_features=64, scale_factor=2, reduction=16):
        super(RCAN, self).__init__()
        self.sf = scale_factor
        self.num_features = num_features

        self.head = nn.Conv2d(num_channels, num_features, 3, padding=1)

        self.body = nn.Sequential(*[RG(num_features, num_rcab, reduction) for _ in range(num_rg)])

        self.tail = nn.Sequential(
            nn.Conv2d(num_features, num_features, 3, padding=1),
            nn.Conv2d(num_features, num_channels * (scale_factor ** 2), 3, padding=1),
            nn.PixelShuffle(scale_factor)
        )

    def forward(self, x):
        x = self.head(x)
        res = self.body(x)
        res += x
        x = self.tail(res)
        return x
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Flatten(),
            nn.Linear(36, 1),  
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

device = torch.device("mps")
model = RCAN().to(device)
discriminator = Discriminator().to(device)

# Define the optimizers
optimizer_G = torch.optim.Adam(model.parameters(), lr=2e-4)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=2e-4)

# Define the loss functions
criterion_GAN = torch.nn.BCELoss().to(device)
criterion_content = torch.nn.L1Loss().to(device)

# Define the real and fake labels
real_label = 1.
fake_label = 0.

n_epochs = 30

# Training
for epoch in tqdm(range(1, n_epochs+1)):
    for i, data in enumerate(dataloader):
        
        # Get the LR and HR images
        datalr = data[0].to(device)
        datahr = data[1].to(device)
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        discriminator.zero_grad()
        # Format batch
        label = torch.full((datalr.size(0),), real_label, device=device)
        # Forward pass real batch through D
        output = discriminator(datahr).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion_GAN(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate fake image batch with G
        fake = model(datalr)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = discriminator(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion_GAN(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizer_D.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        model.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = discriminator(fake).view(-1)
        # Calculate G's loss based on this output
        errG_GAN = criterion_GAN(output, label)
        # Calculate G's loss based on content
        errG_content = criterion_content(fake, datahr)
        # Calculate gradients for G
        errG = errG_GAN + errG_content
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizer_G.step()

        '''
        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, n_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        '''
        
    # Save the models
    torch.save(model, './Weights/RCAN_GAN.pth')
    torch.save(discriminator, './Weights/RCAN_GAN_Discriminator.pth')