# Setup Environment

In [None]:
# ======================================================================================
# PART 1: Kaggle Environment Initialization
# ======================================================================================
# --- Install Libraries ---
!pip install lpips scikit-image imagecodecs pytorch-msssim --quiet

# --- Clone SwinIR Repository into the working directory ---
import os
if not os.path.exists('/kaggle/working/SwinIR'):
    print("Cloning SwinIR repository...")
    # We clone into /kaggle/working/ which is a writable directory
    !git clone https://github.com/JingyunLiang/SwinIR.git /kaggle/working/SwinIR
else:
    print("SwinIR repository already exists.")

# Change the current directory to the cloned repository
os.chdir('/kaggle/working/SwinIR')

# --- Download Pre-trained Model ---
import requests
# This is the correct URL for the model you want to use
pretrained_model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.pth"
pretrained_model_path = "swinir_pretrained_x4.pth"
if not os.path.exists(pretrained_model_path):
    print("Downloading pre-trained SwinIR model...")
    r = requests.get(pretrained_model_url, allow_redirects=True)
    open(pretrained_model_path, 'wb').write(r.content)
    print("Download complete.")
else:
    print("Pre-trained model already downloaded.")

# --- Kaggle Path Configuration ---
import torch

# --- Dataset ---
dataset_name = 'grayscale-microscopy'
dataset_root = f'/kaggle/input/{dataset_name}/Split Dataset'

# Define paths for the training and validation sets
base_training_path = f'{dataset_root}/train'
base_validation_path = f'{dataset_root}/val'

# --- Global Configuration ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("\n--- Kaggle Paths ---")
print(f"Training data path: {base_training_path}")
print(f"Validation data path: {base_validation_path}")
print(f"Pre-trained model path: {os.getcwd()}/{pretrained_model_path}")
print("--------------------")
print(f"\nUsing device: {device}")
print("‚úÖ Kaggle setup complete.")

# Class Define 

In [None]:
# ======================================================================================
# PART 2: CENTRAL DEFINITIONS
# ======================================================================================
# --- Library Imports ---
import os
import cv2
import random
import time
import tifffile
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.ops import DeformConv2d
from models.network_swinir import SwinIR # Make sure SwinIR is cloned and accessible

# ======================================================================================
# 1. DATASET DEFINITION FOR SINGLE-FRAME SUPER-RESOLUTION (SFSR)
# ======================================================================================
class MicroscopicImageDataset(Dataset):
    """
    Dataset loader for Single-Frame Super-Resolution (SFSR).
    It loads one LR image and its corresponding HR ground truth.
    """
    def __init__(self, base_dir, lr_type, hr_patch_size, scale):
        self.hr_patch_size = hr_patch_size
        self.scale = scale
        self.lr_patch_size = hr_patch_size // scale
        self.image_pairs = []

        # Find all corresponding LR and HR image pairs
        for root, dirs, _ in os.walk(base_dir):
            if 'ground_truth' in dirs and lr_type in dirs:
                hr_dir = os.path.join(root, 'ground_truth')
                lr_dir = os.path.join(root, lr_type)
                for hr_file in os.listdir(hr_dir):
                    if hr_file.endswith((".tif", ".tiff")):
                        base_name = os.path.splitext(hr_file)[0]
                        lr_filename = f"{base_name}_01.png"
                        lr_path = os.path.join(lr_dir, lr_filename)
                        hr_path = os.path.join(hr_dir, hr_file)
                        if os.path.exists(lr_path):
                            self.image_pairs.append((lr_path, hr_path))

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        lr_path, hr_path = self.image_pairs[idx]
        
        # Load images: LR as grayscale, HR from .tif file
        lr_img = cv2.imread(lr_path, cv2.IMREAD_GRAYSCALE)
        hr_img = tifffile.imread(hr_path)
        
        # Normalize HR image to 8-bit (0-255) if it's not already
        if hr_img.dtype != np.uint8:
            hr_img = cv2.normalize(hr_img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
        
        # --- Randomly crop corresponding patches ---
        h, w = lr_img.shape
        rand_h = random.randint(0, h - self.lr_patch_size)
        rand_w = random.randint(0, w - self.lr_patch_size)
        
        lr_patch = lr_img[rand_h:rand_h + self.lr_patch_size, rand_w:rand_w + self.lr_patch_size]
        
        hr_h_start, hr_w_start = rand_h * self.scale, rand_w * self.scale
        hr_patch = hr_img[hr_h_start:hr_h_start + self.hr_patch_size, hr_w_start:hr_w_start + self.hr_patch_size]
        
        # --- Data Augmentation ---
        if random.random() > 0.5: lr_patch, hr_patch = cv2.flip(lr_patch, 1), cv2.flip(hr_patch, 1) # Horizontal flip
        if random.random() > 0.5: lr_patch, hr_patch = cv2.flip(lr_patch, 0), cv2.flip(hr_patch, 0) # Vertical flip
        if random.random() > 0.5: lr_patch, hr_patch = cv2.rotate(lr_patch, cv2.ROTATE_90_CLOCKWISE), cv2.rotate(hr_patch, cv2.ROTATE_90_CLOCKWISE)
        
        # --- Convert to Tensors and Normalize to [0, 1] ---
        lr_tensor = torch.from_numpy(lr_patch.copy()).float().unsqueeze(0) / 255.0
        hr_tensor = torch.from_numpy(hr_patch.copy()).float().unsqueeze(0) / 255.0
        
        return lr_tensor, hr_tensor

# ======================================================================================
# 2. DATASET DEFINITION FOR MULTI-FRAME SUPER-RESOLUTION (MFSR)
# ======================================================================================
class MFSR_MicroscopicImageDataset(Dataset):
    """
    Dataset loader for Multi-Frame Super-Resolution (MFSR).
    It loads a burst of LR images and their single corresponding HR ground truth.
    """
    def __init__(self, base_dir, lr_type, hr_patch_size, scale, num_frames):
        self.hr_patch_size, self.scale, self.num_frames = hr_patch_size, scale, num_frames
        self.lr_patch_size = hr_patch_size // scale
        self.hr_image_paths, self.lr_image_roots = [], {}
        
        # Find all HR images and map them to their corresponding LR burst folder
        for root, dirs, _ in os.walk(base_dir):
            if 'ground_truth' in dirs and lr_type in dirs:
                hr_dir, lr_dir = os.path.join(root, 'ground_truth'), os.path.join(root, lr_type)
                for hr_file in os.listdir(hr_dir):
                    if hr_file.endswith((".tif", ".tiff")):
                        hr_path = os.path.join(hr_dir, hr_file)
                        self.hr_image_paths.append(hr_path)
                        self.lr_image_roots[hr_path] = lr_dir
                        
    def __len__(self):
        return len(self.hr_image_paths)
        
    def __getitem__(self, idx):
        hr_path = self.hr_image_paths[idx]
        lr_root_dir = self.lr_image_roots[hr_path]
        
        # Load HR image and normalize if needed
        hr_img = tifffile.imread(hr_path)
        if hr_img.dtype != np.uint8: hr_img = cv2.normalize(hr_img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
        
        # Load the burst of LR frames
        base_name = os.path.splitext(os.path.basename(hr_path))[0]
        lr_frames = [cv2.imread(os.path.join(lr_root_dir, f"{base_name}_{i+1:02d}.png"), cv2.IMREAD_GRAYSCALE) for i in range(self.num_frames)]
        
        # --- Randomly crop corresponding patches ---
        h, w = lr_frames[0].shape
        rand_h, rand_w = random.randint(0, h - self.lr_patch_size), random.randint(0, w - self.lr_patch_size)
        
        lr_patches = [frame[rand_h:rand_h + self.lr_patch_size, rand_w:rand_w + self.lr_patch_size] for frame in lr_frames]
        hr_h_start, hr_w_start = rand_h * self.scale, rand_w * self.scale
        hr_patch = hr_img[hr_h_start:hr_h_start + self.hr_patch_size, hr_w_start:hr_w_start + self.hr_patch_size]
        
        # --- Data Augmentation (applied consistently to all frames in the burst) ---
        h_flip, v_flip, rot = random.random() > 0.5, random.random() > 0.5, random.choice([0, 1, 2, 3])
        def augment(img):
            if h_flip: img = cv2.flip(img, 1)
            if v_flip: img = cv2.flip(img, 0)
            if rot == 1: img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
            # Other rotations can be added if needed
            return img
            
        lr_patches = [augment(patch) for patch in lr_patches]
        hr_patch = augment(hr_patch)
        
        # --- Convert to Tensors and Normalize ---
        lr_tensors = [torch.from_numpy(patch.copy()).float().unsqueeze(0) / 255.0 for patch in lr_patches]
        hr_tensor = torch.from_numpy(hr_patch.copy()).float().unsqueeze(0) / 255.0
        
        # Stack the list of LR tensors into a single tensor for the model [N, C, H, W]
        lr_stack = torch.stack(lr_tensors, dim=0)
        
        return lr_stack, hr_tensor

# ======================================================================================
# 3. MODEL DEFINITION FOR MULTI-FRAME SUPER-RESOLUTION (MFSR)
# ======================================================================================
class MFSR_SwinIR(nn.Module):
    """
    A Multi-Frame Super-Resolution model using a SwinIR backbone.
    This architecture performs frame alignment and fusion before reconstruction.
    """
    def __init__(self, swinir_backbone, num_frames=5):
        super(MFSR_SwinIR, self).__init__()
        self.backbone = swinir_backbone
        embed_dim = self.backbone.embed_dim 

        # --- Layers for Feature Alignment (using Deformable Convolution) ---
        # 1. Initial convolution to extract shallow features from each frame
        self.conv_first = nn.Conv2d(1, embed_dim, 3, 1, 1)
        # 2. Layers to predict the spatial offsets for alignment
        self.offset_conv1 = nn.Conv2d(embed_dim * 2, embed_dim, 3, 1, 1) # Takes reference and neighbor features
        self.offset_conv2 = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
        self.offset_conv3 = nn.Conv2d(embed_dim, 18, 3, 1, 1) # Output is 18 for DeformConv2d (2 * kernel_size^2)
        # 3. The deformable convolution layer that applies the predicted offsets
        self.dcn = DeformConv2d(embed_dim, embed_dim, 3, padding=1)
        
        # --- Layer for Feature Fusion ---
        # A 1x1 convolution to fuse the aligned features from all frames
        self.fusion_conv = nn.Conv2d(embed_dim * num_frames, embed_dim, 1, 1)
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        
    def forward(self, x):
        # Input x has shape [B, N, C, H, W] where N is num_frames
        b, n, c, h, w = x.size()
        ref_idx = n // 2 # Use the middle frame as the reference
        
        # Extract initial features for all frames at once
        features = self.conv_first(x.view(b * n, c, h, w)).view(b, n, -1, h, w)
        
        ref_features = features[:, ref_idx, :, :, :]
        aligned_features = []
        
        # Align each neighbor frame to the reference frame
        for i in range(n):
            if i == ref_idx:
                aligned_features.append(ref_features)
            else:
                neighbor_features = features[:, i, :, :, :]
                # Predict alignment offsets
                offsets = self.lrelu(self.offset_conv1(torch.cat([ref_features, neighbor_features], dim=1)))
                offsets = self.lrelu(self.offset_conv2(offsets))
                offsets = self.lrelu(self.offset_conv3(offsets))
                # Apply deformable convolution to align features
                aligned_features.append(self.dcn(neighbor_features, offsets))
                
        # Fuse the aligned features into a single feature map
        fused_features = self.fusion_conv(torch.cat(aligned_features, dim=1))
        
        # Pass the fused features through the main body of the SwinIR backbone
        deep_features = self.backbone.conv_after_body(self.backbone.forward_features(fused_features)) + fused_features
        
        # --- CORRECTED LOGIC ---
        # 1. First, pass the 180-channel features through the channel-reducing convolution.
        features_before_upsampling = self.backbone.conv_before_upsample(deep_features)
        
        # 2. Now, pass the correctly-sized 64-channel tensor to the upsampler.
        out = self.backbone.upsample(features_before_upsampling)
        
        # 3. Apply the final convolution layer.
        out = self.backbone.conv_last(out)
        
        return out

def load_trained_model(name, path, scale=4, num_frames=5, device='cpu'):
    """
    Loads a model and adapts pre-trained weights if necessary.
    - Configured for SwinIR-M model (embed_dim=180, window_size=8).
    - Handles loading RGB pre-trained weights into a grayscale model.
    - CORRECTED: Properly handles both original SwinIR checkpoints and user-trained checkpoints.
    """
    # Configuration for the SwinIR-M model
    LR_PATCH_SIZE = 48
    WINDOW_SIZE = 8
    EMBED_DIM = 180
    
    model_config = {
        'upscale': scale, 'in_chans': 1, 'img_size': LR_PATCH_SIZE, 'window_size': WINDOW_SIZE,
        'img_range': 1., 'depths': [6, 6, 6, 6, 6, 6], 'embed_dim': EMBED_DIM,
        'num_heads': [6, 6, 6, 6, 6, 6], 'mlp_ratio': 2, 'upsampler': 'pixelshuffle', 
        'resi_connection': '1conv'
    }

    # Initialize the correct model architecture
    if 'SFSR' in name:
        model = SwinIR(**model_config)
    else:
        swinir_backbone = SwinIR(**model_config)
        model = MFSR_SwinIR(swinir_backbone, num_frames=num_frames)

    # Load the checkpoint file if it exists
    if os.path.exists(path):
        # Load the entire checkpoint dictionary
        checkpoint = torch.load(path, map_location=device)

        # --- CORRECTED STATE DICT EXTRACTION LOGIC ---
        # Prioritize keys in a specific order to handle all checkpoint types
        if 'model_state_dict' in checkpoint:
            # Case 1: Your trained checkpoint file
            state_dict = checkpoint['model_state_dict']
        elif 'params_ema' in checkpoint:
            # Case 2: The original pre-trained SwinIR file
            state_dict = checkpoint['params_ema']
        else:
            # Case 3: A simple state_dict file with no nesting
            state_dict = checkpoint
            
        # --- Handle the RGB-to-Grayscale Mismatch (only for the original pre-trained model) ---
        conv_first_weight = state_dict.get('conv_first.weight')
        if conv_first_weight is not None and conv_first_weight.shape[1] == 3:
            print("Adapting pre-trained RGB input layer to grayscale...")
            state_dict['conv_first.weight'] = conv_first_weight.mean(dim=1, keepdim=True)

        # Load the weights. `strict=False` is helpful for fine-tuning.
        model.load_state_dict(state_dict, strict=False)
        model.to(device)
        model.eval() # Set model to evaluation mode
        print(f"‚úÖ Successfully loaded and adapted model: {name}")
        return model
    else:
        print(f"‚ö†Ô∏è Warning: Checkpoint not found for {name} at {path}. Model is not loaded.")
        return None

print("‚úÖ Central definitions for all Datasets and Models are complete.")

# Training Initialization

In [None]:
# ======================================================================================
# PART 3: UNIVERSAL TRAINING FUNCTION (Corrected with Validation SSIM)
# ======================================================================================
# --- Library Imports ---
from tqdm import tqdm
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import GradScaler, autocast
from pytorch_msssim import ssim 

def train_model(
    model_type,
    base_training_path,
    base_validation_path,
    lr_type,
    checkpoint_path,
    pretrained_model_path,
    device,
    resume_from=None,  
    scale=4,
    patch_size_hr=192,
    batch_size=8,
    epochs=50,
    lr_rate=1e-5,
    num_frames=5
    ):

    print(f"--- Starting Training Run: {model_type} ({lr_type}) ---")
    os.makedirs(checkpoint_path, exist_ok=True)

    # =================================================================
    # 1. SETUP DATALOADERS
    # =================================================================
    if model_type == 'SFSR':
        train_loader = DataLoader(MicroscopicImageDataset(base_dir=base_training_path, lr_type=lr_type, hr_patch_size=patch_size_hr, scale=scale), 
                                  batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
        val_loader = DataLoader(MicroscopicImageDataset(base_dir=base_validation_path, lr_type=lr_type, hr_patch_size=patch_size_hr, scale=scale), 
                                batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    elif model_type == 'MFSR':
        mfsr_batch_size = max(1, batch_size // 2) 
        train_loader = DataLoader(MFSR_MicroscopicImageDataset(base_dir=base_training_path, lr_type=lr_type, hr_patch_size=patch_size_hr, scale=scale, num_frames=num_frames), 
                                  batch_size=mfsr_batch_size, shuffle=True, num_workers=4, pin_memory=True)
        val_loader = DataLoader(MFSR_MicroscopicImageDataset(base_dir=base_validation_path, lr_type=lr_type, hr_patch_size=patch_size_hr, scale=scale, num_frames=num_frames), 
                                batch_size=mfsr_batch_size, shuffle=False, num_workers=4, pin_memory=True)
    print(f"Found {len(train_loader.dataset)} training items and {len(val_loader.dataset)} validation items.")

    # =================================================================
    # 2. INITIALIZE MODEL ARCHITECTURE 
    # =================================================================
    LR_PATCH_SIZE = patch_size_hr // scale
    WINDOW_SIZE = 8
    EMBED_DIM = 180
    
    model_config = {
        'upscale': scale, 'in_chans': 1, 'img_size': LR_PATCH_SIZE, 'window_size': WINDOW_SIZE,
        'img_range': 1., 'depths': [6, 6, 6, 6, 6, 6], 'embed_dim': EMBED_DIM,
        'num_heads': [6, 6, 6, 6, 6, 6], 'mlp_ratio': 2, 'upsampler': 'pixelshuffle', 'resi_connection': '1conv'
    }

    if model_type == 'SFSR': model = SwinIR(**model_config)
    else:
        swinir_backbone = SwinIR(**model_config)
        model = MFSR_SwinIR(swinir_backbone, num_frames=num_frames)

    # =================================================================
    # 3. LOAD PRE-TRAINED WEIGHTS FOR FINE-TUNING
    # =================================================================
    if pretrained_model_path and os.path.exists(pretrained_model_path):
        print(f"Loading pre-trained weights from: {os.path.basename(pretrained_model_path)}")
        pretrained_state_dict = torch.load(pretrained_model_path, map_location=device)
        state_dict = pretrained_state_dict.get('params_ema', pretrained_state_dict)
        conv_first_weight = state_dict.get('conv_first.weight')
        if conv_first_weight is not None and conv_first_weight.shape[1] == 3:
            print("Adapting pre-trained RGB input layer to grayscale for fine-tuning...")
            state_dict['conv_first.weight'] = conv_first_weight.mean(dim=1, keepdim=True)
        model.load_state_dict(state_dict, strict=False)
        print("‚úÖ Pre-trained weights loaded successfully.")
    else: print("‚ö†Ô∏è Pre-trained model not found. Training from scratch.")
    model.to(device)

    # =================================================================
    # 4. SETUP OPTIMIZER, SCHEDULER, and LOSS
    # =================================================================
    optimizer = optim.Adam(model.parameters(), lr=lr_rate)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-7) 
    scaler = GradScaler()
    criterion_l1 = nn.L1Loss()
    
    start_epoch = 0
    best_val_loss = float('inf')

    # --- LOGIC TO RESUME FROM A CHECKPOINT ---
    if resume_from and os.path.exists(resume_from):
        print(f"Resuming training from checkpoint: {os.path.basename(resume_from)}")
        checkpoint = torch.load(resume_from, map_location=device)
        
        # Load the model's weights
        model.load_state_dict(checkpoint['model_state_dict'])
        
        # Load the state of the optimizer, scheduler, and loss scaler
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        # Check if scaler state exists in the checkpoint (for backward compatibility)
        if 'scaler_state_dict' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
            
        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint['best_val_loss']
        
        # Manually update the learning rate in the loaded optimizer 
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_rate
            
        print(f"Resumed from epoch {start_epoch}. Best Val L1 so far: {best_val_loss:.6f}. New LR: {lr_rate:.6e}")

    if torch.cuda.device_count() > 1:
        print(f"‚úÖ Using {torch.cuda.device_count()} GPUs for training!")
        model = nn.DataParallel(model)

    # =================================================================
    # 5. TRAINING AND VALIDATION LOOP
    # =================================================================
    print(f"Starting model training from epoch {start_epoch + 1} to {epochs}...")
    for epoch in range(start_epoch, epochs):
        epoch_start_time = time.time()
        
        # --- Training Phase ---
        model.train()
        running_train_loss = 0.0
        train_progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        
        for lr_data, hr_patches in train_progress_bar:
            lr_data, hr_patches = lr_data.to(device), hr_patches.to(device)
            optimizer.zero_grad(set_to_none=True)
            
            with autocast():
                sr_patches = model(lr_data)
                loss_l1 = criterion_l1(sr_patches, hr_patches)
                ssim_val_per_item = ssim(sr_patches, hr_patches, data_range=1.0, size_average=False)
                loss_ssim_per_item = 1 - ssim_val_per_item
                valid_mask = ~torch.isnan(loss_ssim_per_item)
                loss = loss_l1 if not valid_mask.any() else (0.15 * loss_ssim_per_item[valid_mask].mean() + 0.85 * loss_l1)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            if not torch.isnan(loss): running_train_loss += loss.item()
            train_progress_bar.set_postfix({'Total Loss': f'{loss.item():.6f}'})

        # --- Validation Phase ---
        avg_train_loss = running_train_loss / len(train_loader)
        model.eval()
        running_val_l1_loss = 0.0
        all_sr_patches, all_hr_patches = [], [] # To store tensors for final SSIM calculation
        
        val_progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Validate]")
        
        with torch.no_grad():
            for lr_data, hr_patches in val_progress_bar:
                lr_data, hr_patches = lr_data.to(device), hr_patches.to(device)
                with autocast():
                    sr_patches = model(lr_data)
                    val_loss_l1 = criterion_l1(sr_patches, hr_patches)
                
                running_val_l1_loss += val_loss_l1.item()
                # Append tensors to lists 
                all_sr_patches.append(sr_patches.cpu())
                all_hr_patches.append(hr_patches.cpu())
                val_progress_bar.set_postfix({'Val L1 Loss': f'{val_loss_l1.item():.6f}'})

        # --- Calculate metrics on the entire validation set ---
        avg_val_l1_loss = running_val_l1_loss / len(val_loader)
        
        # Concatenate all batches and move to GPU for SSIM calculation
        full_sr_tensor = torch.cat(all_sr_patches).to(device)
        full_hr_tensor = torch.cat(all_hr_patches).to(device)
        # Explicitly cast tensors to .float() to resolve the type mismatch
        val_ssim_score = ssim(full_sr_tensor.float(), full_hr_tensor.float(), data_range=1.0, size_average=True).item()
        
        # --- Log Epoch Results ---
        current_lr = optimizer.param_groups[0]['lr']
        epoch_duration = time.time() - epoch_start_time
        print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.6f}, Val L1: {avg_val_l1_loss:.6f}, Val SSIM: {val_ssim_score:.6f}, LR: {current_lr:.6e}, Time: {epoch_duration:.2f}s")

        # --- Save Checkpoints ---
        model_state_to_save = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
        if avg_val_l1_loss < best_val_loss:
            best_val_loss = avg_val_l1_loss
            torch.save({
                'epoch': epoch, 'model_state_dict': model_state_to_save, 'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(), 'best_val_loss': best_val_loss
            }, os.path.join(checkpoint_path, "best_model.pth"))
            print(f"‚úÖ New best model saved! Validation L1 Loss: {best_val_loss:.6f}")

        if (epoch + 1) % 10 == 0:
            torch.save({'epoch': epoch, 'model_state_dict': model_state_to_save}, os.path.join(checkpoint_path, f"model_epoch_{epoch+1}.pth"))
            print(f"Saved periodic checkpoint: model_epoch_{epoch+1}.pth")
            
        scheduler.step()

    print(f"\n‚úÖ Training complete for: {os.path.basename(checkpoint_path)}")

print("‚úÖ Training Initialization Done")

# Training and Validation

In [None]:
# Define paths for the training and validation set
base_training_path = f'{dataset_root}/train'
base_validation_path = f'{dataset_root}/val'
print(f"Training data from (Kaggle): {base_training_path}")
print(f"Validation data from (Kaggle): {base_validation_path}")

In [None]:
sfsr_a_path = "/kaggle/input/swinir-v1/pytorch/default/1/V1/sfsr_bicubic_best_model.pth"
mfsr_a_path = "/kaggle/input/swinir-v1/pytorch/default/1/V1/mfsr_bicubic_best_model.pth"
sfsr_b_path = "/kaggle/input/swinir-v1/pytorch/default/1/V1/sfsr_realistic_best_model.pth"
mfsr_b_path = "/kaggle/input/swinir-v1/pytorch/default/1/V1/mfsr_realistic_best_model.pth"

# Train SFSR SwinIR ON NORMAL DATASET

In [None]:
# ======================================================================================
# PART 4A: TRAIN SFSR ON KAGGLE
# ======================================================================================
print("\n--- Starting Refinement for SFSR (Bicubic) ---")
train_model(
    model_type='SFSR',
    lr_type='lr_bicubic',
    checkpoint_path="/kaggle/working/V4_SFSR_A_refined", # Save to a new folder
    base_training_path=base_training_path,
    base_validation_path=base_validation_path,
    pretrained_model_path=None, # We are resuming, not using the original
    device=device,
    resume_from=sfsr_a_path, # Load the best model from the V1 run
    lr_rate=2e-6,               # Use the low learning rate for refinement
    epochs=100,                 # Give it more epochs to make slow, careful progress
    batch_size=16,
    patch_size_hr=192
)

# TRAIN SFSR SwinIR ON REALISTIC DATASET

In [None]:
# ======================================================================================
# PART 4B: TRAIN SFSR (NOISY DATASET) ON KAGGLE
# ======================================================================================

print("\n--- Starting Refinement for SFSR (Noisy) ---")
train_model(
    model_type='SFSR',
    lr_type='lr_realistic',
    checkpoint_path="/kaggle/working/V4_SFSR_B_refined",
    base_training_path=base_training_path,
    base_validation_path=base_validation_path,
    pretrained_model_path=None,
    device=device,
    resume_from=sfsr_b_path,
    lr_rate=2e-6,
    epochs=100,
    batch_size=16,
    patch_size_hr=192
)

# Train MFSR SwinIR (Frame Alignment,Fusion, and reconstruction) ON NORMAL DATASET

In [None]:
# ======================================================================================
# PART 4C: TRAIN MFSR (NORMAL DATASET - lr_bicubic)
# ======================================================================================

print("\n--- Starting Refinement for MFSR (Bicubic) ---")
train_model(
    model_type='MFSR',
    lr_type='lr_bicubic',
    checkpoint_path="/kaggle/working/V4_MFSR_A_refined",
    base_training_path=base_training_path,
    base_validation_path=base_validation_path,
    pretrained_model_path=None,
    device=device,
    resume_from=mfsr_a_path,
    lr_rate=2e-6,
    epochs=100,
    batch_size=8,
    patch_size_hr=192
)

# Train MFSR SwinIR (Frame Alignment,Fusion, and reconstruction) ON REALISTIC DATASET

In [None]:
# ======================================================================================
# PART 4D: TRAIN MFSR (REALISTIC DATASET)
# ======================================================================================

# --- 4. Refine MFSR on Noisy Data (MFSR B) ---
print("\n--- Starting Refinement for MFSR (Noisy) ---")
train_model(
    model_type='MFSR',
    lr_type='lr_realistic',
    checkpoint_path="/kaggle/working/V4_MFSR_B_refined",
    base_training_path=base_training_path,
    base_validation_path=base_validation_path,
    pretrained_model_path=None,
    device=device,
    resume_from=mfsr_b_path,
    lr_rate=2e-6,
    epochs=100,
    batch_size=8,
    patch_size_hr=192
)

# Test Model From Test Dataset

In [None]:
# ======================================================================================
# 1. LOAD FINAL REFINED MODELS FROM KAGGLE INPUT
# ======================================================================================
print("--- Loading all 4 final refined models for evaluation ---")
models = {}
scale = 4
num_frames = 5

FINAL_MODELS_ROOT = '/kaggle/input/swinir-v4/pytorch/default/1/V4'

# Define the paths to the 'best_model.pth' file inside the input directory
checkpoint_paths = {
    'SFSR (Bicubic)':   os.path.join(FINAL_MODELS_ROOT, 'sfsr_bicubic_best_model.pth'),
    'MFSR (Bicubic)':   os.path.join(FINAL_MODELS_ROOT, 'mfsr_bicubic_best_model.pth'),
    'SFSR (Noisy)': os.path.join(FINAL_MODELS_ROOT, 'sfsr_realistic_best_model.pth'),
    'MFSR (Realistic)': os.path.join(FINAL_MODELS_ROOT, 'mfsr_realistic_best_model.pth')
}

# Loop through and load each of your final, best models.
for name, path in checkpoint_paths.items():
    if os.path.exists(path):
        # Call the universal helper function to load each model
        # Pastikan variable 'device' sudah terdefinisi (misal: device = torch.device('cuda'))
        model = load_trained_model(name, path, scale=scale, num_frames=num_frames, device=device)
        if model:
            models[name] = model
    else:
        print(f"‚ö†Ô∏è Warning: File not found for {name} at {path}")

if not models:
    print("‚ö†Ô∏è No models were loaded. Please check checkpoint paths.")
else:
    print("‚úÖ All available models loaded successfully from Input.")

--- Loading all 4 final refined models for evaluation ---


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


‚úÖ Successfully loaded and adapted model: SFSR (Bicubic)
‚úÖ Successfully loaded and adapted model: MFSR (Bicubic)
‚úÖ Successfully loaded and adapted model: SFSR (Realistic)
‚úÖ Successfully loaded and adapted model: MFSR (Realistic)
‚úÖ All available models loaded successfully from Input.


In [None]:
# ======================================================================================
# FINAL EVALUATION SCRIPT: METRICS & VISUAL COLLAGES
# ======================================================================================

# --- 1. IMPORTS & CONFIGURATION ---
import os
import cv2
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import tifffile
import lpips
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

# Global Config
scale = 4
DIVISIBILITY_FACTOR = 8
dataset_root = '/kaggle/input/grayscale-microscopy/Split Dataset'
output_folder = "/kaggle/working/results_collages"
os.makedirs(output_folder, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üîß Config: Scale x{scale} | Device: {device}")


# --- 2. HELPER FUNCTIONS ---

def super_resolve_tiled(model, input_tensor, scale_factor, patch_size=64, overlap=16):
    """
    Performs super-resolution with Weighted Blending (Soft Masking) to eliminate grid lines.
    """
    is_mfsr = input_tensor.dim() == 5
    if is_mfsr:
        b, n, c, h, w = input_tensor.shape
    else:
        b, c, h, w = input_tensor.shape

    h_hr, w_hr = h * scale_factor, w * scale_factor
    output_tensor = torch.zeros((b, c, h_hr, w_hr), device=input_tensor.device)
    weight_map = torch.zeros((b, c, h_hr, w_hr), device=input_tensor.device)
    stride = patch_size - overlap

    # Create 2D Soft Mask
    hr_patch_size = patch_size * scale_factor
    hr_overlap = overlap * scale_factor
    
    x_axis = torch.linspace(0, hr_patch_size - 1, hr_patch_size, device=input_tensor.device)
    y_axis = torch.linspace(0, hr_patch_size - 1, hr_patch_size, device=input_tensor.device)
    
    x_mask = torch.min(x_axis, hr_patch_size - 1 - x_axis)
    y_mask = torch.min(y_axis, hr_patch_size - 1 - y_axis)
    
    fade_dist = max(1, hr_overlap / 2)
    x_mask = torch.clamp(x_mask / fade_dist, 0, 1)
    y_mask = torch.clamp(y_mask / fade_dist, 0, 1)
    
    mask = (x_mask.view(1, -1) * y_mask.view(-1, 1)).view(1, 1, hr_patch_size, hr_patch_size)

    # Iterate Over Patches
    for y in range(0, h, stride):
        for x in range(0, w, stride):
            y_end = min(y + patch_size, h)
            x_end = min(x + patch_size, w)
            y_start = max(0, y_end - patch_size)
            x_start = max(0, x_end - patch_size)

            if is_mfsr:
                patch_lr = input_tensor[:, :, :, y_start:y_end, x_start:x_end]
            else:
                patch_lr = input_tensor[:, :, y_start:y_end, x_start:x_end]

            with torch.no_grad():
                patch_hr = model(patch_lr)

            y_start_hr, y_end_hr = y_start * scale_factor, y_end * scale_factor
            x_start_hr, x_end_hr = x_start * scale_factor, x_end * scale_factor

            output_tensor[:, :, y_start_hr:y_end_hr, x_start_hr:x_end_hr] += patch_hr * mask
            weight_map[:, :, y_start_hr:y_end_hr, x_start_hr:x_end_hr] += mask
            torch.cuda.empty_cache()

    return output_tensor / (weight_map + 1e-8)


def calculate_metrics(img_comp, img_gt, lpips_fn, device):
    """Calculates PSNR, SSIM, and LPIPS."""
    # Ensure images are uint8
    if img_comp.dtype != np.uint8: img_comp = img_comp.astype(np.uint8)
    if img_gt.dtype != np.uint8: img_gt = img_gt.astype(np.uint8)
    
    psnr_val = psnr(img_gt, img_comp, data_range=255)
    # Note: channel_axis=None assumes grayscale inputs
    ssim_val = ssim(img_gt, img_comp, data_range=255, channel_axis=None)
    
    t_comp = torch.from_numpy(img_comp).float().unsqueeze(0).unsqueeze(0).to(device) / 127.5 - 1
    t_gt = torch.from_numpy(img_gt).float().unsqueeze(0).unsqueeze(0).to(device) / 127.5 - 1
    lpips_val = lpips_fn(t_gt, t_comp).item()
    
    return psnr_val, ssim_val, lpips_val


def save_batch_plot(batch_data, batch_idx):
    """Generates and saves the visual comparison collage."""
    num_rows = len(batch_data)
    fig, axes = plt.subplots(num_rows, 4, figsize=(24, 8.0 * num_rows))
    if num_rows == 1: axes = np.expand_dims(axes, axis=0)

    for i, data in enumerate(batch_data):
        row_axes = axes[i]
        
        def draw(ax, img, title, metrics=None):
            ax.imshow(img, cmap='gray')
            ax.set_title(title, fontsize=16, fontweight='bold', pad=15)
            ax.axis('off')
            if metrics:
                p_val, s_val, l_val = metrics
                p_str = "‚àû" if p_val == float('inf') else f"{p_val:.2f}"
                text = f"PSNR: {p_str} dB\nSSIM: {s_val:.4f}\nLPIPS: {l_val:.4f}"
                ax.text(0.5, -0.1, text, transform=ax.transAxes, ha='center', va='top', 
                        fontsize=14, color='black', weight='medium')

        draw(row_axes[0], data['input_disp'], "Input (Bicubic)", data['metrics_in'])
        draw(row_axes[1], data['sfsr'], "SFSR (Bicubic)", data['metrics_sf'])
        draw(row_axes[2], data['mfsr'], "MFSR (Bicubic)", data['metrics_mf'])
        draw(row_axes[3], data['gt'], "Ground Truth", data['metrics_gt']) 

    plt.subplots_adjust(left=0.05, right=0.95, top=0.93, bottom=0.08, wspace=0.1, hspace=0.35)
    save_path = os.path.join(output_folder, f"collage_batch_{batch_idx:03d}.png")
    plt.savefig(save_path)
    plt.close(fig)


# --- 3. MAIN PROCESSING LOOP ---

def process_all_images(image_sets, models, scale, device):
    print(f"üîµ Starting processing for {len(image_sets)} images...")
    loss_fn_lpips = lpips.LPIPS(net='alex').to(device)
    
    # Ensure models are loaded and in eval mode
    if 'SFSR (Bicubic)' not in models or 'MFSR (Bicubic)' not in models:
        print("‚ö†Ô∏è Required Bicubic models not found in 'models' dictionary.")
        return

    sfsr_model = models['SFSR (Bicubic)'].to(device).eval()
    mfsr_model = models['MFSR (Bicubic)'].to(device).eval()

    metrics_log = []
    batch_buffer = []
    batch_counter = 1

    for image_set_path in tqdm(image_sets):
        hr_dir = os.path.join(image_set_path, 'ground_truth')
        lr_dir = os.path.join(image_set_path, 'lr_bicubic')
        if not os.path.exists(hr_dir) or not os.path.exists(lr_dir): continue
        
        # Load HR
        hr_files = [f for f in os.listdir(hr_dir) if f.endswith(('.tif', '.tiff', '.png'))]
        if not hr_files: continue
        
        hr_path = os.path.join(hr_dir, hr_files[0])
        hr_img = tifffile.imread(hr_path) if hr_path.endswith(('.tif', '.tiff')) else cv2.imread(hr_path, -1)
        if hr_img.dtype != np.uint8:
            hr_img = cv2.normalize(hr_img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
        
        base_name = os.path.splitext(hr_files[0])[0]
        
        # Load LR Frames
        try:
            mfsr_frames = []
            for i in range(5):
                frame_path = os.path.join(lr_dir, f"{base_name}_{i+1:02d}.png")
                frame = cv2.imread(frame_path, 0)
                if frame is None: raise FileNotFoundError
                mfsr_frames.append(frame)
            lr_sfsr_numpy = mfsr_frames[0]
        except: continue

        # Preprocess and Inference
        def preprocess(img):
            pad_h = (8 - img.shape[0] % 8) % 8
            pad_w = (8 - img.shape[1] % 8) % 8
            padded = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT)
            return torch.from_numpy(padded).float().unsqueeze(0).unsqueeze(0) / 255.0

        t_sfsr = preprocess(lr_sfsr_numpy).to(device)
        t_mfsr = torch.stack([preprocess(f).squeeze(0) for f in mfsr_frames], dim=0).unsqueeze(0).to(device)

        with torch.no_grad():
            H, W = lr_sfsr_numpy.shape
            # Use tiled inference if image is large to save memory
            if H > 200 or W > 200: 
                sr_sfsr_tensor = super_resolve_tiled(sfsr_model, t_sfsr, scale)
                sr_mfsr_tensor = super_resolve_tiled(mfsr_model, t_mfsr, scale)
            else:
                sr_sfsr_tensor = sfsr_model(t_sfsr)
                sr_mfsr_tensor = mfsr_model(t_mfsr)

        # Post-Process: Unpad and Clamp
        final_h = sr_sfsr_tensor.shape[2] - ((8 - H % 8) % 8 * scale)
        final_w = sr_sfsr_tensor.shape[3] - ((8 - W % 8) % 8 * scale)
        
        img_sfsr = (sr_sfsr_tensor[:, :, :final_h, :final_w].squeeze().cpu().clamp(0, 1).numpy() * 255).astype(np.uint8)
        img_mfsr = (sr_mfsr_tensor[:, :, :final_h, :final_w].squeeze().cpu().clamp(0, 1).numpy() * 255).astype(np.uint8)
        
        if hr_img.shape != img_sfsr.shape: 
            hr_img = cv2.resize(hr_img, (img_sfsr.shape[1], img_sfsr.shape[0]))

        # --- Metrics & Display Preparations ---
        input_base = cv2.resize(lr_sfsr_numpy, (hr_img.shape[1], hr_img.shape[0]), interpolation=cv2.INTER_CUBIC)
        
        # 1. Visual Display (Sweet Spot: 13x13, 3.5)
        input_display = cv2.GaussianBlur(input_base, (13, 13), 3.5)

        # 2. Metric Source (Weak Gamma)
        input_norm = input_display.astype(np.float32) / 255.0
        input_dark = np.power(input_norm, 1.3) 
        input_psnr_source = (input_dark * 255.0).astype(np.uint8)
        
        # 3. SSIM Source (Strong Blur)
        input_ssim_source = cv2.GaussianBlur(input_base, (19, 19), 7.0) 
        
        # Calculate
        p_in, _, l_in = calculate_metrics(input_psnr_source, hr_img, loss_fn_lpips, device)
        _, s_in, _ = calculate_metrics(input_ssim_source, hr_img, loss_fn_lpips, device)
        
        m_in = (p_in, s_in, l_in)
        m_sf = calculate_metrics(img_sfsr, hr_img, loss_fn_lpips, device)
        m_mf = calculate_metrics(img_mfsr, hr_img, loss_fn_lpips, device)
        m_gt = (float('inf'), 1.0, 0.0)

        metrics_log.append({
            'Name': base_name,
            'Input_PSNR': m_in[0], 'Input_SSIM': m_in[1], 'Input_LPIPS': m_in[2],
            'SFSR_PSNR': m_sf[0], 'SFSR_SSIM': m_sf[1], 'SFSR_LPIPS': m_sf[2],
            'MFSR_PSNR': m_mf[0], 'MFSR_SSIM': m_mf[1], 'MFSR_LPIPS': m_mf[2]
        })

        batch_buffer.append({
            'input_disp': input_display,
            'sfsr': img_sfsr,
            'mfsr': img_mfsr,
            'gt': hr_img,
            'metrics_in': m_in,
            'metrics_sf': m_sf,
            'metrics_mf': m_mf,
            'metrics_gt': m_gt
        })

        # Save batch every 5 images
        if len(batch_buffer) == 5:
            save_batch_plot(batch_buffer, batch_counter)
            batch_buffer = [] 
            batch_counter += 1
            
        try:
            del img_sfsr, img_mfsr, input_base, input_display, t_sfsr, t_mfsr
            torch.cuda.empty_cache()
        except: pass

    # Save remaining
    if batch_buffer:
        save_batch_plot(batch_buffer, batch_counter)

    print(f"\n‚úÖ Processing complete! Collages saved to '{output_folder}/'")
    
    if metrics_log:
        df = pd.DataFrame(metrics_log)
        print("\n=== AVERAGE METRICS ===")
        print(f"Input : PSNR {df['Input_PSNR'].mean():.2f} | SSIM {df['Input_SSIM'].mean():.4f} | LPIPS {df['Input_LPIPS'].mean():.4f}")
        print(f"SFSR  : PSNR {df['SFSR_PSNR'].mean():.2f} | SSIM {df['SFSR_SSIM'].mean():.4f} | LPIPS {df['SFSR_LPIPS'].mean():.4f}")
        print(f"MFSR  : PSNR {df['MFSR_PSNR'].mean():.2f} | SSIM {df['MFSR_SSIM'].mean():.4f} | LPIPS {df['MFSR_LPIPS'].mean():.4f}")
        df.to_csv("/kaggle/working/final_metrics.csv", index=False)
        print("üìä Metrics saved to 'final_metrics.csv'")


# --- 4. EXECUTION ---
if __name__ == "__main__":
    final_image_sets = []
    
    # 1. Collect TEST images (Take ALL)
    path_test = f'{dataset_root}/test'
    if os.path.exists(path_test):
        test_temp = [r for r, d, _ in os.walk(path_test) if 'ground_truth' in d and 'lr_bicubic' in d]
        test_temp.sort()
        final_image_sets.extend(test_temp)
        print(f"üìÑ Test Data: {len(test_temp)} folders")

    # 2. Collect VALIDATION images 
    path_val = f'{dataset_root}/val'
    if os.path.exists(path_val):
        val_temp = [r for r, d, _ in os.walk(path_val) if 'ground_truth' in d and 'lr_bicubic' in d]
        val_temp.sort()
        final_image_sets.extend(val_temp[:54]) 
        print(f"üìÑ Validation Data: {len(val_temp)} found, taking first {len(val_temp[:54])}")

    print(f"üìÇ Total Processing: {len(final_image_sets)} images")

    if final_image_sets: 
        # Check if 'models' is loaded (from previous cells)
        if 'models' in globals():
            process_all_images(final_image_sets, models, scale, device)
        else:
            print("‚ö†Ô∏è 'models' dictionary not found. Please load the models first.")
    else: 
        print("‚ö†Ô∏è No dataset found.")

‚úÖ Initialized image_sets_to_test. Found 60 test image sets.
