In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# cd drive/My \Drive/.....

In [None]:
'''
Pranath Reddy
Benchmark Notebook for Superresolution
Model: EDSR
'''

# Import required libraries
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm.notebook import tqdm
from torch import autograd
from torchvision import models
import torch.utils.model_zoo as model_zoo
import math
from skimage.metrics import structural_similarity as ssim
from sklearn.utils import shuffle
from torch.utils.data import TensorDataset, DataLoader

# Load training data
# High-Resolution lensing data
x_trainHR = np.load('./Data/train_HR.npy').astype(np.float32).reshape(-1,1,150,150) 
# Low-Resolution lensing data
x_trainLR = np.load('./Data/train_LR.npy').astype(np.float32).reshape(-1,1,75,75)
x_trainHR = torch.Tensor(x_trainHR)
x_trainLR = torch.Tensor(x_trainLR)
# Print data dimensions
print(x_trainHR.shape)
print(x_trainLR.shape)

# Create dataset and dataloader for efficient data loading and batching
dataset = TensorDataset(x_trainLR, x_trainHR)
dataloader = DataLoader(dataset, batch_size=8)

# Define ResidualBlock class for EDSR
class ResidualBlock(nn.Module):
    def __init__(self, num_features):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_features, num_features, kernel_size=3, padding=3 // 2)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_features, num_features, kernel_size=3, padding=3 // 2)

    def forward(self, x):
        residual = x
        x = self.relu(self.conv1(x))
        x = self.conv2(x)
        x += residual
        return x

# Define EDSR (Enhanced Deep Super-Resolution) class
class EDSR(nn.Module):
    def __init__(self, scale_factor=2, num_channels=1, num_features=64, num_blocks=16):
        super(EDSR, self).__init__()
        
        self.input = nn.Conv2d(num_channels, num_features, kernel_size=3, padding=3 // 2)
        
        self.residual_blocks = nn.Sequential(*[ResidualBlock(num_features) for _ in range(num_blocks)])

        self.output_conv = nn.Conv2d(num_features, num_features, kernel_size=3, padding=3 // 2)
        
        self.upscale = nn.Sequential(
            nn.Conv2d(num_features, num_features * (scale_factor ** 2), kernel_size=3, padding=3 // 2),
            nn.PixelShuffle(scale_factor),
        )
        
        self.final = nn.Conv2d(num_features, num_channels, kernel_size=3, padding=3 // 2)

    def forward(self, x):
        initial = self.input(x)
        
        residual = self.residual_blocks(initial)
        residual += initial
        
        upscaled = self.upscale(residual)
        output = self.final(upscaled)
        
        return output

# Set the device to use for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Pass the model to the device
model = EDSR().to(device)

# Set the loss criterion and optimizer
criteria = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

# Set the number of training epochs and learning rate scheduler
n_epochs = 30
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 2e-4, epochs=n_epochs, steps_per_epoch=x_trainHR.shape[0])

# Training loop
loss_array = []
for epoch in tqdm(range(1, n_epochs+1)):
    train_loss = 0.0
    
    for data in dataloader:

        # Fetch HR, LR data and pass to device
        datalr = data[0]
        datahr = data[1]
        datalr = datalr.to(device)
        datahr = datahr.to(device)

        # Forward pass: compute predicted outputs by passing inputs to the model
        outputs = model(datalr)
        # Calculate the loss
        loss = criteria(outputs, datahr)

        # Reset the gradients
        optimizer.zero_grad()
        # Perform a backward pass (backpropagation)
        loss.backward()
        # Update the parameters
        optimizer.step()
        # Update the learning rate
        #scheduler.step()

         # Update the training loss
        train_loss += (loss.item()*datahr.size(0))
        
    # Print average training statistics
    train_loss = train_loss/x_trainHR.shape[0]
    loss_array.append(train_loss)

    # Save model and training loss
    torch.save(model, './Weights/EDSR.pth')
    np.save('Results/EDSR_Loss.npy', loss_array)