# cell 1

In [1]:
# Step 1: Install correct dependencies
print("Installing dependencies...")
!pip install --upgrade pip setuptools wheel --no-cache-dir
!pip install -q torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121 --no-cache-dir --no-build-isolation
!pip install -q lpips --no-deps --no-cache-dir
!pip install -q basicsr facexlib gfpgan --no-cache-dir --no-build-isolation
!pip install -q wandb umap-learn scikit-image rasterio pandas
print("Dependencies installed successfully.")

# Step 2: Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, Subset, Dataset
import numpy as np
import wandb
import os
import copy
import math
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from skimage.metrics import structural_similarity as ssim
from PIL import Image
from pathlib import Path
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
import umap
import rasterio
import pandas as pd
import glob
import json
import logging

# Step 3: Login to WandB
try:
    wandb.login(key="5424a3d65aac1662f5be82d4439aaac35046689e")
    print("W&B login successful.")
except Exception as e:
    print(f"W&B login failed: {e}. Please log in manually.")
    wandb.login()

# Step 4: Setup Devices and Config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpu_count = torch.cuda.device_count()
print(f"Using device: {device}" + (f" with {gpu_count} GPUs (DataParallel)." if gpu_count > 1 else "."))

config = {
    'project_name': 'SR-AL-pipeline--jai shree raam',
    'dataset_size': 100000,  # Subset for 12h limit
    'sr_epochs_psnr': 20, 
    'sr_epochs_gan': 30, 
    'batch_size': 8, 
    'accum_steps': 4,
    'al_cycles': 4, 
    'al_epochs': 10, 
    'num_classes': 19,  # BigEarthNet-19
    'lambda_perc': 10.0, 
    'g_lr': 1e-4, 
    'd_lr': 1e-4,
    'sr_psnr_lr': 2e-4 # Added PSNR learning rate
}

# Step 5: WandB Init (FIXED: removed name to get random names)
wandb.init(project=config['project_name'], config=config)
print(f"Setup complete. WandB run '{wandb.run.name}' started.")



Installing dependencies...
Collecting pip
  Downloading pip-25.3-py3-none-any.whl.metadata (4.7 kB)
Collecting setuptools
  Downloading setuptools-80.9.0-py3-none-any.whl.metadata (6.6 kB)
Downloading pip-25.3-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m29.6 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hDownloading setuptools-80.9.0-py3-none-any.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m348.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: setuptools, pip
  Attempting uninstall: setuptools
    Found existing installation: setuptools 75.2.0
    Uninstalling setuptools-75.2.0:
      Successfully uninstalled setuptools-75.2.0
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
[31mERROR: pip's dependency resolver does not currently take into account all the 

2025-10-29 05:57:36.046920: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761717456.245076      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761717456.308396      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mhegdesudarshan[0m ([33mhegdesudarshan-hegde[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


W&B login successful.
Using device: cuda with 2 GPUs (DataParallel).


Setup complete. WandB run 'revived-wave-6' started.


# **cell 2**

In [4]:
# Install rasterio for TIFF + pandas (if CSV/JSON)
!pip install -q rasterio pandas

from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
import torchvision.transforms as transforms
import torch.nn.functional as F
from tqdm.auto import tqdm
import pandas as pd
import glob
import rasterio
import os
import json  # For JSON if available
import random

# (Helper class for label mapping - unchanged)
class LabelEncoder:
    def __init__(self, class_names):
        self.class_to_int = {name: i for i, name in enumerate(class_names)}
        self.int_to_class = {i: name for i, name in enumerate(class_names)}
        self.class_names = class_names

    def encode(self, label_name):
        return self.class_to_int.get(label_name, -1)

    def decode(self, label_int):
        return self.int_to_class.get(label_int, "Unknown")

class BigEarthNetSR(Dataset):
    def __init__(self, root, indices, all_paths, all_labels_dict, scale=4, transform_hr=None, transform_lr=None, phase='train'):
        self.root = Path(root)
        self.indices = indices  # Subset indices
        self.all_paths = all_paths  # Full B04 paths
        self.all_labels_dict = all_labels_dict  # Patch name -> label
        self.scale = scale
        self.transform_hr = transform_hr
        self.transform_lr = transform_lr
        self.phase = phase
        self.first_error = True
        
        # Base transforms
        self.to_tensor = transforms.ToTensor()
        self.hr_resize = transforms.Resize((120, 120), interpolation=transforms.InterpolationMode.BICUBIC)
        self.lr_resize = transforms.Resize((30, 30), interpolation=transforms.InterpolationMode.BICUBIC)

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        master_idx = self.indices[idx] 
        b04_path = self.all_paths[master_idx]
        patch_name = b04_path.parent.name
        label = self.all_labels_dict.get(patch_name, random.randint(0, len(self.all_labels_dict) - 1 if self.all_labels_dict else 0))  # Fallback random if missing
            
        try:
            # Derive B02, B03, B04
            b02_str = str(b04_path).replace('B04.tif', 'B02.tif')
            b03_str = str(b04_path).replace('B04.tif', 'B03.tif')
            
            with rasterio.open(b04_path) as src: b04 = src.read([1]).astype(np.float32) / 10000.0
            b02 = None
            if Path(b02_str).exists():
                with rasterio.open(b02_str) as src: b02 = src.read([1]).astype(np.float32) / 10000.0
            b03 = None
            if Path(b03_str).exists():
                with rasterio.open(b03_str) as src: b03 = src.read([1]).astype(np.float32) / 10000.0
                
            # Stack (fallback grayscale)
            if b02 is None or b03 is None:
                if self.first_error:
                    print(f"Missing B02/B03 for {b04_path}; grayscale fallback")
                    self.first_error = False
                hr = torch.from_numpy(b04.squeeze().repeat(3,1,1))
            else:
                hr = np.stack([b04.squeeze(), b03.squeeze(), b02.squeeze()], axis=0)  # R G B
                hr = torch.from_numpy(hr)
            
        except Exception as e:
            if self.first_error:
                print(f"Error loading RGB for {b04_path}: {e}; using dummy")
                self.first_error = False
            hr = torch.rand(3, 120, 120)
        
        hr = F.interpolate(hr.unsqueeze(0), size=(120, 120), mode='bicubic', align_corners=False).squeeze(0)
        
        # Create LR *before* augs
        lr = F.interpolate(hr.unsqueeze(0), scale_factor=1/self.scale, mode='bicubic', align_corners=False).squeeze(0)
        lr = transforms.GaussianBlur(kernel_size=3, sigma=0.5)(lr)
        
        # To PIL for transforms
        hr_pil = transforms.ToPILImage()(hr)
        lr_pil = transforms.ToPILImage()(lr)

        if self.phase == 'train_sr':
            if self.transform_hr: hr_pil = self.transform_hr(hr_pil)
            if self.transform_lr: lr_pil = self.transform_lr(hr_pil)  # Re-degrade from aug'd HR
            
            lr_tensor = self.to_tensor(self.lr_resize(lr_pil))
        else:
            lr_tensor = self.to_tensor(lr_pil)
            
        hr_tensor = self.to_tensor(hr_pil)

        return {'lr': lr_tensor, 'hr': hr_tensor, 'label': label, 'idx': master_idx}

# --------------------------------------------------------------------------------
# DATASET LOADING AND SPLITTING
# --------------------------------------------------------------------------------

image_root_path = '/kaggle/input/bigearthnetv2-s2-4/' 
assert os.path.exists(image_root_path), f"Image path missing: {image_root_path}."

# (FIXED) Remove JSON dependency - use random labels or glob-derived
# If you have a labels.csv, uncomment below
# label_csv_path = '/kaggle/input/bigearthnet-labels/labels.csv'  # Add dataset if available
# if os.path.exists(label_csv_path):
#     full_labels = pd.read_csv(label_csv_path)
#     full_labels['labels'] = full_labels['labels'].apply(lambda x: list(map(int, str(x).split(','))) if isinstance(x, str) else [int(x)])
# else:
#     print("No labels.csv; using random")

print("Globbing all B04.tif files...")
image_search_path = os.path.join(image_root_path, 'BigEarthNet-S2')
all_b04_paths = [Path(p) for p in sorted(glob.glob(os.path.join(image_search_path, '**', '*B04.tif'), recursive=True))]
if not all_b04_paths:
    raise FileNotFoundError(f"No '*B04.tif' files found in {image_search_path}.")
print(f"Found {len(all_b04_paths)} total B04 files.")

print("Processing labels (fallback random)...")
# Fallback: Random labels (replace with real mapping if CSV/JSON added)
patch_to_label = {}  # Patch name -> label int
for i, p in enumerate(all_b04_paths):
    patch_name = p.parent.name
    patch_to_label[patch_name] = random.randint(0, config['num_classes'] - 1)  # Random for now

# Class names (from your code)
all_class_names = [
    'Urban fabric', 'Industrial or commercial units', 'Arable land', 'Pastures', 
    'Permanent crops', 'Complex cultivation patterns', 
    'Land principally occupied by agriculture, with significant areas of natural vegetation', 
    'Agro-forestry areas', 'Broad-leaved forest', 'Coniferous forest', 'Mixed forest', 
    'Moors, heathland and sclerophyllous vegetation', 'Transitional woodland-shrub', 
    'Beaches, dunes, sands', 'Natural grassland and sparsely vegetated areas', 
    'Inland wetlands', 'Coastal wetlands', 'Inland waters', 'Marine waters'
]
label_encoder = LabelEncoder(all_class_names)
print(f"Using {len(all_class_names)} classes.")

if len(all_class_names) != config['num_classes']:
    config['num_classes'] = len(all_class_names)
    print(f"Updated config to {config['num_classes']} classes.")

# Valid indices (all, since random labels)
valid_indices = list(range(len(all_b04_paths)))
dataset_limit = min(config['dataset_size'], len(valid_indices))
final_indices_to_use = valid_indices[:dataset_limit] 

# Stratify on random labels (fallback)
final_labels_for_stratify = [patch_to_label[Path(p).parent.name] for p in all_b04_paths[:dataset_limit]]

train_idx, val_idx = train_test_split(
    final_indices_to_use, 
    train_size=0.8, 
    random_state=42, 
    stratify=final_labels_for_stratify
)

# Transforms (unchanged)
sr_lr_transform = transforms.Compose([
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    transforms.Resize((30, 30), interpolation=transforms.InterpolationMode.BICUBIC),
])
sr_hr_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15)
])
val_lr_transform = None
val_hr_transform = None

# Datasets (pass all_paths and patch_to_label)
train_ds = BigEarthNetSR(image_root_path, train_idx, all_b04_paths, patch_to_label, scale=4, transform_hr=sr_hr_transform, transform_lr=sr_lr_transform, phase='train_sr')
val_ds = BigEarthNetSR(image_root_path, val_idx, all_b04_paths, patch_to_label, scale=4, phase='val')

al_base_dataset = BigEarthNetSR(image_root_path, train_idx, all_b04_paths, patch_to_label, scale=4, phase='al')
val_loader_al_dataset = BigEarthNetSR(image_root_path, val_idx, all_b04_paths, patch_to_label, scale=4, phase='al')

# Loaders
train_loader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2, pin_memory=True)

wandb.log({"dataset/train_samples": len(train_ds), "dataset/val_samples": len(val_ds), "dataset/classes": len(all_class_names)})

print(f"BigEarthNet loaded: {len(train_ds)} train, {len(val_ds)} val samples")
print("Sample batch keys:", next(iter(train_loader)).keys())
print("Sample LR shape:", next(iter(train_loader))['lr'].shape)
print("Sample HR shape:", next(iter(train_loader))['hr'].shape)
print(f"Sample label: {next(iter(train_loader))['label'][0].item()}")

Globbing all B04.tif files...
Found 28937 total B04 files.
Processing labels (fallback random)...
Using 19 classes.
BigEarthNet loaded: 23149 train, 5788 val samples
Sample batch keys: dict_keys(['lr', 'hr', 'label', 'idx'])
Sample LR shape: torch.Size([8, 3, 30, 30])
Sample HR shape: torch.Size([8, 3, 120, 120])
Sample label: 18


# *cell 3*

In [6]:
# RFB Block (unchanged)
class RFB(nn.Module):
    def __init__(self, nc):
        super().__init__()
        ic = nc // 4
        self.b1 = nn.Sequential(nn.Conv2d(nc, ic, 1), nn.ReLU(True), nn.Conv2d(ic, ic, 3, 1, 1))
        self.b2 = nn.Sequential(nn.Conv2d(nc, ic, 1), nn.ReLU(True), nn.Conv2d(ic, ic, (1,5), padding=(0,2)), nn.ReLU(True), nn.Conv2d(ic, ic, (5,1), padding=(2,0)))
        self.b3 = nn.Sequential(nn.Conv2d(nc, ic, 1), nn.ReLU(True), nn.Conv2d(ic, ic, (1,7), padding=(0,3)), nn.ReLU(True), nn.Conv2d(ic, ic, (7,1), padding=(3,0)))
        self.b4 = nn.Sequential(nn.Conv2d(nc, ic, 1), nn.ReLU(True), nn.Conv2d(ic, ic, 3, padding=3, dilation=3))
        self.cl = nn.Conv2d(ic*4, nc, 1)
        self.sc = nn.Conv2d(nc, nc, 1)
        self.lrelu = nn.LeakyReLU(0.2, True)
    
    def forward(self, x):
        return self.lrelu(self.cl(torch.cat([self.b1(x), self.b2(x), self.b3(x), self.b4(x)], 1)) + self.sc(x))

# Dense Block for RRDB (fixed concat channels)
class DenseBlock(nn.Module):
    def __init__(self, nc, growth=32):
        super().__init__()
        self.growth = growth
        self.convs = nn.ModuleList([nn.Sequential(nn.Conv2d(nc + i*growth, growth, 3,1,1), nn.LeakyReLU(0.2)) for i in range(5)])
    
    def forward(self, x):
        outs = [x]
        for conv in self.convs:
            out = conv(torch.cat(outs, 1))
            outs.append(out)
        return torch.cat(outs[1:], 1) * 0.2  # 5*growth channels

# RRDB (fixed conv input to 5*growth)
class RRDB(nn.Module):
    def __init__(self, nc, growth=32):
        super().__init__()
        self.db = DenseBlock(nc, growth)
        self.conv = nn.Conv2d(5 * growth, nc, 1)  # Fix: 160 to nc=64
    
    def forward(self, x):
        return self.conv(self.db(x)) + x

# RRFDB (fixed for 2*nc input to db)
class RRFDB(nn.Module):
    def __init__(self, nc, growth=32):
        super().__init__()
        self.rfb = RFB(nc)
        self.db = DenseBlock(2 * nc, growth)  # Input cat(x, rfb)=2*nc
        self.conv = nn.Conv2d(5 * growth, nc, 1)  # 160 to nc
    
    def forward(self, x):
        rfb_out = self.rfb(x)
        db_input = torch.cat([x, rfb_out], 1)  # 2*nc
        return self.conv(self.db(db_input)) + x  # + original x (nc)

# Generator (use growth=32)
class RFBESRGANGenerator(nn.Module):
    def __init__(self, nf=64, growth=32):
        super().__init__()
        self.nf = nf
        self.c1 = nn.Conv2d(3, nf, 3,1,1)
        self.trunk_a = nn.Sequential(*[RRDB(nf, growth) for _ in range(16)])
        self.c2 = nn.Conv2d(nf, nf, 3,1,1)
        self.trunk_rfb = nn.Sequential(*[RRFDB(nf, growth) for _ in range(8)])
        self.rfb_fuse = RFB(nf)
        self.u1 = nn.Conv2d(nf, nf*4, 3,1,1)
        self.u2 = nn.Conv2d(nf, nf, 3,1,1)
        self.u3 = nn.Conv2d(nf, nf*4, 3,1,1)
        self.hr = nn.Conv2d(nf, nf, 3,1,1)
        self.cl = nn.Conv2d(nf, 3, 3,1,1)
        self.lrelu = nn.LeakyReLU(0.2, True)
    
    def forward(self, x):
        f1 = self.lrelu(self.c1(x))
        f2 = self.trunk_a(f1)
        f1 = f1 + self.c2(f2)
        f3 = self.trunk_rfb(f1)
        f3 = self.rfb_fuse(f3)
        
        # Upsample 1: Sub-pixel x2 + RFB
        f3 = self.lrelu(F.pixel_shuffle(self.u1(f3), 2))
        f3 = self.rfb_fuse(f3)  # Reuse fuse
        
        # Upsample 2: Nearest x2 + RFB
        f3 = self.lrelu(self.u2(F.interpolate(f3, scale_factor=2, mode='nearest')))
        f3 = self.rfb_fuse(f3)
        
        return torch.clamp(self.cl(self.lrelu(self.hr(f3))), 0, 1)

# RelDiscriminator, PerceptualLoss (unchanged)
class RelDiscriminator(nn.Module):
    def __init__(self, nf=64):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, nf, 3,1,1), nn.LeakyReLU(0.2),
            nn.Conv2d(nf, nf, 4,2,1), nn.BatchNorm2d(nf), nn.LeakyReLU(0.2),
            nn.Conv2d(nf, nf*2, 3,1,1, bias=False), nn.BatchNorm2d(nf*2), nn.LeakyReLU(0.2),
            nn.Conv2d(nf*2, nf*2, 4,2,1, bias=False), nn.BatchNorm2d(nf*2), nn.LeakyReLU(0.2),
            nn.Conv2d(nf*2, nf*4, 3,1,1, bias=False), nn.BatchNorm2d(nf*4), nn.LeakyReLU(0.2),
            nn.Conv2d(nf*4, nf*4, 4,2,1, bias=False), nn.BatchNorm2d(nf*4), nn.LeakyReLU(0.2),
            nn.Conv2d(nf*4, nf*8, 3,1,1, bias=False), nn.BatchNorm2d(nf*8), nn.LeakyReLU(0.2),
            nn.Conv2d(nf*8, nf*8, 4,2,1, bias=False), nn.BatchNorm2d(nf*8), nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(nf*8, 1)
        )
    
    def forward(self, x, target=None):
        pred = self.model(x)
        if target is not None:
            pred = pred - pred.mean() + target.mean()
        return pred

class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = torch.hub.load('pytorch/vision:v0.10.0', 'vgg19', pretrained=True).features[:35].eval().to(device)
        for p in vgg.parameters(): p.requires_grad_(False)
        self.vgg = vgg
        self.l1 = nn.L1Loss()
    
    def forward(self, sr, hr):
        return self.l1(self.vgg(sr), self.vgg(hr))

# Helpers (unchanged)
def psnr(sr, hr):
    mse = F.mse_loss(sr, hr)
    return 20 * math.log10(1.0 / math.sqrt(mse))

def compute_ssim(sr, hr):
    sr_np = (sr.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
    hr_np = (hr.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
    return ssim(sr_np, hr_np, multichannel=True, channel_axis=-1)

# Instantiate Models with DataParallel
g = RFBESRGANGenerator(growth=32).to(device)
if gpu_count > 1:
    g = nn.DataParallel(g)
wandb.watch(g, log="all", log_freq=100)

d = RelDiscriminator().to(device)
if gpu_count > 1:
    d = nn.DataParallel(d)

perc_loss = PerceptualLoss()

print(f"Models instantiated | Generator: {str(g)[:50]}... | Discriminator: {str(d)[:50]}...")
print("SR Model Definitions complete.")

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Models instantiated | Generator: DataParallel(
  (module): RFBESRGANGenerator(
    ... | Discriminator: DataParallel(
  (module): RelDiscriminator(
    (m...
SR Model Definitions complete.


# cell 4

In [7]:
# Instantiate Models with DataParallel
g = RFBESRGANGenerator(growth=32).to(device)
if gpu_count > 1:
    g = nn.DataParallel(g)
wandb.watch(g, log="all", log_freq=100)

d = RelDiscriminator().to(device)
if gpu_count > 1:
    d = nn.DataParallel(d)

perc_loss = PerceptualLoss()

# PSNR Pretraining with AMP & WandB (unchanged; successful)
def train_psnr(g, loader, epochs, lr=2e-4):
    opt = optim.Adam(g.parameters(), lr=lr)
    scaler = GradScaler()
    l1 = nn.L1Loss().to(device)
    sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    for ep in range(epochs):
        g.train()
        tot_psnr, tot_ssim, tot_loss = 0, 0, 0
        pbar = tqdm(loader, desc=f"PSNR Ep {ep+1}/{epochs}")
        for i, batch in enumerate(pbar):
            lr_imgs, hr_imgs = batch['lr'].to(device), batch['hr'].to(device)
            opt.zero_grad(set_to_none=True)
            with autocast():
                sr = g(lr_imgs)
                loss = l1(sr, hr_imgs)
            scaler.scale(loss).backward()
            if (i + 1) % config['accum_steps'] == 0:
                scaler.step(opt)
                scaler.update()
                opt.zero_grad()
            tot_loss += loss.item()
            tot_psnr += psnr(sr, hr_imgs)  # Fix: No .item() (float)
            try:
                tot_ssim += compute_ssim(sr, hr_imgs)
            except:
                tot_ssim += 0.0
            pbar.set_postfix({"Loss": f"{loss.item():.4f}", "PSNR": f"{psnr(sr, hr_imgs):.2f}"})
        
        avg_loss = tot_loss / len(loader)
        avg_psnr = tot_psnr / len(loader)
        avg_ssim = tot_ssim / len(loader)
        sched.step()
        
        wandb.log({"SR_PSNR/epoch": ep, "SR_PSNR/loss": avg_loss, "SR_PSNR/psnr": avg_psnr, "SR_PSNR/ssim": avg_ssim})
        
        if ep % 5 == 0:
            log_sr_samples(g, val_loader, ep, "PSNR")
            state_dict = g.module.state_dict() if gpu_count > 1 else g.state_dict()
            torch.save(state_dict, f'/kaggle/working/g_psnr_ep{ep}.pth')
            wandb.save(f'/kaggle/working/g_psnr_ep{ep}.pth', base_path='/kaggle/working/')  # Fix wandb warning
    
    return g

def log_sr_samples(g, val_loader, epoch, phase):
    g.eval()
    with torch.no_grad():
        batch = next(iter(val_loader))
        lr_sample, hr_sample = batch['lr'][:4].to(device), batch['hr'][:4].to(device)
        sr_sample = g(lr_sample) if gpu_count == 1 else g.module(lr_sample)
        fig, axes = plt.subplots(4, 3, figsize=(12, 16))
        for i in range(4):
            axes[i,0].imshow(lr_sample[i].permute(1,2,0).cpu()); axes[i,0].set_title('LR'); axes[i,0].axis('off')
            axes[i,1].imshow(sr_sample[i].permute(1,2,0).cpu().clamp(0,1)); axes[i,1].set_title('SR'); axes[i,1].axis('off')
            axes[i,2].imshow(hr_sample[i].permute(1,2,0).cpu()); axes[i,2].set_title('HR'); axes[i,2].axis('off')
        plt.suptitle(f"{phase} Samples - Epoch {epoch}")
        wandb.log({f"SR_{phase}/samples": wandb.Image(fig)})
        plt.close()

print("Starting PSNR pre-training...")
g = train_psnr(g, train_loader, config['sr_epochs_psnr'])
print("PSNR pre-training finished.")

# GAN Fine-Tuning with AMP & WandB (Fixed Grad Flow)
def train_gan(g, d, loader, epochs, g_lr=1e-4, lambda_perc=10.0):
    g_opt = optim.Adam(g.parameters(), lr=g_lr)
    d_opt = optim.Adam(d.parameters(), lr=config['d_lr'])
    scaler_g = GradScaler()
    scaler_d = GradScaler()
    sched_g = optim.lr_scheduler.CosineAnnealingLR(g_opt, T_max=epochs)
    sched_d = optim.lr_scheduler.CosineAnnealingLR(d_opt, T_max=epochs)
    adv = nn.BCEWithLogitsLoss().to(device)
    l1 = nn.L1Loss().to(device)
    
    for ep in range(epochs):
        g.train(); d.train()
        tot_g_loss, tot_d_loss = 0, 0
        pbar = tqdm(loader, desc=f"GAN Ep {ep+1}/{epochs}")
        for i, batch in enumerate(pbar):
            lr_imgs, hr_imgs = batch['lr'].to(device), batch['hr'].to(device)
            
            # D Training (Fix: No detach on own preds; only opponent's mean)
            d_opt.zero_grad(set_to_none=True)
            with autocast():
                real_pred = d(hr_imgs)  # Keep grad
                fake = g(lr_imgs).detach()  # Detach for D
                fake_pred = d(fake)  # Keep grad
                d_loss = (adv(real_pred - fake_pred.mean().detach(), torch.ones_like(real_pred)) +  # Detach opponent's mean for stability
                          adv(fake_pred - real_pred.mean().detach(), torch.zeros_like(fake_pred))) / 2
            scaler_d.scale(d_loss).backward()
            if (i + 1) % config['accum_steps'] == 0:
                scaler_d.step(d_opt)
                scaler_d.update()
                d_opt.zero_grad()
            tot_d_loss += d_loss.item()
            
            # G Training (Detach D's real_pred mean)
            g_opt.zero_grad(set_to_none=True)
            with autocast():
                fake = g(lr_imgs)  # Keep grad
                fake_pred = d(fake)  # G sees D output
                real_pred = d(hr_imgs).detach()  # Detach for G
                g_adv = adv(fake_pred - real_pred.mean(), torch.ones_like(fake_pred))  # No detach on own
                g_perc = perc_loss(fake, hr_imgs)
                g_l1 = l1(fake, hr_imgs)
                g_loss = 0.001 * g_adv + lambda_perc * g_perc + g_l1
            scaler_g.scale(g_loss).backward()
            if (i + 1) % config['accum_steps'] == 0:
                scaler_g.step(g_opt)
                scaler_g.update()
                g_opt.zero_grad()
            tot_g_loss += g_loss.item()
            pbar.set_postfix({"G_Loss": f"{g_loss.item():.4f}", "D_Loss": f"{d_loss.item():.4f}"})
        
        avg_g_loss = tot_g_loss / len(loader)
        avg_d_loss = tot_d_loss / len(loader)
        sched_g.step(); sched_d.step()
        
        # Quick Val (Fix: No detach in eval)
        avg_psnr, avg_ssim = 0, 0
        val_iter = iter(val_loader)
        for _ in range(5):
            try:
                batch = next(val_iter)
                lr_v, hr_v = batch['lr'].to(device), batch['hr'].to(device)
                with torch.no_grad():
                    sr_v = g(lr_v) if gpu_count == 1 else g.module(lr_v)
                avg_psnr += psnr(sr_v, hr_v)  # No .item()
                avg_ssim += compute_ssim(sr_v, hr_v)
            except StopIteration:
                break
        avg_psnr /= 5
        avg_ssim /= 5
        
        wandb.log({
            "SR_GAN/epoch": ep, "SR_GAN/g_loss": avg_g_loss, "SR_GAN/d_loss": avg_d_loss,
            "SR_GAN/psnr": avg_psnr, "SR_GAN/ssim": avg_ssim
        })
        
        if ep % 5 == 0:
            log_sr_samples(g, val_loader, ep, "GAN")
            state_dict = g.module.state_dict() if gpu_count > 1 else g.state_dict()
            torch.save(state_dict, f'/kaggle/working/g_gan_ep{ep}.pth')
            wandb.save(f'/kaggle/working/g_gan_ep{ep}.pth', base_path='/kaggle/working/')  # Fix wandb warning
    return g

print("\nStarting GAN fine-tuning...")
g = train_gan(g, d, train_loader, config['sr_epochs_gan'])
print("GAN fine-tuning finished.")

# Save Final SR Model
state_dict = g.module.state_dict() if gpu_count > 1 else g.state_dict()
torch.save(state_dict, '/kaggle/working/sr_model.pth')
wandb.save('/kaggle/working/sr_model.pth', base_path='/kaggle/working/')
print("SR Model saved & logged to WandB.")

Starting PSNR pre-training...


PSNR Ep 1/20:   0%|          | 0/2894 [00:00<?, ?it/s]



PSNR Ep 2/20:   0%|          | 0/2894 [00:00<?, ?it/s]

PSNR Ep 3/20:   0%|          | 0/2894 [00:00<?, ?it/s]

PSNR Ep 4/20:   0%|          | 0/2894 [00:00<?, ?it/s]

PSNR Ep 5/20:   0%|          | 0/2894 [00:00<?, ?it/s]

PSNR Ep 6/20:   0%|          | 0/2894 [00:00<?, ?it/s]

PSNR Ep 7/20:   0%|          | 0/2894 [00:00<?, ?it/s]

KeyboardInterrupt: 

# cell 5

In [None]:
# RobustClassifier with SE (ResNet-inspired)
class SEBlock(nn.Module):
    def __init__(self, c, r=16):
        super().__init__()
        self.fc = nn.Sequential(nn.Linear(c, c//r), nn.ReLU(), nn.Linear(c//r, c), nn.Sigmoid())
    
    def forward(self, x):
        return x * self.fc(x.mean([2,3], keepdim=True))

class BasicBlock(nn.Module):
    def __init__(self, in_c, out_c, stride=1):
        super().__init__()
        self.conv = nn.Conv2d(in_c, out_c, 3, stride, 1)
        self.bn = nn.BatchNorm2d(out_c)
        self.se = SEBlock(out_c)
        self.shortcut = nn.Conv2d(in_c, out_c, 1, stride) if stride > 1 or in_c != out_c else nn.Identity()
    
    def forward(self, x):
        return F.relu(self.se(self.bn(self.conv(x))) + self.shortcut(x))

class RobustClassifier(nn.Module):
    def __init__(self, num_classes=19):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 7, 2, 3); self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = nn.Sequential(*[BasicBlock(64, 64) for _ in range(3)])
        self.layer2 = nn.Sequential(BasicBlock(64, 128, 2), *[BasicBlock(128, 128) for _ in range(2)])
        self.layer3 = nn.Sequential(BasicBlock(128, 256, 2), *[BasicBlock(256, 256) for _ in range(5)])
        self.layer4 = nn.Sequential(BasicBlock(256, 512, 2), *[BasicBlock(512, 512) for _ in range(2)])
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, num_classes)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)
        return self.fc(self.pool(x).flatten(1))
    
    def get_embeddings(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)
        return self.pool(x).flatten(1)

# AL Helpers
def get_embeddings(clf, loader, sr_model):
    clf.eval(); sr_model.eval()
    embeds, labels = [], []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Extracting Embeds"):
            imgs = sr_model(batch['lr'].to(device)) if sr_model else batch['lr'].to(device)  # Use SR for AL
            emb = clf.module.get_embeddings(imgs) if gpu_count > 1 else clf.get_embeddings(imgs)
            embeds.append(emb.cpu().numpy())
            labels.append(batch['label'].numpy())
    return np.concatenate(embeds), np.concatenate(labels)

def dbss_select(unlabeled_embs, labeled_embs, labels, n_select, pin_ratio=0.4):
    centroids = [labeled_embs[labels == i].mean(0) for i in range(config['num_classes']) if len(labeled_embs[labels == i]) > 0]
    centroids = np.stack(centroids) if centroids else np.zeros((config['num_classes'], labeled_embs.shape[1]))
    dists = np.linalg.norm(unlabeled_embs[:, None] - centroids[None], axis=2)
    inner_scores = -np.sum(dists, axis=1)  # Favor low dist (inner-class)
    border_scores = np.abs(np.sort(dists, axis=1)[:,0] - np.sort(dists, axis=1)[:,1])
    inner_top = np.argsort(inner_scores)[-int(n_select * pin_ratio):]
    border_top = np.argsort(border_scores)[-int(n_select * (1-pin_ratio)):]
    return np.unique(np.concatenate([inner_top, border_top]))[:n_select]

def ssas_pseudo(student, teacher, sr_model, unlabeled_loader):
    student.eval(); teacher.eval(); sr_model.eval()
    consistent = []
    with torch.no_grad():
        for batch in tqdm(unlabeled_loader, desc="SSAS"):
            imgs = sr_model(batch['lr'].to(device))
            s_pred = torch.argmax(student(imgs), 1)
            t_pred = torch.argmax(teacher(imgs), 1)
            mask = s_pred == t_pred
            consistent.extend(batch['idx'][mask].tolist())  # Assume 'idx' in batch
    return consistent[:len(unlabeled_loader) * config['batch_size'] // 2]  # Limit

def train_classifier(clf, loader, epochs):
    clf.train()
    opt = optim.Adam(clf.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss().to(device)
    for ep in range(epochs):
        tot_loss = 0
        pbar = tqdm(loader, desc=f"CLF Ep {ep+1}/{epochs}")
        for batch in pbar:
            imgs, lbls = batch['lr'].to(device), batch['label'].to(device)  # Use SR in full loop
            opt.zero_grad()
            outputs = clf(imgs)
            loss = criterion(outputs, lbls)
            loss.backward()
            opt.step()
            tot_loss += loss.item()
            pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
        # Fixed: Remove local_rank check - just log directly
        wandb.log({"AL_Train/loss": tot_loss / len(loader)})
    return clf

def evaluate_model(clf, loader, sr_model=None):
    clf.eval()
    preds, lbls = [], []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Eval"):
            imgs = sr_model(batch['lr'].to(device)) if sr_model else batch['lr'].to(device)
            outputs = clf(imgs)
            pred = torch.argmax(outputs, 1)
            preds.extend(pred.cpu().numpy())
            lbls.extend(batch['label'].numpy())
    acc = (np.array(preds) == np.array(lbls)).mean() * 100
    return acc, preds, lbls

def log_umap(labeled_embs, labeled_lbls, unlabeled_embs, dbss_idx, ssas_idx, cycle):
    reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
    all_embs = np.concatenate([labeled_embs, unlabeled_embs])
    all_lbls = np.concatenate([labeled_lbls, -1 * np.ones(len(unlabeled_embs))])
    emb_2d = reducer.fit_transform(all_embs)
    fig, ax = plt.subplots(figsize=(14,10))
    unlabeled_mask = all_lbls == -1
    ax.scatter(emb_2d[unlabeled_mask,0], emb_2d[unlabeled_mask,1], c='lightgray', s=5, label='Unlabeled')
    labeled_mask = ~unlabeled_mask
    scatter = ax.scatter(emb_2d[labeled_mask,0], emb_2d[labeled_mask,1], c=all_lbls[labeled_mask], cmap='Spectral', s=20)
    if dbss_idx.size > 0:
        ax.scatter(emb_2d[len(labeled_embs):][dbss_idx], emb_2d[len(labeled_embs):][dbss_idx], c='red', s=100, marker='x', label='DBSS')
    if ssas_idx.size > 0:
        ax.scatter(emb_2d[len(labeled_embs):][ssas_idx], emb_2d[len(labeled_embs):][ssas_idx], c='lime', s=100, marker='+', label='SSAS')
    ax.set_title(f'Feature Space UMAP - Cycle {cycle}')
    ax.legend()
    plt.tight_layout()
    # Fixed: Remove local_rank check - just log directly
    wandb.log({f"AL_Cycle_{cycle}/umap": wandb.Image(fig)})
    plt.close()

# Placeholder for log_selected_samples, log_confusion_matrix, etc. (adapt from your prev code; log to WandB as Image)
def log_confusion_matrix(preds, lbls, names):
    cm = confusion_matrix(lbls, preds)
    fig, ax = plt.subplots(figsize=(12,10))
    sns.heatmap(cm, annot=True, fmt='d', ax=ax, cmap='Blues', xticklabels=names[:19], yticklabels=names[:19])  # Truncate if needed
    ax.set_xlabel('Predicted'); ax.set_ylabel('True'); ax.set_title('Confusion Matrix')
    plt.tight_layout()
    # Fixed: Remove local_rank check - just log directly
    wandb.log({"AL_ConfMatrix": wandb.Image(fig)})
    plt.close()
    return fig

print("AL components defined.")
print(f"Dual GPU mode: {'Enabled' if gpu_count > 1 else 'Disabled'} ({gpu_count} GPUs detected)")

# cell 6

In [None]:
# --------------------------------------------------------------------------------
# SECTION 6: ACTIVE LEARNING PIPELINE
# --------------------------------------------------------------------------------

# Load SR for AL (unwrap if DataParallel for saving)
# We assume 'g' (the SR model) is already trained and exists in memory from Cell 4
if gpu_count > 1:
    g.module.load_state_dict(torch.load('/kaggle/working/sr_model.pth'))
else:
    g.load_state_dict(torch.load('/kaggle/working/sr_model.pth'))
g.eval()

# Instantiate Classifier with DataParallel for dual T4 GPUs
clf = RobustClassifier(num_classes=config['num_classes']).to(device)
if gpu_count > 1:
    clf = nn.DataParallel(clf)
    print(f"Classifier using DataParallel across {gpu_count} GPUs")
wandb.watch(clf, log="all", log_freq=100)

# AL Setup: Initial labeled pool
# Use the real labels for stratification from train_idx
labeled_indices, unlabeled_indices = train_test_split(
    train_idx, 
    train_size=int(0.1 * len(train_idx)), 
    stratify=[all_labels[train_idx.index(i)] if i in train_idx else 0 for i in train_idx][:len(train_idx)], 
    random_state=42
)

# AL Dataset for subsets (adapt BigEarthNetSR to include idx)
class IndexedALDataset(Dataset):
    def __init__(self, sr_ds, indices):
        self.sr_ds = sr_ds
        self.indices = indices
    
    def __len__(self): return len(self.indices)
    
    def __getitem__(self, idx):
        item = self.sr_ds[self.indices[idx]]
        item['idx'] = torch.tensor(self.indices[idx])
        return item

labeled_loader = DataLoader(IndexedALDataset(train_ds, labeled_indices), batch_size=config['batch_size'], shuffle=True, num_workers=2)
unlabeled_loader = DataLoader(IndexedALDataset(train_ds, unlabeled_indices), batch_size=config['batch_size'], shuffle=False, num_workers=2)

teacher = None
for cycle in range(config['al_cycles']):
    print(f"\n--- AL Cycle {cycle+1}/{config['al_cycles']} (Dual T4 GPU Mode) ---")
    wandb.log({"AL_Cycle": cycle})
    
    # Train Classifier (use SR inputs) - optimized for dual GPUs
    clf = train_classifier(clf, labeled_loader, config['al_epochs'])
    
    # Get Embeddings (handle DataParallel for dual GPUs)
    def safe_get_embeds(model, loader, sr_model):
        model.eval(); sr_model.eval()
        embeds, labels = [], []
        with torch.no_grad():
            for batch in tqdm(loader, desc="Extracting Embeds (Dual GPU)"):
                imgs = sr_model(batch['lr'].to(device)) if sr_model else batch['lr'].to(device)
                if gpu_count > 1:
                    emb = model.module.get_embeddings(imgs)
                else:
                    emb = model.get_embeddings(imgs)
                embeds.append(emb.cpu().numpy())
                labels.append(batch['label'].numpy())
        return np.concatenate(embeds), np.concatenate(labels)
    
    labeled_embs, labeled_lbls = safe_get_embeds(clf, labeled_loader, g)
    unlabeled_embs, _ = safe_get_embeds(clf, unlabeled_loader, g)
    
    # DBSS Selection
    n_select = min(int(len(unlabeled_indices) * 0.1), len(unlabeled_embs))
    dbss_local_idx = dbss_select(unlabeled_embs, labeled_embs, labeled_lbls, n_select)
    newly_labeled_human = [unlabeled_indices[i] for i in dbss_local_idx]
    
    # SSAS (post-cycle 1)
    newly_labeled_pseudo = []
    if cycle >= 1 and teacher is not None:
        pseudo_local = ssas_pseudo(clf, teacher, g, unlabeled_loader)
        newly_labeled_pseudo = [unlabeled_indices[i] for i in pseudo_local if i < len(unlabeled_indices)]
    
    # Update Pools
    all_new = np.concatenate([newly_labeled_human, newly_labeled_pseudo])
    labeled_indices = np.concatenate([labeled_indices, all_new])
    unlabeled_indices = np.setdiff1d(unlabeled_indices, all_new)
    
    # Log Visuals (UMAP)
    log_umap(labeled_embs, labeled_lbls, unlabeled_embs, dbss_local_idx, np.array(pseudo_local) if 'pseudo_local' in locals() else np.array([]), cycle)
    
    # Update Teacher (deepcopy handles DataParallel)
    teacher = copy.deepcopy(clf)
    
    # Cycle Eval (handle DataParallel for dual GPUs)
    def safe_eval(model, loader, sr_model):
        model.eval()
        preds, lbls = [], []
        with torch.no_grad():
            for batch in tqdm(loader, desc="Eval (Dual GPU)"):
                imgs = sr_model(batch['lr'].to(device)) if sr_model else batch['lr'].to(device)
                outputs = model(imgs)
                pred = torch.argmax(outputs, 1)
                preds.extend(pred.cpu().numpy())
                lbls.extend(batch['label'].numpy())
        acc = (np.array(preds) == np.array(lbls)).mean() * 100
        return acc, preds, lbls
    
    val_acc, val_preds, val_lbls = safe_eval(clf, val_loader, g)
    wandb.log({f"AL_Cycle_{cycle}/accuracy": val_acc})
    log_confusion_matrix(val_preds, val_lbls, [f"Class_{i}" for i in range(config['num_classes'])])
    print(f"Cycle {cycle} Val Acc: {val_acc:.2f}% (Dual T4 GPU)")
    
    # Update Loaders
    labeled_loader = DataLoader(IndexedALDataset(train_ds, labeled_indices), batch_size=config['batch_size'], shuffle=True, num_workers=2)
    unlabeled_loader = DataLoader(IndexedALDataset(train_ds, unlabeled_indices), batch_size=config['batch_size'], shuffle=False, num_workers=2)

# Save Classifier (unwrap if DataParallel)
save_model(clf, '/kaggle/working/clf_model.pth')

print("AL Pipeline complete with Dual T4 GPU optimization!")

# cell 7 

In [None]:
# --------------------------------------------------------------------------------
# SECTION 7: FINAL EVALUATION
# --------------------------------------------------------------------------------

print("Starting final evaluation...")

# We assume 'clf', 'val_loader', and 'g' (SR model) exist from Cell 6
# We assume 'psnr' and 'compute_ssim' exist from Cell 3

final_acc, final_preds, final_lbls = evaluate_model(clf, val_loader, g)

# Calculate final PSNR/SSIM on a subset of the validation loader
final_psnr, final_ssim, batch_count = 0, 0, 0
val_iter = iter(val_loader)
for _ in range(10): # Use 10 batches for a good approximation
    try:
        batch = next(val_iter)
        lr_v, hr_v = batch['lr'].to(device), batch['hr'].to(device)
        with torch.no_grad():
            sr_v = g(lr_v) if world_size == 1 else g.module(lr_v)
        final_psnr += psnr(sr_v, hr_v)
        final_ssim += compute_ssim(sr_v, hr_v)
        batch_count += 1
    except StopIteration:
        break # Stop if val_loader has fewer than 10 batches
    except Exception as e:
        print(f"Error in final eval: {e}")

final_psnr /= batch_count
final_ssim /= batch_count

# (FIXED) Removed `if local_rank == 0:` as this is not distributed training
wandb.log({
    "final/psnr": final_psnr, "final/ssim": final_ssim,
    "final/accuracy": final_acc
})

# Summary Table
table = wandb.Table(columns=["Metric", "Value"])
table.add_data(["PSNR (dB)", f"{final_psnr:.2f}"])
table.add_data(["SSIM", f"{final_ssim:.3f}"])
table.add_data(["AL Accuracy (%)", f"{final_acc:.2f}"])
wandb.log({"final/summary": table})

print(f"Final PSNR: {final_psnr:.2f} | SSIM: {final_ssim:.3f} | AL Acc: {final_acc:.2f}%")

wandb.finish()
print("Pipeline complete! Check WandB dashboard for full logs/visuals.")
