In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

class CustomMAE(nn.Module):
    def __init__(self, img_size=64, patch_size=8, embed_dim=512, num_heads=8, depth=6):
        super(CustomMAE, self).__init__()

        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.embed_dim = embed_dim

        # Patch embedding (Linear Projection + Positional Encoding)
        self.patch_embed = nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))

        # Transformer Encoder (Inspired by ViT)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=1024, dropout=0.1),
            num_layers=depth
        )

        # Fully connected layer to reshape encoded features for decoding
        # self.fc = nn.Linear(embed_dim, embed_dim * (patch_size // 2) * (patch_size // 2))

        # Decoder with Up-Convolutions
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, 256, kernel_size=4, stride=2, padding=1),  # 8x8 → 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 16x16 → 32x32
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 32x32 → 64x64
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)  # Output: 64x64
        )

    def forward(self, x):
        B, C, H, W = x.shape  # Bx1x64x64
        patches = self.patch_embed(x).flatten(2).transpose(1, 2)  # Convert to patches
        patches = patches + self.pos_embed  # Add positional embeddings
        
        # masked_patches = patches * mask.unsqueeze(-1)  # Apply mask

        encoded = self.encoder(patches)  # Transformer encoding
        # print(encoded.shape)

        # Reshape for decoder
        # encoded = self.fc(encoded)  # Fully connected to reshape embedding
        encoded = encoded.view(B, self.embed_dim, H // self.patch_size, W // self.patch_size)

        # Decode using up-convolutions
        reconstructed_img = self.decoder(encoded)

        return reconstructed_img


In [4]:
from kaggle_secrets import UserSecretsClient
import wandb

user_secrets = UserSecretsClient()

my_secret = user_secrets.get_secret("wandb_api_key") 

wandb.login(key=my_secret)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpechetti-1[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
import wandb
import torch.optim as optim
from torch.utils.data import DataLoader

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader



In [None]:
import wandb

In [13]:
# Initialize WandB
wandb.init(project="super-resolution-task")

[34m[1mwandb[0m: Tracking run with wandb version 0.19.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20250327_230908-jxbm1zct[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mfirm-vortex-3[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/pechetti-1/super-resolution-task[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/pechetti-1/super-resolution-task/runs/jxbm1zct[0m


In [14]:
# Custom Dataset
import torchvision.transforms as transforms
import cv2
from torch.utils.data import Dataset
class SuperResDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, transform=None):
        self.lr_file_list = [os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.endswith('.npy')]
        self.hr_file_list = [os.path.join("/".join(f.split('/')[:-2]),"HR",f.split('/')[-1]) for f in self.lr_file_list]

        # print(self.lr_file_list[0],self.hr_file_list[0])
        self.lr_transform = transforms.Compose([
            # transforms.ToTensor(),  # Convert to Tensor (C, H, W)
            transforms.Resize((64, 64), interpolation=transforms.InterpolationMode.BICUBIC),  # Resize LR images
        ])

        self.hr_transform = transforms.Compose([
            # transforms.ToTensor(),  # Convert HR images to Tensor (C, H, W)
        ])

    def __len__(self):
        return len(self.lr_file_list)

    def __getitem__(self, idx):
        lr = np.load(self.lr_file_list[idx]).astype(np.float32)  # (75,75)
        hr = np.load(self.hr_file_list[idx]).astype(np.float32)  # (150,150)

        # Ensure shape (1, H, W) for grayscale images
        if len(lr.shape) == 2:
            lr = np.expand_dims(lr, axis=0)  # (1, 75, 75)
        if len(hr.shape) == 2:
            hr = np.expand_dims(hr, axis=0)  # (1, 150, 150)''

        # Apply transforms
        lr = self.lr_transform(torch.from_numpy(lr))
        hr = self.hr_transform(torch.from_numpy(hr))

        # Normalize images to [-1,1]
        lr = (lr - lr.min()) / (lr.max() - lr.min())  # Normalize
        hr = (hr - hr.min()) / (hr.max() - hr.min())  # Normalize

       

        return lr, hr

In [15]:
from torch.utils.data import random_split, DataLoader


# Define transformation
transform = transforms.Compose([
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load dataset
dataset = SuperResDataset("/kaggle/input/foundational-model-task-ml4sci/task4-b/Dataset/LR","/kaggle/input/foundational-model-task-ml4sci/task4-b/Dataset/HR", transform=transform)


# Train/Validation Split (90:10)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

In [16]:
class SuperResMAE(CustomMAE):
    def __init__(self, img_size=64, patch_size=8, embed_dim=512, num_heads=8, depth=6):
        super(SuperResMAE, self).__init__(img_size, patch_size, embed_dim, num_heads, depth)

        # Remove the last Conv2d from the original decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, 256, kernel_size=4, stride=2, padding=1),  # 8x8 → 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 16x16 → 32x32
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 32x32 → 64x64
            nn.ReLU(),
        )

        # Additional up-convolutions and upsampling
        self.extra_upconv = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)  # 64x64 → 128x128
        self.upsample = nn.Upsample(size=(150, 150), mode='bilinear', align_corners=True)
        self.final_conv = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)  # 128x128 → 150x150 (smoothing)

    def forward(self, x):
        encoded = self.encoder(self.patch_embed(x).flatten(2).transpose(1, 2) + self.pos_embed)
        encoded = encoded.view(x.shape[0], self.embed_dim, self.img_size // self.patch_size, self.img_size // self.patch_size)
        
        x = self.decoder(encoded)  # Decode to 64x64
        x = F.relu(self.extra_upconv(x))  # Upscale to 128x128
        x = self.upsample(x)  # Upscale to 150x150
        x = self.final_conv(x)  # Smooth artifacts

        return x


In [None]:
# Load Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SuperResMAE().to(device)
model.load_state_dict(torch.load("/kaggle/input/pretrained-mae/pytorch/default/1/pretrainedmae_epoch10.pth"), strict=False)  # Load weights

  model.load_state_dict(torch.load("/kaggle/input/pretrained-mae/pytorch/default/1/pretrainedmae_epoch10.pth"), strict=False)  # Load weights


_IncompatibleKeys(missing_keys=['extra_upconv.weight', 'extra_upconv.bias', 'final_conv.weight', 'final_conv.bias'], unexpected_keys=['decoder.6.weight', 'decoder.6.bias'])

In [18]:
# Freeze encoder initially
for param in model.patch_embed.parameters():
    param.requires_grad = False
for param in model.encoder.parameters():
    param.requires_grad = False

optimizer = optim.AdamW(model.parameters(), lr=3e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
criterion = nn.MSELoss()

# Training Loop
epochs = 50
for epoch in range(epochs):
    if epoch == 5:  # Unfreeze encoder after 5 epochs
        for param in model.patch_embed.parameters():
            param.requires_grad = True
        for param in model.encoder.parameters():
            param.requires_grad = True
        print("Encoder Unfrozen!")
        
    model.train()
    train_loss = 0
    for lr_imgs, hr_imgs in train_loader:
        lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
        
        optimizer.zero_grad()
        outputs = model(lr_imgs)
        
        loss = criterion(outputs, hr_imgs)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for lr_imgs, hr_imgs in val_loader:
            lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
            outputs = model(lr_imgs)
            loss = criterion(outputs, hr_imgs)
            val_loss += loss.item()

    # Log to WandB
    wandb.log({"epoch": epoch, "train_loss": train_loss / len(train_loader), "val_loss": val_loss / len(val_loader)})

    print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Save best model
    if epoch % 2 == 0:
        torch.save(model.state_dict(), f"superres_model_epoch{epoch}.pth")

# Save final model
torch.save(model.state_dict(), "superres_model_final.pth")
wandb.finish()

Epoch [1/50], Train Loss: 1.9971, Val Loss: 0.0941
Epoch [2/50], Train Loss: 0.7509, Val Loss: 0.0585
Epoch [3/50], Train Loss: 0.5271, Val Loss: 0.0356
Epoch [4/50], Train Loss: 0.4341, Val Loss: 0.0341
Epoch [5/50], Train Loss: 0.3835, Val Loss: 0.0294
Encoder Unfrozen!
Epoch [6/50], Train Loss: 0.4161, Val Loss: 0.0218
Epoch [7/50], Train Loss: 0.1882, Val Loss: 0.0154
Epoch [8/50], Train Loss: 0.1530, Val Loss: 0.0164
Epoch [9/50], Train Loss: 0.1313, Val Loss: 0.0114
Epoch [10/50], Train Loss: 0.1197, Val Loss: 0.0139
Epoch [11/50], Train Loss: 0.1140, Val Loss: 0.0130
Epoch [12/50], Train Loss: 0.1089, Val Loss: 0.0118
Epoch [13/50], Train Loss: 0.1007, Val Loss: 0.0158
Epoch [14/50], Train Loss: 0.0987, Val Loss: 0.0117
Epoch [15/50], Train Loss: 0.0986, Val Loss: 0.0110
Epoch [16/50], Train Loss: 0.0900, Val Loss: 0.0099
Epoch [17/50], Train Loss: 0.0878, Val Loss: 0.0116
Epoch [18/50], Train Loss: 0.0892, Val Loss: 0.0092
Epoch [19/50], Train Loss: 0.0854, Val Loss: 0.0105
Epo

[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:      epoch ▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇███
[34m[1mwandb[0m: train_loss ▅▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▂▅▇████████████████████
[34m[1mwandb[0m:   val_loss ▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▄▇███████████████████
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:      epoch 49
[34m[1mwandb[0m: train_loss 0.00639
[34m[1mwandb[0m:   val_loss 0.00632
[34m[1mwandb[0m: 
[34m[1mwandb[0m: 🚀 View run [33mfirm-vortex-3[0m at: [34m[4mhttps://wandb.ai/pechetti-1/super-resolution-task/runs/jxbm1zct[0m
[34m[1mwandb[0m: ⭐️ View project at: [34m[4mhttps://wandb.ai/pechetti-1/super-resolution-task[0m
[34m[1mwandb[0m: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
[34m[1mwandb[0m: Find logs at: [35m[1m./wandb/run-20250327_230908-jxbm1zct/logs[0m
