# Project: Video Super-Resolution with Progressive GANs

This notebook explores a generative adversarial network (GAN) approach to video super-resolution. My goal is to upscale low-quality video frames to high-quality, leveraging both the spatial detail learned by GANs and temporal information across frames. Due to computational constraints, I'm employing a progressive training strategy and an architecture optimized for efficiency, aiming for a clear proof of concept on a limited dataset.

In [None]:
!pip install gputil

# 1) Data sampeling

In [None]:
import os
import shutil

# Source and destination paths
src_root = '/kaggle/input/rescale-reds-dataset'
dst_root = '/kaggle/working/reds-subsample'

# Make destination directory if it doesn't exist
os.makedirs(dst_root, exist_ok=True)

# List and sort folders to ensure consistency
all_folders = sorted([f for f in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, f))])

# Choose first 10 folders
selected_folders = all_folders[:30]#<---- number of foldrts in the sample 

# Copy folders
for folder in selected_folders:
    src_path = os.path.join(src_root, folder)
    dst_path = os.path.join(dst_root, folder)
    shutil.copytree(src_path, dst_path)

print(f"Copied {len(selected_folders)} folders to {dst_root}")

# 2) Data downScalling

In [None]:
# --- Data Preprocessing Script (Run this ONCE to prepare your dataset) ---
# This part creates the multi-resolution versions of your video frames.
# Ensure '/kaggle/working/reds-subsample' contains your original video frames.
import os
import cv2
from PIL import Image # Import PIL for Image.Resampling

print("--- Starting Data Preprocessing ---")

# # Resolutions to generate: name -> scale factor
res_scales = {
    'high_res': 1.0,
    'med_res': 0.5,   # downsample by 2x
    'low_res': 0.25   # downsample by 4x
}

# # Create destination directories
for res in res_scales:
    os.makedirs(os.path.join(dst_root, res), exist_ok=True)

# # Iterate through each folder and frame
# # Check if src_root exists
if not os.path.exists(src_root):
    print(f"Warning: Source directory '{src_root}' not found. Skipping preprocessing.")
    print("Please ensure your 'reds-subsample' data is correctly placed.")
else:
    for folder in sorted(os.listdir(src_root)):
        src_folder = os.path.join(src_root, folder)
        if not os.path.isdir(src_folder):
            continue

        # Create subdirectories for this video folder in each resolution
        for res, scale in res_scales.items():
            dst_folder = os.path.join(dst_root, res, folder)
            os.makedirs(dst_folder, exist_ok=True)

        for img_file in sorted(os.listdir(src_folder)):
            img_path = os.path.join(src_folder, img_file)
            # Ensure it's a file and a common image extension
            if not os.path.isfile(img_path) or not img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                continue
            
            img = cv2.imread(img_path)
            if img is None:
                print(f"Warning: Could not read image {img_path}. Skipping.")
                continue

            for res, scale in res_scales.items():
                dst_img_path = os.path.join(dst_root, res, folder, img_file)
                if scale == 1.0:
                    resized = img  # keep original
                else:
                    h, w = img.shape[:2]
                    # Ensure dimensions are positive
                    new_w, new_h = int(w * scale), int(h * scale)
                    if new_w <= 0 or new_h <= 0:
                        print(f"Warning: Skipping resize for {img_file} at scale {scale} due to zero or negative dimensions ({new_w}x{new_h}).")
                        continue
                    resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
                cv2.imwrite(dst_img_path, resized)

    print("Multi-resolution dataset created at:", dst_root)
print("--- Data Preprocessing Complete ---")

# --- End of Data Preprocessing Script ---

#-----------------------------------------------------------------------------------------------------------------------------------------------------------


# --- Main Training Code ---


In [None]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageOps # Import Image and ImageOps from PIL
import cv2
import warnings
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from skimage.metrics import structural_similarity as ssim
import math
import GPUtil # For checking GPU utilization

# --- Optimization: Added for Mixed Precision Training ---
from torch.cuda.amp import autocast, GradScaler

# Define 16:9 resolutions that maintain your scaling factors
# These should match the sizes your model expects and the preprocessing creates
LR_SIZE = (128, 72)
MR_SIZE = (256, 144)
HR_SIZE = (512, 288)

class VideoSuperResDataset(Dataset):
    def __init__(self, lr_frame_paths, mr_frame_paths, hr_frame_paths, sequence_length=3):
        """
        Args:
            lr_frame_paths: List of paths to low-res video frames
            mr_frame_paths: List of paths to medium-res video frames
            hr_frame_paths: List of paths to high-res video frames
            sequence_length: Number of consecutive frames to use (odd number)
        """
        if not (len(lr_frame_paths) == len(mr_frame_paths) == len(hr_frame_paths)):
            raise ValueError("All frame path lists must have the same length.")

        self.lr_frames = lr_frame_paths
        self.mr_frames = mr_frame_paths
        self.hr_frames = hr_frame_paths
        self.sequence_length = sequence_length
        
        # Ensure sequence_length is odd
        if self.sequence_length % 2 == 0:
            warnings.warn("Sequence length should ideally be an odd number for centered frame selection.")

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        
        # Verify aspect ratio preservation
        self._verify_aspect_ratios()
        
    def _verify_aspect_ratios(self):
        """Check that all resolutions maintain 16:9 aspect ratio"""
        for w, h in [LR_SIZE, MR_SIZE, HR_SIZE]:
            # Allow a small epsilon for floating point inaccuracies
            assert abs((w/h) - (16/9)) < 0.01, f"Size {w}x{h} is not 16:9 or close enough."
            
    def _load_and_resize(self, path, target_size):
        """Load image and resize while maintaining aspect ratio with padding"""
        try:
            img = Image.open(path).convert("RGB") # Ensure RGB
        except Exception as e:
            print(f"Error loading image {path}: {e}")
            # Return a black image as fallback
            return Image.new("RGB", target_size, (0, 0, 0))

        # Simple resize (will slightly stretch if original wasn't perfect 16:9)
        # Using Image.Resampling.BICUBIC for newer Pillow versions
        img = img.resize(target_size, Image.Resampling.BICUBIC)
        
        return img
    
    def __len__(self):
        # We can only create sequences up to (total_frames - sequence_length + 1)
        # For simplicity, we'll just use the length of HR frames, and handle padding in getitem
        return len(self.hr_frames)
    
    def __getitem__(self, idx):
        # Get sequence of frames centered at idx
        half_seq = self.sequence_length // 2
        
        # Calculate actual indices for the sequence, handling boundaries
        start_idx = idx - half_seq
        end_idx = idx + half_seq
        
        lr_sequence_paths = []
        for i in range(start_idx, end_idx + 1):
            # Pad by repeating first/last frame if out of bounds
            if i < 0:
                lr_sequence_paths.append(self.lr_frames[0])
            elif i >= len(self.lr_frames):
                lr_sequence_paths.append(self.lr_frames[-1])
            else:
                lr_sequence_paths.append(self.lr_frames[i])

        # Load frames for the LR sequence
        lr_sequence = []
        for path in lr_sequence_paths:
            lr_img = self._load_and_resize(path, LR_SIZE)
            lr_sequence.append(self.transform(lr_img))
            
        # Load MR and HR targets (always the frame at current idx)
        mr_img = self._load_and_resize(self.mr_frames[idx], MR_SIZE)
        mr_target = self.transform(mr_img)
        
        hr_img = self._load_and_resize(self.hr_frames[idx], HR_SIZE)
        hr_target = self.transform(hr_img)
        
        # Stack frames along new dimension
        lr_sequence = torch.stack(lr_sequence, dim=0)  # Shape: (seq_len, C, H, W)
        
        return lr_sequence, mr_target, hr_target

# --- Data LOADERS ----

In [None]:


def prepare_data_loaders(batch_size=16, num_workers=4):
    # Assume the preprocessing script has already been run
    dst_root = '/kaggle/working/reds-multiscale' # This should match your preprocessing output

    lr_frames_list = []
    mr_frames_list = []
    hr_frames_list = []

    # Collect paths for all frames across all video folders
    # Assuming subfolders like '000', '001', etc. directly under high_res, med_res, low_res
    video_subfolders_path = os.path.join(dst_root, 'high_res')
    if not os.path.exists(video_subfolders_path):
        raise RuntimeError(f"Multi-resolution dataset not found at '{dst_root}'. Please run the preprocessing script first.")

    video_subfolders = sorted(os.listdir(video_subfolders_path))

    for subfolder in video_subfolders:
        hr_subfolder_path = os.path.join(dst_root, 'high_res', subfolder)
        mr_subfolder_path = os.path.join(dst_root, 'med_res', subfolder)
        lr_subfolder_path = os.path.join(dst_root, 'low_res', subfolder)

        if not os.path.isdir(hr_subfolder_path):
            continue # Skip if it's not a directory

        # Get all image files in this subfolder, sorted numerically
        frame_files = sorted([f for f in os.listdir(hr_subfolder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

        for frame_file in frame_files:
            lr_frames_list.append(os.path.join(lr_subfolder_path, frame_file))
            mr_frames_list.append(os.path.join(mr_subfolder_path, frame_file))
            hr_frames_list.append(os.path.join(hr_subfolder_path, frame_file))

    print(f"Found {len(hr_frames_list)} HR frames across all videos.")
    if len(hr_frames_list) == 0:
        raise RuntimeError(f"No image files found in {video_subfolders_path}. Make sure preprocessing was run and paths are correct.")

    # Split data into training and validation sets
    val_split_ratio = 0.1 # 10% for validation
    num_total_frames = len(hr_frames_list)
    
    # Use a fixed seed for reproducibility of the split
    np.random.seed(42) 
    indices = list(range(num_total_frames))
    np.random.shuffle(indices) # Shuffle to get a random split

    num_val_frames = int(num_total_frames * val_split_ratio)
    train_indices = indices[num_val_frames:]
    val_indices = indices[:num_val_frames]

    # Create lists for train and val datasets
    train_lr_frames = [lr_frames_list[i] for i in train_indices]
    train_mr_frames = [mr_frames_list[i] for i in train_indices]
    train_hr_frames = [hr_frames_list[i] for i in train_indices]

    val_lr_frames = [lr_frames_list[i] for i in val_indices]
    val_mr_frames = [mr_frames_list[i] for i in val_indices]
    val_hr_frames = [hr_frames_list[i] for i in val_indices]

    print(f"Train frames: {len(train_hr_frames)}, Validation frames: {len(val_hr_frames)}")

    train_dataset = VideoSuperResDataset(
        lr_frame_paths=train_lr_frames,
        mr_frame_paths=train_mr_frames,
        hr_frame_paths=train_hr_frames,
        sequence_length=3 # Assuming sequence_length is 3 based on your model
    )
    val_dataset = VideoSuperResDataset(
        lr_frame_paths=val_lr_frames,
        mr_frame_paths=val_mr_frames,
        hr_frame_paths=val_hr_frames,
        sequence_length=3
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    print("Real data loaders created.")
    return train_loader, val_loader



# 4)  Generator Architecture


In [None]:

class VideoSuperResGenerator(nn.Module):
    def __init__(self, scale_factor=2, num_residual_blocks=8):
        super(VideoSuperResGenerator, self).__init__()
        
        # Initial 3D convolutions for temporal processing
        self.temporal_conv1 = nn.Conv3d(3, 32, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.temporal_conv2 = nn.Conv3d(32, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        
        # Switch to 2D with aspect ratio aware padding
        self.conv1 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        
        # Residual blocks
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(num_residual_blocks)]
        )
        
        # Upsampling layers - FIXED: proper channel calculation
        self.up1 = UpsampleBlock(64, 64 * (scale_factor ** 2), scale_factor=scale_factor)
        # After pixel shuffle: 64 * (scale_factor^2) / (scale_factor^2) = 64 channels
        self.conv2 = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        
        # For stage 2 (will be added dynamically)
        self.up2 = None
        self.conv3 = None
        
    def forward(self, x):
        # Input shape: (batch, seq_len, C, H, W)
        x = x.permute(0, 2, 1, 3, 4)  # (batch, C, seq_len, H, W)
        
        # Temporal processing
        x = F.relu(self.temporal_conv1(x))
        x = F.relu(self.temporal_conv2(x))
        
        # Collapse temporal dimension (take middle frame)
        x = x[:, :, x.shape[2]//2, :, :]  # (batch, C, H, W)
        
        # Spatial processing
        x = F.relu(self.conv1(x))
        x = self.residual_blocks(x)
        x = self.up1(x)
        
        # Stage 2 processing (if available)
        if self.up2 is not None and self.conv3 is not None:
            x = self.up2(x)
            x = torch.tanh(self.conv3(x))
        else:
            x = torch.tanh(self.conv2(x))
        
        return x

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels) 
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        return F.relu(x)

class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor):
        super(UpsampleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.shuffle = nn.PixelShuffle(scale_factor)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.shuffle(x)
        return F.relu(x)


#  Discriminator 


In [None]:

class VideoSuperResDiscriminator(nn.Module):
    def __init__(self, input_height, input_width):
        super(VideoSuperResDiscriminator, self).__init__()
        
        self.model = nn.Sequential(
            # Input: 3 x H x W
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(256, 512, kernel_size=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.model(x).view(-1, 1)

# VGG Perceptual Loss


In [None]:


class VGGPerceptualLoss(nn.Module):
    def __init__(self, requires_grad=False):
        super(VGGPerceptualLoss, self).__init__()
        vgg = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.IMAGENET1K_V1).features
        self.slice1 = nn.Sequential()
        self.slice2 = nn.Sequential()
        self.slice3 = nn.Sequential()
        self.slice4 = nn.Sequential()
        
        for x in range(2):
            self.slice1.add_module(str(x), vgg[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg[x])
            
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False
            
    def forward(self, x):
        with autocast():
            h = self.slice1(x)
            h_relu1_1 = h
            h = self.slice2(h)
            h_relu2_1 = h
            h = self.slice3(h)
            h_relu3_1 = h
            h = self.slice4(h)
            h_relu4_1 = h
        return h_relu1_1, h_relu2_1, h_relu3_1, h_relu4_1

def perceptual_loss(vgg, real, fake):
    real_features = vgg(real)
    fake_features = vgg(fake)
    
    loss = 0
    for r, f in zip(real_features, fake_features):
        loss += F.l1_loss(r, f)
    return loss


#  Training Loop


In [None]:
def train_stage_with_tracking(stage, generator, discriminator, train_loader, val_loader, 
                              device, epochs=50, lr=1e-4, save_dir="checkpoints"):
    
    os.makedirs(save_dir, exist_ok=True)
    
    # Track losses
    g_losses = []
    d_losses = []
    val_losses = []
    
    # Optimizers
    g_optim = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    d_optim = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    
    scaler = GradScaler()
    
    try:
        vgg = VGGPerceptualLoss().to(device)
        vgg.eval() 
    except Exception as e:
        print(f"Warning: Could not load VGG model: {e}")
        vgg = None
    
    criterion_adv = nn.BCEWithLogitsLoss()
    criterion_pix = nn.L1Loss()
    
    print(f"Starting training for Stage {stage}...")
    for epoch in range(epochs):
        generator.train()
        discriminator.train()
        
        epoch_g_loss = 0
        epoch_d_loss = 0
        num_batches = 0
        
        for i, (lr_seq, mr_target, hr_target) in enumerate(train_loader):
            lr_seq = lr_seq.to(device)
            target = mr_target.to(device) if stage == 1 else hr_target.to(device)
            
            d_optim.zero_grad()
            with autocast():
                real_pred = discriminator(target)
                real_loss = criterion_adv(real_pred, torch.ones_like(real_pred))
                
                fake = generator(lr_seq)
                fake_pred = discriminator(fake.detach())
                fake_loss = criterion_adv(fake_pred, torch.zeros_like(fake_pred))
                
                d_loss = (real_loss + fake_loss) / 2
            
            scaler.scale(d_loss).backward()
            scaler.step(d_optim)
            
            g_optim.zero_grad()
            with autocast():
                fake = generator(lr_seq)
                fake_pred = discriminator(fake)
                
                adv_loss = criterion_adv(fake_pred, torch.ones_like(fake_pred))
                pix_loss = criterion_pix(fake, target)
                
                if vgg is not None:
                    perc_loss = perceptual_loss(vgg, target, fake)
                    g_loss = 0.1 * adv_loss + 0.8 * perc_loss + 0.1 * pix_loss
                else:
                    g_loss = 0.5 * adv_loss + 0.5 * pix_loss
            
            scaler.scale(g_loss).backward()
            scaler.step(g_optim)
            
            scaler.update()
            
            epoch_g_loss += g_loss.item()
            epoch_d_loss += d_loss.item()
            num_batches += 1
            
            if (i + 1) % 100 == 0:
                print(f"Stage {stage} | Epoch {epoch+1}/{epochs} | Batch {i+1}/{len(train_loader)} | G Loss: {g_loss.item():.4f} | D Loss: {d_loss.item():.4f}")

        # Record losses
        avg_g_loss = epoch_g_loss / num_batches
        avg_d_loss = epoch_d_loss / num_batches
        g_losses.append(avg_g_loss)
        d_losses.append(avg_d_loss)
        
        # Validation
        if (epoch + 1) % 2 == 0:
            generator.eval()
            with torch.no_grad():
                val_pix_loss_sum = 0
                val_num_samples = 0
                for val_i, (val_lr, val_mr, val_hr) in enumerate(val_loader):
                    val_lr = val_lr.to(device)
                    val_target = val_mr.to(device) if stage == 1 else val_hr.to(device)
                    
                    with autocast():
                        val_fake = generator(val_lr)
                        val_pix_loss = criterion_pix(val_fake, val_target)
                    
                    val_pix_loss_sum += val_pix_loss.item() * val_lr.size(0)
                    val_num_samples += val_lr.size(0)

                    if val_i == 0:
                        save_image((val_fake[0] + 1) / 2, f"{save_dir}/stage{stage}_epoch{epoch+1}.png")
                
                avg_val_pix_loss = val_pix_loss_sum / val_num_samples
                val_losses.append(avg_val_pix_loss)
                
                print(f"Stage {stage} | Epoch {epoch+1}/{epochs} | G Loss: {avg_g_loss:.4f} | D Loss: {avg_d_loss:.4f} | Val Loss: {avg_val_pix_loss:.4f}")
            
            # Save checkpoint
            torch.save({
                'epoch': epoch,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'g_optim_state_dict': g_optim.state_dict(),
                'd_optim_state_dict': d_optim.state_dict(),
                'g_losses': g_losses,
                'd_losses': d_losses,
                'val_losses': val_losses,
            }, f"{save_dir}/stage{stage}_checkpoint_epoch{epoch+1}.pth")
    
    return {'g_losses': g_losses, 'd_losses': d_losses, 'val_losses': val_losses}



# Plotting

In [None]:


def plot_training_curves(stage1_losses, stage2_losses, save_path="training_curves.png"):
    """Plot training loss curves for both stages"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Stage 1 losses
    axes[0,0].plot(range(1, len(stage1_losses['g_losses']) + 1), stage1_losses['g_losses'], label='Generator', color='blue')
    axes[0,0].plot(range(1, len(stage1_losses['d_losses']) + 1), stage1_losses['d_losses'], label='Discriminator', color='red')
    axes[0,0].set_title('Stage 1: LR→MR Training Losses')
    axes[0,0].set_xlabel('Epoch')
    axes[0,0].set_ylabel('Loss')
    axes[0,0].legend()
    axes[0,0].grid(True)
    
    # Stage 1 validation
    if 'val_losses' in stage1_losses and stage1_losses['val_losses']:
        val_x_axis = np.arange(1, len(stage1_losses['val_losses']) * 1 + 1, 1) 
        axes[0,1].plot(val_x_axis, stage1_losses['val_losses'], label='Validation L1', color='green')
        axes[0,1].set_title('Stage 1: Validation Loss')
        axes[0,1].set_xlabel('Epoch')
        axes[0,1].set_ylabel('L1 Loss')
        axes[0,1].legend()
        axes[0,1].grid(True)
    
    # Stage 2 losses
    axes[1,0].plot(range(1, len(stage2_losses['g_losses']) + 1), stage2_losses['g_losses'], label='Generator', color='blue')
    axes[1,0].plot(range(1, len(stage2_losses['d_losses']) + 1), stage2_losses['d_losses'], label='Discriminator', color='red')
    axes[1,0].set_title('Stage 2: MR→HR Training Losses')
    axes[1,0].set_xlabel('Epoch')
    axes[1,0].set_ylabel('Loss')
    axes[1,0].legend()
    axes[1,0].grid(True)
    
    # Stage 2 validation
    if 'val_losses' in stage2_losses and stage2_losses['val_losses']:
        val_x_axis = np.arange(1, len(stage2_losses['val_losses']) * 1 + 1, 1)
        axes[1,1].plot(val_x_axis, stage2_losses['val_losses'], label='Validation L1', color='green')
        axes[1,1].set_title('Stage 2: Validation Loss')
        axes[1,1].set_xlabel('Epoch')
        axes[1,1].set_ylabel('L1 Loss')
        axes[1,1].legend()
        axes[1,1].grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show() # <-- UNCOMMENTED THIS LINE!
    plt.close(fig) # Close figure to free memory

def create_comparison_grid(generator, val_loader, device, num_samples=4, save_path="final_results.png"):
    """Create a grid showing LR→MR→HR progression"""
    generator.eval()
    
    fig = plt.figure(figsize=(20, 5 * num_samples))
    gs = GridSpec(num_samples, 4, figure=fig)
    
    with torch.no_grad():
        # Iterate through the validation loader to get samples
        # Use iter(val_loader) and next() to get a few batches
        # instead of the whole loader if num_samples is small
        val_iter = iter(val_loader)
        for i in range(num_samples):
            try:
                lr_seq, mr_target, hr_target = next(val_iter)
            except StopIteration:
                print(f"Not enough samples in val_loader to create {num_samples} comparison grids. Created {i}.")
                break
                
            lr_seq = lr_seq.to(device)
            
            with autocast():
                generated = generator(lr_seq)
            
            # Convert tensors to numpy for plotting
            # Take middle frame for LR (index 1 assuming seq_len=3)
            lr_frame = lr_seq[0, lr_seq.shape[1]//2].cpu().numpy().transpose(1, 2, 0)
            lr_frame = ((lr_frame + 1) / 2).astype(np.float32) # Denormalize and cast to float32
            lr_frame = np.clip(lr_frame, 0, 1) # Ensure valid pixel range
            
            mr_real = mr_target[0].cpu().numpy().transpose(1, 2, 0)
            mr_real = ((mr_real + 1) / 2).astype(np.float32)
            mr_real = np.clip(mr_real, 0, 1)
            
            hr_real = hr_target[0].cpu().numpy().transpose(1, 2, 0)
            hr_real = ((hr_real + 1) / 2).astype(np.float32)
            hr_real = np.clip(hr_real, 0, 1)
            
            generated_img = generated[0].cpu().numpy().transpose(1, 2, 0)
            generated_img = ((generated_img + 1) / 2).astype(np.float32)
            generated_img = np.clip(generated_img, 0, 1)
            
            # Plot row
            ax1 = fig.add_subplot(gs[i, 0])
            ax1.imshow(lr_frame)
            ax1.set_title(f'LR Input {lr_frame.shape[1]}x{lr_frame.shape[0]}')
            ax1.axis('off')
            
            ax2 = fig.add_subplot(gs[i, 1])
            ax2.imshow(mr_real)
            ax2.set_title(f'MR Target {mr_real.shape[1]}x{mr_real.shape[0]}')
            ax2.axis('off')
            
            ax3 = fig.add_subplot(gs[i, 2])
            ax3.imshow(generated_img)
            ax3.set_title(f'Generated {generated_img.shape[1]}x{generated_img.shape[0]}')
            ax3.axis('off')
            
            ax4 = fig.add_subplot(gs[i, 3])
            ax4.imshow(hr_real)
            ax4.set_title(f'HR Target {hr_real.shape[1]}x{hr_real.shape[0]}')
            ax4.axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show() # <-- UNCOMMENTED THIS LINE!
    plt.close(fig) # Close figure to free memory

def calculate_metrics(generator, val_loader, device):
    """Calculate PSNR and SSIM metrics"""
    generator.eval()
    psnr_scores = []
    ssim_scores = []
    
    with torch.no_grad():
        for lr_seq, mr_target, hr_target in val_loader:
            lr_seq = lr_seq.to(device)
            target = hr_target.to(device)
            
            with autocast():
                generated = generator(lr_seq)
            
            # Convert to numpy
            for i in range(generated.shape[0]):
                gen_img = generated[i].cpu().numpy().transpose(1, 2, 0)
                target_img = target[i].cpu().numpy().transpose(1, 2, 0)
                
                # Denormalize
                gen_img = (gen_img + 1) / 2
                target_img = (target_img + 1) / 2
                
                # Ensure images are in [0, 1] range for metrics
                gen_img = np.clip(gen_img, 0, 1)
                target_img = np.clip(target_img, 0, 1)

                # Calculate PSNR
                mse = np.mean((gen_img - target_img) ** 2)
                if mse == 0:
                    psnr = float('inf')
                else:
                    psnr = 20 * math.log10(1.0 / math.sqrt(mse))
                psnr_scores.append(psnr)
                
                # Calculate SSIM (data_range is important for SSIM)
                ssim_score = ssim(target_img, gen_img, data_range=1.0, multichannel=True, channel_axis=2)
                ssim_scores.append(ssim_score)
    
    return np.mean(psnr_scores), np.mean(ssim_scores)



# Updated main function with plotting


In [None]:
def main():
    GPUtil.showUtilization()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # FIXED: Add size constants
    # These are already defined globally for the dataset, but good to keep here for clarity
    # LR_SIZE = (128, 72)
    # MR_SIZE = (256, 144)  
    # HR_SIZE = (512, 288)
    
    # Training parameters (adjusted for potential speedup)
    stage1_epochs = 2 # Decreased for quick testing
    stage2_epochs = 2 # Decreased for quick testing
    
    # --- Optimization: Increased batch_size and num_workers ---
    BATCH_SIZE = 16
    NUM_WORKERS = 4

    # Prepare data loaders (now using your actual dataset logic)
    train_loader, val_loader = prepare_data_loaders(batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
    
    # Verify data loading
    try:
        lr_seq, mr_t, hr_t = next(iter(train_loader))
        print("\nData verification:")
        print(f"LR sequence shape: {lr_seq.shape}")
        print(f"MR target shape: {mr_t.shape}")
        print(f"HR target shape: {hr_t.shape}")
        print(f"Batch size: {lr_seq.shape[0]}")
    except StopIteration:
        print("Error: train_loader is empty. Check your dataset and data loading.")
        return

    # --- Stage 1: LR to MR ---
    print("=== Starting Stage 1: LR to MR ===")
    gen_stage1 = VideoSuperResGenerator(scale_factor=2).to(device)
    disc_stage1 = VideoSuperResDiscriminator(MR_SIZE[1], MR_SIZE[0]).to(device)
    
    stage1_losses = train_stage_with_tracking(
        stage=1,
        generator=gen_stage1,
        discriminator=disc_stage1,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        epochs=stage1_epochs,
        lr=1e-4,
        save_dir="stage1_checkpoints"
    )
    
    # --- Stage 2: MR to HR ---
    print("\n=== Starting Stage 2: MR to HR ===")
    gen_stage2 = VideoSuperResGenerator(scale_factor=2).to(device)
    
    gen_stage2.load_state_dict(gen_stage1.state_dict(), strict=False)
    
    gen_stage2.up2 = UpsampleBlock(64, 64 * 4, scale_factor=2).to(device)
    gen_stage2.conv3 = nn.Conv2d(64, 3, kernel_size=3, padding=1).to(device)
    
    disc_stage2 = VideoSuperResDiscriminator(HR_SIZE[1], HR_SIZE[0]).to(device)
    
    stage2_losses = train_stage_with_tracking(
        stage=2,
        generator=gen_stage2,
        discriminator=disc_stage2,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        epochs=stage2_epochs,
        lr=1e-4,
        save_dir="stage2_checkpoints"
    )
    
    # === FINAL RESULTS AND PLOTTING ===
    print("\n" + "="*50)
    print("GENERATING FINAL RESULTS")
    print("="*50)
    
    # 1. Plot training curves
    print("Plotting training curves...")
    plot_training_curves(stage1_losses, stage2_losses)
    
    # 2. Create comparison grid
    print("Creating comparison grid...")
    create_comparison_grid(gen_stage2, val_loader, device)
    
    # 3. Calculate metrics
    print("Calculating final metrics...")
    try:
        psnr, ssim_score = calculate_metrics(gen_stage2, val_loader, device)
        print(f"Final PSNR: {psnr:.2f} dB")
        print(f"Final SSIM: {ssim_score:.4f}")
    except Exception as e:
        print(f"Could not calculate metrics: {e}")
    
    # 4. Save final models
    print("Saving final models...")
    torch.save(gen_stage2.state_dict(), "final_generator.pth")
    torch.save(disc_stage2.state_dict(), "final_discriminator.pth")
    
    print("\n" + "="*50)
    print("TRAINING COMPLETE!")
    print("Check the following files:")
    print("- training_curves.png (loss plots)")
    print("- final_results.png (LR→MR→HR comparisons)")
    print("- final_generator.pth (trained model)")
    print("="*50)

if __name__ == '__main__':
    main()