# Strong Lensing Challenge - Image Super-Resolution

Gravitational lensing has been a cornerstone in many cosmology experiments and studies since it was discussed in Einstein’s calculations back in 1936 and discovered in 1979, and one area of particular interest is the study of dark matter via substructure in strong lensing images. In this challenge, we focus on exploring the potential of ML models in enhancing the resolution of lensing images.

This is an example notebook for the Image Super-Resolution Challenge. In this notebook, we demonstrate a simple CNN model implemented using the PyTorch library to solve the task of super-resolution of strong lensing images.

### Dataset

The Dataset consists of HR and LR pairs. The images have been normalized using min-max normalization, but you are free to use any normalization or data augmentation methods to improve your results.

Link to the Dataset: https://drive.google.com/file/d/1lUOGo2B0Rhxwj_TGZSVEdZJ79GdI7awa/view?usp=sharing

### Evaluation Metrics

* MSE, SSIM and PSNR   

The model performance will be tested on the hidden test dataset based on the above metrics.

### Instructions for using the notebook

1. Use GPU acceleration: (Edit --> Notebook settings --> Hardware accelerator --> GPU)
2. Run the cells: (Runtime --> Run all)

In [None]:
# Check if the dataset folder is missing
import os
if not os.path.exists('./dataset'):
    # Download and extract the dataset
    !gdown http://drive.google.com/uc?id=1lUOGo2B0Rhxwj_TGZSVEdZJ79GdI7awa
    !unzip -q dataset.zip

## Multi-Class Classification using a Supervised Model

### 1. Data Visualization and Preprocessing

#### 1.1 Import all the necessary libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from tqdm.notebook import tqdm
import torch.utils.data as data
from scipy import interp
from itertools import cycle
import torch.nn.functional as F
from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr
%matplotlib inline

#### 1.2 Preview the Data

In [None]:
# Define the input paths
train_hr_path = './dataset/train/HR'
train_hr_files = [os.path.join(train_hr_path, f) for f in os.listdir(train_hr_path) if f.endswith(".npy")]
train_lr_path = './dataset/train/LR'
train_lr_files = [os.path.join(train_lr_path, f) for f in os.listdir(train_lr_path) if f.endswith(".npy")]

# Number of samples to display
n = 5

# Plot the samples
i = 1
print('High-Resolution (HR) samples: ')
plt.rcParams['figure.figsize'] = [14, 14]
for image in train_hr_files[:n]:
    ax = plt.subplot(2, n, i)
    plt.imshow(np.load(image).reshape(128,128), cmap='gray')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    i += 1
plt.show()

print('Low-Resolution (LR) samples: ')
plt.rcParams['figure.figsize'] = [14, 14]
for image in train_lr_files[:n]:
    ax = plt.subplot(2, n, i)
    plt.imshow(np.load(image).reshape(64,64), cmap='gray')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    i += 1
plt.show()

#### 1.3 Import Training and Validation Data

In [None]:
# Set Batch Size
batch_size = 100

# Define Data Loaders
class SuperResolutionDataset(data.Dataset):
    def __init__(self, lr_path, hr_path):
        self.lr_files = [os.path.join(lr_path, f) for f in os.listdir(lr_path) if f.endswith(".npy")]
        self.hr_files = [os.path.join(hr_path, f) for f in os.listdir(hr_path) if f.endswith(".npy")]
        
    def __len__(self):
        return len(self.lr_files)
    
    def __getitem__(self, idx):
        lr_image = np.load(self.lr_files[idx])
        hr_image = np.load(self.hr_files[idx])
        return torch.from_numpy(lr_image).float(), torch.from_numpy(hr_image).float()

train_data = SuperResolutionDataset('./dataset/train/LR', './dataset/train/HR')
train_data_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
val_data = SuperResolutionDataset('./dataset/val/LR', './dataset/val/HR')
val_data_loader = data.DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=4)

### 2. Training

#### 2.1 Defining a Super-Resolution CNN Model

You may refer to this [article](https://medium.com/@RaghavPrabhu/understanding-of-convolutional-neural-network-cnn-deep-learning-99760835f148) to learn about Convolutional Neural Networks (CNN)

In [None]:
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, padding=2)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=5, padding=2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        x = F.interpolate(x, size=(128, 128), mode='bicubic', align_corners=False)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SRCNN().to(device)

#### 2.2 Training the Super-Resolution CNN Model

In [None]:
# Loss Function
criteria = nn.MSELoss()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

# Train the model
n_epochs = 20 # Number of Training Epochs
loss_array = []
pbar = tqdm(range(1, n_epochs+1))
for epoch in pbar:
    train_loss = 0.0

    for step, (lr, hr) in enumerate(train_data_loader):

        lr = Variable(lr).type(torch.FloatTensor).to(device)
        hr = Variable(hr).type(torch.FloatTensor).to(device)
        optimizer.zero_grad()
        outputs = model(lr)
        loss = criteria(outputs, hr)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss = train_loss / len(train_data_loader)
    # Display the Training Stats
    pbar.set_postfix({ 'Training Loss': train_loss })

### 3. Testing

#### 3.1 Testing the Super-Resolution CNN Model on Validation Data - Calculate Quantitative Metrics

In [None]:
# Calculate Metrics

# Get predications
model.eval()
out = []
with torch.no_grad():
    for lr, hr in val_data_loader:
        lr = lr.to(device)
        hr = hr.to(device)
        recon = model(lr)
        out.append(recon.cpu().detach().numpy())
        del lr, hr, recon
        torch.cuda.empty_cache()
dataSR = np.concatenate(out, axis=0)

# Prepare ground truth for comparison
val_hr = []
for _, hr in val_data_loader:
    val_hr.append(hr.cpu().numpy())
val_hr = np.concatenate(val_hr, axis=0)

# Calculate metrics
print("Metrics:")
criteria = nn.MSELoss()
criteria2 = nn.L1Loss()
losses = []
losses2 = []
Ssim = []
Psnr = []

for i in range(dataSR.shape[0]):
    losses.append(criteria(torch.from_numpy(dataSR[i]), torch.from_numpy(val_hr[i])))
    losses2.append(criteria2(torch.from_numpy(dataSR[i]), torch.from_numpy(val_hr[i])))
    Ssim.append(ssim(val_hr[i][0], dataSR[i][0], data_range=dataSR[i][0].max() - dataSR[i][0].min()))
    Psnr.append(psnr(val_hr[i][0], dataSR[i][0], data_range=dataSR[i][0].max() - dataSR[i][0].min()))

print("Average MSE super resolution samples: " + str('%.7f'%np.average(losses)))
print("Average L1 super resolution samples: " + str('%.7f'%np.average(losses2)))
print("Average SSIM super resolution samples: " + str('%.5f'%np.average(Ssim)))
print("Average PSNR super resolution samples: " + str('%.5f'%np.average(Psnr)))

#### 3.2 Visualize Outputs for Qualitative Analysis

In [None]:
# Visualize Outputs
with torch.no_grad():
    for lr, hr in val_data_loader:
        lr = lr.to(device)
        hr = hr.to(device)
        output = model(lr)

        lr = lr.cpu().numpy()
        output = output.cpu().numpy()
        hr = hr.cpu().numpy()

        # Display the results
        plt.figure(figsize=(12, 8))
        for i in range(5):
            plt.subplot(3, 5, i + 1)
            plt.imshow(lr[i].reshape(64, 64), cmap='gray')
            plt.title('Low Res')
            plt.axis('off')
            plt.subplot(3, 5, i + 6)
            plt.imshow(hr[i].reshape(128, 128), cmap='gray')
            plt.title('High Res')
            plt.axis('off')
            plt.subplot(3, 5, i + 11)
            plt.imshow(output[i].reshape(128, 128), cmap='gray')
            plt.title('Output')
            plt.axis('off')
        plt.show()
        break

## Submission Guidelines

* You are required to submit a Google Colab Jupyter Notebook clearly showing your implementation along with the evaluation metrics (ROC curve, and AUC score) for the validation data.
* You must also submit the final trained model, including the model architecture and the trained weights ( For example: HDF5 file, .pb file, .pt file, etc. )
* You can use this example notebook as a template for your work.

> **_NOTE:_**  You are free to use any ML framework such as PyTorch, Keras, TensorFlow, etc.