# RFB-ESRGAN for BigEarthNet Super-Resolution

## Setup Instructions

**IMPORTANT:** Run Cell 1 ONCE to install dependencies, then:
1. **Restart the kernel** (Kernel → Restart in Kaggle)
2. **Skip Cell 1** and run directly from Cell 2 onwards

This is necessary to clear cached PyTorch/torchvision versions.

In [None]:
# Cell 1: Imports and Setup

# Install required dependencies
# Uninstall old versions and install compatible PyTorch/torchvision
!pip uninstall -y torch torchvision torchaudio
!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu121
!pip install -q wandb tqdm pillow numpy opencv-python scikit-image

print("✓ All dependencies installed successfully!")
print("⚠️ If you see CUDA version mismatch errors, restart the kernel and run from Cell 2")

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import glob
from tqdm import tqdm
import wandb
from collections import OrderedDict
import time

# WandB setup
wandb.login(key='5424a3d65aac1662f5be82d4439aaac35046689e')
wandb.init(
    project='RFB-ESRGAN-BigEarthNet',
    config={
        'upscale_factor': 16,
        'lr_size': 32,
        'hr_size': 512,
        'batch_size': 8,
        'stage1_epochs': 50,
        'stage2_iters': 100000,
        'stage1_lr': 2e-4,
        'stage2_lr': 1e-4,
        'lambda_pix': 10,
        'lambda_vgg': 1,
        'lambda_adv': 5e-3,
        'num_rrdb': 16,
        'num_rrfdb': 8,
        'ensemble_models': 10
    }
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print(f'PyTorch version: {torch.__version__}')
print(f'Torchvision version: {torchvision.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'CUDA version: {torch.version.cuda}')
    print(f'Number of GPUs available: {torch.cuda.device_count()}')
    for i in range(torch.cuda.device_count()):
        print(f'GPU {i}: {torch.cuda.get_device_name(i)}')

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

Found existing installation: torch 2.6.0+cu118
Uninstalling torch-2.6.0+cu118:
Uninstalling torch-2.6.0+cu118:
  Successfully uninstalled torch-2.6.0+cu118
  Successfully uninstalled torch-2.6.0+cu118
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
  Successfully uninstalled torchaudio-2.6.0+cu124
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m780.5/780.5 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/780.5 MB[0m [31m?[0m eta [36m-:--:--[0m
[

RuntimeError: Detected that PyTorch and torchvision were compiled with different CUDA major versions. PyTorch has CUDA Version=11.8 and torchvision has CUDA Version=12.4. Please reinstall the torchvision that matches your PyTorch install.

In [None]:
# Cell 2: Model Architecture - RFB, RRDB, RRFDB Blocks

class RFB(nn.Module):
    """Receptive Field Block - Multi-scale feature extraction with small kernels"""
    def __init__(self, in_channels=64):
        super(RFB, self).__init__()
        # Branch 1: AvgPool(3) + 1x1 conv + dilated 3x3 (d=1)
        self.branch1 = nn.Sequential(
            nn.AvgPool2d(3, stride=1, padding=1),
            nn.Conv2d(in_channels, 16, 1, 1, 0),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 16, 3, 1, padding=1, dilation=1),
            nn.ReLU(inplace=True)
        )
        
        # Branch 2: AvgPool(5) + 1x1 conv + dilated 3x3 (d=2)
        self.branch2 = nn.Sequential(
            nn.AvgPool2d(5, stride=1, padding=2),
            nn.Conv2d(in_channels, 24, 1, 1, 0),
            nn.ReLU(inplace=True),
            nn.Conv2d(24, 24, 3, 1, padding=2, dilation=2),
            nn.ReLU(inplace=True)
        )
        
        # Branch 3: AvgPool(7) + 1x1 conv + dilated 3x3 (d=3)
        self.branch3 = nn.Sequential(
            nn.AvgPool2d(7, stride=1, padding=3),
            nn.Conv2d(in_channels, 24, 1, 1, 0),
            nn.ReLU(inplace=True),
            nn.Conv2d(24, 24, 3, 1, padding=3, dilation=3),
            nn.ReLU(inplace=True)
        )
        
        # Concat 16+24+24=64 → 1x1 conv to 64
        self.conv_concat = nn.Sequential(
            nn.Conv2d(64, in_channels, 1, 1, 0),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
    def forward(self, x):
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        concat = torch.cat([b1, b2, b3], dim=1)
        out = self.conv_concat(concat)
        return out


class DenseBlock(nn.Module):
    """Dense Block with 5 convolutions (from ESRGAN RRDB)"""
    def __init__(self, nf=64, gc=32):
        super(DenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1)
        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1)
        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1)
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat([x, x1], dim=1)))
        x3 = self.lrelu(self.conv3(torch.cat([x, x1, x2], dim=1)))
        x4 = self.lrelu(self.conv4(torch.cat([x, x1, x2, x3], dim=1)))
        x5 = self.conv5(torch.cat([x, x1, x2, x3, x4], dim=1))
        return x5 * 0.2 + x  # Residual scaling


class RRDB(nn.Module):
    """Residual-in-Residual Dense Block (ESRGAN)"""
    def __init__(self, nf=64, gc=32):
        super(RRDB, self).__init__()
        self.db1 = DenseBlock(nf, gc)
        self.db2 = DenseBlock(nf, gc)
        self.db3 = DenseBlock(nf, gc)

    def forward(self, x):
        out = self.db1(x)
        out = self.db2(out)
        out = self.db3(out)
        return out * 0.2 + x  # Residual scaling


class RRFDB(nn.Module):
    """Residual Receptive Field Dense Block (5 RFBs in dense style)"""
    def __init__(self, nf=64):
        super(RRFDB, self).__init__()
        self.rfb1 = RFB(nf)
        self.rfb2 = RFB(nf)
        self.rfb3 = RFB(nf)
        self.rfb4 = RFB(nf)
        self.rfb5 = RFB(nf)
        # Simple dense connection via addition (simplified from paper)

    def forward(self, x):
        out = self.rfb1(x)
        out = self.rfb2(out)
        out = self.rfb3(out)
        out = self.rfb4(out)
        out = self.rfb5(out)
        return out * 0.2 + x  # Residual scaling


class Generator(nn.Module):
    """RFB-ESRGAN Generator (20.5M params, x16 upscale)"""
    def __init__(self, num_rrdb=16, num_rrfdb=8, nf=64):
        super(Generator, self).__init__()
        # First conv
        self.conv_first = nn.Conv2d(3, nf, 3, 1, 1)
        
        # Trunk-A: 16 RRDBs
        self.trunk_a = nn.Sequential(*[RRDB(nf) for _ in range(num_rrdb)])
        
        # Trunk-RFB: 8 RRFDBs
        self.trunk_rfb = nn.Sequential(*[RRFDB(nf) for _ in range(num_rrfdb)])
        
        # Single RFB before upsampling
        self.rfb_up = RFB(nf)
        
        # Alternating upsampling: inter → sub → inter → sub (x16 total)
        # x2 nearest → x2 pixelshuffle → x2 nearest → x2 pixelshuffle
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),  # x2
            nn.Conv2d(nf, nf * 4, 3, 1, 1),
            nn.PixelShuffle(2),  # x2
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),  # x2
            nn.Conv2d(nf, nf * 4, 3, 1, 1),
            nn.PixelShuffle(2),  # x2
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Final convs
        self.conv_final = nn.Sequential(
            nn.Conv2d(nf, nf, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(nf, 3, 3, 1, 1),
            nn.Tanh()
        )
        
    def forward(self, x):
        feat = self.conv_first(x)
        trunk_a_out = self.trunk_a(feat)
        trunk_rfb_out = self.trunk_rfb(trunk_a_out)
        rfb_out = self.rfb_up(trunk_rfb_out)
        upsampled = self.upsample(rfb_out)
        out = self.conv_final(upsampled)
        return out


class Discriminator(nn.Module):
    """ESRGAN-style Discriminator with spectral norm"""
    def __init__(self, in_channels=3, nf=64):
        super(Discriminator, self).__init__()
        
        def conv_block(in_c, out_c, stride=1, norm=True):
            layers = [nn.Conv2d(in_c, out_c, 3, stride, 1)]
            if norm:
                layers.append(nn.BatchNorm2d(out_c))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return nn.Sequential(*layers)
        
        self.features = nn.Sequential(
            conv_block(in_channels, nf, 1, False),
            conv_block(nf, nf, 2),
            conv_block(nf, nf * 2, 1),
            conv_block(nf * 2, nf * 2, 2),
            conv_block(nf * 2, nf * 4, 1),
            conv_block(nf * 4, nf * 4, 2),
            conv_block(nf * 4, nf * 8, 1),
            conv_block(nf * 8, nf * 8, 2),
        )
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(nf * 8, 1)
        )
        
    def forward(self, x):
        feat = self.features(x)
        out = self.classifier(feat)
        return out

print("Architecture defined successfully!")

In [None]:
# Cell 3: Loss Functions

class VGGPerceptualLoss(nn.Module):
    """VGG19 conv3_4 perceptual loss (L_VGG)"""
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        vgg = torchvision.models.vgg19(pretrained=True).features
        self.vgg_layers = nn.Sequential(*list(vgg.children())[:16])  # Up to conv3_4
        for param in self.vgg_layers.parameters():
            param.requires_grad = False
        self.vgg_layers.eval()
        
        # ImageNet normalization
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
        
    def forward(self, sr, hr):
        # Normalize from [-1,1] to ImageNet range
        sr = (sr + 1) / 2  # [0,1]
        hr = (hr + 1) / 2
        sr = (sr - self.mean) / self.std
        hr = (hr - self.mean) / self.std
        
        sr_feat = self.vgg_layers(sr)
        hr_feat = self.vgg_layers(hr)
        return F.l1_loss(sr_feat, hr_feat)


class GANLoss(nn.Module):
    """Relativistic GAN loss from ESRGAN"""
    def __init__(self):
        super(GANLoss, self).__init__()
        
    def forward(self, d_real, d_fake, is_disc=False):
        if is_disc:
            # Discriminator loss: L_D = -E[log(Delta_Real)] - E[1-log(Delta_Fake)]
            delta_real = torch.sigmoid(d_real - d_fake.mean())
            delta_fake = torch.sigmoid(d_fake - d_real.mean())
            loss_real = -torch.log(delta_real + 1e-8).mean()
            loss_fake = -torch.log(1 - delta_fake + 1e-8).mean()
            return loss_real + loss_fake
        else:
            # Generator adversarial loss: L_adv = -E[log(1-Delta_Real)] - E[log(Delta_Fake)]
            delta_real = torch.sigmoid(d_real - d_fake.mean())
            delta_fake = torch.sigmoid(d_fake - d_real.mean())
            loss = -torch.log(1 - delta_real + 1e-8).mean() - torch.log(delta_fake + 1e-8).mean()
            return loss


def compute_generator_loss(sr, hr, d_real, d_fake, vgg_loss_fn, gan_loss_fn, lambda_pix=10, lambda_vgg=1, lambda_adv=5e-3):
    """Total generator loss: L_G = λ*L_pix + L_VGG + η*L_adv"""
    l_pix = F.l1_loss(sr, hr)
    l_vgg = vgg_loss_fn(sr, hr)
    l_adv = gan_loss_fn(d_real, d_fake, is_disc=False)
    
    total_loss = lambda_pix * l_pix + lambda_vgg * l_vgg + lambda_adv * l_adv
    
    return total_loss, l_pix, l_vgg, l_adv

print("Loss functions defined successfully!")

In [None]:
# Cell 4: Dataset and DataLoader for BigEarthNet

class BigEarthNetDataset(Dataset):
    """BigEarthNet dataset for super-resolution"""
    def __init__(self, data_dir, hr_size=512, lr_size=32, transform=None):
        self.data_dir = data_dir
        self.hr_size = hr_size
        self.lr_size = lr_size
        
        # Find all image files (adjust pattern based on BigEarthNet structure)
        self.image_paths = glob.glob(os.path.join(data_dir, '**/*.jpg'), recursive=True) + \
                          glob.glob(os.path.join(data_dir, '**/*.png'), recursive=True) + \
                          glob.glob(os.path.join(data_dir, '**/*.tif'), recursive=True)
        
        print(f"Found {len(self.image_paths)} images in {data_dir}")
        
        # Augmentation - Random flips and 90-degree rotations
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomChoice([
                transforms.RandomRotation([0, 0]),
                transforms.RandomRotation([90, 90]),
                transforms.RandomRotation([180, 180]),
                transforms.RandomRotation([270, 270]),
            ])
        ]) if transform is None else transform
        
        self.to_tensor = transforms.ToTensor()
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        
        try:
            # Load image
            img = Image.open(img_path).convert('RGB')
            
            # Random crop to hr_size
            w, h = img.size
            if w < self.hr_size or h < self.hr_size:
                # Resize if image is too small
                img = img.resize((max(w, self.hr_size), max(h, self.hr_size)), Image.BICUBIC)
                w, h = img.size
            
            left = np.random.randint(0, w - self.hr_size + 1)
            top = np.random.randint(0, h - self.hr_size + 1)
            hr_img = img.crop((left, top, left + self.hr_size, top + self.hr_size))
            
            # Apply augmentation
            hr_img = self.transform(hr_img)
            
            # Create LR via bicubic downsampling
            lr_img = hr_img.resize((self.lr_size, self.lr_size), Image.BICUBIC)
            
            # Convert to tensor and normalize to [-1, 1]
            hr_tensor = self.to_tensor(hr_img) * 2 - 1
            lr_tensor = self.to_tensor(lr_img) * 2 - 1
            
            return lr_tensor, hr_tensor
            
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            # Return random tensors on error
            return torch.randn(3, self.lr_size, self.lr_size), torch.randn(3, self.hr_size, self.hr_size)


# Setup datasets - BigEarthNet V2 S2 paths
BIGEARTHNET_DIR = '/kaggle/input/bigearthnetv2-s2-4/BigEarthNet-S2'
LABEL_DIR = '/kaggle/input/label-indices'

# Train/Val/Test CSVs for reference
TRAIN_CSV = os.path.join(LABEL_DIR, 'train.csv')
VAL_CSV = os.path.join(LABEL_DIR, 'val.csv')
TEST_CSV = os.path.join(LABEL_DIR, 'test.csv')

# If running locally for testing, use a local path
if not os.path.exists(BIGEARTHNET_DIR):
    print("⚠️ BigEarthNet dataset not found. Please update paths for local testing.")
    BIGEARTHNET_DIR = './bigearthnet_sample'  # Fallback for local testing

train_dataset = BigEarthNetDataset(BIGEARTHNET_DIR, hr_size=512, lr_size=32)
val_dataset = BigEarthNetDataset(BIGEARTHNET_DIR, hr_size=512, lr_size=32)

train_loader = DataLoader(train_dataset, batch_size=wandb.config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

In [None]:
# Cell 5: Training Loop (Stage 1: PSNR + Stage 2: GAN)

def train_stage1(generator, train_loader, val_loader, epochs=50, lr=2e-4):
    """Stage 1: PSNR-oriented training with L1 loss"""
    print("\n" + "="*50)
    print("STAGE 1: PSNR-ORIENTED TRAINING")
    print("="*50)
    
    optimizer = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.9, 0.99))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    
    generator.train()
    
    for epoch in range(epochs):
        epoch_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for lr_img, hr_img in pbar:
            lr_img = lr_img.to(device)
            hr_img = hr_img.to(device)
            
            optimizer.zero_grad()
            sr_img = generator(lr_img)
            loss = F.l1_loss(sr_img, hr_img)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            pbar.set_postfix({'L1 Loss': f"{loss.item():.4f}"})
        
        avg_loss = epoch_loss / len(train_loader)
        scheduler.step()
        
        # Validation
        generator.eval()
        val_psnr = 0
        with torch.no_grad():
            for lr_img, hr_img in val_loader:
                lr_img = lr_img.to(device)
                hr_img = hr_img.to(device)
                sr_img = generator(lr_img)
                mse = F.mse_loss(sr_img, hr_img)
                psnr = 10 * torch.log10(4 / mse)  # Range [-1,1] → max=2, so 4
                val_psnr += psnr.item()
        val_psnr /= len(val_loader)
        generator.train()
        
        wandb.log({
            'stage1/epoch': epoch + 1,
            'stage1/train_loss': avg_loss,
            'stage1/val_psnr': val_psnr,
            'stage1/lr': optimizer.param_groups[0]['lr']
        })
        
        print(f"Epoch {epoch+1}: Train Loss={avg_loss:.4f}, Val PSNR={val_psnr:.2f} dB")
    
    # Save Stage 1 checkpoint - handle DataParallel wrapper
    model_state = generator.module.state_dict() if isinstance(generator, nn.DataParallel) else generator.state_dict()
    torch.save(model_state, 'generator_stage1.pth')
    print("✓ Stage 1 complete. Model saved to generator_stage1.pth")


def train_stage2(generator, discriminator, train_loader, val_loader, iterations=100000, lr=1e-4):
    """Stage 2: GAN training with perceptual losses"""
    print("\n" + "="*50)
    print("STAGE 2: GAN TRAINING")
    print("="*50)
    
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.9, 0.99))
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.9, 0.99))
    
    # LR decay schedule
    milestones = [50000, 100000]
    scheduler_g = torch.optim.lr_scheduler.MultiStepLR(optimizer_g, milestones=milestones, gamma=0.5)
    scheduler_d = torch.optim.lr_scheduler.MultiStepLR(optimizer_d, milestones=milestones, gamma=0.5)
    
    vgg_loss_fn = VGGPerceptualLoss().to(device)
    gan_loss_fn = GANLoss()
    
    generator.train()
    discriminator.train()
    
    saved_models = []  # Track top-10 models for ensemble
    iter_count = 0
    
    while iter_count < iterations:
        for lr_img, hr_img in train_loader:
            if iter_count >= iterations:
                break
                
            lr_img = lr_img.to(device)
            hr_img = hr_img.to(device)
            
            # ========== Train Discriminator ==========
            optimizer_d.zero_grad()
            sr_img = generator(lr_img).detach()
            d_real = discriminator(hr_img)
            d_fake = discriminator(sr_img)
            loss_d = gan_loss_fn(d_real, d_fake, is_disc=True)
            loss_d.backward()
            optimizer_d.step()
            
            # ========== Train Generator ==========
            optimizer_g.zero_grad()
            sr_img = generator(lr_img)
            d_real = discriminator(hr_img).detach()
            d_fake = discriminator(sr_img)
            
            # Total generator loss: L_G = L_pix + λ_vgg*L_vgg + λ_adv*L_adv
            l_pix = F.l1_loss(sr_img, hr_img)
            l_vgg = vgg_loss_fn(sr_img, hr_img)
            l_adv = gan_loss_fn(d_real, d_fake, is_disc=False)
            
            loss_g = (
                wandb.config.lambda_pix * l_pix +
                wandb.config.lambda_vgg * l_vgg +
                wandb.config.lambda_adv * l_adv
            )
            loss_g.backward()
            optimizer_g.step()
            
            scheduler_g.step()
            scheduler_d.step()
            iter_count += 1
            
            # Logging
            if iter_count % 100 == 0:
                wandb.log({
                    'stage2/iteration': iter_count,
                    'stage2/loss_g': loss_g.item(),
                    'stage2/loss_d': loss_d.item(),
                    'stage2/l_pix': l_pix.item(),
                    'stage2/l_vgg': l_vgg.item(),
                    'stage2/l_adv': l_adv.item(),
                })
                print(f"Iter {iter_count}: G={loss_g.item():.4f}, D={loss_d.item():.4f}, Pix={l_pix.item():.4f}")
            
            # Save model every 5k iterations for ensemble - handle DataParallel wrapper
            if iter_count % 5000 == 0:
                model_path = f'generator_iter_{iter_count}.pth'
                model_state = generator.module.state_dict() if isinstance(generator, nn.DataParallel) else generator.state_dict()
                torch.save(model_state, model_path)
                saved_models.append(model_path)
                print(f"✓ Saved checkpoint: {model_path}")
    
    print(f"\n✓ Stage 2 complete. {len(saved_models)} checkpoints saved for ensemble.")
    return saved_models


def ensemble_models(generator, model_paths, top_k=10):
    """Average parameters of top-k models"""
    print(f"\nCreating ensemble from top-{top_k} models...")
    
    # Select top-k models (simple: use last k models, or evaluate each)
    selected_models = model_paths[-top_k:] if len(model_paths) >= top_k else model_paths
    
    # Average state dicts
    ensemble_state = OrderedDict()
    for path in selected_models:
        state = torch.load(path, map_location=device)
        for key in state:
            if key not in ensemble_state:
                ensemble_state[key] = state[key].clone()
            else:
                ensemble_state[key] += state[key]
    
    for key in ensemble_state:
        ensemble_state[key] /= len(selected_models)
    
    # Load into generator - handle DataParallel wrapper
    if isinstance(generator, nn.DataParallel):
        generator.module.load_state_dict(ensemble_state)
    else:
        generator.load_state_dict(ensemble_state)
    torch.save(ensemble_state, 'generator_ensemble.pth')
    print(f"✓ Ensemble model saved to generator_ensemble.pth")


# Initialize models
generator = Generator(
    num_rrdb=wandb.config.num_rrdb,
    num_rrfdb=wandb.config.num_rrfdb
).to(device)

discriminator = Discriminator().to(device)

# Wrap models with DataParallel to use multiple GPUs
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs with DataParallel!")
    generator = nn.DataParallel(generator)
    discriminator = nn.DataParallel(discriminator)
else:
    print("Using single GPU")

print(f"Generator params: {sum(p.numel() for p in generator.parameters())/1e6:.2f}M")
print(f"Discriminator params: {sum(p.numel() for p in discriminator.parameters())/1e6:.2f}M")

# Execute training
start_time = time.time()

# Stage 1: PSNR training
train_stage1(generator, train_loader, val_loader, 
             epochs=wandb.config.stage1_epochs, 
             lr=wandb.config.stage1_lr)

# Stage 2: GAN training
saved_models = train_stage2(generator, discriminator, train_loader, val_loader,
                            iterations=wandb.config.stage2_iters,
                            lr=wandb.config.stage2_lr)

# Ensemble top models
ensemble_models(generator, saved_models, top_k=wandb.config.ensemble_models)

total_time = (time.time() - start_time) / 3600
print(f"\n{'='*50}")
print(f"✓ Training complete! Total time: {total_time:.2f} hours")
print(f"{'='*50}")

wandb.log({'total_training_hours': total_time})

In [None]:
# Cell 6: Inference and Evaluation

def evaluate_model(generator, val_loader, num_samples=10):
    """Evaluate model and visualize results"""
    generator.eval()
    
    total_psnr = 0
    total_ssim = 0
    sample_count = 0
    
    print("\nEvaluating model...")
    
    with torch.no_grad():
        for i, (lr_img, hr_img) in enumerate(val_loader):
            if sample_count >= num_samples:
                break
                
            lr_img = lr_img.to(device)
            hr_img = hr_img.to(device)
            
            start = time.time()
            sr_img = generator(lr_img)
            inference_time = time.time() - start
            
            # PSNR
            mse = F.mse_loss(sr_img, hr_img)
            psnr = 10 * torch.log10(4 / mse)
            total_psnr += psnr.item()
            
            # Simple SSIM approximation (for demo; use proper library for accuracy)
            # ssim = ... (would need skimage or pytorch-msssim)
            
            sample_count += 1
            
            # Log samples to wandb
            if i < 5:  # Log first 5 samples
                lr_grid = (lr_img[0].cpu() + 1) / 2  # Denormalize
                sr_grid = (sr_img[0].cpu() + 1) / 2
                hr_grid = (hr_img[0].cpu() + 1) / 2
                
                wandb.log({
                    f'samples/sample_{i}': [
                        wandb.Image(lr_grid, caption='LR Input'),
                        wandb.Image(sr_grid, caption='SR Output'),
                        wandb.Image(hr_grid, caption='HR Ground Truth')
                    ]
                })
    
    avg_psnr = total_psnr / sample_count
    
    print(f"\nEvaluation Results:")
    print(f"  Average PSNR: {avg_psnr:.2f} dB")
    print(f"  Inference time: {inference_time*1000:.2f} ms/image")
    
    wandb.log({
        'eval/psnr': avg_psnr,
        'eval/inference_time_ms': inference_time * 1000
    })
    
    return avg_psnr


# Load best model and evaluate - handle DataParallel wrapper
ensemble_state = torch.load('generator_ensemble.pth')
if isinstance(generator, nn.DataParallel):
    generator.module.load_state_dict(ensemble_state)
else:
    generator.load_state_dict(ensemble_state)

final_psnr = evaluate_model(generator, val_loader, num_samples=20)

print("\n" + "="*50)
print(f"FINAL RESULTS: PSNR = {final_psnr:.2f} dB")
print("="*50)
print("\n✓ All tasks complete! Check WandB dashboard for detailed metrics and visualizations.")
print(f"   WandB Project: {wandb.run.project}")
print(f"   Run URL: {wandb.run.url}")

wandb.finish()