# Task VI (b)
 Take the pre-trained model from Task VI.A and fine-tune it for a super-resolution task. The model should be fine-tuned to upscale low-resolution strong lensing images using the provided high-resolution samples as ground truths. Please implement your approach in PyTorch or Keras and discuss your strategy.


## Breakdown to approach:
* Using Smart Upsampling

* Transpose Convolutions & PixelShuffle help scale up the image while preserving fine details.

* A Tanh activation at the end smooths the output, making high-resolution (HR) images look more natural.

* Extracting Meaningful Features from the Pre-trained Model

* The MAE model has already learned useful patterns from lensing images.

* Instead of retraining everything, we freeze early layers and fine-tune the important ones to adapt to super-resolution.

* Better Loss Function & Optimization for Sharper Images

* L1 Loss instead of MSE: Helps keep fine details sharp instead of making images blurry.

* AdamW Optimizer: Helps stabilize training and avoids overfitting, which is useful for scientific images.




In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import os
from PIL import Image


Model loading and prep

In [None]:
# Load the pre-trained MAE classifier model
model_path = "mae_classifier.pth"
state_dict = torch.load(model_path, map_location=torch.device('cpu'))

# Remove 'mae.' prefix if present
new_state_dict = {k.replace("mae.", ""): v for k, v in state_dict.items()}

# Remove classifier layer keys from state_dict
new_state_dict = {k: v for k, v in new_state_dict.items() if not k.startswith("fc.")}

pretrained_model = resnet18(pretrained=False)
pretrained_model.load_state_dict(new_state_dict, strict=False)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Modify the model for super-resolution using MAE
class SuperResolutionMAE(nn.Module):
    def __init__(self, base_model):
        super(SuperResolutionMAE, self).__init__()

        # Extract feature encoder from MAE (usually the transformer backbone)
        self.feature_extractor = base_model.encoder  # Adjust based on your MAE model structure

        # Upsampling module for super-resolution
        self.upsample = nn.Sequential(
            nn.Conv2d(768, 512, kernel_size=3, stride=1, padding=1),  # Adjust input channels based on MAE
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.feature_extractor(x)  # Extract features using MAE encoder
        x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False)  # Upsampling
        x = self.upsample(x)  # Refinement with CNN layers
        return x


# Instantiate the model
model = SuperResolutionMAE(pretrained_model)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999))

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import os

In [None]:
import torch

# Load pre-trained MAE classifier model
model_path = "mae_classifier.pth"
state_dict = torch.load(model_path, map_location=torch.device('cpu'))

# Remove 'mae.' prefix if present
new_state_dict = {k.replace("mae.", ""): v for k, v in state_dict.items()}

# Remove classifier layer keys if they exist
new_state_dict = {k: v for k, v in new_state_dict.items() if not k.startswith("classifier.")}

# Load modified MAE model
from mae_model import MaskedAutoencoder  # Import your MAE model class

pretrained_model = MaskedAutoencoder()  # Initialize your MAE model
pretrained_model.load_state_dict(new_state_dict, strict=False)


In [None]:
import torch
import torch.nn as nn

# Super-Resolution Model using Transformer-based MAE Features
class SuperResolutionMAE(nn.Module):
    def __init__(self, base_model):
        super(SuperResolutionMAE, self).__init__()

        # Extract the MAE encoder (transformer backbone)
        self.feature_extractor = base_model.encoder  # Adjust based on your MAE model structure

        # Upsampling module using Transpose Convolutions and PixelShuffle
        self.upsample = nn.Sequential(
            nn.Conv2d(768, 512, kernel_size=3, stride=1, padding=1),  # Adjust input channels based on MAE
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 4, kernel_size=3, stride=1, padding=1),  # 4 channels for PixelShuffle
            nn.PixelShuffle(2),  # Upscales by 2x
            nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1),  # Final refinement layer
            nn.Tanh()
        )

    def forward(self, x):
        x = self.feature_extractor(x)  # Extract features using MAE encoder
        x = x.permute(0, 2, 1).view(x.shape[0], 768, 14, 14)  # Reshape to (B, C, H, W) if needed
        x = self.upsample(x)  # Apply upsampling
        return x


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Initialize MAE-based Super-Resolution Model
model = SuperResolutionMAE(pretrained_model)

# Freeze early layers (Transformer Encoder)
for param in list(model.feature_extractor.parameters())[:len(list(model.feature_extractor.parameters())) // 2]:
    param.requires_grad = False

# Define L1 Loss and AdamW optimizer
criterion = nn.L1Loss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

# Custom Dataset Class
class SuperResolutionDataset(Dataset):
    def __init__(self, dataset_path):
        self.lr_path = os.path.join(dataset_path, "LR")
        self.hr_path = os.path.join(dataset_path, "HR")
        self.lr_files = sorted(os.listdir(self.lr_path))
        self.hr_files = sorted(os.listdir(self.hr_path))

    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        lr_image = np.load(os.path.join(self.lr_path, self.lr_files[idx])).astype(np.float32)
        hr_image = np.load(os.path.join(self.hr_path, self.hr_files[idx])).astype(np.float32)

        # Ensure correct dimensions
        if len(lr_image.shape) == 2: lr_image = np.expand_dims(lr_image, axis=0)  # Add channel dim
        if len(hr_image.shape) == 2: hr_image = np.expand_dims(hr_image, axis=0)

        lr_image = torch.tensor(lr_image, dtype=torch.float32)
        hr_image = torch.tensor(hr_image, dtype=torch.float32)

        return lr_image, hr_image

# Load Dataset
dataset_path = "Dataset2/Dataset"
dataset = SuperResolutionDataset(dataset_path)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Training Loop
def train_model(model, train_loader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0
        for lr_images, hr_images in train_loader:
            lr_images, hr_images = lr_images.to(device), hr_images.to(device)
            optimizer.zero_grad()
            outputs = model(lr_images)

            # Ensure output size matches target size
            outputs = F.interpolate(outputs, size=hr_images.shape[2:], mode='bilinear', align_corners=False)

            loss = criterion(outputs, hr_images)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader):.6f}")

# Train the model
train_model(model, train_loader, criterion, optimizer, num_epochs=20)


Epoch [1/20], Loss: 0.040184
Epoch [2/20], Loss: 0.025419
Epoch [3/20], Loss: 0.022257
Epoch [4/20], Loss: 0.019837
Epoch [5/20], Loss: 0.018492
Epoch [6/20], Loss: 0.018193
Epoch [7/20], Loss: 0.017289
Epoch [8/20], Loss: 0.016779
Epoch [9/20], Loss: 0.016173
Epoch [10/20], Loss: 0.016080
Epoch [11/20], Loss: 0.015976
Epoch [12/20], Loss: 0.015278
Epoch [13/20], Loss: 0.015062
Epoch [14/20], Loss: 0.014190
Epoch [15/20], Loss: 0.013255
Epoch [16/20], Loss: 0.012234
Epoch [17/20], Loss: 0.011827
Epoch [18/20], Loss: 0.011456
Epoch [19/20], Loss: 0.011448
Epoch [20/20], Loss: 0.011247


In [None]:
# Evaluation Function
def evaluate_model(model, test_loader):
    model.eval()
    avg_psnr, avg_ssim, avg_mse = 0, 0, 0
    with torch.no_grad():
        for lr_images, hr_images in test_loader:
            lr_images, hr_images = lr_images.to(device), hr_images.to(device)
            outputs = model(lr_images)
            outputs = F.interpolate(outputs, size=hr_images.shape[2:], mode='bilinear', align_corners=False)

            outputs = outputs.cpu().numpy()
            hr_images = hr_images.cpu().numpy()

            for i in range(len(outputs)):
                avg_psnr += psnr(hr_images[i, 0], outputs[i, 0], data_range=1.0)
                avg_ssim += ssim(hr_images[i, 0], outputs[i, 0], data_range=1.0)
                avg_mse += np.mean((hr_images[i, 0] - outputs[i, 0]) ** 2)

    avg_psnr /= len(test_loader.dataset)
    avg_ssim /= len(test_loader.dataset)
    avg_mse /= len(test_loader.dataset)
    print(f"Average PSNR: {avg_psnr:.2f} dB, Average SSIM: {avg_ssim:.4f}, Average MSE: {avg_mse:.6f}")

# Evaluate the model
evaluate_model(model, test_loader)


Average PSNR: 31.47 dB, Average SSIM: 0.9310, Average MSE: 0.000736
