In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# cd drive/My \Drive/.....

In [None]:
'''
Pranath Reddy
Benchmark Notebook for Superresolution
Model: FSRCNN
'''

# Import required libraries
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm.notebook import tqdm
from torch import autograd
from torchvision import models
import torch.utils.model_zoo as model_zoo
import math
from skimage.metrics import structural_similarity as ssim
from sklearn.utils import shuffle
from torch.utils.data import TensorDataset, DataLoader

# Load training data
# High-Resolution lensing data
x_trainHR = np.load('./Data/train_HR.npy').astype(np.float32).reshape(-1,1,150,150) 
# Low-Resolution lensing data
x_trainLR = np.load('./Data/train_LR.npy').astype(np.float32).reshape(-1,1,75,75)
x_trainHR = torch.Tensor(x_trainHR)
x_trainLR = torch.Tensor(x_trainLR)
# Print data dimensions
print(x_trainHR.shape)
print(x_trainLR.shape)

# Create dataset and dataloader for efficient data loading and batching
dataset = TensorDataset(x_trainLR, x_trainHR)
dataloader = DataLoader(dataset, batch_size=8)

class FSRCNN(nn.Module):
    def __init__(self, scale_factor=2, num_channels=1, d=64, s=12, m=4):
        super(FSRCNN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2),
            nn.PReLU(d)
        )
        self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)]
        for _ in range(m):
            self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3//2), nn.PReLU(s)])
        self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)])
        self.mid_part = nn.Sequential(*self.mid_part)
        self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2,
                                            output_padding=scale_factor-1)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.first_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        for m in self.mid_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)
        nn.init.zeros_(self.last_part.bias.data)

    def forward(self, x):
        x1 = self.first_part(x)
        x2 = self.mid_part(x1)
        x3 = self.last_part(x2)
        return x3
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FSRCNN(scale_factor=2).to(device)

# Set the loss criterion and optimizer
criteria = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

# Set the number of training epochs and learning rate scheduler
n_epochs = 100
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 2e-4, epochs=n_epochs, steps_per_epoch=x_trainHR.shape[0])

# Training loop
loss_array = []
for epoch in tqdm(range(1, n_epochs+1)):
    train_loss = 0.0
    
    for data in dataloader:

        # Fetch HR, LR data and pass to device
        datalr = data[0]
        datahr = data[1]
        datalr = datalr.to(device)
        datahr = datahr.to(device)

        # Forward pass: compute predicted outputs by passing inputs to the model
        outputs = model(datalr)
        # Calculate the loss
        loss = criteria(outputs, datahr)

        # Reset the gradients
        optimizer.zero_grad()
        # Perform a backward pass (backpropagation)
        loss.backward()
        # Update the parameters
        optimizer.step()
        # Update the learning rate
        scheduler.step()

         # Update the training loss
        train_loss += (loss.item()*datahr.size(0))
        
    # Print average training statistics
    train_loss = train_loss/x_trainHR.shape[0]
    loss_array.append(train_loss)

    # Save model and training loss
    torch.save(model, './Weights/FSRCNN.pth')
    np.save('Results/FSRCNN_Loss.npy', loss_array)