In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# cd drive/My \Drive/.....

In [None]:
'''
Pranath Reddy
Benchmark Notebook for Superresolution
Model: SRResNet
'''

# Import required libraries
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm.notebook import tqdm
from torch import autograd
from torchvision import models
import torch.utils.model_zoo as model_zoo
import math
from skimage.metrics import structural_similarity as ssim
from sklearn.utils import shuffle
from torch.utils.data import TensorDataset, DataLoader

# Load training data
# High-Resolution lensing data
x_trainHR = np.load('./Data/train_HR.npy').astype(np.float32).reshape(-1,1,150,150) 
# Low-Resolution lensing data
x_trainLR = np.load('./Data/train_LR.npy').astype(np.float32).reshape(-1,1,75,75)
x_trainHR = torch.Tensor(x_trainHR)
x_trainLR = torch.Tensor(x_trainLR)
# Print data dimensions
print(x_trainHR.shape)
print(x_trainLR.shape)

# Create dataset and dataloader for efficient data loading and batching
dataset = TensorDataset(x_trainLR, x_trainHR)
dataloader = DataLoader(dataset, batch_size=8)

class Conv(nn.Module):
  def __init__(self, in_c, out_c, **kwargs):
      super().__init__()
      self.cnn = nn.Conv2d(in_c, out_c, **kwargs)
      self.actication = nn.PReLU(num_parameters=out_c)

  def forward(self, x):
      x = self.cnn(x)
      x = self.actication(x)
      return x

class Upsample(nn.Module):
  def __init__(self, in_c, scale_factor):
      super().__init__()
      self.conv = nn.Conv2d(in_c, in_c * scale_factor ** 2, kernel_size=3, stride=1, padding=1)
      self.pixel_shuffle = nn.PixelShuffle(scale_factor)
      self.activation = nn.PReLU(num_parameters=in_c)

  def forward(self, x):
      x = self.conv(x)
      x = self.pixel_shuffle(x)
      x = self.activation(x)
      return x

class RBlock(nn.Module):
  def __init__(self, in_c):
      super().__init__()
      self.b1 = Conv(in_c, in_c, kernel_size=3, stride=1, padding=1)
      self.b2 = Conv(in_c, in_c, kernel_size=3, stride=1, padding=1)

  def forward(self, x):
      b1_output = self.b1(x)
      b2_output = self.b2(b1_output)
      return b2_output + x

class Generator(nn.Module):
  def __init__(self, in_c=1, out_c=64, no_blocks=18):
      super().__init__()
      self.first_conv = Conv(in_c, out_c, kernel_size=9, stride=1, padding=4)
      res_blocks = []
      for _ in range(no_blocks):
          res_blocks.append(RBlock(64))
      self.res_blocks = nn.Sequential(*res_blocks)
      self.conv1 = Conv(out_c, out_c, kernel_size=3, stride=1, padding=1)

      self.upsampling = Upsample(out_c, scale_factor=2)
      self.last_conv = nn.Conv2d(out_c, in_c, kernel_size=3, stride=1, padding=1)

  def forward(self,x):
      first_conv = self.first_conv(x)
      x = self.res_blocks(first_conv)
      x = self.conv1(x) + first_conv
      x = self.upsampling(x)
      x = self.last_conv(x)
      return x
     
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Generator().to(device)
model = model.to(device)

# Set the loss criterion and optimizer
criteria = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

# Set the number of training epochs and learning rate scheduler
n_epochs = 50
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 2e-4, epochs=n_epochs, steps_per_epoch=x_trainHR.shape[0])

# Training loop
loss_array = []
for epoch in tqdm(range(1, n_epochs+1)):
    train_loss = 0.0
    
    for data in dataloader:

        # Fetch HR, LR data and pass to device
        datalr = data[0]
        datahr = data[1]
        datalr = datalr.to(device)
        datahr = datahr.to(device)

        # Forward pass: compute predicted outputs by passing inputs to the model
        outputs = model(datalr)
        # Calculate the loss
        loss = criteria(outputs, datahr)

        # Reset the gradients
        optimizer.zero_grad()
        # Perform a backward pass (backpropagation)
        loss.backward()
        # Update the parameters
        optimizer.step()
        # Update the learning rate
        scheduler.step()

         # Update the training loss
        train_loss += (loss.item()*datahr.size(0))
        
    # Print average training statistics
    train_loss = train_loss/x_trainHR.shape[0]
    loss_array.append(train_loss)

    # Save model and training loss
    torch.save(model, './Weights/SRResNet.pth')
    np.save('Results/SRResNet_Loss.npy', loss_array)