In [14]:
# Custom Dataset
import torchvision.transforms as transforms
import cv2
from torch.utils.data import Dataset
class SuperResDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, transform=None):
        self.lr_file_list = [os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.endswith('.npy')]
        self.hr_file_list = [os.path.join("/".join(f.split('/')[:-2]),"HR",f.split('/')[-1]) for f in self.lr_file_list]

        # print(self.lr_file_list[0],self.hr_file_list[0])
        self.lr_transform = transforms.Compose([
            # transforms.ToTensor(),  # Convert to Tensor (C, H, W)
            transforms.Resize((64, 64), interpolation=transforms.InterpolationMode.BICUBIC),  # Resize LR images
        ])

        self.hr_transform = transforms.Compose([
            # transforms.ToTensor(),  # Convert HR images to Tensor (C, H, W)
        ])

    def __len__(self):
        return len(self.lr_file_list)

    def __getitem__(self, idx):
        lr = np.load(self.lr_file_list[idx]).astype(np.float32)  # (75,75)
        hr = np.load(self.hr_file_list[idx]).astype(np.float32)  # (150,150)

        # Ensure shape (1, H, W) for grayscale images
        if len(lr.shape) == 2:
            lr = np.expand_dims(lr, axis=0)  # (1, 75, 75)
        if len(hr.shape) == 2:
            hr = np.expand_dims(hr, axis=0)  # (1, 150, 150)''

        # Apply transforms
        lr = self.lr_transform(torch.from_numpy(lr))
        hr = self.hr_transform(torch.from_numpy(hr))

        # Normalize images to [-1,1]
        lr = (lr - lr.min()) / (lr.max() - lr.min())  # Normalize
        hr = (hr - hr.min()) / (hr.max() - hr.min())  # Normalize

       

        return lr, hr

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
from torch.utils.data import random_split, DataLoader


# Define transformation
transform = transforms.Compose([
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load dataset
dataset = SuperResDataset("/kaggle/input/foundational-model-task-ml4sci/task4-b/Dataset/LR","/kaggle/input/foundational-model-task-ml4sci/task4-b/Dataset/HR", transform=transform)


# Train/Validation Split (90:10)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

class CustomMAE(nn.Module):
    def __init__(self, img_size=64, patch_size=8, embed_dim=512, num_heads=8, depth=6):
        super(CustomMAE, self).__init__()

        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.embed_dim = embed_dim

        # Patch embedding (Linear Projection + Positional Encoding)
        self.patch_embed = nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))

        # Transformer Encoder (Inspired by ViT)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=1024, dropout=0.1),
            num_layers=depth
        )

        # Fully connected layer to reshape encoded features for decoding
        # self.fc = nn.Linear(embed_dim, embed_dim * (patch_size // 2) * (patch_size // 2))

        # Decoder with Up-Convolutions
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, 256, kernel_size=4, stride=2, padding=1),  # 8x8 → 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 16x16 → 32x32
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 32x32 → 64x64
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)  # Output: 64x64
        )

    def forward(self, x):
        B, C, H, W = x.shape  # Bx1x64x64
        patches = self.patch_embed(x).flatten(2).transpose(1, 2)  # Convert to patches
        patches = patches + self.pos_embed  # Add positional embeddings
        
        # masked_patches = patches * mask.unsqueeze(-1)  # Apply mask

        encoded = self.encoder(patches)  # Transformer encoding
        # print(encoded.shape)

        # Reshape for decoder
        # encoded = self.fc(encoded)  # Fully connected to reshape embedding
        encoded = encoded.view(B, self.embed_dim, H // self.patch_size, W // self.patch_size)

        # Decode using up-convolutions
        reconstructed_img = self.decoder(encoded)

        return reconstructed_img

In [6]:
# Load Pretrained MAE
pretrained_mae = CustomMAE()  # Load pre-trained weights if available
pretrained_mae.load_state_dict(torch.load("/kaggle/input/model-weights/pytorch/default/1/models/pretrainedmae_epoch10.pth"))

  pretrained_mae.load_state_dict(torch.load("/kaggle/input/model-weights/pytorch/default/1/models/pretrainedmae_epoch10.pth"))


<All keys matched successfully>

In [19]:
class SuperResMAE(CustomMAE):
    def __init__(self, img_size=64, patch_size=8, embed_dim=512, num_heads=8, depth=6):
        super(SuperResMAE, self).__init__(img_size, patch_size, embed_dim, num_heads, depth)

        # Remove the last Conv2d from the original decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, 256, kernel_size=4, stride=2, padding=1),  # 8x8 → 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 16x16 → 32x32
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 32x32 → 64x64
            nn.ReLU(),
        )

        # Additional up-convolutions and upsampling
        self.extra_upconv = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)  # 64x64 → 128x128
        self.upsample = nn.Upsample(size=(150, 150), mode='bilinear', align_corners=True)
        self.final_conv = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)  # 128x128 → 150x150 (smoothing)

    def forward(self, x):
        encoded = self.encoder(self.patch_embed(x).flatten(2).transpose(1, 2) + self.pos_embed)
        encoded = encoded.view(x.shape[0], self.embed_dim, self.img_size // self.patch_size, self.img_size // self.patch_size)
        
        x = self.decoder(encoded)  # Decode to 64x64
        x = F.relu(self.extra_upconv(x))  # Upscale to 128x128
        x = self.upsample(x)  # Upscale to 150x150
        x = self.final_conv(x)  # Smooth artifacts

        return x

In [20]:
model = SuperResMAE().to(device)
model.load_state_dict(torch.load("/kaggle/input/model-weights/pytorch/default/1/models/superres_model_epoch20.pth"))  # Load weights

  model.load_state_dict(torch.load("/kaggle/input/model-weights/pytorch/default/1/models/superres_model_epoch20.pth"))  # Load weights


<All keys matched successfully>

In [None]:
from skimage.metrics import structural_similarity as ssim
import math

def psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    return 20 * math.log10(1.0 / math.sqrt(mse))

model.eval()
total_psnr, total_ssim = 0, 0
test_loss=0
criterion = nn.MSELoss()

with torch.no_grad():
    for lr_imgs, hr_imgs in val_loader:
        lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
        outputs = model(lr_imgs)
        loss = criterion(outputs, hr_imgs)
        test_loss += loss.item()
        for i in range(outputs.shape[0]):
            output_img = outputs[i].cpu().numpy().squeeze()
            hr_img = hr_imgs[i].cpu().numpy().squeeze()

            total_psnr += psnr(output_img, hr_img)
            total_ssim += ssim(output_img, hr_img, data_range=1)
        to

print(f"Test PSNR: {total_psnr / len(val_loader):.2f}, SSIM: {total_ssim / len(val_loader):.4f}, MSE: {test_loss / len(val_loader)}")


Test PSNR: 618.11, SSIM: 15.3489, MSE: 0.00013130609884337034


In [None]:
import matplotlib.pyplot as plt

# Function to display images
def display_images(lr_img, hr_img, sr_img, psnr_value):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(lr_img.squeeze(), cmap='gray')
    axes[0].set_title("Low-Resolution (LR)")
    axes[0].axis('off')

    axes[1].imshow(hr_img.squeeze(), cmap='gray')
    axes[1].set_title("High-Resolution (HR)")
    axes[1].axis('off')

    axes[2].imshow(sr_img.squeeze(), cmap='gray')
    axes[2].set_title(f"Super-Resolved (SR)\nPSNR: {psnr_value:.2f}")
    axes[2].axis('off')

    plt.show()

# Display a few sample images
model.eval()
with torch.no_grad():
    for lr_imgs, hr_imgs in val_loader:
        lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
        outputs = model(lr_imgs)

        for i in range(min(3, outputs.shape[0])):  # Display up to 3 samples
            lr_img = lr_imgs[i].cpu().numpy()
            hr_img = hr_imgs[i].cpu().numpy()
            sr_img = outputs[i].cpu().numpy()

            psnr_value = psnr(sr_img.squeeze(), hr_img.squeeze())
            display_images(lr_img, hr_img, sr_img, psnr_value)
        break  # Only display one batch