In [None]:
# ===============================================================================
# CELL 1: Dependencies & WandB Setup
# ===============================================================================
print("Installing dependencies...")
!pip install -q torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
!pip install -q wandb scikit-learn matplotlib seaborn tqdm Pillow rasterio pandas

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torch.cuda.amp import autocast, GradScaler
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from pathlib import Path
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
import rasterio
import wandb
import os
import gc
import time
import warnings
warnings.filterwarnings('ignore')

# WandB Login
wandb.login(key="5424a3d65aac1662f5be82d4439aaac35046689e")
print("‚úì WandB authenticated")

# Device Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpu_count = torch.cuda.device_count()
print(f"\nüöÄ Device: {device} | GPUs: {gpu_count}")
if gpu_count > 0:
    for i in range(gpu_count):
        print(f"   GPU {i}: {torch.cuda.get_device_name(i)}")

# Optimized Config for 12h training with enhanced accuracy
config = {
    # Dataset
    'dataset_size': 60000,  # Increased from 50k (still under 12h)
    'num_classes': 19,      # BigEarthNet-19 classes
    
    # SR Model
    'sr_model_path': '/kaggle/input/sr-model/pytorch/default/3/generator_ensemble.pth',
    'lr_size': 32,          # LR input size
    'hr_size': 128,         # HR output size (32*4)
    
    # Classifier Training - OPTIMIZED
    'clf_epochs': 30,       # Increased from 20 for better convergence
    'batch_size': 48,       # Increased from 32 (faster training)
    'lr': 2e-4,             # Increased from 1e-4 (faster learning)
    'weight_decay': 5e-5,   # Reduced from 1e-4 (less regularization)
    'warmup_epochs': 3,     # NEW: Warmup for stable start
    'label_smoothing': 0.1, # NEW: Better generalization
    
    # Active Learning
    'al_cycles': 4,
    'al_epochs_per_cycle': 10,
    'initial_labeled_ratio': 0.1,
    'query_size_ratio': 0.1,
    
    # Training Enhancements
    'num_workers': 4,
    'pin_memory': True,
    'mixed_precision': True,
    'gradient_clip': 1.0,   # NEW: Prevent gradient explosion
    'ema_decay': 0.999,     # NEW: Exponential Moving Average for stability
}

# Initialize WandB
wandb.init(
    project="SR-ResNet-AL-Classification",
    config=config,
    name=f"SR-ResNet-AL-{time.strftime('%Y%m%d-%H%M%S')}"
)

print("\n‚úì Setup complete!")
print(f"Config: {config}")

Installing dependencies...
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m2.2/2.2 GB[0m [31m481.7 kB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m7.0/7.0 MB[0m [31m29.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m89.2/89.2 MB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchaudio 2.6.0+cu124 requires torch==2.6.0, but you have torch 2.1.0+cu121 which is incompatible.[0m[31m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚î

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mhegdesudarshan[0m ([33mhegdesudarshan-hegde[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


‚úì WandB authenticated

üöÄ Device: cuda | GPUs: 2
   GPU 0: Tesla T4
   GPU 1: Tesla T4



‚úì Setup complete!
Config: {'dataset_size': 50000, 'num_classes': 19, 'sr_model_path': '/kaggle/input/sr-model/pytorch/default/3/generator_ensemble.pth', 'lr_size': 32, 'hr_size': 128, 'clf_epochs': 20, 'batch_size': 32, 'lr': 0.0001, 'weight_decay': 0.0001, 'al_cycles': 4, 'al_epochs_per_cycle': 10, 'initial_labeled_ratio': 0.1, 'query_size_ratio': 0.1, 'num_workers': 4, 'pin_memory': True, 'mixed_precision': True}


In [2]:
# ===============================================================================
# CELL 2: Load Pre-trained SR Model (EXACT ARCHITECTURE FROM CHECKPOINT)
# ===============================================================================

class RFB(nn.Module):
    """Receptive Field Block - EXACT match to checkpoint"""
    def __init__(self, in_channels=64):
        super().__init__()
        # Branch 1: AvgPool(3) + Conv + ReLU + Conv
        self.branch1 = nn.Sequential(
            nn.AvgPool2d(3, stride=1, padding=1),
            nn.Conv2d(in_channels, 16, 1, 1, 0),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 16, 3, 1, padding=1, dilation=1),
            nn.ReLU(inplace=True)
        )
        # Branch 2: AvgPool(5) + Conv + ReLU + Conv
        self.branch2 = nn.Sequential(
            nn.AvgPool2d(5, stride=1, padding=2),
            nn.Conv2d(in_channels, 24, 1, 1, 0),
            nn.ReLU(inplace=True),
            nn.Conv2d(24, 24, 3, 1, padding=2, dilation=2),
            nn.ReLU(inplace=True)
        )
        # Branch 3: AvgPool(7) + Conv + ReLU + Conv
        self.branch3 = nn.Sequential(
            nn.AvgPool2d(7, stride=1, padding=3),
            nn.Conv2d(in_channels, 24, 1, 1, 0),
            nn.ReLU(inplace=True),
            nn.Conv2d(24, 24, 3, 1, padding=3, dilation=3),
            nn.ReLU(inplace=True)
        )
        # Changed to match checkpoint: conv_concat instead of conv
        self.conv_concat = nn.Sequential(
            nn.Conv2d(64, 64, 1, 1, 0)
        )
        
    def forward(self, x):
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        out = torch.cat([b1, b2, b3], 1)
        return self.conv_concat(out) * 0.2 + x

class DenseBlock(nn.Module):
    """Dense Block with 5 conv layers - MODIFIED channel counts to match checkpoint"""
    def __init__(self, nf=64):
        super().__init__()
        # Changed: nf=64 in Generator, but DenseBlock uses nf_internal=32
        nf_internal = 32
        self.conv1 = nn.Conv2d(nf, nf_internal, 3, 1, 1)
        self.conv2 = nn.Conv2d(nf + nf_internal, nf_internal, 3, 1, 1)
        self.conv3 = nn.Conv2d(nf + nf_internal*2, nf_internal, 3, 1, 1)
        self.conv4 = nn.Conv2d(nf + nf_internal*3, nf_internal, 3, 1, 1)
        self.conv5 = nn.Conv2d(nf + nf_internal*4, nf, 3, 1, 1)
        
    def forward(self, x):
        x1 = F.relu(self.conv1(x))
        x2 = F.relu(self.conv2(torch.cat([x, x1], 1)))
        x3 = F.relu(self.conv3(torch.cat([x, x1, x2], 1)))
        x4 = F.relu(self.conv4(torch.cat([x, x1, x2, x3], 1)))
        x5 = self.conv5(torch.cat([x, x1, x2, x3, x4], 1))
        return x5 * 0.2 + x

class RRDB(nn.Module):
    """Residual-in-Residual Dense Block (3 DenseBlocks)"""
    def __init__(self, nf=64):
        super().__init__()
        self.db1 = DenseBlock(nf)
        self.db2 = DenseBlock(nf)
        self.db3 = DenseBlock(nf)
        
    def forward(self, x):
        out = self.db3(self.db2(self.db1(x)))
        return out * 0.2 + x

class RRFDB(nn.Module):
    """Residual RFB Dense Block - MODIFIED to match checkpoint structure"""
    def __init__(self, nf=64):
        super().__init__()
        # Changed: Use named attributes instead of ModuleList to match checkpoint keys
        self.rfb1 = RFB(nf)
        self.rfb2 = RFB(nf)
        self.rfb3 = RFB(nf)
        self.rfb4 = RFB(nf)
        self.rfb5 = RFB(nf)
        
    def forward(self, x):
        out = self.rfb1(x)
        out = self.rfb2(out)
        out = self.rfb3(out)
        out = self.rfb4(out)
        out = self.rfb5(out)
        return out * 0.2 + x

class Generator(nn.Module):
    """Generator: 12 RRDB + 6 RRFDB + 8x upscale - EXACT architecture from checkpoint"""
    def __init__(self, num_rrdb=12, num_rrfdb=6, nf=64):
        super().__init__()
        self.conv_first = nn.Conv2d(3, nf, 3, 1, 1)
        
        # Trunk A: 12 RRDB blocks
        self.trunk_a = nn.Sequential(*[RRDB(nf) for _ in range(num_rrdb)])
        
        # Trunk RFB: 6 RRFDB blocks
        self.trunk_rfb = nn.Sequential(*[RRFDB(nf) for _ in range(num_rrfdb)])
        
        # RFB upsampling
        self.rfb_up = RFB(nf)
        
        # 8x upscaling (3 PixelShuffle layers: 2x each = 2^3 = 8x)
        self.upsample = nn.Sequential(
            nn.Conv2d(nf, nf*4, 3, 1, 1),
            nn.PixelShuffle(2),
            nn.ReLU(inplace=True),
            nn.Conv2d(nf, nf*4, 3, 1, 1),
            nn.PixelShuffle(2),
            nn.ReLU(inplace=True),
            nn.Conv2d(nf, nf*4, 3, 1, 1),
            nn.PixelShuffle(2),
            nn.ReLU(inplace=True)
        )
        
        # Changed to match checkpoint: Sequential with conv_final layers
        self.conv_final = nn.Sequential(
            nn.Conv2d(nf, nf, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(nf, 3, 3, 1, 1)
        )
        
    def forward(self, x):
        feat = self.conv_first(x)
        trunk_a_out = self.trunk_a(feat)
        trunk_rfb_out = self.trunk_rfb(trunk_a_out)
        rfb_up_out = self.rfb_up(trunk_rfb_out)
        up = self.upsample(rfb_up_out + feat)
        return torch.tanh(self.conv_final(up))

# Load Pre-trained SR Model
print("Loading pre-trained SR model...")
sr_model = Generator(num_rrdb=12, num_rrfdb=6, nf=64).to(device)

try:
    state_dict = torch.load(config['sr_model_path'], map_location=device)
    sr_model.load_state_dict(state_dict)
    sr_model.eval()
    
    # Test SR model
    with torch.no_grad():
        test_input = torch.randn(1, 3, 32, 32).to(device)
        test_output = sr_model(test_input)
        print(f"‚úì SR Model loaded: {test_input.shape} ‚Üí {test_output.shape}")
        
    # Count parameters
    total_params = sum(p.numel() for p in sr_model.parameters())
    print(f"  Parameters: {total_params/1e6:.2f}M")
    
except Exception as e:
    print(f"‚ùå Error loading SR model: {e}")
    import traceback
    traceback.print_exc()
    raise

# Freeze SR model (no training needed)
for param in sr_model.parameters():
    param.requires_grad = False

print("\n‚úì SR model ready for inference!")


Loading pre-trained SR model...
‚úì SR Model loaded: torch.Size([1, 3, 32, 32]) ‚Üí torch.Size([1, 3, 256, 256])
  Parameters: 9.77M

‚úì SR model ready for inference!


In [3]:
# ===============================================================================
# CELL 3: Dataset Loading (BigEarthNet)
# ===============================================================================

class BigEarthNetDataset(Dataset):
    """BigEarthNet dataset with SR preprocessing"""
    def __init__(self, root_path, patch_ids, patch_to_bands, patch_to_label, 
                 sr_model=None, phase='train'):
        self.root_path = root_path
        self.patch_ids = patch_ids
        self.patch_to_bands = patch_to_bands
        self.patch_to_label = patch_to_label
        self.sr_model = sr_model
        self.phase = phase
        
        # Transforms
        if phase == 'train':
            self.spatial_aug = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(90)
            ])
        else:
            self.spatial_aug = None
            
        self.to_tensor = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
    
    def __len__(self):
        return len(self.patch_ids)
    
    def __getitem__(self, idx):
        patch_id = self.patch_ids[idx]
        bands = self.patch_to_bands[patch_id]
        
        try:
            # Load RGB bands (B04=Red, B03=Green, B02=Blue)
            b02 = rasterio.open(bands['02']).read(1).astype(np.float32) / 10000.0
            b03 = rasterio.open(bands['03']).read(1).astype(np.float32) / 10000.0
            b04 = rasterio.open(bands['04']).read(1).astype(np.float32) / 10000.0
            
            # Stack to RGB (120x120)
            hr_np = np.stack([b04, b03, b02], axis=-1)
            hr_np = np.clip(hr_np, 0, 1)
            
            # Convert to PIL for augmentation
            hr_pil = Image.fromarray((hr_np * 255).astype(np.uint8))
            
            # Apply spatial augmentation
            if self.spatial_aug:
                hr_pil = self.spatial_aug(hr_pil)
            
            # Resize to 32x32 for LR input
            lr_pil = hr_pil.resize((32, 32), Image.BICUBIC)
            
            # To tensor
            lr_tensor = self.to_tensor(lr_pil)
            
            # Get label (multi-hot ‚Üí single label via argmax)
            label_multihot = self.patch_to_label.get(patch_id, torch.zeros(config['num_classes']))
            label = torch.argmax(label_multihot).long()
            
            return {
                'lr': lr_tensor,
                'label': label,
                'patch_id': patch_id
            }
            
        except Exception as e:
            # Return black dummy on error
            print(f"Error loading {patch_id}: {e}")
            return {
                'lr': torch.zeros(3, 32, 32),
                'label': torch.tensor(0, dtype=torch.long),
                'patch_id': patch_id
            }

# Load BigEarthNet metadata
print("Loading BigEarthNet dataset...")
image_root_path = '/kaggle/input/bigearthnetv2-s2-4/'

# Find all TIF files
import glob
all_tif_paths = glob.glob(os.path.join(image_root_path, '**/*.tif'), recursive=True)
print(f"Found {len(all_tif_paths)} band files")

# Group by patch ID
patch_to_bands = {}
for path in all_tif_paths:
    fname = os.path.basename(path)
    if '_B' in fname:
        patch_id = '_'.join(fname.split('_B')[:-1])
        band = fname.split('_B')[-1].split('.')[0]
        if patch_id not in patch_to_bands:
            patch_to_bands[patch_id] = {}
        patch_to_bands[patch_id][band] = path

# Filter patches with RGB bands
valid_patches = [pid for pid, bands in patch_to_bands.items() 
                 if all(b in bands for b in ['02', '03', '04'])]
valid_patches = valid_patches[:config['dataset_size']]
print(f"Valid RGB patches: {len(valid_patches)}")

# Load labels from metadata
metadata_path = os.path.join(image_root_path, 'metadata.parquet')
if os.path.exists(metadata_path):
    df = pd.read_parquet(metadata_path)
    patch_to_label = {}
    for _, row in df.iterrows():
        pid = row['patch_id']
        labels_list = row['labels'] if isinstance(row['labels'], list) else []
        multi_hot = torch.zeros(config['num_classes'])
        for lbl in labels_list:
            if 0 <= lbl < config['num_classes']:
                multi_hot[lbl] = 1.0
        if pid in valid_patches:
            patch_to_label[pid] = multi_hot
    print(f"Loaded labels for {len(patch_to_label)} patches")
else:
    print("Warning: metadata.parquet not found, using dummy labels")
    patch_to_label = {pid: torch.zeros(config['num_classes']) for pid in valid_patches}

# Train/Val split
train_ids, val_ids = train_test_split(valid_patches, test_size=0.2, random_state=42)
print(f"Split: {len(train_ids)} train, {len(val_ids)} val")

# Create datasets
train_dataset = BigEarthNetDataset(image_root_path, train_ids, patch_to_bands, 
                                   patch_to_label, sr_model, phase='train')
val_dataset = BigEarthNetDataset(image_root_path, val_ids, patch_to_bands, 
                                 patch_to_label, sr_model, phase='val')

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], 
                         shuffle=True, num_workers=config['num_workers'], 
                         pin_memory=config['pin_memory'])
val_loader = DataLoader(val_dataset, batch_size=config['batch_size']*2, 
                       shuffle=False, num_workers=config['num_workers'], 
                       pin_memory=config['pin_memory'])

print("\n‚úì Dataset loaded!")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

# Test batch
sample_batch = next(iter(train_loader))
print(f"\nSample batch shapes:")
print(f"  LR: {sample_batch['lr'].shape}")
print(f"  Label: {sample_batch['label'].shape}")

Loading BigEarthNet dataset...
Found 347244 band files
Valid RGB patches: 28937
Split: 23149 train, 5788 val

‚úì Dataset loaded!
  Train batches: 724
  Val batches: 91

Sample batch shapes:
  LR: torch.Size([32, 3, 32, 32])
  Label: torch.Size([32])


In [4]:
# ===============================================================================
# CELL 3.5: GPU Memory Cleanup (Run this if you encounter CUDA errors)
# ===============================================================================

print("Cleaning up GPU memory...")
torch.cuda.empty_cache()
gc.collect()

# Reset CUDA device if needed
if torch.cuda.is_available():
    torch.cuda.synchronize()
    print(f"‚úì GPU memory cleaned")
    for i in range(torch.cuda.device_count()):
        mem_allocated = torch.cuda.memory_allocated(i) / 1024**3
        mem_reserved = torch.cuda.memory_reserved(i) / 1024**3
        print(f"  GPU {i}: {mem_allocated:.2f}GB allocated, {mem_reserved:.2f}GB reserved")


Cleaning up GPU memory...
‚úì GPU memory cleaned
  GPU 0: 0.07GB allocated, 0.08GB reserved
  GPU 1: 0.00GB allocated, 0.00GB reserved


In [None]:
# ===============================================================================
# CELL 4: ResNet Classifier Definition
# ===============================================================================

class SREnhancedClassifier(nn.Module):
    """ResNet50-based classifier that processes SR-enhanced images"""
    def __init__(self, num_classes, sr_model, pretrained=True):
        super().__init__()
        self.sr_model = sr_model  # Frozen SR model
        
        # Load pretrained ResNet50
        if pretrained:
            weights = ResNet50_Weights.IMAGENET1K_V2
            self.backbone = resnet50(weights=weights)
        else:
            self.backbone = resnet50(weights=None)
        
        # Replace final FC layer with enhanced classifier head
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Dropout(0.4),  # Increased dropout
            nn.Linear(in_features, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, lr_images):
        # Step 1: SR enhancement (frozen) - 32x32 ‚Üí 256x256
        with torch.no_grad():
            sr_images = self.sr_model(lr_images)
            # Resize 256x256 ‚Üí 224x224 for ResNet50
            sr_images = F.interpolate(sr_images, size=(224, 224), mode='bilinear', align_corners=False)
        
        # Step 2: ResNet classification
        x = self.backbone.conv1(sr_images)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        
        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)
        
        x = self.backbone.avgpool(x)
        features = torch.flatten(x, 1)
        output = self.backbone.fc(features)
        return output
    
    def get_features(self, lr_images):
        """Extract features for active learning"""
        with torch.no_grad():
            sr_images = self.sr_model(lr_images)
            sr_images = F.interpolate(sr_images, size=(224, 224), mode='bilinear', align_corners=False)
            
            x = self.backbone.conv1(sr_images)
            x = self.backbone.bn1(x)
            x = self.backbone.relu(x)
            x = self.backbone.maxpool(x)
            
            x = self.backbone.layer1(x)
            x = self.backbone.layer2(x)
            x = self.backbone.layer3(x)
            x = self.backbone.layer4(x)
            
            x = self.backbone.avgpool(x)
            return torch.flatten(x, 1)

# Create classifier with error handling
print("Creating SR-Enhanced ResNet50 Classifier...")

try:
    # Clean GPU memory before creating model
    torch.cuda.empty_cache()
    gc.collect()
    
    # Create classifier on CPU first
    classifier = SREnhancedClassifier(config['num_classes'], sr_model, pretrained=True)
    
    # Move to GPU carefully
    classifier = classifier.to(device)
    print("‚úì Classifier moved to GPU")
    
except RuntimeError as e:
    if "CUDA" in str(e):
        print(f"‚ö† CUDA Error: {e}")
        print("Attempting recovery: Restarting kernel may help")
        print("Run the memory cleanup cell (Cell 3.5) and try again")
        raise
    else:
        raise

if gpu_count > 1:
    classifier = nn.DataParallel(classifier)
    print(f"  Using DataParallel across {gpu_count} GPUs")

# Count parameters
trainable_params = sum(p.numel() for p in classifier.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in classifier.parameters())
print(f"\n‚úì Classifier created")
print(f"  Total parameters: {total_params/1e6:.2f}M")
print(f"  Trainable parameters: {trainable_params/1e6:.2f}M")

# Test forward pass
try:
    with torch.no_grad():
        test_lr = torch.randn(2, 3, 32, 32).to(device)
        test_output = classifier(test_lr)
        print(f"\nTest forward pass: {test_lr.shape} ‚Üí {test_output.shape}")
        print(f"Output range: [{test_output.min():.3f}, {test_output.max():.3f}]")
except RuntimeError as e:
    print(f"‚ö† Test forward pass failed: {e}")
    print("This may indicate GPU memory issues. Try restarting the kernel.")


Creating SR-Enhanced ResNet50 Classifier...


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 97.8M/97.8M [00:00<00:00, 151MB/s] 


‚úì Classifier moved to GPU
  Using DataParallel across 2 GPUs

‚úì Classifier created
  Total parameters: 33.32M
  Trainable parameters: 23.55M

Test forward pass: torch.Size([2, 3, 32, 32]) ‚Üí torch.Size([2, 19])
Output range: [-0.144, 0.135]


In [6]:
# ===============================================================================
# CELL 5: Training & Evaluation Functions
# ===============================================================================

def train_epoch(model, loader, criterion, optimizer, scaler, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc="Training")
    for batch in pbar:
        lr_imgs = batch['lr'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        
        if config['mixed_precision']:
            with autocast():
                outputs = model(lr_imgs)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(lr_imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        total_loss += loss.item() * lr_imgs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'acc': f"{100.*correct/total:.2f}%"
        })
    
    return total_loss / total, 100. * correct / total

def evaluate(model, loader, criterion, device):
    """Evaluate model"""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            lr_imgs = batch['lr'].to(device)
            labels = batch['label'].to(device)
            
            if config['mixed_precision']:
                with autocast():
                    outputs = model(lr_imgs)
                    loss = criterion(outputs, labels)
            else:
                outputs = model(lr_imgs)
                loss = criterion(outputs, labels)
            
            total_loss += loss.item() * lr_imgs.size(0)
            _, predicted = outputs.max(1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    accuracy = accuracy_score(all_labels, all_preds) * 100
    f1 = f1_score(all_labels, all_preds, average='macro') * 100
    avg_loss = total_loss / len(all_labels)
    
    return avg_loss, accuracy, f1, all_preds, all_labels

def plot_confusion_matrix(y_true, y_pred, epoch):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=False, fmt='d', cmap='Blues')
    plt.title(f'Confusion Matrix - Epoch {epoch}')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    
    # Log to wandb
    wandb.log({"confusion_matrix": wandb.Image(plt)})
    plt.close()

print("‚úì Training functions defined")

‚úì Training functions defined


In [None]:
# ===============================================================================
# CELL 5.5: Comprehensive Evaluation Metrics & Visualization Functions
# ===============================================================================

def compute_per_class_metrics(y_true, y_pred, num_classes=19):
    """Compute precision, recall, F1 per class"""
    from sklearn.metrics import precision_recall_fscore_support
    precision, recall, f1, support = precision_recall_fscore_support(
        y_true, y_pred, labels=list(range(num_classes)), zero_division=0
    )
    return precision, recall, f1, support

def plot_training_curves(history):
    """Plot training and validation curves"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss
    axes[0, 0].plot(history['train_loss'], label='Train', linewidth=2)
    axes[0, 0].plot(history['val_loss'], label='Val', linewidth=2)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Loss Curves')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[0, 1].plot(history['train_acc'], label='Train', linewidth=2)
    axes[0, 1].plot(history['val_acc'], label='Val', linewidth=2)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].set_title('Accuracy Curves')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # F1 Score
    axes[1, 0].plot(history['val_f1'], label='Val F1 (Macro)', linewidth=2, color='green')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('F1 Score (%)')
    axes[1, 0].set_title('F1 Score Curve')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Learning Rate
    axes[1, 1].plot(history['lr'], linewidth=2, color='red')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Learning Rate')
    axes[1, 1].set_title('Learning Rate Schedule')
    axes[1, 1].set_yscale('log')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    wandb.log({"training_curves": wandb.Image(plt)})
    plt.savefig('/kaggle/working/training_curves.png', dpi=300, bbox_inches='tight')
    plt.close()

def plot_per_class_performance(precision, recall, f1, support, class_names=None):
    """Plot per-class metrics"""
    if class_names is None:
        class_names = [f'Class {i}' for i in range(len(precision))]
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    x = np.arange(len(class_names))
    
    # Precision
    axes[0, 0].bar(x, precision * 100, color='skyblue', edgecolor='navy', alpha=0.7)
    axes[0, 0].set_xlabel('Class')
    axes[0, 0].set_ylabel('Precision (%)')
    axes[0, 0].set_title('Per-Class Precision')
    axes[0, 0].set_xticks(x)
    axes[0, 0].set_xticklabels(class_names, rotation=45, ha='right', fontsize=8)
    axes[0, 0].grid(True, alpha=0.3, axis='y')
    axes[0, 0].axhline(y=precision.mean()*100, color='red', linestyle='--', linewidth=2, label='Mean')
    axes[0, 0].legend()
    
    # Recall
    axes[0, 1].bar(x, recall * 100, color='lightgreen', edgecolor='darkgreen', alpha=0.7)
    axes[0, 1].set_xlabel('Class')
    axes[0, 1].set_ylabel('Recall (%)')
    axes[0, 1].set_title('Per-Class Recall')
    axes[0, 1].set_xticks(x)
    axes[0, 1].set_xticklabels(class_names, rotation=45, ha='right', fontsize=8)
    axes[0, 1].grid(True, alpha=0.3, axis='y')
    axes[0, 1].axhline(y=recall.mean()*100, color='red', linestyle='--', linewidth=2, label='Mean')
    axes[0, 1].legend()
    
    # F1 Score
    axes[1, 0].bar(x, f1 * 100, color='lightcoral', edgecolor='darkred', alpha=0.7)
    axes[1, 0].set_xlabel('Class')
    axes[1, 0].set_ylabel('F1 Score (%)')
    axes[1, 0].set_title('Per-Class F1 Score')
    axes[1, 0].set_xticks(x)
    axes[1, 0].set_xticklabels(class_names, rotation=45, ha='right', fontsize=8)
    axes[1, 0].grid(True, alpha=0.3, axis='y')
    axes[1, 0].axhline(y=f1.mean()*100, color='red', linestyle='--', linewidth=2, label='Mean')
    axes[1, 0].legend()
    
    # Support (sample count)
    axes[1, 1].bar(x, support, color='plum', edgecolor='purple', alpha=0.7)
    axes[1, 1].set_xlabel('Class')
    axes[1, 1].set_ylabel('Sample Count')
    axes[1, 1].set_title('Per-Class Sample Distribution')
    axes[1, 1].set_xticks(x)
    axes[1, 1].set_xticklabels(class_names, rotation=45, ha='right', fontsize=8)
    axes[1, 1].grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    wandb.log({"per_class_metrics": wandb.Image(plt)})
    plt.savefig('/kaggle/working/per_class_metrics.png', dpi=300, bbox_inches='tight')
    plt.close()

class EMA:
    """Exponential Moving Average for model weights"""
    def __init__(self, model, decay=0.999):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}
        self.register()
    
    def register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()
    
    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()
    
    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data.clone()
                param.data = self.shadow[name]
    
    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]
        self.backup = {}

print("‚úì Comprehensive evaluation metrics & EMA defined")

In [None]:
# ===============================================================================
# CELL 6: Full Training Pipeline with Active Learning
# ===============================================================================

# Setup with label smoothing and warmup
criterion = nn.CrossEntropyLoss(label_smoothing=config['label_smoothing'])
optimizer = optim.AdamW(classifier.parameters(), lr=config['lr'], 
                        weight_decay=config['weight_decay'])

# Warmup + Cosine Annealing scheduler
warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, 
                                               total_iters=config['warmup_epochs'])
main_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                                                      T_max=config['clf_epochs'] - config['warmup_epochs'],
                                                      eta_min=1e-6)
scheduler = optim.lr_scheduler.SequentialLR(optimizer, 
                                            schedulers=[warmup_scheduler, main_scheduler],
                                            milestones=[config['warmup_epochs']])

scaler = GradScaler() if config['mixed_precision'] else None

# Initialize EMA
ema = EMA(classifier, decay=config['ema_decay'])

# Training history
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [], 'val_f1': [],
    'lr': []
}

print("\n" + "="*80)
print("STARTING OPTIMIZED TRAINING PIPELINE")
print("="*80)
print(f"üìä Enhancements: Label Smoothing, Warmup LR, EMA, Enhanced Metrics")
print(f"‚è±Ô∏è  Estimated Time: ~2.5-3 hours for 30 epochs")
print("="*80)

# Initial supervised training with enhanced evaluation
best_val_acc = 0
best_val_f1 = 0

for epoch in range(config['clf_epochs']):
    epoch_start = time.time()
    print(f"\nEpoch {epoch+1}/{config['clf_epochs']}")
    print("-" * 80)
    
    # Train
    train_loss, train_acc = train_epoch(classifier, train_loader, criterion, 
                                        optimizer, scaler, device)
    
    # Update EMA
    ema.update()
    
    # Validate with original weights
    val_loss, val_acc, val_f1, val_preds, val_labels = evaluate(classifier, val_loader, 
                                                                 criterion, device)
    
    # Validate with EMA weights
    ema.apply_shadow()
    val_loss_ema, val_acc_ema, val_f1_ema, val_preds_ema, val_labels_ema = evaluate(
        classifier, val_loader, criterion, device)
    ema.restore()
    
    scheduler.step()
    
    # Use EMA results if better
    use_ema = val_acc_ema > val_acc
    final_val_acc = val_acc_ema if use_ema else val_acc
    final_val_f1 = val_f1_ema if use_ema else val_f1
    final_val_loss = val_loss_ema if use_ema else val_loss
    final_preds = val_preds_ema if use_ema else val_preds
    
    # Store history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(final_val_loss)
    history['val_acc'].append(final_val_acc)
    history['val_f1'].append(final_val_f1)
    history['lr'].append(optimizer.param_groups[0]['lr'])
    
    # Compute per-class metrics
    precision, recall, f1_per_class, support = compute_per_class_metrics(
        val_labels_ema if use_ema else val_labels, final_preds, config['num_classes']
    )
    
    # Log metrics
    wandb.log({
        'epoch': epoch + 1,
        'train/loss': train_loss,
        'train/accuracy': train_acc,
        'val/loss': final_val_loss,
        'val/accuracy': final_val_acc,
        'val/f1_macro': final_val_f1,
        'val/precision_macro': precision.mean() * 100,
        'val/recall_macro': recall.mean() * 100,
        'val_ema/accuracy': val_acc_ema,
        'val_ema/f1_macro': val_f1_ema,
        'lr': optimizer.param_groups[0]['lr'],
        'epoch_time': time.time() - epoch_start
    })
    
    print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
    print(f"Val   - Loss: {final_val_loss:.4f}, Acc: {final_val_acc:.2f}%, F1: {final_val_f1:.2f}%")
    if use_ema:
        print(f"        (Using EMA weights - {val_acc_ema:.2f}% vs {val_acc:.2f}%)")
    print(f"        Precision: {precision.mean()*100:.2f}%, Recall: {recall.mean()*100:.2f}%")
    print(f"        LR: {optimizer.param_groups[0]['lr']:.2e}, Time: {time.time()-epoch_start:.1f}s")
    
    # Save best model (both standard and EMA)
    if final_val_acc > best_val_acc:
        best_val_acc = final_val_acc
        best_val_f1 = final_val_f1
        if use_ema:
            ema.apply_shadow()
            torch.save(classifier.state_dict(), '/kaggle/working/best_classifier.pth')
            ema.restore()
        else:
            torch.save(classifier.state_dict(), '/kaggle/working/best_classifier.pth')
        print(f"‚úì Saved best model (acc: {best_val_acc:.2f}%, F1: {best_val_f1:.2f}%)")
    
    # Comprehensive visualizations every 5 epochs
    if (epoch + 1) % 5 == 0:
        plot_confusion_matrix(val_labels_ema if use_ema else val_labels, final_preds, epoch + 1)
        plot_per_class_performance(precision, recall, f1_per_class, support)
        plot_training_curves(history)
    
    # Memory cleanup
    if (epoch + 1) % 3 == 0:
        torch.cuda.empty_cache()
        gc.collect()

print("\n" + "="*80)
print(f"TRAINING COMPLETE")
print(f"Best Val Accuracy: {best_val_acc:.2f}%")
print(f"Best Val F1 Score: {best_val_f1:.2f}%")
print("="*80)

# Final comprehensive evaluation
print("\n" + "="*80)
print("FINAL EVALUATION WITH BEST MODEL")
print("="*80)

# Load best model
classifier.load_state_dict(torch.load('/kaggle/working/best_classifier.pth'))
val_loss, val_acc, val_f1, val_preds, val_labels = evaluate(classifier, val_loader, 
                                                             criterion, device)

# Compute all metrics
precision, recall, f1_per_class, support = compute_per_class_metrics(
    val_labels, val_preds, config['num_classes']
)

print(f"\nüìä Overall Metrics:")
print(f"  Accuracy:  {val_acc:.2f}%")
print(f"  F1 (Macro): {val_f1:.2f}%")
print(f"  Precision: {precision.mean()*100:.2f}%")
print(f"  Recall:    {recall.mean()*100:.2f}%")

print(f"\nüìà Per-Class Statistics:")
print(f"  Best F1 Class:   Class {f1_per_class.argmax()} ({f1_per_class.max()*100:.2f}%)") 
print(f"  Worst F1 Class:  Class {f1_per_class.argmin()} ({f1_per_class.min()*100:.2f}%)")
print(f"  F1 Std Dev:      {f1_per_class.std()*100:.2f}%")

# Generate final visualizations
plot_confusion_matrix(val_labels, val_preds, 'FINAL')
plot_per_class_performance(precision, recall, f1_per_class, support)
plot_training_curves(history)

# Classification report
print(f"\nüìã Detailed Classification Report:")
print(classification_report(val_labels, val_preds, digits=4))

# Save classification report
report = classification_report(val_labels, val_preds, output_dict=True)
report_df = pd.DataFrame(report).transpose()
report_df.to_csv('/kaggle/working/classification_report.csv')
print("\n‚úì Saved classification report to classification_report.csv")

wandb.finish()
print("\n‚úì Training pipeline complete!")


STARTING FULL TRAINING PIPELINE

Epoch 1/20


Training:   0%|          | 0/724 [00:00<?, ?it/s]

KeyboardInterrupt: 