# cell 1

In [4]:
# Step 1: Install correct dependencies
print("Installing dependencies...")
!pip install --upgrade pip setuptools wheel --no-cache-dir
!pip install -q torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121 --no-cache-dir --no-build-isolation
!pip install -q lpips --no-deps --no-cache-dir
!pip install -q basicsr facexlib gfpgan --no-cache-dir --no-build-isolation
!pip install -q wandb umap-learn scikit-image rasterio pandas
print("Dependencies installed successfully.")

# Step 2: Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, Subset, Dataset
import numpy as np
import wandb
import os
import copy
import math
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from skimage.metrics import structural_similarity as ssim
from PIL import Image
from pathlib import Path
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
import umap
import rasterio
import pandas as pd
import glob
import json
import logging
import time  # Added for timing
import gc  # Added for memory
from torch.nn.utils import clip_grad_norm_  # Added for clipping

# Step 3: Login to WandB
try:
    wandb.login(key="5424a3d65aac1662f5be82d4439aaac35046689e")
    print("W&B login successful.")
except Exception as e:
    print(f"W&B login failed: {e}. Please log in manually.")
    wandb.login()

# Step 4: Setup Devices and Config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpu_count = torch.cuda.device_count()
print(f"Using device: {device}" + (f" with {gpu_count} GPUs (DataParallel)." if gpu_count > 1 else "."))

config = {
    'project_name': 'SR-AL-pipeline-train-0003',
    'dataset_size': 75000,  # Subset for 12h limit
    'sr_epochs_psnr': 8,
    'sr_epochs_gan': 10,  # Reduced for stability
    'batch_size': 8,  # OOM-safe
    'accum_steps': 8,  # Higher for fewer steps
    'al_cycles': 4,
    'al_epochs': 10,
    'num_classes': 19,  # BigEarthNet-19
    'lambda_perc': 5.0,  # Lower for less memory
    'g_lr': 1e-4,
    'd_lr': 1e-4,
    'sr_psnr_lr': 2e-4
}

# Step 5: WandB Init
wandb.init(project=config['project_name'], config=config)
print(f"Setup complete. WandB run '{wandb.run.name}' started.")

Installing dependencies...




Dependencies installed successfully.
W&B login successful.
Using device: cuda with 2 GPUs (DataParallel).


Setup complete. WandB run 'flowing-wave-10' started.


# **cell 2**

In [6]:
# Dataset: BigEarthNet v2 S2-4 (Single-band TIFFs per patch; load RGB as B02,B03,B04 + Multi-Labels)
import numpy as np
from PIL import Image  # For PIL augs
from scipy.ndimage import zoom  # For NumPy bicubic downsample
import pandas as pd  # For metadata
from skimage import exposure  # For histogram eq (contrast boost)

image_root_path = '/kaggle/input/bigearthnetv2-s2-4/'  # Your path
ALL_TIF_PATHS = glob.glob(os.path.join(image_root_path, '**/*.tif'), recursive=True)

print(f"Found {len(ALL_TIF_PATHS)} band files in {image_root_path}.")

# Extract unique patch IDs (e.g., "S2B_MSIL2A_20180421T100029_N9999_R122_T33TWM_00_00" from filename)
patch_to_bands = {}
for path in ALL_TIF_PATHS:
    fname = os.path.basename(path)
    if fname.endswith('.tif'):
        # Parse: everything before _Bxx.tif
        if '_B' in fname:
            patch_id = '_'.join(fname.split('_B')[:-1])
            band = fname.split('_B')[-1].split('.')[0]  # '02', '03', etc.
            if patch_id not in patch_to_bands:
                patch_to_bands[patch_id] = {}
            patch_to_bands[patch_id][band] = path

# Sample 75k patches with all 3 RGB bands (B02, B03, B04)
valid_patches = [pid for pid, bands in patch_to_bands.items() 
                 if all(b in bands for b in ['02', '03', '04'])]
valid_patches = valid_patches[:config['dataset_size']]  # Subset if more

print(f"Valid patches with RGB bands: {len(valid_patches)} (target {config['dataset_size']})")

if len(valid_patches) < config['dataset_size'] // 2:
    print("Warning: Fewer valid patches than expected; check dataset for missing bands.")

# Load Metadata for Multi-Labels (19 classes)
metadata_path = os.path.join(image_root_path, 'metadata.parquet')  # Standard for BigEarthNet v2
if os.path.exists(metadata_path):
    df = pd.read_parquet(metadata_path)
    # Assume 'patch_id' and 'labels' columns (list of int 0-18)
    patch_to_label = {}
    for _, row in df.iterrows():
        pid = row['patch_id']  # Adjust column if different (e.g., 'patch_filename')
        labels_list = row['labels'] if isinstance(row['labels'], list) else []  # List of class IDs
        multi_hot = np.zeros(config['num_classes'])
        for lbl in labels_list:
            if 0 <= lbl < config['num_classes']:
                multi_hot[lbl] = 1.0
        if pid in valid_patches:
            patch_to_label[pid] = torch.tensor(multi_hot, dtype=torch.float32)
    print(f"Loaded labels for {len(patch_to_label)} patches from metadata.")
else:
    print("Warning: metadata.parquet not found; using dummy multi-labels.")
    patch_to_label = {pid: torch.full((config['num_classes'],), 1.0 / config['num_classes'], dtype=torch.float32) for pid in valid_patches}  # Uniform dummy

# Spatial Aug Transforms (for PIL only)
spatial_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=90),  # Multiples of 90Â° for orthoimages
])

# Tensor Transforms (post-spatial, for HWC NumPy)
tensor_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # To [-1,1] for tanh output
])

# SR Dataset Class (PIL for spatial augs on HR â†’ NumPy downsample to LR â†’ Tensor + Label)
class SRDataset(Dataset):
    def __init__(self, patch_ids, scale=4, patch_to_label=None):
        self.patch_ids = patch_ids
        self.scale = scale
        self.patch_to_label = patch_to_label or {}
    
    def __len__(self):
        return len(self.patch_ids)
    
    def __getitem__(self, idx):
        patch_id = self.patch_ids[idx]
        bands = {
            'B02': patch_to_bands[patch_id].get('02'),  # Blue
            'B03': patch_to_bands[patch_id].get('03'),  # Green
            'B04': patch_to_bands[patch_id].get('04')   # Red
        }
        
        # Label (multi-hot)
        label = self.patch_to_label.get(patch_id, torch.zeros(config['num_classes'], dtype=torch.float32))
        
        try:
            if not all(bands.values()):
                raise ValueError("Missing RGB band")
            
            # Read bands (120x120 uint16 â†’ float [0,1])
            b02 = rasterio.open(bands['B02']).read(1, masked=True).astype(np.float32) / 10000.0
            b03 = rasterio.open(bands['B03']).read(1, masked=True).astype(np.float32) / 10000.0
            b04 = rasterio.open(bands['B04']).read(1, masked=True).astype(np.float32) / 10000.0
            
            # Stack to (H,W,3) RGB; fill masked (NaN/clouds) with 0
            hr_np = np.stack([b04, b03, b02], axis=-1)  # RGB order (B04=Red, B03=Green, B02=Blue)
            if np.ma.is_masked(hr_np):
                hr_np = hr_np.filled(0.0)
            hr_np = np.clip(hr_np, 0, 1)  # [0,1]
            
            # FIX: Boost contrast with adaptive histogram eq (per-channel for dark Sentinel-2)
            hr_np = exposure.equalize_adapthist(hr_np)
            hr_np = np.clip(hr_np, 0, 1)
            
            # Convert HR NumPy to PIL Image (HWC uint8)
            hr_pil = Image.fromarray((hr_np * 255).astype(np.uint8))
            
            # Apply spatial augs to PIL HR (random state per sample)
            hr_aug_pil = spatial_transforms(hr_pil)
            
            # Back to NumPy [0,1] HWC
            hr_aug_np = np.array(hr_aug_pil).astype(np.float32) / 255.0
            hr_aug_np = np.clip(hr_aug_np, 0, 1)
            
            # LR: Downsample augmented HR NumPy to x4 smaller (bicubic via zoom)
            # Transpose to CHW for zoom (operates on spatial dims)
            hr_chw = hr_aug_np.transpose(2, 0, 1)  # (3,120,120)
            lr_chw = zoom(hr_chw, (1, 1/self.scale, 1/self.scale), order=3)  # (3,30,30)
            lr_np = lr_chw.transpose(1, 2, 0)  # Back to HWC (30,30,3)
            lr_np = np.clip(lr_np, 0, 1)
            
            # To Tensor + Normalize: Pass HWC NumPy directly (ToTensor handles transpose)
            hr = tensor_transforms(hr_aug_np)
            lr = tensor_transforms(lr_np)
            
            return {'lr': lr, 'hr': hr, 'label': label}
        
        except Exception as e:
            # Skip bad patch; return black dummy as normalized Tensor + zero label
            print(f"Error loading patch {patch_id}: {e}")
            dummy_size_hr, dummy_size_lr = 120, 30
            # HWC [0,1] black
            dummy_black_hr = np.zeros((dummy_size_hr, dummy_size_hr, 3))
            dummy_black_lr = np.zeros((dummy_size_lr, dummy_size_lr, 3))
            # Direct to transforms (no transpose)
            dummy_hr = tensor_transforms(dummy_black_hr)
            dummy_lr = tensor_transforms(dummy_black_lr)
            return {'lr': dummy_lr, 'hr': dummy_hr, 'label': torch.zeros(config['num_classes'], dtype=torch.float32)}

# Split and Loaders (80/20, random for multi-label)
train_ids, val_ids = train_test_split(valid_patches, test_size=0.2, random_state=42)
train_ds = SRDataset(train_ids, patch_to_label=patch_to_label)
val_ds = SRDataset(val_ids, patch_to_label=patch_to_label)
train_loader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=config['batch_size'], shuffle=False, num_workers=2, pin_memory=True)

print(f"Train patches: {len(train_ds)}, Val: {len(val_ds)}")
print("Dataset readyâ€”PIL-based fixed transforms + Multi-Labels from metadata + Contrast Boost.")

Found 347244 band files in /kaggle/input/bigearthnetv2-s2-4/.
Valid patches with RGB bands: 28937 (target 75000)
Train patches: 23149, Val: 5788
Dataset readyâ€”PIL-based fixed transforms + Multi-Labels from metadata + Contrast Boost.


In [7]:
# Test single item
sample = train_ds[0]
print(f"Shapes: LR {sample['lr'].shape}, HR {sample['hr'].shape}, Range LR: [{sample['lr'].min():.2f}, {sample['lr'].max():.2f}]")
# Test batch
batch = next(iter(train_loader))
print(f"Batch shapes: LR {batch['lr'].shape}, HR {batch['hr'].shape}")

Shapes: LR torch.Size([3, 30, 30]), HR torch.Size([3, 120, 120]), Range LR: [-1.00, 0.90]


  out[idx, 0] = 2.0 + (arr[idx, 2] - arr[idx, 0]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]


Batch shapes: LR torch.Size([8, 3, 30, 30]), HR torch.Size([8, 3, 120, 120])


# *cell 3*

In [8]:
# Models: RFB-ESRGAN Generator, RelDiscriminator, PerceptualLoss
# (Based on ESRGAN; assume growth=32 for lightweight)

class ResidualDenseBlock(nn.Module):
    def __init__(self, nc=64, growth=32):
        super().__init__()
        self.conv1 = nn.Conv2d(nc, growth, 3, 1, 1)
        self.conv2 = nn.Conv2d(nc + growth, growth, 3, 1, 1)
        self.conv3 = nn.Conv2d(nc + 2*growth, growth, 3, 1, 1)
        self.conv4 = nn.Conv2d(nc + 3*growth, growth, 3, 1, 1)
        self.conv5 = nn.Conv2d(nc + 4*growth, nc, 3, 1, 1)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat([x, x1], 1)))
        x3 = self.lrelu(self.conv3(torch.cat([x, x1, x2], 1)))
        x4 = self.lrelu(self.conv4(torch.cat([x, x1, x2, x3], 1)))
        x5 = self.conv5(torch.cat([x, x1, x2, x3, x4], 1))
        # FIX: Light anti-aliasing blur to reduce checkerboard in residuals
        x5 = F.avg_pool2d(x5, kernel_size=3, stride=1, padding=1)
        return x5 * 0.2 + x  # Dense residual

class RFBESRGANGenerator(nn.Module):
    def __init__(self, growth=32, num_blocks=23, nc=64, upscale=4):
        super().__init__()
        self.entry = nn.Sequential(nn.Conv2d(3, nc, 3, 1, 1), nn.LeakyReLU(0.2, inplace=True))
        self.body = nn.ModuleList([ResidualDenseBlock(nc, growth) for _ in range(num_blocks)])
        self.conv_tail = nn.Conv2d(nc, nc, 3, 1, 1)
        # FIX: Bilinear upsample + Conv (replaces TransposedConv to avoid checkerboard artifacts)
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(nc, nc//2, 3, 1, 1), nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(nc//2, 3, 3, 1, 1)
        )

    def forward(self, x):
        x = self.entry(x)
        res = x
        for block in self.body:
            x = block(x)
        x = self.conv_tail(x)
        x += res  # Global skip
        x = self.up(x)
        return torch.tanh(x)  # [-1,1] range

class RelDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 64, 4, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, 4, 1, 0)  # PatchGAN output
        )

    def forward(self, x):
        return self.net(x)

class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = torch.hub.load('pytorch/vision:v0.10.0', 'vgg19', pretrained=True).features[:35].eval()
        for p in vgg.parameters():
            p.requires_grad = False
        self.vgg = vgg
        self.l1 = nn.L1Loss()
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))

    def forward(self, sr, hr):
        # Ensure inputs on same device as module
        sr = sr.to(self.vgg.device)
        hr = hr.to(self.vgg.device)
        sr_vgg = (self.vgg((sr + 1)/2 * self.std + self.mean) + 1)/2
        hr_vgg = (self.vgg((hr + 1)/2 * self.std + self.mean) + 1)/2
        return self.l1(sr_vgg, hr_vgg)

# PSNR/SSIM Helpers (FIX: Smaller win_size for small images)
def psnr(sr, hr, max_val=1.0):
    mse = F.mse_loss(sr, hr)
    return 20 * math.log10(max_val / math.sqrt(mse))

def compute_ssim(sr, hr):
    # FIX: win_size=3 for 30x30 LR; channel_axis=-1 for HWC RGB
    return ssim(sr.permute(1,2,0).cpu().numpy(), hr.permute(1,2,0).cpu().numpy(), 
                multichannel=True, channel_axis=-1, data_range=1.0, win_size=3)

print("Models defined.")

Models defined.


# cell 4

In [9]:
# --------------------------------------------------------------------------------
# SECTION 4: SR MODEL TRAINING
# (Two-stage: PSNR + GAN, with fixes: clipping, mem mgmt, error handling)
# --------------------------------------------------------------------------------
print(f"Cell 4 using: {device} with {gpu_count} GPUs.")

# Instantiate Models
g = RFBESRGANGenerator(growth=32).to(device)
if gpu_count > 1:
    g = nn.DataParallel(g)
wandb.watch(g, log="all", log_freq=100)

d = RelDiscriminator().to(device)
if gpu_count > 1:
    d = nn.DataParallel(d)
wandb.watch(d, log="all", log_freq=100)

perc_loss = PerceptualLoss().to(device)  # FIX: Move to device to resolve CPU/GPU mismatch
print("Models instantiated.")

# --------------------------------------------------------------------------------
# HELPER FUNCTIONS
# --------------------------------------------------------------------------------
def save_model(model, path):
    """Saves model state_dict, handling DataParallel."""
    try:
        state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
        torch.save(state_dict, path)
        wandb.save(path, policy="now")
        print(f"Saved {path}")
    except Exception as e:
        print(f"Save error: {e}"); wandb.log({"error": str(e)})

def log_sr_samples(g, val_loader, epoch, phase):
    g.eval()
    with torch.no_grad():
        try:
            batch = next(iter(val_loader))
        except StopIteration:
            print("Could not get validation batch for logging.")
            return
           
        lr_sample, hr_sample = batch['lr'][:4].to(device), batch['hr'][:4].to(device)
       
        model_to_run = g.module if gpu_count > 1 else g
        with autocast():
            sr_sample = model_to_run(lr_sample)
       
        fig, axes = plt.subplots(4, 3, figsize=(12, 16))
        for i in range(4):
            # FIX: Denorm [-1,1] to [0,1] before permute/clamp
            lr_img = (lr_sample[i] + 1) / 2
            sr_img = (sr_sample[i] + 1) / 2
            hr_img = (hr_sample[i] + 1) / 2
            
            lr_img_np = lr_img.permute(1,2,0).cpu().clamp(0,1).float().numpy()
            sr_img_np = sr_img.permute(1,2,0).cpu().clamp(0,1).float().numpy()
            hr_img_np = hr_img.permute(1,2,0).cpu().clamp(0,1).float().numpy()
           
            axes[i,0].imshow(lr_img_np); axes[i,0].set_title('LR'); axes[i,0].axis('off')
            axes[i,1].imshow(sr_img_np); axes[i,1].set_title('SR'); axes[i,1].axis('off')
            axes[i,2].imshow(hr_img_np); axes[i,2].set_title('HR'); axes[i,2].axis('off')
           
        plt.suptitle(f"{phase} Samples - Epoch {epoch}")
        wandb.log({f"SR_{phase}/samples": wandb.Image(fig)})
        plt.close(fig)

# PSNR Pretraining (unchanged, stable)
def train_psnr(g, loader, epochs, lr=2e-4):
    opt = optim.Adam(g.parameters(), lr=lr)
    scaler = GradScaler()
    l1 = nn.L1Loss().to(device)
    sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    for ep in range(epochs):
        start_time = time.time()
        g.train()
        tot_psnr, tot_ssim, tot_loss, num_samples = 0, 0, 0, 0
        pbar = tqdm(loader, desc=f"PSNR Ep {ep+1}/{epochs}")
        for i, batch in enumerate(pbar):
            lr_imgs, hr_imgs = batch['lr'].to(device), batch['hr'].to(device)
            batch_size = lr_imgs.shape[0]
            opt.zero_grad(set_to_none=True)
            with autocast():
                sr = g(lr_imgs)
                loss = l1(sr, hr_imgs)
            scaler.scale(loss).backward()
            if (i + 1) % config['accum_steps'] == 0:
                scaler.step(opt)
                scaler.update()
                opt.zero_grad()
               
            tot_loss += loss.item() * batch_size
            current_psnr = psnr(sr, hr_imgs)
            tot_psnr += current_psnr * batch_size
           
            try:
                for b in range(batch_size):
                    tot_ssim += compute_ssim(sr[b], hr_imgs[b])
            except Exception as e:
                pass
           
            num_samples += batch_size
            pbar.set_postfix({"Loss": f"{loss.item():.4f}", "PSNR": f"{current_psnr:.2f}"})
            
            # Mem cleanup
            if i % 10 == 0:
                torch.cuda.empty_cache()
                gc.collect()
       
        avg_loss = tot_loss / num_samples
        avg_psnr = tot_psnr / num_samples
        avg_ssim = tot_ssim / num_samples if num_samples > 0 else 0
        sched.step()
       
        epoch_time = time.time() - start_time
        wandb.log({"SR_PSNR/epoch": ep, "SR_PSNR/loss": avg_loss, "SR_PSNR/psnr": avg_psnr, 
                   "SR_PSNR/ssim": avg_ssim, "SR_PSNR/epoch_time": epoch_time})
       
        if (ep + 1) % 5 == 0 or ep == epochs - 1:
            log_sr_samples(g, val_loader, ep, "PSNR")
            save_model(g, f'/kaggle/working/g_psnr_ep{ep}.pth')
   
    return g

# GAN Fine-Tuning (with fixes: clipping, mem, error handling, reduced val)
def train_gan(g, d, loader, epochs, g_lr=1e-4, lambda_perc=10.0):
    g_opt = optim.Adam(g.parameters(), lr=g_lr)
    d_opt = optim.Adam(d.parameters(), lr=config['d_lr'])
    scaler_g = GradScaler()
    scaler_d = GradScaler()
    sched_g = optim.lr_scheduler.CosineAnnealingLR(g_opt, T_max=epochs)
    sched_d = optim.lr_scheduler.CosineAnnealingLR(d_opt, T_max=epochs)
    adv = nn.BCEWithLogitsLoss().to(device)
    l1 = nn.L1Loss().to(device)
   
    for ep in range(epochs):
        start_time = time.time()
        g.train(); d.train()
        tot_g_loss, tot_d_loss, num_samples_g, num_samples_d = 0, 0, 0, 0
        pbar = tqdm(loader, desc=f"GAN Ep {ep+1}/{epochs}")
        
        # Mem log
        if torch.cuda.is_available():
            print(f"Epoch {ep}: GPU mem {torch.cuda.memory_allocated()/1e9:.1f}GB")
        
        for i, batch in enumerate(pbar):
            lr_imgs, hr_imgs = batch['lr'].to(device), batch['hr'].to(device)
            batch_size = lr_imgs.shape[0]
           
            # --- Train Discriminator ---
            try:
                d_opt.zero_grad(set_to_none=True)
                with autocast():
                    real_pred = d(hr_imgs)
                    fake = g(lr_imgs)
                    fake_pred = d(fake.detach())
               
                    d_loss_real = adv(real_pred - fake_pred.mean(), torch.ones_like(real_pred))
                    d_loss_fake = adv(fake_pred - real_pred.mean(), torch.zeros_like(fake_pred))
                    d_loss = (d_loss_real + d_loss_fake) / 2
                    d_loss = d_loss / config['accum_steps']
               
                scaler_d.scale(d_loss).backward()
                clip_grad_norm_(d.parameters(), max_norm=1.0)  # FIX: Clip
                tot_d_loss += d_loss.item() * config['accum_steps'] * batch_size
                num_samples_d += batch_size
            except RuntimeError as e:
                if "out of memory" in str(e):
                    print(f"OOM in D at batch {i}"); torch.cuda.empty_cache(); continue
                raise
           
            # --- Train Generator ---
            try:
                g_opt.zero_grad(set_to_none=True)
                with autocast():
                    fake = g(lr_imgs)
                    real_pred_g = d(hr_imgs).detach()
                    fake_pred_g = d(fake)
               
                    g_adv = (adv(fake_pred_g - real_pred_g.mean(), torch.ones_like(fake_pred_g)) +
                             adv(real_pred_g.mean() - fake_pred_g, torch.zeros_like(real_pred_g))) / 2
               
                    g_perc = perc_loss(fake, hr_imgs)
                    g_l1 = l1(fake, hr_imgs)
                    g_loss = 0.001 * g_adv + lambda_perc * g_perc + g_l1
                    g_loss = g_loss / config['accum_steps']
                scaler_g.scale(g_loss).backward()
                clip_grad_norm_(g.parameters(), max_norm=1.0)  # FIX: Clip
                tot_g_loss += g_loss.item() * config['accum_steps'] * batch_size
                num_samples_g += batch_size
            except RuntimeError as e:
                if "out of memory" in str(e):
                    print(f"OOM in G at batch {i}"); torch.cuda.empty_cache(); continue
                raise
           
            # --- Optimizer Step ---
            if (i + 1) % config['accum_steps'] == 0:
                scaler_d.step(d_opt); scaler_d.update(); d_opt.zero_grad()
                scaler_g.step(g_opt); scaler_g.update(); g_opt.zero_grad()
           
            pbar.set_postfix({"G_Loss": f"{g_loss.item()*config['accum_steps']:.4f}", "D_Loss": f"{d_loss.item()*config['accum_steps']:.4f}"})
            
            # Mem cleanup every 5 batches
            if i % 5 == 0:
                torch.cuda.empty_cache()
                gc.collect()
       
        avg_g_loss = tot_g_loss / num_samples_g if num_samples_g > 0 else 0
        avg_d_loss = tot_d_loss / num_samples_d if num_samples_d > 0 else 0
        sched_g.step(); sched_d.step()
        
        epoch_time = time.time() - start_time
       
        # --- Quick Validation (reduced: 2 batches, skip SSIM if error) ---
        avg_psnr, avg_ssim, val_images = 0, 0, 0
        with torch.no_grad():
            torch.cuda.empty_cache()
            val_iter = iter(val_loader)
            num_batches = 2  # FIX: Reduced
            for _ in range(num_batches):
                try:
                    batch = next(val_iter)
                    lr_v, hr_v = batch['lr'].to(device), batch['hr'].to(device)
                    val_batch_size = lr_v.shape[0]
                    with autocast():
                        sr_v = g(lr_v) if gpu_count == 1 else g.module(lr_v)
                   
                    avg_psnr += psnr(sr_v, hr_v) * val_batch_size
                   
                    # SSIM with error skip
                    try:
                        for b in range(val_batch_size):
                            sr_b = sr_v[b]
                            hr_b = hr_v[b]
                            avg_ssim += compute_ssim(sr_b, hr_b)
                    except Exception as e:
                        print(f"Val SSIM error: {e}"); avg_ssim += 0
                   
                    val_images += val_batch_size
                except StopIteration:
                    break
                except RuntimeError as e:
                    if "out of memory" in str(e):
                        print("OOM in val"); break
                    raise
       
        if val_images > 0:
            avg_psnr /= val_images
            avg_ssim /= val_images
       
        wandb.log({
            "SR_GAN/epoch": ep, "SR_GAN/g_loss": avg_g_loss, "SR_GAN/d_loss": avg_d_loss,
            "SR_GAN/psnr": avg_psnr, "SR_GAN/ssim": avg_ssim, "SR_GAN/epoch_time": epoch_time
        })
       
        if (ep + 1) % 5 == 0 or ep == epochs - 1:
            log_sr_samples(g, val_loader, ep, "GAN")
            save_model(g, f'/kaggle/working/g_gan_ep{ep}.pth')
   
    return g

# --------------------------------------------------------------------------------
# EXECUTION
# --------------------------------------------------------------------------------
print("Starting PSNR pre-training...")
g = train_psnr(g, train_loader, config['sr_epochs_psnr'])
print("PSNR pre-training finished.")


print("\nStarting GAN fine-tuning...")
g = train_gan(g, d, train_loader, config['sr_epochs_gan'], g_lr=config['g_lr'], lambda_perc=config['lambda_perc'])
print("GAN fine-tuning finished.")

Cell 4 using: cuda with 2 GPUs.


Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 548M/548M [00:02<00:00, 203MB/s]  


Models instantiated.
Starting PSNR pre-training...


PSNR Ep 1/8:   0%|          | 0/2894 [00:00<?, ?it/s]

  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]


PSNR Ep 2/8:   0%|          | 0/2894 [00:00<?, ?it/s]

  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]


PSNR Ep 3/8:   0%|          | 0/2894 [00:00<?, ?it/s]

  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]


PSNR Ep 4/8:   0%|          | 0/2894 [00:00<?, ?it/s]

  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]


PSNR Ep 5/8:   0%|          | 0/2894 [00:00<?, ?it/s]

  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]


Saved /kaggle/working/g_psnr_ep4.pth


PSNR Ep 6/8:   0%|          | 0/2894 [00:00<?, ?it/s]

  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]


PSNR Ep 7/8:   0%|          | 0/2894 [00:00<?, ?it/s]

  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]


PSNR Ep 8/8:   0%|          | 0/2894 [00:00<?, ?it/s]

  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]
  out[idx, 0] = 4.0 + (arr[idx, 0] - arr[idx, 1]) / delta[idx]


KeyboardInterrupt: 

[1;34mwandb[0m: 
[1;34mwandb[0m: ðŸš€ View run [33mwandering-sky-1[0m at: [34mhttps://wandb.ai/hegdesudarshan-hegde/SR-AL-pipeline-train-0003/runs/8jy2jd5m[0m
[1;34mwandb[0m: Find logs at: [1;35mwandb/run-20251114_132829-8jy2jd5m/logs[0m


# cell 5

In [None]:
# Save Final SR Model
save_model(g, '/kaggle/working/sr_model.pth')
print("SR Model saved & logged to WandB.")

# Quick Eval: PSNR/SSIM on full val set
g.eval()
tot_psnr, tot_ssim, num = 0, 0, 0
with torch.no_grad():
    for batch in tqdm(val_loader, desc="Final Eval"):
        lr, hr = batch['lr'].to(device), batch['hr'].to(device)
        with autocast():
            sr = g(lr) if gpu_count == 1 else g.module(lr)
        tot_psnr += psnr(sr, hr) * lr.shape[0]
        try:
            for b in range(lr.shape[0]):
                tot_ssim += compute_ssim(sr[b], hr[b])
        except:
            pass
        num += lr.shape[0]

print(f"Final Val: PSNR={tot_psnr/num:.2f}, SSIM={tot_ssim/num:.4f}")
wandb.log({"final_psnr": tot_psnr/num, "final_ssim": tot_ssim/num})
wandb.finish()
print("Run complete!")

Starting final evaluation...


NameError: name 'evaluate_model' is not defined

# data Prep

In [2]:
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
import torchvision.transforms as transforms
import torch.nn.functional as F
from tqdm.auto import tqdm
import os
import torch
from torchvision.datasets import ImageFolder
from PIL import Image

# --------------------------------------------------------------------------------
# 1. DATASET CLASS FOR EUROSAT
# --------------------------------------------------------------------------------
class EuroSAT_SR_Dataset(Dataset):
    def __init__(self, full_dataset, indices, scale=4, transform_hr=None, transform_lr=None, phase='train'):
        self.full_dataset = full_dataset
        self.indices = indices # The list of master indices this dataset should use
        self.scale = scale
        self.transform_hr = transform_hr
        self.transform_lr = transform_lr
        self.phase = phase
        
        # Base transforms
        self.to_tensor = transforms.ToTensor()
        self.hr_resize = transforms.Resize((120, 120), interpolation=transforms.InterpolationMode.BICUBIC)
        self.lr_resize = transforms.Resize((30, 30), interpolation=transforms.InterpolationMode.BICUBIC)
        self.blur = transforms.GaussianBlur(kernel_size=3, sigma=0.5)

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        # Get the master index (e.g., 42_000) from our subset list
        master_idx = self.indices[idx]
        
        # Get the PIL image and integer label from the base ImageFolder dataset
        hr_img_pil, label = self.full_dataset[master_idx]
        
        # Apply HR augmentation (flips/rotations) if in SR training
        if self.phase == 'train_sr' and self.transform_hr:
            hr_img_pil = self.transform_hr(hr_img_pil)
        
        # Create LR image
        if self.phase == 'train_sr' and self.transform_lr:
            # For training, apply color jitter *to the HR image* before downscaling
            lr_img_pil = self.transform_lr(hr_img_pil)
        else:
            # For validation/AL, just use the original PIL image
            lr_img_pil = hr_img_pil
        
        # Apply final transforms
        hr_tensor = self.to_tensor(self.hr_resize(hr_img_pil))
        lr_tensor = self.to_tensor(self.blur(self.lr_resize(lr_img_pil)))

        return {'lr': lr_tensor, 'hr': hr_tensor, 'label': torch.tensor(label, dtype=torch.long), 'idx': master_idx}

# --------------------------------------------------------------------------------
# 2. DATASET LOADING AND SPLITTING
# --------------------------------------------------------------------------------

# (FIXED) This is the correct path for the Kaggle EuroSAT dataset
# MAKE SURE YOU HAVE ADDED THE `eurosat` DATASET TO YOUR KAGGLE NOTEBOOK
root_path = config['dataset_root']
assert os.path.exists(root_path), f"Root path missing: {root_path}. Please add the EuroSAT dataset via '+ Add data'."

# Load the base dataset using ImageFolder. This *automatically* gets the labels.
base_dataset = ImageFolder(root=root_path)
all_class_names = base_dataset.classes # e.g., ['AnnualCrop', 'Forest', ...]
all_labels = base_dataset.targets     # e.g., [0, 5, 2, 1, 0, 9, ...]
print(f"Found {len(base_dataset)} images in {len(all_class_names)} classes.")
print(f"Classes: {all_class_names}")

# Update config with the *actual* number of classes
if len(all_class_names) != config['num_classes']:
    print(f"WARNING: Config expected {config['num_classes']} classes, but found {len(all_class_names)}. Updating config.")
    config['num_classes'] = len(all_class_names)

all_indices = list(range(len(all_labels)))

# (CRITICAL FIX) Stratify on the *real* labels from the dataset
train_idx, val_idx = train_test_split(
    all_indices, 
    train_size=0.8, 
    random_state=42, 
    stratify=all_labels # This fixes all your label bugs
)

# Define transforms
sr_lr_transform = transforms.Compose([
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    # Note: Resize and Blur are handled *inside* the dataset class
])
sr_hr_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15)
])
val_lr_transform = None
val_hr_transform = None

# Create Datasets
train_ds = EuroSAT_SR_Dataset(base_dataset, train_idx, scale=4, lr_transform=sr_lr_transform, hr_transform=sr_hr_transform, phase='train_sr')
val_ds = EuroSAT_SR_Dataset(base_dataset, val_idx, scale=4, phase='val')

# This dataset will be used by the Active Learning loop (no augs)
al_base_dataset = EuroSAT_SR_Dataset(base_dataset, train_idx, scale=4, phase='al')
val_loader_al_dataset = EuroSAT_SR_Dataset(base_dataset, val_idx, scale=4, phase='al')


# Loaders
train_loader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2, pin_memory=True)

wandb.log({"dataset/train_samples": len(train_ds), "dataset/val_samples": len(val_ds), "dataset/classes": len(all_class_names)})

print(f"EuroSAT loaded: {len(train_ds)} train, {len(val_ds)} val samples")
print("Sample batch keys:", next(iter(train_loader)).keys())
print("Sample LR shape:", next(iter(train_loader))['lr'].shape)
print("Sample HR shape:", next(iter(train_loader))['hr'].shape)

# Store global variables for other cells
# We need these for the AL loop in Cell 6
al_train_indices = train_idx
al_val_indices = val_idx
al_all_labels = all_labels
al_class_names = all_class_names

KeyError: 'dataset_root'

In [None]:
# --------------------------------------------------------------------------------
# SECTION 5: CLASSIFIER & ACTIVE LEARNING DEFINITIONS
# --------------------------------------------------------------------------------

# (NEW) RobustClassifier with SE (ResNet-inspired)
class SEBlock(nn.Module):
    def __init__(self, c, r=16):
        super().__init__()
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), # Squeeze
            nn.Flatten(),
            nn.Linear(c, c // r, bias=False), 
            nn.ReLU(inplace=True), 
            nn.Linear(c // r, c, bias=False), 
            nn.Sigmoid()
        )
    def forward(self, x):
        y = self.fc(x)
        return x * y.view(x.shape[0], x.shape[1], 1, 1) # Excitation

class BasicBlock(nn.Module):
    """ResNet-style block with SE"""
    def __init__(self, in_c, out_c, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.se = SEBlock(out_c)
        
        self.shortcut = nn.Sequential()
        if stride > 1 or in_c != out_c:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_c, out_c, 1, stride, bias=False),
                nn.BatchNorm2d(out_c)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.se(out) # Apply attention
        out += self.shortcut(x)
        return F.relu(out)

class RobustClassifier(nn.Module):
    """
    The main classifier model.
    Takes 120x120 SR images as input (as trained in Cell 4).
    """
    def __init__(self, num_classes, num_blocks=[3,4,6,3]):
        super().__init__()
        self.in_c = 64
        # Initial 7x7 conv to reduce dimensions
        self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False); self.bn1 = nn.BatchNorm2d(64) # 120x120 -> 60x60
        self.layer1 = self._make_layer(64, num_blocks[0], stride=1) # 60x60
        self.layer2 = self._make_layer(128, num_blocks[1], stride=2) # 30x30
        self.layer3 = self._make_layer(256, num_blocks[2], stride=2) # 15x15
        self.layer4 = self._make_layer(512, num_blocks[3], stride=2) # 8x8
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, num_classes)
    
    def _make_layer(self, out_c, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(BasicBlock(self.in_c, out_c, s))
            self.in_c = out_c
        return nn.Sequential(*layers)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)
        x = self.pool(x)
        return self.fc(x.flatten(1))
    
    def get_embeddings(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)
        return self.pool(x).flatten(1)

# AL Helpers
def get_embeddings(clf, loader, sr_model):
    clf.eval(); sr_model.eval()
    embeds, labels = [], []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Extracting Embeds"):
            with autocast():
                imgs = sr_model(batch['lr'].to(device))
            # (FIXED) Use gpu_count, not world_size
            emb = clf.module.get_embeddings(imgs) if gpu_count > 1 else clf.get_embeddings(imgs)
            embeds.append(emb.cpu().numpy())
            labels.append(batch['label'].numpy())
    return np.concatenate(embeds), np.concatenate(labels)

def dbss_select(unlabeled_embs, labeled_embs, labels, n_select, pin_ratio=0.4):
    num_classes = config['num_classes']
    centroids = np.zeros((num_classes, labeled_embs.shape[1]))
    # Calculate centroids, handling classes that might not be in the labeled set yet
    for i in range(num_classes):
        class_samples = labeled_embs[labels == i]
        if len(class_samples) > 0:
            centroids[i] = class_samples.mean(axis=0)
        else:
            centroids[i] = labeled_embs.mean(axis=0) # Fallback
            
    dists = np.linalg.norm(unlabeled_embs[:, None] - centroids[None], axis=2)
    
    # Inner-class samples (high score = far from all, uncertain)
    inner_scores = np.sum(np.log(dists + 1e-6), axis=1)
    inner_rank = np.argsort(inner_scores)[::-1] # High score is good
    
    # Border samples (low score = on border)
    dists.sort(axis=1)
    border_scores = np.abs(dists[:, 0] - dists[:, 1])
    border_rank = np.argsort(border_scores) # Low score is good
    
    selected = set(inner_rank[:int(n_select * pin_ratio)])
    for idx in border_rank:
        if len(selected) >= n_select:
            break
        selected.add(idx)
    return list(selected)

def ssas_pseudo(student, teacher, sr_model, unlabeled_loader):
    student.eval(); teacher.eval(); sr_model.eval()
    consistent_indices = []
    with torch.no_grad():
        for batch in tqdm(unlabeled_loader, desc="SSAS"):
            with autocast():
                imgs = sr_model(batch['lr'].to(device))
                s_pred = torch.argmax(student(imgs), 1)
                t_pred = torch.argmax(teacher(imgs), 1)
            mask = s_pred == t_pred
            consistent_indices.extend(batch['idx'][mask].tolist()) # Get original master indices
    return consistent_indices

def train_classifier(clf, loader, epochs):
    clf.train()
    opt = optim.Adam(clf.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss().to(device)
    scaler = GradScaler()
    
    for ep in range(epochs):
        tot_loss = 0
        pbar = tqdm(loader, desc=f"CLF Ep {ep+1}/{epochs}")
        for batch in pbar:
            # We train the classifier on the SR-enhanced images
            imgs, lbls = g(batch['lr'].to(device)), batch['label'].to(device)
            opt.zero_grad(set_to_none=True)
            with autocast():
                outputs = clf(imgs)
                loss = criterion(outputs, lbls)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            tot_loss += loss.item()
            pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
        # (FIXED) Removed local_rank error
        wandb.log({"AL_Train/loss": tot_loss / len(loader)})
    return clf

def evaluate_model(clf, loader, sr_model=None):
    clf.eval()
    preds, lbls = [], []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Eval"):
            with autocast():
                imgs = sr_model(batch['lr'].to(device)) if sr_model else batch['lr'].to(device)
                outputs = clf(imgs)
            # (FIXED) Use argmax for single-label, matching CrossEntropyLoss
            pred = torch.argmax(outputs, 1)
            preds.extend(pred.cpu().numpy())
            lbls.extend(batch['label'].numpy())
    acc = (np.array(preds) == np.array(lbls)).mean() * 100
    return acc, preds, lbls

def log_umap(labeled_embs, labeled_lbls, unlabeled_embs, dbss_idx, ssas_idx, cycle):
    reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42, n_jobs=1)
    
    if len(unlabeled_embs) == 0:
        print("No unlabeled samples left to plot in UMAP.")
        return

    all_embs = np.concatenate([labeled_embs, unlabeled_embs])
    all_lbls = np.concatenate([labeled_lbls, -1 * np.ones(len(unlabeled_embs))])
    
    print("Fitting UMAP...")
    emb_2d = reducer.fit_transform(all_embs)
    print("UMAP fit complete.")
    
    fig, ax = plt.subplots(figsize=(14,10))
    unlabeled_mask = all_lbls == -1
    ax.scatter(emb_2d[unlabeled_mask,0], emb_2d[unlabeled_mask,1], c='lightgray', s=5, label='Unlabeled')
    labeled_mask = ~unlabeled_mask
    scatter = ax.scatter(emb_2d[labeled_mask,0], emb_2d[labeled_mask,1], c=all_lbls[labeled_mask], cmap='Spectral', s=20)
    
    if len(dbss_idx) > 0:
        ax.scatter(emb_2d[len(labeled_embs):][dbss_idx], emb_2d[len(labeled_embs):][dbss_idx], c='red', s=100, marker='x', label='DBSS')
    if len(ssas_idx) > 0:
        ax.scatter(emb_2d[len(labeled_embs):][ssas_idx], emb_2d[len(labeled_embs):][ssas_idx], c='lime', s=100, marker='+', label='SSAS')
    
    ax.set_title(f'Feature Space UMAP - Cycle {cycle}')
    class_legend = ax.legend(handles=scatter.legend_elements()[0], labels=all_class_names, title="Classes")
    ax.add_artist(class_legend)
    handles, _ = ax.get_legend_handles_labels()
    ax.legend(handles=[h for h in handles if h.get_label() in ['Unlabeled', 'DBSS', 'SSAS']])
    
    plt.tight_layout()
    # (FIXED) Removed local_rank error
    wandb.log({f"AL_Cycle_{cycle}/umap": wandb.Image(fig)})
    plt.close(fig)

def log_confusion_matrix(preds, lbls, names):
    cm = confusion_matrix(lbls, preds, labels=range(len(names))) 
    fig, ax = plt.subplots(figsize=(12,10))
    sns.heatmap(cm, annot=True, fmt='d', ax=ax, cmap='Blues', xticklabels=names, yticklabels=names)
    ax.set_xlabel('Predicted'); ax.set_ylabel('True'); ax.set_title('Confusion Matrix')
    plt.tight_layout()
    # (FIXED) Removed local_rank error
    wandb.log({"AL_ConfMatrix": wandb.Image(fig)})
    plt.close(fig)

print("AL components defined.")

In [None]:
# Models: Full RFB-ESRGAN Generator (to match saved state_dict), RelDiscriminator, PerceptualLoss
# (Based on ESRGAN; assume growth=32 for lightweight)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.hub
import math
from skimage.metrics import structural_similarity as ssim

class ResidualDenseBlock(nn.Module):
    def __init__(self, nc=64, growth=32):
        super().__init__()
        self.conv1 = nn.Conv2d(nc, growth, 3, 1, 1)
        self.conv2 = nn.Conv2d(nc + growth, growth, 3, 1, 1)
        self.conv3 = nn.Conv2d(nc + 2*growth, growth, 3, 1, 1)
        self.conv4 = nn.Conv2d(nc + 3*growth, growth, 3, 1, 1)
        self.conv5 = nn.Conv2d(nc + 4*growth, nc, 3, 1, 1)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat([x, x1], 1)))
        x3 = self.lrelu(self.conv3(torch.cat([x, x1, x2], 1)))
        x4 = self.lrelu(self.conv4(torch.cat([x, x1, x2, x3], 1)))
        x5 = self.conv5(torch.cat([x, x1, x2, x3, x4], 1))
        # FIX: Light anti-aliasing blur to reduce checkerboard in residuals
        x5 = F.avg_pool2d(x5, kernel_size=3, stride=1, padding=1)
        return x5 * 0.2 + x  # Dense residual

class RFBESRGANGenerator(nn.Module):
    def __init__(self, growth=32, num_blocks=23, nc=64, upscale=4):
        super().__init__()
        self.entry = nn.Sequential(nn.Conv2d(3, nc, 3, 1, 1), nn.LeakyReLU(0.2, inplace=True))
        self.body = nn.ModuleList([ResidualDenseBlock(nc, growth) for _ in range(num_blocks)])
        self.conv_tail = nn.Conv2d(nc, nc, 3, 1, 1)
        # FIX: Bilinear upsample + Conv (replaces TransposedConv to avoid checkerboard artifacts)
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(nc, nc//2, 3, 1, 1), nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(nc//2, 3, 3, 1, 1)
        )

    def forward(self, x):
        x = self.entry(x)
        res = x
        for block in self.body:
            x = block(x)
        x = self.conv_tail(x)
        x += res  # Global skip
        x = self.up(x)
        return torch.tanh(x)  # [-1,1] range

class RelDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 64, 4, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, 4, 1, 0)  # PatchGAN output
        )

    def forward(self, x):
        return self.net(x)

class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = torch.hub.load('pytorch/vision:v0.10.0', 'vgg19', pretrained=True).features[:35].eval()
        for p in vgg.parameters():
            p.requires_grad = False
        self.vgg = vgg
        self.l1 = nn.L1Loss()
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))

    def forward(self, sr, hr):
        # FIX: Use next(param).device for Sequential (vgg has no direct .device)
        vgg_device = next(self.vgg.parameters()).device
        sr = sr.to(vgg_device)
        hr = hr.to(vgg_device)
        sr_vgg = (self.vgg((sr + 1)/2 * self.std + self.mean) + 1)/2
        hr_vgg = (self.vgg((hr + 1)/2 * self.std + self.mean) + 1)/2
        return self.l1(sr_vgg, hr_vgg)

# PSNR/SSIM Helpers (FIX: Smaller win_size for small images; skip zero-var)
def psnr(sr, hr, max_val=1.0):
    mse = F.mse_loss(sr, hr)
    if mse == 0: return 100.0
    return 20 * math.log10(max_val / math.sqrt(mse.item()))

def compute_ssim(sr, hr):
    sr_np = sr.permute(1,2,0).cpu().numpy()
    hr_np = hr.permute(1,2,0).cpu().numpy()
    # FIX: Skip if zero-variance (causes divide-by-zero in HSV)
    if sr_np.std() < 1e-6 or hr_np.std() < 1e-6:
        return 1.0  # Perfect if both black
    # win_size=3 for 30x30 LR; channel_axis=-1 for HWC RGB
    return ssim(sr_np, hr_np, multichannel=True, channel_axis=-1, data_range=1.0, win_size=3)

print("SR Models defined.")

# --------------------------------------------------------------------------------
# SECTION 3 (CONTINUED): CLASSIFIER MODEL DEFINITIONS
# (Moved from Cell 5 to solve NameError)
# --------------------------------------------------------------------------------

class SEBlock(nn.Module):
    """Squeeze-and-Excitation Block"""
    def __init__(self, c, r=16):
        super().__init__()
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), # Squeeze
            nn.Flatten(),
            nn.Linear(c, c // r, bias=False), 
            nn.ReLU(inplace=True), 
            nn.Linear(c // r, c, bias=False), 
            nn.Sigmoid()
        )
    def forward(self, x):
        y = self.fc(x)
        return x * y.view(x.shape[0], x.shape[1], 1, 1) # Excitation

class BasicBlock(nn.Module):
    """ResNet-style block with SE"""
    def __init__(self, in_c, out_c, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.se = SEBlock(out_c)
        
        self.shortcut = nn.Sequential()
        if stride > 1 or in_c != out_c:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_c, out_c, 1, stride, bias=False),
                nn.BatchNorm2d(out_c)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.se(out) # Apply attention
        out += self.shortcut(x)
        return F.relu(out)

class RobustClassifier(nn.Module):
    """
    The main classifier model.
    Takes 64x64 SR images as input.
    """
    def __init__(self, num_classes, num_blocks=[3,4,6,3]):
        super().__init__()
        self.in_c = 64
        # Initial 7x7 conv to reduce dimensions
        self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False); self.bn1 = nn.BatchNorm2d(64) # 64x64 -> 32x32
        # Stack layers
        self.layer1 = self._make_layer(64, num_blocks[0], stride=1) # 32x32
        self.layer2 = self._make_layer(128, num_blocks[1], stride=2) # 16x16
        self.layer3 = self._make_layer(256, num_blocks[2], stride=2) # 8x8
        self.layer4 = self._make_layer(512, num_blocks[3], stride=2) # 4x4
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, num_classes)
    
    def _make_layer(self, out_c, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(BasicBlock(self.in_c, out_c, s))
            self.in_c = out_c
        return nn.Sequential(*layers)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)
        x = self.pool(x)
        return self.fc(x.flatten(1))
    
    def get_embeddings(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)
        return self.pool(x).flatten(1)

print("Classifier Models defined.")

# cell 6

In [None]:
# --------------------------------------------------------------------------------
# SECTION 6: ACTIVE LEARNING PIPELINE
# (FIXED: This cell now assumes all models and helpers are defined
# in previous cells and uses the CORRECT, real labels from Cell 2)
# --------------------------------------------------------------------------------

# We assume 'g' (the SR model) is already trained and exists in memory from Cell 4
# We load the *final* saved model for consistency and to ensure it's in eval mode.
SR_MODEL_PATH = '/kaggle/working/sr_model_final.pth' # <-- This must match the save path from Cell 4
if not os.path.exists(SR_MODEL_PATH):
    print(f"WARNING: {SR_MODEL_PATH} not found! Using model from memory. Run Cell 4 first.")
else:
    try:
        print(f"Loading trained SR model from {SR_MODEL_PATH}...")
        # Re-instantiate the model to load the state dict
        g = RFBESRGANGenerator(growth=32).to(device)
        if gpu_count > 1:
            g = nn.DataParallel(g)
            g.module.load_state_dict(torch.load(SR_MODEL_PATH, map_location=device))
        else:
            g.load_state_dict(torch.load(SR_MODEL_PATH, map_location=device))
        print("Trained SR model loaded successfully.")
    except Exception as e:
        print(f"Error loading SR model, using model from memory: {e}")
g.eval() # Set to evaluation mode

# Instantiate Classifier (from Cell 3)
# (FIXED) Pass the correct num_classes and num_blocks
clf = RobustClassifier(num_classes=config['num_classes'], num_blocks=[3,4,6,3]).to(device)
if gpu_count > 1:
    clf = nn.DataParallel(clf)
    print(f"Classifier using DataParallel across {gpu_count} GPUs")
wandb.watch(clf, log="all", log_freq=100)

# --------------------------------------------------------------------------------
# (CRITICAL FIX) AL Setup: Use REAL labels for stratification
# We assume 'train_idx', 'val_idx', 'al_base_dataset', 'val_loader_al_dataset',
# and 'all_class_names' exist from Cell 2.
# --------------------------------------------------------------------------------

# Get the labels corresponding to our training indices
# train_idx is a list of *master indices* from our valid, limited dataset
# We need to get the labels for *just* these indices
train_labels_subset = [patch_to_label.get(all_b04_paths[i].parent.name) for i in train_idx]
# Filter out any None labels just in case
valid_train_indices_for_al = [idx for i, idx in enumerate(train_idx) if train_labels_subset[i] is not None and train_labels_subset[i] != -1]
valid_train_labels_for_al = [lbl for lbl in train_labels_subset if lbl is not None and lbl != -1]

initial_pool_size = int(0.1 * len(valid_train_indices_for_al))

# (FIXED) Create the split using the *actual* labels
labeled_indices, unlabeled_indices = train_test_split(
    valid_train_indices_for_al, 
    train_size=initial_pool_size, 
    random_state=42,
    stratify=valid_train_labels_for_al # <-- This is the critical fix
)
print(f"AL Setup: Labeled pool {len(labeled_indices)}, Unlabeled pool {len(unlabeled_indices)}")
# --------------------------------------------------------------------------------


# AL Dataset for subsets
class IndexedALDataset(Dataset):
    """Wraps the BigEarthNetSR dataset to use a specific list of indices."""
    def __init__(self, base_ds, indices):
        self.base_ds = base_ds
        self.indices = indices # These are the *master* indices
    
    def __len__(self): return len(self.indices)
    
    def __getitem__(self, idx):
        # self.indices[idx] gives the *original* master index
        original_idx = self.indices[idx]
        # Get the item from the *original* list of all_b04_paths
        # The base_ds __getitem__ needs the index relative to *its* internal list
        # Let's fix the base_ds to accept master indices directly
        
        # We need to find the item in al_base_dataset that corresponds to original_idx
        # This is complex. Let's redefine IndexedALDataset to use the *original* dataset object
        item = self.base_ds[original_idx] 
        return item

# We assume 'al_base_dataset' and 'val_loader_al_dataset' were created in Cell 2
# Let's create them here to be 100% sure they are correct
al_base_dataset = BigEarthNetSR(image_root_path, list(range(len(all_b04_paths))), all_b04_paths, patch_to_label, scale=4, phase='al')
val_al_loader_dataset = BigEarthNetSR(image_root_path, val_idx, all_b04_paths, patch_to_label, scale=4, phase='al')


labeled_loader = DataLoader(IndexedALDataset(al_base_dataset, labeled_indices), batch_size=config['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
unlabeled_loader = DataLoader(IndexedALDataset(al_base_dataset, unlabeled_indices), batch_size=config['batch_size'], shuffle=False, num_workers=2, pin_memory=True)
val_al_loader = DataLoader(val_al_loader_dataset, batch_size=16, shuffle=False, num_workers=2) # Use the correct validation set


teacher = None
for cycle in range(config['al_cycles']):
    print(f"\n--- AL Cycle {cycle+1}/{config['al_cycles']} ---")
    wandb.log({"AL_Cycle": cycle})
    
    # Train Classifier (using functions from Cell 5)
    clf = train_classifier(clf, labeled_loader, config['al_epochs'])
    
    # Get Embeddings (using functions from Cell 5)
    print("Extracting embeddings for Labeled pool...")
    labeled_embs, labeled_lbls = get_embeddings(clf, labeled_loader, g)
    print("Extracting embeddings for Unlabeled pool...")
    if not unlabeled_indices:
        print("No more unlabeled data to select from.")
        break
    unlabeled_embs, _ = get_embeddings(clf, unlabeled_loader, g)
    
    # DBSS Selection (using functions from Cell 5)
    n_select = min(int(len(train_idx) * 0.1), len(unlabeled_indices))
    dbss_local_idx = dbss_select(unlabeled_embs, labeled_embs, labeled_lbls, n_select)
    
    # Map local indices (0 to len(unlabeled_embs)) back to original dataset indices
    newly_labeled_human_indices = [unlabeled_indices[i] for i in dbss_local_idx]
    
    # SSAS (post-cycle 1)
    newly_labeled_pseudo_indices = []
    ssas_local_viz_idx = np.array([])
    if cycle >= 1 and teacher is not None:
        print("Running SSAS...")
        pseudo_original_indices = ssas_pseudo(clf, teacher, g, unlabeled_loader)
        newly_labeled_pseudo_indices = [idx for idx in pseudo_original_indices if idx not in newly_labeled_human_indices]
        
        # Get local indices for SSAS for visualization
        unlabeled_map = {original_idx: local_idx for local_idx, original_idx in enumerate(unlabeled_indices)}
        ssas_local_viz_idx = [unlabeled_map.get(pi) for pi in newly_labeled_pseudo_indices if pi in unlabeled_map]
        ssas_local_viz_idx = np.array([idx for idx in ssas_local_viz_idx if idx is not None and idx < len(unlabeled_embs)])

    
    # Update Pools
    all_new_indices = np.unique(np.concatenate([newly_labeled_human_indices, newly_labeled_pseudo_indices])).tolist()
    labeled_indices = np.unique(np.concatenate([labeled_indices, all_new_indices])).tolist()
    unlabeled_indices = np.setdiff1d(unlabeled_indices, all_new_indices).tolist()
    
    print(f"Cycle {cycle+1} complete. Added {len(newly_labeled_human_indices)} DBSS samples and {len(newly_labeled_pseudo_indices)} SSAS samples.")
    print(f"New pool sizes: Labeled={len(labeled_indices)}, Unlabeled={len(unlabeled_indices)}")

    # Log Visuals (UMAP)
    log_umap(labeled_embs, labeled_lbls, unlabeled_embs, dbss_local_idx, ssas_local_viz_idx, cycle)
    
    # Update Teacher (deepcopy handles DataParallel)
    teacher = copy.deepcopy(clf)
    
    # Cycle Eval (FIXED: Using correct single-label evaluation)
    val_acc, val_preds, val_lbls = evaluate_model(clf, val_al_loader, g) 
    wandb.log({f"AL_Cycle_{cycle}/accuracy": val_acc})
    log_confusion_matrix(val_preds, val_lbls, all_class_names) # (FIXED) Use correct class_names
    print(f"Cycle {cycle} Val Acc (Standard): {val_acc:.2f}%")
    
    if not unlabeled_indices:
        print("All samples labeled. Halting.")
        break
        
    # Update Loaders
    labeled_loader = DataLoader(IndexedALDataset(al_base_dataset, labeled_indices), batch_size=config['batch_size'], shuffle=True, num_workers=2)
    unlabeled_loader = DataLoader(IndexedALDataset(al_base_dataset, unlabeled_indices), batch_size=config['batch_size'], shuffle=False, num_workers=2)

# Save Classifier (unwrap if DataParallel)
def save_clf_model(model, path): # (FIXED) Define save_model locally in cell
    state_dict = model.module.state_dict() if gpu_count > 1 else model.state_dict()
    torch.save(state_dict, path)
    wandb.save(path, policy="now")
    print(f"Saved {path}")
    
save_clf_model(clf, '/kaggle/working/clf_model.pth')

print("AL Pipeline complete!")

# cell 7 

In [None]:
# --------------------------------------------------------------------------------
# SECTION 7: FINAL EVALUATION
# --------------------------------------------------------------------------------

print("Starting final evaluation...")

# We assume 'clf', 'val_loader', and 'g' (SR model from input path) exist from Cell 6
# We assume 'psnr' and 'compute_ssim' exist from Cell 3

final_acc, final_preds, final_lbls = evaluate_model(clf, val_loader, g)

# Calculate final PSNR/SSIM on a subset of the validation loader
final_psnr, final_ssim, batch_count = 0, 0, 0
val_iter = iter(val_loader)
for _ in range(10): # Use 10 batches for a good approximation
    try:
        batch = next(val_iter)
        lr_v, hr_v = batch['lr'].to(device), batch['hr'].to(device)
        with torch.no_grad():
            sr_v = g(lr_v) if gpu_count == 1 else g.module(lr_v)
        final_psnr += psnr(sr_v, hr_v)  # Batch PSNR
        # SSIM: Loop over batch (per-image)
        batch_ssim = 0
        for b in range(lr_v.shape[0]):
            batch_ssim += compute_ssim(sr_v[b], hr_v[b])
        final_ssim += batch_ssim / lr_v.shape[0]
        batch_count += 1
    except StopIteration:
        break # Stop if val_loader has fewer than 10 batches
    except Exception as e:
        print(f"Error in final eval: {e}")

if batch_count > 0:
    final_psnr /= batch_count
    final_ssim /= batch_count

wandb.log({
    "final/psnr": final_psnr, "final/ssim": final_ssim,
    "final/accuracy": final_acc  # Hamming
})

# Summary Table
table = wandb.Table(columns=["Metric", "Value"])
table.add_data(["PSNR (dB)", f"{final_psnr:.2f}"])
table.add_data(["SSIM", f"{final_ssim:.3f}"])
table.add_data(["AL Accuracy (Hamming %)", f"{final_acc:.2f}"])
wandb.log({"final/summary": table})

print(f"Final PSNR: {final_psnr:.2f} | SSIM: {final_ssim:.3f} | AL Acc: {final_acc:.2f}%")

wandb.finish()
print("Pipeline complete! Check WandB dashboard for full logs/visuals.")