## Cell 1: Mount Drive & Install Libraries

In [1]:
from google.colab import drive
import os

# Mount your Google Drive
drive.mount('/content/drive')

# Install required libraries
!pip install -q einops scikit-image pandas opencv-python

Mounted at /content/drive


In [2]:
import os
import time
import glob

# --- 1. SET THE CORRECT SOURCE PATH ---
# Make sure this path to your dataset folder on Google Drive is correct.
gdrive_dataset_path = "/content/drive/MyDrive/Datasets" # <-- UPDATE THIS PATH IF NEEDED

# --- 2. PERFORM THE COPY TO RAM ---
print(f"Copying dataset from '{gdrive_dataset_path}' to RAM disk (/dev/shm)...")
start_time = time.time()

# MODIFIED: The destination is now the RAM disk
local_dataset_path = "/dev/shm/datasets"

if os.path.exists(gdrive_dataset_path):
    !rsync -a --info=progress2 {gdrive_dataset_path} {local_dataset_path}
else:
    print(f"❌ ERROR: Source path not found: {gdrive_dataset_path}")

end_time = time.time()
print(f"\nCopy operation finished in {end_time - start_time:.2f} seconds.")

# --- 3. VERIFY THE COPY ---
print("\nVerifying copied files in RAM...")
local_base = os.path.join(local_dataset_path, os.path.basename(gdrive_dataset_path.rstrip('/')))

div2k_path = os.path.join(local_base, 'DIV2K/HR')
flickr2k_path = os.path.join(local_base, 'Flickr2K/HR')
set14_path = os.path.join(local_base, 'Set14/image_SRF_4')

try:
    num_div2k = len(glob.glob(os.path.join(div2k_path, '*.png')))
    num_flickr2k = len(glob.glob(os.path.join(flickr2k_path, '*.png')))
    num_set14 = len(glob.glob(os.path.join(set14_path, '*.png')))

    print(f"✅ Found {num_div2k} images in DIV2K.")
    print(f"✅ Found {num_flickr2k} images in Flickr2K.")
    print(f"✅ Found {num_set14} images in Set14.")

    if num_div2k == 0 or num_flickr2k == 0:
        print("\n⚠️ WARNING: Training dataset appears empty. Check your paths again.")
    else:
        print("\n👍 Verification successful! You can now proceed to training.")

except Exception as e:
    print(f"\n❌ VERIFICATION FAILED: {e}")

Copying dataset from '/content/drive/MyDrive/Datasets' to RAM disk (/dev/shm)...
              0 100%    0.00kB/s    0:00:00 (xfr#0, to-chk=0/1)

Copy operation finished in 0.49 seconds.

Verifying copied files in RAM...
✅ Found 800 images in DIV2K.
✅ Found 2650 images in Flickr2K.
✅ Found 28 images in Set14.

👍 Verification successful! You can now proceed to training.


## Cell 2: Project Setup and Configuration

In [3]:
import torch
import os

class Config:
    # --- Project and Directory Paths (Version 5) ---
    DRIVE_PREFIX = '/content/drive/MyDrive/'
    PROJECT_DIR = os.path.join(DRIVE_PREFIX, 'FusionSR_ClassicalSR_v5')
    CHECKPOINTS_DIR = os.path.join(PROJECT_DIR, 'checkpoints')
    BEST_MODEL_DIR = os.path.join(PROJECT_DIR, 'best_models')

    # --- Dataset Paths ---
    # Using local copy in RAM for max speed
    LOCAL_DATASET_BASE = '/dev/shm/datasets/Datasets/'
    TRAIN_HR_DIR_DIV2K = os.path.join(LOCAL_DATASET_BASE, 'DIV2K/HR')
    TRAIN_HR_DIR_FLICKR = os.path.join(LOCAL_DATASET_BASE, 'Flickr2K/HR')
    VAL_HR_DIR = os.path.join(LOCAL_DATASET_BASE, 'Set14/image_SRF_4')

    # --- Model Architecture ---
    UPSCALE_FACTOR = 4
    BASE_DIM = 180
    NUM_BLOCKS = 8
    NUM_HEADS = 6
    WINDOW_SIZE = 8

    # --- Training Hyperparameters (MODIFIED) ---
    NUM_EPOCHS = 50
    BATCH_SIZE = 96  # REVERTED: Back to original large batch size
    LEARNING_RATE = 2e-4
    VGG_LOSS_WEIGHT = 0.1
    HR_PATCH_SIZE = 256
    GRADIENT_CLIP_VAL = 1.0

    # --- Data Loading ---
    NUM_WORKERS = 2 # Keeping this at a safe, static number

# Create project directories for v5
os.makedirs(Config.CHECKPOINTS_DIR, exist_ok=True)
os.makedirs(Config.BEST_MODEL_DIR, exist_ok=True)

# Set up device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

print(f"Project directory (on Drive): {Config.PROJECT_DIR}")
print(f"Dataset path (in RAM): {Config.LOCAL_DATASET_BASE}")
print(f"Using device: {DEVICE}")

Project directory (on Drive): /content/drive/MyDrive/FusionSR_ClassicalSR_v5
Dataset path (in RAM): /dev/shm/datasets/Datasets/
Using device: cuda


## Cell 3: Data Pipeline (for Classical SR)

In [4]:
import glob
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import random

class ClassicalSRDataset(Dataset):
    def __init__(self, hr_dirs, hr_patch_size, is_train=True):
        super(ClassicalSRDataset, self).__init__()
        self.hr_paths = []
        for d in hr_dirs:
            # Handle both DIV2K and Set14/other naming patterns
            self.hr_paths.extend(glob.glob(os.path.join(d, '*.png')))
            self.hr_paths.extend(glob.glob(os.path.join(d, '*_HR.png')))

        self.hr_patch_size = hr_patch_size
        self.is_train = is_train

        # Define the transformations
        self.hr_crop = transforms.RandomCrop(self.hr_patch_size) if is_train else transforms.CenterCrop(self.hr_patch_size)
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.hr_paths)

    def __getitem__(self, idx):
        hr_image = Image.open(self.hr_paths[idx]).convert('RGB')
        hr_patch = self.hr_crop(hr_image)

        # ADDED: Data Augmentation for training
        if self.is_train:
            if random.random() > 0.5: # Random horizontal flip
                hr_patch = hr_patch.transpose(Image.FLIP_LEFT_RIGHT)
            if random.random() > 0.5: # Random vertical flip
                hr_patch = hr_patch.transpose(Image.FLIP_TOP_BOTTOM)
            if random.random() > 0.5: # Random 90-degree rotation
                hr_patch = hr_patch.rotate(90)

        return self.to_tensor(hr_patch)

## Cell 4: 🏗️ FusionSR Model Architecture (Complete)

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torchvision import models

# --- Helper Modules ---
class ResidualBlock(nn.Module):
    def __init__(self, features):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(features, features, kernel_size=3, padding=1), nn.GELU(),
            nn.Conv2d(features, features, kernel_size=3, padding=1),
        )
    def forward(self, x):
        return x + self.conv_block(x)

def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

# --- Core SwinIR-style Components ---
class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)

        # RE-ENABLED: Relative Position Bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads))
        coords_h = torch.arange(window_size)
        coords_w = torch.arange(window_size)
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += window_size - 1
        relative_coords[:, :, 1] += window_size - 1
        relative_coords[:, :, 0] *= 2 * window_size - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)
        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)

    def forward(self, x, mask=None):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        attn = (q @ k.transpose(-2, -1)) * self.scale

        # RE-ENABLED: Add relative position bias
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            N, N, -1).permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)

        attn = F.softmax(attn, dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        return self.proj(x)

# The rest of this cell (SwinBlock, FusionSR, etc.) is the same as the v4 cell you already have.
# I've included it here for completeness.
class SwinBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size, shift_size=0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, num_heads, window_size)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
        self.window_size = window_size
        self.shift_size = shift_size
    def forward(self, x, H, W):
        B, L, C = x.shape; shortcut = x; x = self.norm1(x); x = x.view(B, H, W, C)
        if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else: shifted_x = x
        x_windows = window_partition(shifted_x, self.window_size); x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
        attn_mask = None
        if self.shift_size > 0:
            img_mask = torch.zeros((1, H, W, 1), device=x.device); h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)); w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)); cnt = 0
            for h in h_slices:
                for w in w_slices: img_mask[:, h, w, :] = cnt; cnt += 1
            mask_windows = window_partition(img_mask, self.window_size); mask_windows = mask_windows.view(-1, self.window_size * self.window_size); attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2); attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        attn_windows = self.attn(x_windows, mask=attn_mask); shifted_x = window_reverse(attn_windows, self.window_size, H, W)
        if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else: x = shifted_x
        x = x.view(B, L, C); x = shortcut + x; x = x + self.mlp(self.norm2(x)); return x
class FusionSR(nn.Module):
    def __init__(self, Cfg):
        super().__init__(); self.conv_first = nn.Sequential(nn.Conv2d(3, Cfg.BASE_DIM // 2, 3, 1, 1), nn.GELU(), ResidualBlock(Cfg.BASE_DIM // 2), ResidualBlock(Cfg.BASE_DIM // 2), nn.Conv2d(Cfg.BASE_DIM // 2, Cfg.BASE_DIM, 1, 1, 0)); self.body = nn.ModuleList([SwinBlock(dim=Cfg.BASE_DIM, num_heads=Cfg.NUM_HEADS, window_size=Cfg.WINDOW_SIZE, shift_size=0 if (i % 2 == 0) else Cfg.WINDOW_SIZE // 2) for i in range(Cfg.NUM_BLOCKS)]); self.feature_fusion = nn.Sequential(nn.Conv2d(Cfg.BASE_DIM * 2, Cfg.BASE_DIM, 1, 1, 0), nn.GELU(), nn.Conv2d(Cfg.BASE_DIM, Cfg.BASE_DIM, 3, 1, 1)); self.upsample = nn.Sequential(nn.Conv2d(Cfg.BASE_DIM, 3 * (Cfg.UPSCALE_FACTOR ** 2), 3, 1, 1), nn.PixelShuffle(Cfg.UPSCALE_FACTOR))
    def forward(self, x):
        B, C, H, W = x.shape; shallow_features = self.conv_first(x); deep_features_in = rearrange(shallow_features, 'b c h w -> b (h w) c');
        for block in self.body: deep_features_in = block(deep_features_in, H, W)
        deep_features = rearrange(deep_features_in, 'b (h w) c -> b c h w', h=H, w=W); fused_features = self.feature_fusion(torch.cat((shallow_features, deep_features), dim=1)); return self.upsample(fused_features)
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__(); vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features[:36].eval().to(DEVICE);
        for param in vgg.parameters(): param.requires_grad = False
        self.vgg = vgg; self.loss = nn.L1Loss()
    def forward(self, x, y): return self.loss(self.vgg(x), self.vgg(y))
class CharbonnierLoss(nn.Module):
    def __init__(self, eps=1e-3): super(CharbonnierLoss, self).__init__(); self.eps = eps
    def forward(self, x, y): diff = x - y; loss = torch.mean(torch.sqrt(diff * diff + self.eps**2)); return loss

In [6]:
# !pip install -q torch-ema


## Cell 5: 🚀 Model Training (Corrected and Complete)

In [8]:
import torch.optim as optim
from torch.amp import GradScaler, autocast
from tqdm import tqdm
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import numpy as np

def train_model():
    # --- Dataloaders ---
    train_dirs = [Config.TRAIN_HR_DIR_DIV2K, Config.TRAIN_HR_DIR_FLICKR]
    train_dataset = ClassicalSRDataset(hr_dirs=train_dirs, hr_patch_size=Config.HR_PATCH_SIZE, is_train=True)
    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=Config.NUM_WORKERS, pin_memory=True)
    val_dataset = ClassicalSRDataset(hr_dirs=[Config.VAL_HR_DIR], hr_patch_size=Config.HR_PATCH_SIZE, is_train=False)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, pin_memory=True)
    gpu_downsampler = transforms.Resize(size=Config.HR_PATCH_SIZE // Config.UPSCALE_FACTOR, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True).to(DEVICE)

    # --- Model, Optimizer, Loss ---
    model = FusionSR(Config).to(DEVICE)
    model = torch.compile(model)

    optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE, betas=(0.9, 0.99))
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.NUM_EPOCHS, eta_min=1e-6)
    criterion_pix = CharbonnierLoss().to(DEVICE)
    criterion_vgg = VGGPerceptualLoss().to(DEVICE)
    scaler = GradScaler()

    # --- MODIFIED: Checkpoint Logic for v5 ---
    start_epoch = 0
    best_psnr = 0.0
    BEST_MODEL_PATH = os.path.join(Config.BEST_MODEL_DIR, "fusionsr_classical_best_v5.pth")

    # Re-enabled logic to load the best model if it exists
    if os.path.exists(BEST_MODEL_PATH):
        print(f"✅ Found existing model. Loading weights from: {BEST_MODEL_PATH}")
        # Use ._orig_mod to access the model inside the torch.compile wrapper
        model._orig_mod.load_state_dict(torch.load(BEST_MODEL_PATH))
        print("Resuming training...")
    else:
        print("Starting training from scratch for v5 model.")


    print(f"--- Starting training cycle (Total Epochs: {Config.NUM_EPOCHS}) ---")
    for epoch in range(start_epoch, Config.NUM_EPOCHS):
        model.train()
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.NUM_EPOCHS}")

        for hr_imgs in progress_bar:
            hr_imgs = hr_imgs.to(DEVICE, non_blocking=True)
            lr_imgs = gpu_downsampler(hr_imgs)

            optimizer.zero_grad(set_to_none=True)
            with autocast(device_type='cuda', dtype=torch.bfloat16):
                sr_imgs = model(lr_imgs)
                loss_pix = criterion_pix(sr_imgs, hr_imgs)
                loss_vgg = criterion_vgg(sr_imgs, hr_imgs)
                loss = loss_pix + Config.VGG_LOSS_WEIGHT * loss_vgg

            scaler.scale(loss).backward()

            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), Config.GRADIENT_CLIP_VAL)

            scaler.step(optimizer)
            scaler.update()

            progress_bar.set_postfix({'Loss': f"{loss.item():.4f}", 'Pix': f"{loss_pix.item():.4f}"})

        # --- Validation ---
        model.eval()
        current_psnr, current_ssim = 0.0, 0.0
        with torch.no_grad():
            for hr_tensor in val_loader:
                hr_tensor = hr_tensor.to(DEVICE, non_blocking=True)
                lr_tensor = gpu_downsampler(hr_tensor)
                with autocast(device_type='cuda', dtype=torch.bfloat16):
                    sr_tensor = model(lr_tensor)

                sr_tensor = sr_tensor.clamp(0, 1).cpu().float()
                sr_np = sr_tensor.squeeze(0).permute(1, 2, 0).numpy()
                hr_np = hr_tensor.cpu().squeeze(0).permute(1, 2, 0).numpy()

                current_psnr += peak_signal_noise_ratio(hr_np, sr_np, data_range=1.0)
                current_ssim += structural_similarity(hr_np, sr_np, channel_axis=2, data_range=1.0)

        avg_psnr = current_psnr / len(val_loader)
        avg_ssim = current_ssim / len(val_loader)

        scheduler.step()
        print(f"Epoch {epoch+1} | Val PSNR: {avg_psnr:.4f} | Val SSIM: {avg_ssim:.4f} | LR: {optimizer.param_groups[0]['lr']:.1e}")

        if avg_psnr > best_psnr:
            best_psnr = avg_psnr
            torch.save(model._orig_mod.state_dict(), BEST_MODEL_PATH)
            print(f"✅ New best model saved with PSNR: {best_psnr:.4f}")

    print("--- Training Finished ---")

train_model()

✅ Found existing model. Loading weights from: /content/drive/MyDrive/FusionSR_ClassicalSR_v5/best_models/fusionsr_classical_best_v5.pth
Resuming training...
--- Starting training cycle (Total Epochs: 50) ---


Epoch 1/50: 100%|██████████| 36/36 [09:24<00:00, 15.68s/it, Loss=0.0881, Pix=0.0793]


Epoch 1 | Val PSNR: 19.0733 | Val SSIM: 0.3679 | LR: 2.0e-04
✅ New best model saved with PSNR: 19.0733


Epoch 2/50: 100%|██████████| 36/36 [03:46<00:00,  6.30s/it, Loss=0.0813, Pix=0.0721]


Epoch 2 | Val PSNR: 19.9845 | Val SSIM: 0.4954 | LR: 2.0e-04
✅ New best model saved with PSNR: 19.9845


Epoch 3/50: 100%|██████████| 36/36 [03:44<00:00,  6.23s/it, Loss=0.0767, Pix=0.0683]


Epoch 3 | Val PSNR: 20.9168 | Val SSIM: 0.5209 | LR: 2.0e-04
✅ New best model saved with PSNR: 20.9168


Epoch 4/50: 100%|██████████| 36/36 [03:34<00:00,  5.96s/it, Loss=0.0800, Pix=0.0711]


Epoch 4 | Val PSNR: 20.2046 | Val SSIM: 0.3541 | LR: 2.0e-04


Epoch 5/50: 100%|██████████| 36/36 [03:37<00:00,  6.05s/it, Loss=0.0694, Pix=0.0607]


Epoch 5 | Val PSNR: 21.6787 | Val SSIM: 0.5973 | LR: 2.0e-04
✅ New best model saved with PSNR: 21.6787


Epoch 6/50: 100%|██████████| 36/36 [03:47<00:00,  6.32s/it, Loss=0.0603, Pix=0.0528]


Epoch 6 | Val PSNR: 22.1951 | Val SSIM: 0.5430 | LR: 1.9e-04
✅ New best model saved with PSNR: 22.1951


Epoch 7/50: 100%|██████████| 36/36 [03:41<00:00,  6.16s/it, Loss=0.0596, Pix=0.0520]


Epoch 7 | Val PSNR: 22.1651 | Val SSIM: 0.5805 | LR: 1.9e-04


Epoch 8/50: 100%|██████████| 36/36 [03:50<00:00,  6.41s/it, Loss=0.0619, Pix=0.0539]


Epoch 8 | Val PSNR: 22.6605 | Val SSIM: 0.5240 | LR: 1.9e-04
✅ New best model saved with PSNR: 22.6605


Epoch 9/50: 100%|██████████| 36/36 [03:54<00:00,  6.52s/it, Loss=0.0573, Pix=0.0496]


Epoch 9 | Val PSNR: 23.2008 | Val SSIM: 0.6129 | LR: 1.8e-04
✅ New best model saved with PSNR: 23.2008


Epoch 10/50: 100%|██████████| 36/36 [03:47<00:00,  6.33s/it, Loss=0.0560, Pix=0.0482]


Epoch 10 | Val PSNR: 23.1786 | Val SSIM: 0.5112 | LR: 1.8e-04


Epoch 11/50: 100%|██████████| 36/36 [03:51<00:00,  6.44s/it, Loss=0.0502, Pix=0.0432]


Epoch 11 | Val PSNR: 23.4524 | Val SSIM: 0.5513 | LR: 1.8e-04
✅ New best model saved with PSNR: 23.4524


Epoch 12/50: 100%|██████████| 36/36 [03:53<00:00,  6.50s/it, Loss=0.0473, Pix=0.0404]


Epoch 12 | Val PSNR: 23.9029 | Val SSIM: 0.6705 | LR: 1.7e-04
✅ New best model saved with PSNR: 23.9029


Epoch 13/50: 100%|██████████| 36/36 [03:52<00:00,  6.47s/it, Loss=0.0452, Pix=0.0382]


Epoch 13 | Val PSNR: 24.1628 | Val SSIM: 0.7107 | LR: 1.7e-04
✅ New best model saved with PSNR: 24.1628


Epoch 14/50: 100%|██████████| 36/36 [03:51<00:00,  6.42s/it, Loss=0.0479, Pix=0.0410]


Epoch 14 | Val PSNR: 24.2125 | Val SSIM: 0.6480 | LR: 1.6e-04
✅ New best model saved with PSNR: 24.2125


Epoch 15/50: 100%|██████████| 36/36 [03:53<00:00,  6.48s/it, Loss=0.0497, Pix=0.0422]


Epoch 15 | Val PSNR: 23.7409 | Val SSIM: 0.5584 | LR: 1.6e-04


Epoch 16/50: 100%|██████████| 36/36 [03:48<00:00,  6.34s/it, Loss=0.0668, Pix=0.0598]


Epoch 16 | Val PSNR: 23.0998 | Val SSIM: 0.4846 | LR: 1.5e-04


Epoch 17/50: 100%|██████████| 36/36 [03:58<00:00,  6.63s/it, Loss=0.0538, Pix=0.0473]


Epoch 17 | Val PSNR: 24.0876 | Val SSIM: 0.6152 | LR: 1.5e-04


Epoch 18/50: 100%|██████████| 36/36 [03:54<00:00,  6.50s/it, Loss=0.0467, Pix=0.0396]


Epoch 18 | Val PSNR: 24.5718 | Val SSIM: 0.7264 | LR: 1.4e-04
✅ New best model saved with PSNR: 24.5718


Epoch 19/50: 100%|██████████| 36/36 [03:59<00:00,  6.66s/it, Loss=0.0405, Pix=0.0335]


Epoch 19 | Val PSNR: 24.6300 | Val SSIM: 0.7073 | LR: 1.4e-04
✅ New best model saved with PSNR: 24.6300


Epoch 20/50: 100%|██████████| 36/36 [03:51<00:00,  6.42s/it, Loss=0.0404, Pix=0.0336]


Epoch 20 | Val PSNR: 24.6955 | Val SSIM: 0.6620 | LR: 1.3e-04
✅ New best model saved with PSNR: 24.6955


Epoch 21/50: 100%|██████████| 36/36 [03:57<00:00,  6.59s/it, Loss=0.0413, Pix=0.0344]


Epoch 21 | Val PSNR: 24.8489 | Val SSIM: 0.7425 | LR: 1.3e-04
✅ New best model saved with PSNR: 24.8489


Epoch 22/50: 100%|██████████| 36/36 [03:58<00:00,  6.64s/it, Loss=0.0434, Pix=0.0372]


Epoch 22 | Val PSNR: 24.8435 | Val SSIM: 0.7373 | LR: 1.2e-04


Epoch 23/50: 100%|██████████| 36/36 [03:50<00:00,  6.40s/it, Loss=0.0391, Pix=0.0325]


Epoch 23 | Val PSNR: 24.9184 | Val SSIM: 0.7206 | LR: 1.1e-04
✅ New best model saved with PSNR: 24.9184


Epoch 24/50: 100%|██████████| 36/36 [03:51<00:00,  6.43s/it, Loss=0.0429, Pix=0.0364]


Epoch 24 | Val PSNR: 25.0268 | Val SSIM: 0.7335 | LR: 1.1e-04
✅ New best model saved with PSNR: 25.0268


Epoch 25/50: 100%|██████████| 36/36 [03:50<00:00,  6.40s/it, Loss=0.0388, Pix=0.0327]


Epoch 25 | Val PSNR: 25.0176 | Val SSIM: 0.6972 | LR: 1.0e-04


Epoch 26/50: 100%|██████████| 36/36 [04:01<00:00,  6.70s/it, Loss=0.0392, Pix=0.0328]


Epoch 26 | Val PSNR: 25.0786 | Val SSIM: 0.7438 | LR: 9.4e-05
✅ New best model saved with PSNR: 25.0786


Epoch 27/50: 100%|██████████| 36/36 [03:47<00:00,  6.32s/it, Loss=0.0424, Pix=0.0363]


Epoch 27 | Val PSNR: 24.9408 | Val SSIM: 0.7426 | LR: 8.8e-05


Epoch 28/50: 100%|██████████| 36/36 [03:46<00:00,  6.29s/it, Loss=0.0428, Pix=0.0363]


Epoch 28 | Val PSNR: 25.0334 | Val SSIM: 0.7430 | LR: 8.2e-05


Epoch 29/50: 100%|██████████| 36/36 [03:50<00:00,  6.41s/it, Loss=0.0403, Pix=0.0334]


Epoch 29 | Val PSNR: 25.1534 | Val SSIM: 0.7508 | LR: 7.6e-05
✅ New best model saved with PSNR: 25.1534


Epoch 30/50: 100%|██████████| 36/36 [03:43<00:00,  6.21s/it, Loss=0.0422, Pix=0.0358]


Epoch 30 | Val PSNR: 25.1381 | Val SSIM: 0.7199 | LR: 7.0e-05


Epoch 31/50: 100%|██████████| 36/36 [03:51<00:00,  6.43s/it, Loss=0.0405, Pix=0.0341]


Epoch 31 | Val PSNR: 25.1678 | Val SSIM: 0.7131 | LR: 6.4e-05
✅ New best model saved with PSNR: 25.1678


Epoch 32/50: 100%|██████████| 36/36 [03:45<00:00,  6.26s/it, Loss=0.0424, Pix=0.0357]


Epoch 32 | Val PSNR: 25.2053 | Val SSIM: 0.7324 | LR: 5.8e-05
✅ New best model saved with PSNR: 25.2053


Epoch 33/50: 100%|██████████| 36/36 [03:52<00:00,  6.45s/it, Loss=0.0398, Pix=0.0333]


Epoch 33 | Val PSNR: 25.2373 | Val SSIM: 0.7505 | LR: 5.3e-05
✅ New best model saved with PSNR: 25.2373


Epoch 34/50: 100%|██████████| 36/36 [03:52<00:00,  6.45s/it, Loss=0.0387, Pix=0.0322]


Epoch 34 | Val PSNR: 25.2491 | Val SSIM: 0.7506 | LR: 4.7e-05
✅ New best model saved with PSNR: 25.2491


Epoch 35/50: 100%|██████████| 36/36 [03:58<00:00,  6.63s/it, Loss=0.0382, Pix=0.0318]


Epoch 35 | Val PSNR: 25.2411 | Val SSIM: 0.7241 | LR: 4.2e-05


Epoch 36/50: 100%|██████████| 36/36 [03:50<00:00,  6.40s/it, Loss=0.0380, Pix=0.0316]


Epoch 36 | Val PSNR: 25.2653 | Val SSIM: 0.7538 | LR: 3.7e-05
✅ New best model saved with PSNR: 25.2653


Epoch 37/50: 100%|██████████| 36/36 [03:53<00:00,  6.48s/it, Loss=0.0407, Pix=0.0340]


Epoch 37 | Val PSNR: 25.2772 | Val SSIM: 0.7447 | LR: 3.2e-05
✅ New best model saved with PSNR: 25.2772


Epoch 38/50: 100%|██████████| 36/36 [04:02<00:00,  6.74s/it, Loss=0.0363, Pix=0.0300]


Epoch 38 | Val PSNR: 25.2861 | Val SSIM: 0.7540 | LR: 2.8e-05
✅ New best model saved with PSNR: 25.2861


Epoch 39/50: 100%|██████████| 36/36 [03:45<00:00,  6.28s/it, Loss=0.0404, Pix=0.0342]


Epoch 39 | Val PSNR: 25.3008 | Val SSIM: 0.7557 | LR: 2.4e-05
✅ New best model saved with PSNR: 25.3008


Epoch 40/50: 100%|██████████| 36/36 [03:50<00:00,  6.39s/it, Loss=0.0332, Pix=0.0268]


Epoch 40 | Val PSNR: 25.2977 | Val SSIM: 0.7518 | LR: 2.0e-05


Epoch 41/50: 100%|██████████| 36/36 [03:47<00:00,  6.32s/it, Loss=0.0375, Pix=0.0309]


Epoch 41 | Val PSNR: 25.3111 | Val SSIM: 0.7530 | LR: 1.6e-05
✅ New best model saved with PSNR: 25.3111


Epoch 42/50: 100%|██████████| 36/36 [03:53<00:00,  6.50s/it, Loss=0.0431, Pix=0.0361]


Epoch 42 | Val PSNR: 25.3165 | Val SSIM: 0.7550 | LR: 1.3e-05
✅ New best model saved with PSNR: 25.3165


Epoch 43/50: 100%|██████████| 36/36 [03:50<00:00,  6.40s/it, Loss=0.0365, Pix=0.0306]


Epoch 43 | Val PSNR: 25.3185 | Val SSIM: 0.7540 | LR: 1.0e-05
✅ New best model saved with PSNR: 25.3185


Epoch 44/50: 100%|██████████| 36/36 [03:45<00:00,  6.25s/it, Loss=0.0415, Pix=0.0353]


Epoch 44 | Val PSNR: 25.3223 | Val SSIM: 0.7541 | LR: 8.0e-06
✅ New best model saved with PSNR: 25.3223


Epoch 45/50: 100%|██████████| 36/36 [03:47<00:00,  6.32s/it, Loss=0.0349, Pix=0.0291]


Epoch 45 | Val PSNR: 25.3273 | Val SSIM: 0.7545 | LR: 5.9e-06
✅ New best model saved with PSNR: 25.3273


Epoch 46/50: 100%|██████████| 36/36 [03:48<00:00,  6.35s/it, Loss=0.0366, Pix=0.0300]


Epoch 46 | Val PSNR: 25.3299 | Val SSIM: 0.7539 | LR: 4.1e-06
✅ New best model saved with PSNR: 25.3299


Epoch 47/50: 100%|██████████| 36/36 [03:54<00:00,  6.52s/it, Loss=0.0349, Pix=0.0284]


Epoch 47 | Val PSNR: 25.3295 | Val SSIM: 0.7535 | LR: 2.8e-06


Epoch 48/50: 100%|██████████| 36/36 [03:44<00:00,  6.25s/it, Loss=0.0405, Pix=0.0337]


Epoch 48 | Val PSNR: 25.3303 | Val SSIM: 0.7509 | LR: 1.8e-06
✅ New best model saved with PSNR: 25.3303


Epoch 49/50: 100%|██████████| 36/36 [03:46<00:00,  6.29s/it, Loss=0.0367, Pix=0.0297]


Epoch 49 | Val PSNR: 25.3352 | Val SSIM: 0.7546 | LR: 1.2e-06
✅ New best model saved with PSNR: 25.3352


Epoch 50/50: 100%|██████████| 36/36 [03:42<00:00,  6.18s/it, Loss=0.0411, Pix=0.0341]


Epoch 50 | Val PSNR: 25.3351 | Val SSIM: 0.7540 | LR: 1.0e-06
--- Training Finished ---


## Cell 6: 📊 Model Benchmarking (Complete)

In [10]:
import pandas as pd
from tqdm import tqdm
from skimage.metrics import structural_similarity, peak_signal_noise_ratio
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import glob

# Re-define the ClassicalSRDataset here for a self-contained cell
class BenchmarkSRDataset(Dataset):
    def __init__(self, hr_dirs, upscale_factor=4):
        self.hr_paths = []
        for d in hr_dirs:
            self.hr_paths.extend(glob.glob(os.path.join(d, '*.png')))
            self.hr_paths.extend(glob.glob(os.path.join(d, '*_HR.png')))

        self.upscale_factor = upscale_factor
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.hr_paths)

    def __getitem__(self, idx):
        hr_image = Image.open(self.hr_paths[idx]).convert('RGB')

        w, h = hr_image.size
        crop_w = w - (w % self.upscale_factor)
        crop_h = h - (h % self.upscale_factor)
        hr_image = transforms.functional.center_crop(hr_image, (crop_h, crop_w))

        lr_image = hr_image.resize(
            (hr_image.width // self.upscale_factor, hr_image.height // self.upscale_factor),
            Image.BICUBIC
        )
        return self.to_tensor(lr_image), self.to_tensor(hr_image)

def benchmark_sr_model():
    BEST_MODEL_PATH = os.path.join(Config.BEST_MODEL_DIR, "fusionsr_classical_best_v5.pth")
    if not os.path.exists(BEST_MODEL_PATH):
        print(f"Best model not found at {BEST_MODEL_PATH}. Please train the model first.")
        return

    model = FusionSR(Config).to(DEVICE)
    model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE))
    model.eval()

    BENCHMARK_DATASETS = {
        "Set14": Config.VAL_HR_DIR,
    }

    results = []
    print("\n--- Starting Benchmarks on Best Model ---")

    for name, path in BENCHMARK_DATASETS.items():
        if not os.path.exists(path):
            print(f"\nPath not found for {name}: {path}. Skipping.")
            continue

        dataset = BenchmarkSRDataset(hr_dirs=[path], upscale_factor=Config.UPSCALE_FACTOR)
        if len(dataset) == 0:
            print(f"\nNo images found for {name} at {path}. Skipping.")
            continue

        loader = DataLoader(dataset, batch_size=1, shuffle=False)
        psnr_scores, ssim_scores = [], []

        with torch.no_grad():
            for lr_tensor, hr_tensor in tqdm(loader, desc=f"Benchmarking on {name}"):
                lr_tensor = lr_tensor.to(DEVICE)
                sr_tensor = model(lr_tensor).clamp(0, 1).cpu()
                sr_np = sr_tensor.squeeze(0).permute(1, 2, 0).numpy()
                hr_np = hr_tensor.squeeze(0).permute(1, 2, 0).numpy()
                psnr_scores.append(peak_signal_noise_ratio(hr_np, sr_np, data_range=1.0))
                ssim_scores.append(structural_similarity(hr_np, sr_np, multichannel=True, data_range=1.0, channel_axis=2))

        results.append({
            "Dataset": name,
            "PSNR (dB)": f"{np.mean(psnr_scores):.2f}",
            "SSIM": f"{np.mean(ssim_scores):.4f}"
        })

    if results:
        results_df = pd.DataFrame(results)
        print("\n\n--- FINAL BENCHMARK RESULTS ---")
        print(results_df.to_string(index=False))
    else:
        print("\n--- No benchmarks were run. Please check your dataset paths. ---")

# UNCOMMENT THE LINE BELOW TO START BENCHMARKING
benchmark_sr_model()


--- Starting Benchmarks on Best Model ---


Benchmarking on Set14:   0%|          | 0/42 [00:00<?, ?it/s]


RuntimeError: shape '[1, 11, 8, 15, 8, 180]' is invalid for input of size 2025000

Cell 7: 🖼️ Inference on Custom Images

## Cell 7: 🖼️ Inference on Custom Images (Complete)

In [None]:
from google.colab import files
import matplotlib.pyplot as plt
import io

def infer_and_show():
    BEST_MODEL_PATH = os.path.join(Config.BEST_MODEL_DIR, "fusionsr_classical_best.pth")
    if not os.path.exists(BEST_MODEL_PATH):
        print(f"Best model not found at {BEST_MODEL_PATH}. Please train first.")
        return

    model = FusionSR(Config).to(DEVICE)
    model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE))
    model.eval()

    print("Please upload a LOW-RESOLUTION image to upscale:")
    uploaded = files.upload()

    if not uploaded:
        print("\nNo image uploaded. Aborting.")
        return

    file_name = list(uploaded.keys())[0]
    lr_img = Image.open(io.BytesIO(uploaded[file_name])).convert('RGB')

    lr_tensor = transforms.ToTensor()(lr_img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        sr_tensor = model(lr_tensor).clamp(0, 1).cpu()

    sr_img = transforms.ToPILImage()(sr_tensor.squeeze(0))

    fig, axes = plt.subplots(1, 2, figsize=(15, 8))
    axes[0].imshow(lr_img); axes[0].set_title("Original Low-Resolution"); axes[0].axis('off')
    axes[1].imshow(sr_img); axes[1].set_title(f"FusionSR Upscaled (x{Config.UPSCALE_FACTOR})"); axes[1].axis('off')
    plt.show()

    output_filename = f"upscaled_{file_name}"
    sr_img.save(output_filename)
    print(f"\nUpscaled image saved as {output_filename}. Downloading now...")
    files.download(output_filename)

# UNCOMMENT THE LINE BELOW TO RUN INFERENCE
infer_and_show()