In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# cd drive/My \Drive/.....

In [None]:
!pip3 install e2cnn

In [None]:
'''
Pranath Reddy
Benchmark Notebook for Superresolution
Model: Equivariant FSRCNN
'''

# Import required libraries
import torch
import numpy as np
#import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm.notebook import tqdm
from torch import autograd
from torchvision import models
import torch.utils.model_zoo as model_zoo
import math
from skimage.metrics import structural_similarity as ssim
from sklearn.utils import shuffle
from torch.utils.data import TensorDataset, DataLoader
import torch
from e2cnn import gspaces
from e2cnn import nn

# Load training data
# High-Resolution lensing data
x_trainHR = np.load('./Data/train_HR.npy').astype(np.float32).reshape(-1,1,150,150) 
# Low-Resolution lensing data
x_trainLR = np.load('./Data/train_LR.npy').astype(np.float32).reshape(-1,1,75,75)
x_trainHR = torch.Tensor(x_trainHR)
x_trainLR = torch.Tensor(x_trainLR)
# Print data dimensions
print(x_trainHR.shape)
print(x_trainLR.shape)

# Create dataset and dataloader for efficient data loading and batching
dataset = TensorDataset(x_trainLR, x_trainHR)
dataloader = DataLoader(dataset, batch_size=8)

class Equivariant_FSRCNN(torch.nn.Module):
    
    def __init__(self, sym_group = "Dihyderal", N = 2, scale_factor=2, num_channels=1, d=16, s=64, m=4):
        
        super(Equivariant_FSRCNN, self).__init__()
        
        if sym_group == 'Dihyderal':
            self.r2_act = gspaces.FlipRot2dOnR2(N=N)
        elif sym_group == 'Circular':
            self.r2_act = gspaces.Rot2dOnR2(N=N)
            
        in_type = nn.FieldType(self.r2_act, num_channels*[self.r2_act.trivial_repr])
        self.input_type = in_type

        out_type = nn.FieldType(self.r2_act, d*[self.r2_act.regular_repr])
        self.first_part = nn.SequentialModule(
            nn.MaskModule(in_type, 75, margin=1),
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=5//2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        mid_part = []
        in_type = self.first_part.out_type
        out_type = nn.FieldType(self.r2_act, s*[self.r2_act.regular_repr])
        mid_part.extend([
            nn.R2Conv(in_type, out_type, kernel_size=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        ])
        for _ in range(m):
            in_type = out_type
            out_type = nn.FieldType(self.r2_act, s*[self.r2_act.regular_repr])
            mid_part.extend([
                nn.R2Conv(in_type, out_type, kernel_size=3, padding=3//2, bias=False),
                nn.InnerBatchNorm(out_type),
                nn.ReLU(out_type, inplace=True)
            ])
        self.mid_part = nn.SequentialModule(*mid_part)

        in_type = self.mid_part.out_type
        out_type = nn.FieldType(self.r2_act, d*[self.r2_act.regular_repr])
        self.last_part = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True),
            nn.R2ConvTransposed(out_type, out_type, kernel_size=4, stride=scale_factor)
        )

        in_type = self.last_part.out_type
        out_type = nn.FieldType(self.r2_act, 1*[self.r2_act.trivial_repr])
        self.final_layer = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=0, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
    def forward(self, input: torch.Tensor):
        x = nn.GeometricTensor(input, self.input_type)
        x = self.first_part(x)
        x = self.mid_part(x)
        x = self.last_part(x)
        x = self.final_layer(x)
        return x.tensor
       
device = torch.device("cuda")
model = Equivariant_FSRCNN().to(device)

# Set the loss criterion and optimizer
criteria = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

# Set the number of training epochs and learning rate scheduler
n_epochs = 50
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 2e-4, epochs=n_epochs, steps_per_epoch=x_trainHR.shape[0])

# Training loop
loss_array = []
for epoch in tqdm(range(1, n_epochs+1)):
    train_loss = 0.0
    
    for data in dataloader:

        # Fetch HR, LR data and pass to device
        datalr = data[0]
        datahr = data[1]
        datalr = datalr.to(device)
        datahr = datahr.to(device)

        # Forward pass: compute predicted outputs by passing inputs to the model
        outputs = model(datalr)
        # Calculate the loss
        loss = criteria(outputs, datahr)

        # Reset the gradients
        optimizer.zero_grad()
        # Perform a backward pass (backpropagation)
        loss.backward()
        # Update the parameters
        optimizer.step()
        # Update the learning rate
        scheduler.step()

         # Update the training loss
        train_loss += (loss.item()*datahr.size(0))
        
    # Print average training statistics
    train_loss = train_loss/x_trainHR.shape[0]
    loss_array.append(train_loss)

    # Save model and training loss
    torch.save(model.state_dict(), './Weights/EFSRCNN.pth')
    np.save('Results/EFSRCNN_Loss.npy', loss_array)