# Semantic Enhancement of Prithvi via Knowledge Distillation from RemoteCLIP

## Task 2: Semantic Segmentation on LoveDA Dataset

### Objective
Compare baseline Prithvi vs semantically-enhanced (RemoteCLIP-aligned) Prithvi on a **harder** downstream task: pixel-wise semantic segmentation.

**Dataset**: LoveDA (Land-cOVEr Domain Adaptive) — high-resolution remote sensing images with 7 land-cover classes.

**Previous Task (Classification)**: Model A: 35.78% vs Model B: 40.38% (+4.60%) on NWPU-RESISC45

**This Task (Segmentation)**: Compare the two models on dense per-pixel prediction.

## 1. Setup and Imports

In [None]:
import torch
import numpy as np
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Explore the LoveDA Dataset Structure

In [None]:
dataset_root = Path('LoveDA-dataset')

# Count images in each split
for split in ['Train', 'Val']:
    for domain in ['Rural', 'Urban']:
        img_dir = dataset_root / split / domain / 'images_png'
        mask_dir = dataset_root / split / domain / 'masks_png'
        num_images = len(list(img_dir.glob('*.png')))
        num_masks = len(list(mask_dir.glob('*.png')))
        print(f"{split}/{domain}: {num_images} images, {num_masks} masks")

print(f"\nTotal Training Images: {len(list((dataset_root / 'Train' / 'Rural' / 'images_png').glob('*.png'))) + len(list((dataset_root / 'Train' / 'Urban' / 'images_png').glob('*.png')))}")
print(f"Total Validation Images: {len(list((dataset_root / 'Val' / 'Rural' / 'images_png').glob('*.png'))) + len(list((dataset_root / 'Val' / 'Urban' / 'images_png').glob('*.png')))}")

## 3. Load and Inspect a Sample Image

In [None]:
# Load a sample image from Train/Urban
sample_img_path = dataset_root / 'Train' / 'Urban' / 'images_png' / '1366.png'
sample_mask_path = dataset_root / 'Train' / 'Urban' / 'masks_png' / '1366.png'

sample_image = Image.open(sample_img_path)
sample_mask = Image.open(sample_mask_path)

# Image properties
print("=" * 60)
print("IMAGE PROPERTIES")
print("=" * 60)
print(f"File Path:    {sample_img_path}")
print(f"Format:       {sample_image.format}")
print(f"Mode:         {sample_image.mode}")
print(f"Size (WxH):   {sample_image.size[0]} x {sample_image.size[1]} pixels")
print(f"Channels:     {len(sample_image.getbands())} ({', '.join(sample_image.getbands())})")

img_array = np.array(sample_image)
print(f"Array Shape:  {img_array.shape}")
print(f"Dtype:        {img_array.dtype}")
print(f"Value Range:  [{img_array.min()}, {img_array.max()}]")
print(f"Mean (RGB):   R={img_array[:,:,0].mean():.1f}, G={img_array[:,:,1].mean():.1f}, B={img_array[:,:,2].mean():.1f}")

# Mask properties
print(f"\n{'=' * 60}")
print("MASK PROPERTIES")
print("=" * 60)
print(f"File Path:    {sample_mask_path}")
print(f"Mode:         {sample_mask.mode}")
print(f"Size (WxH):   {sample_mask.size[0]} x {sample_mask.size[1]} pixels")

mask_array = np.array(sample_mask)
print(f"Array Shape:  {mask_array.shape}")
print(f"Dtype:        {mask_array.dtype}")
print(f"Value Range:  [{mask_array.min()}, {mask_array.max()}]")
print(f"Unique Values: {np.unique(mask_array)}")

# Count pixels per class
print(f"\nPixel Distribution:")
unique, counts = np.unique(mask_array, return_counts=True)
total_pixels = mask_array.size
for val, count in zip(unique, counts):
    print(f"  Class {val}: {count:>8,} pixels ({100*count/total_pixels:5.2f}%)")

## 4. Visualize the Image and its Segmentation Mask

In [None]:
# LoveDA class definitions
# Class 0: Background (ignored in evaluation)
# Class 1: Building
# Class 2: Road
# Class 3: Water
# Class 4: Barren
# Class 5: Forest
# Class 6: Agriculture

LOVEDA_CLASSES = {
    0: 'Background',
    1: 'Building',
    2: 'Road',
    3: 'Water',
    4: 'Barren',
    5: 'Forest',
    6: 'Agriculture'
}

LOVEDA_COLORS = {
    0: [0, 0, 0],         # Background - Black
    1: [255, 0, 0],       # Building - Red
    2: [255, 255, 0],     # Road - Yellow
    3: [0, 0, 255],       # Water - Blue
    4: [159, 129, 183],   # Barren - Purple
    5: [0, 255, 0],       # Forest - Green
    6: [255, 195, 128]    # Agriculture - Orange
}

def colorize_mask(mask_array):
    """Convert a single-channel label mask to an RGB color image."""
    h, w = mask_array.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for class_id, color in LOVEDA_COLORS.items():
        color_mask[mask_array == class_id] = color
    return color_mask

# Colorize the mask
colored_mask = colorize_mask(mask_array)

# Create figure with 3 panels: Image, Colored Mask, Overlay
fig, axes = plt.subplots(1, 3, figsize=(20, 7))

# Original image
axes[0].imshow(img_array)
axes[0].set_title('Original Image', fontsize=14, fontweight='bold')
axes[0].axis('off')

# Colored segmentation mask
axes[1].imshow(colored_mask)
axes[1].set_title('Segmentation Mask', fontsize=14, fontweight='bold')
axes[1].axis('off')

# Overlay: image with mask transparency
overlay = img_array.copy().astype(np.float32)
mask_float = colored_mask.astype(np.float32)
alpha = 0.45
blended = ((1 - alpha) * overlay + alpha * mask_float).astype(np.uint8)
axes[2].imshow(blended)
axes[2].set_title('Overlay', fontsize=14, fontweight='bold')
axes[2].axis('off')

# Add legend
present_classes = np.unique(mask_array)
legend_patches = []
for class_id in present_classes:
    color = [c / 255.0 for c in LOVEDA_COLORS[class_id]]
    patch = mpatches.Patch(color=color, label=f"{class_id}: {LOVEDA_CLASSES[class_id]}")
    legend_patches.append(patch)

fig.legend(handles=legend_patches, loc='lower center', ncol=len(present_classes),
           fontsize=11, frameon=True, fancybox=True, shadow=True,
           bbox_to_anchor=(0.5, -0.02))

plt.suptitle(f'LoveDA Sample — {sample_img_path.parent.parent.name} ({sample_img_path.stem}.png)',
             fontsize=16, fontweight='bold')
plt.tight_layout()
plt.subplots_adjust(bottom=0.08)
plt.show()

## 5. Load Foundation Models

Load both models:
- **RemoteCLIP** ViT-B/32 (Teacher — frozen) → provides semantic patch-level targets
- **Prithvi EO v2** (Student backbone — frozen) → provides patch-level features to align

Both models are frozen. Only the spatial projection head (defined later) will be trained.

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import torch.nn.functional as F

# ── Load RemoteCLIP (Teacher) ──────────────────────────────────────────
from huggingface_hub import hf_hub_download
import open_clip

model_name = 'ViT-B-32'
checkpoint_path = hf_hub_download(
    "chendelong/RemoteCLIP", f"RemoteCLIP-{model_name}.pt", cache_dir='checkpoints'
)
print(f"RemoteCLIP {model_name} downloaded to: {checkpoint_path}")

print("\nCreating model architecture...")
remoteclip_model, _, remoteclip_preprocess = open_clip.create_model_and_transforms(model_name)

print("Loading RemoteCLIP pretrained weights...")
path_to_checkpoints = 'checkpoints/models--chendelong--RemoteCLIP/snapshots/bf1d8a3ccf2ddbf7c875705e46373bfe542bce38'
ckpt = torch.load(f"{path_to_checkpoints}/RemoteCLIP-{model_name}.pt", map_location="cpu")
load_result = remoteclip_model.load_state_dict(ckpt)
print(f"Load result: {load_result}")

remoteclip_model = remoteclip_model.to(device).eval()
for param in remoteclip_model.parameters():
    param.requires_grad = False
print(f"RemoteCLIP loaded and frozen. Output dim: 512")

# ── Load Prithvi EO v2 (Student Backbone) ─────────────────────────────
from terratorch.registry import BACKBONE_REGISTRY

prithvi_backbone = BACKBONE_REGISTRY.build("prithvi_eo_v2_300")
prithvi_backbone = prithvi_backbone.to(device).eval()
for param in prithvi_backbone.parameters():
    param.requires_grad = False
print(f"Prithvi EO v2 loaded and frozen. Output dim: 1024")

## 6. Patch Token Extraction & Resolution Analysis

Segmentation requires **spatial** features, not just a single CLS token.

| Model | Patch Size | Grid (224×224 input) | Token Dim |
|-------|-----------|----------------------|-----------|
| Prithvi EO v2 | 16×16 | 14×14 = 196 tokens | 1024 |
| RemoteCLIP ViT-B/32 | 32×32 | 7×7 = 49 tokens | 512 |

We extract patch tokens from both, then spatially interpolate RemoteCLIP's 7×7 grid → 14×14 to match Prithvi's resolution.

In [None]:
def extract_prithvi_patch_tokens(backbone, images):
    """
    Extract spatial patch tokens from Prithvi (excluding CLS token).
    
    Input:  images [B, 6, 224, 224]
    Output: patch_tokens [B, 1024, 14, 14]
    """
    with torch.no_grad():
        features = backbone(images)
        # features[-1] shape: [B, 197, 1024] → 1 CLS + 196 patch tokens
        patch_tokens = features[-1][:, 1:, :]  # [B, 196, 1024] — drop CLS
        
        B, N, C = patch_tokens.shape
        h = w = int(N ** 0.5)  # 14
        patch_tokens = patch_tokens.permute(0, 2, 1).reshape(B, C, h, w)  # [B, 1024, 14, 14]
    return patch_tokens


def extract_remoteclip_patch_tokens(model, images):
    """
    Extract spatial patch tokens from RemoteCLIP ViT-B/32 (excluding CLS token).
    
    Input:  images [B, 3, 224, 224]
    Output: patch_tokens [B, 512, 7, 7]
    """
    with torch.no_grad():
        visual = model.visual
        # Manual forward pass to get intermediate patch tokens
        x = visual.conv1(images)  # [B, 768, 7, 7]  (32x32 patches on 224x224)
        x = x.reshape(x.shape[0], x.shape[1], -1)  # [B, 768, 49]
        x = x.permute(0, 2, 1)  # [B, 49, 768]
        
        # Add class embedding and positional embedding
        cls_token = visual.class_embedding.unsqueeze(0).unsqueeze(0).expand(x.shape[0], -1, -1)
        x = torch.cat([cls_token, x], dim=1)  # [B, 50, 768]
        x = x + visual.positional_embedding.unsqueeze(0)
        
        # Pre-LN
        x = visual.ln_pre(x)
        
        # Transformer
        x = x.permute(1, 0, 2)  # [50, B, 768] (seq_first for transformer)
        x = visual.transformer(x)
        x = x.permute(1, 0, 2)  # [B, 50, 768]
        
        # Post-LN on all tokens
        x = visual.ln_post(x)
        
        # Project all tokens (not just CLS) through the output projection
        if visual.proj is not None:
            x = x @ visual.proj  # [B, 50, 512]
        
        # Take patch tokens (drop CLS at position 0)
        patch_tokens = x[:, 1:, :]  # [B, 49, 512]
        
        # L2 normalize each token
        patch_tokens = patch_tokens / patch_tokens.norm(dim=-1, keepdim=True)
        
        B, N, C = patch_tokens.shape
        h = w = int(N ** 0.5)  # 7
        patch_tokens = patch_tokens.permute(0, 2, 1).reshape(B, C, h, w)  # [B, 512, 7, 7]
    return patch_tokens


# ── Verify on sample image ────────────────────────────────────────────
sample_img = Image.open(dataset_root / 'Train' / 'Urban' / 'images_png' / '1366.png').convert('RGB')

# Prepare for Prithvi (6-channel, zero-padded)
prithvi_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
prithvi_input = prithvi_tf(sample_img).unsqueeze(0).to(device)  # [1, 3, 224, 224]
prithvi_input = torch.cat([prithvi_input, torch.zeros_like(prithvi_input)], dim=1)  # [1, 6, 224, 224]

# Prepare for RemoteCLIP (uses its own preprocessing)
remoteclip_input = remoteclip_preprocess(sample_img).unsqueeze(0).to(device)  # [1, 3, 224, 224]

# Extract patch tokens
prithvi_patches = extract_prithvi_patch_tokens(prithvi_backbone, prithvi_input)
remoteclip_patches = extract_remoteclip_patch_tokens(remoteclip_model, remoteclip_input)

print("=" * 60)
print("PATCH TOKEN SHAPES")
print("=" * 60)
print(f"Prithvi patch tokens:    {prithvi_patches.shape}")
print(f"  → {prithvi_patches.shape[1]}-dim features on a {prithvi_patches.shape[2]}×{prithvi_patches.shape[3]} grid")
print(f"\nRemoteCLIP patch tokens: {remoteclip_patches.shape}")
print(f"  → {remoteclip_patches.shape[1]}-dim features on a {remoteclip_patches.shape[2]}×{remoteclip_patches.shape[3]} grid")

# Demonstrate spatial interpolation: 7x7 → 14x14
remoteclip_upsampled = F.interpolate(
    remoteclip_patches, size=(14, 14), mode='bilinear', align_corners=False
)
print(f"\nRemoteCLIP after interpolation: {remoteclip_upsampled.shape}")
print(f"  → Now matches Prithvi's 14×14 spatial grid")

# Visualize the feature maps
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Prithvi: mean activation across channels
prithvi_vis = prithvi_patches[0].mean(dim=0).cpu().numpy()
axes[0].imshow(prithvi_vis, cmap='viridis')
axes[0].set_title(f'Prithvi Patch Features\n(mean of 1024 channels, 14×14)', fontsize=12, fontweight='bold')
axes[0].axis('off')

# RemoteCLIP original: mean activation
remoteclip_vis = remoteclip_patches[0].mean(dim=0).cpu().numpy()
axes[1].imshow(remoteclip_vis, cmap='magma')
axes[1].set_title(f'RemoteCLIP Patch Features\n(mean of 512 channels, 7×7)', fontsize=12, fontweight='bold')
axes[1].axis('off')

# RemoteCLIP upsampled
remoteclip_up_vis = remoteclip_upsampled[0].mean(dim=0).cpu().numpy()
axes[2].imshow(remoteclip_up_vis, cmap='magma')
axes[2].set_title(f'RemoteCLIP Upsampled\n(interpolated to 14×14)', fontsize=12, fontweight='bold')
axes[2].axis('off')

plt.suptitle('Spatial Feature Maps from Both Models', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 7. Define Spatial Projection Head

A 2-layer MLP applied **per spatial position** (1×1 convolutions) to map Prithvi's 1024-dim patch features to RemoteCLIP's 512-dim semantic space.

This is equivalent to applying the same linear transformation independently at every position in the 14×14 grid — no spatial mixing, purely channel-wise projection.

In [None]:
class SpatialProjectionHead(nn.Module):
    """
    Per-position MLP using 1x1 convolutions.
    Maps Prithvi features (1024-dim) → RemoteCLIP space (512-dim)
    at every spatial location independently.
    
    Input:  [B, 1024, H, W]
    Output: [B, 512,  H, W]
    """
    def __init__(self, input_dim=1024, hidden_dim=768, output_dim=512):
        super(SpatialProjectionHead, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(input_dim, hidden_dim, kernel_size=1),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Conv2d(hidden_dim, output_dim, kernel_size=1)
        )
    
    def forward(self, x):
        return self.net(x)

spatial_proj_head = SpatialProjectionHead(input_dim=1024, hidden_dim=768, output_dim=512).to(device)

# Verify
total_params = sum(p.numel() for p in spatial_proj_head.parameters())
trainable_params = sum(p.numel() for p in spatial_proj_head.parameters() if p.requires_grad)
print("=" * 60)
print("SPATIAL PROJECTION HEAD")
print("=" * 60)
print(f"Architecture: Conv1x1(1024→768) → GELU → Dropout(0.1) → Conv1x1(768→512)")
print(f"Total parameters:     {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Quick shape verification
dummy_input = torch.randn(2, 1024, 14, 14).to(device)
dummy_output = spatial_proj_head(dummy_input)
print(f"\nShape check: {dummy_input.shape} → {dummy_output.shape}")
print(f"Per-position: 1024-dim → 512-dim at each of the 14×14 positions")

## 8. Dataset for Phase 1: Unsupervised Spatial Alignment

For distillation we only need **images** (no masks). We use all LoveDA training images (Rural + Urban combined) and apply dual transforms — one for each model.

In [None]:
class LoveDA_Unsupervised(Dataset):
    """
    LoveDA images only (no masks) for unsupervised spatial alignment.
    Combines Rural + Urban from the training split.
    Returns dual-transformed images for Prithvi and RemoteCLIP.
    """
    def __init__(self, root_dir, prithvi_transform, remoteclip_transform, split='Train'):
        self.root_dir = Path(root_dir)
        self.prithvi_transform = prithvi_transform
        self.remoteclip_transform = remoteclip_transform
        
        self.image_paths = []
        for domain in ['Rural', 'Urban']:
            img_dir = self.root_dir / split / domain / 'images_png'
            self.image_paths.extend(sorted(img_dir.glob('*.png')))
        
        print(f"[Phase 1 Dataset] Found {len(self.image_paths)} images for unsupervised alignment")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        prithvi_img = self.prithvi_transform(image)
        remoteclip_img = self.remoteclip_transform(image)
        return prithvi_img, remoteclip_img


# Prithvi transform (3-channel; we zero-pad to 6-ch in the extraction function)
prithvi_transform_phase1 = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# RemoteCLIP uses its own preprocess pipeline
phase1_dataset = LoveDA_Unsupervised(
    'LoveDA-dataset',
    prithvi_transform=prithvi_transform_phase1,
    remoteclip_transform=remoteclip_preprocess,
    split='Train'
)

phase1_loader = DataLoader(
    phase1_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

print(f"Phase 1 batches per epoch: {len(phase1_loader)}")

## 9. PHASE 1: Spatial Knowledge Distillation Training

Train the Spatial Projection Head to align Prithvi's patch-level features with RemoteCLIP's.

**Pipeline per image:**
1. Extract Prithvi patch tokens → `[B, 1024, 14, 14]`
2. Extract RemoteCLIP patch tokens → `[B, 512, 7, 7]` → interpolate to `[B, 512, 14, 14]`
3. Project Prithvi tokens through Spatial Projection Head → `[B, 512, 14, 14]`
4. **Per-position cosine embedding loss** between projected Prithvi and interpolated RemoteCLIP features

In [None]:
def spatial_cosine_loss(student_features, teacher_features):
    """
    Per-position cosine embedding loss.
    
    student_features: [B, C, H, W]
    teacher_features: [B, C, H, W]
    
    Returns: scalar loss = mean(1 - cosine_sim) over all positions and batch
    """
    # L2 normalize along channel dimension
    student_norm = F.normalize(student_features, p=2, dim=1)
    teacher_norm = F.normalize(teacher_features, p=2, dim=1)
    
    # Per-position cosine similarity: [B, H, W]
    cosine_sim = (student_norm * teacher_norm).sum(dim=1)
    
    # Loss: minimize (1 - cosine_sim)
    loss = (1 - cosine_sim).mean()
    return loss


# Training configuration
phase1_epochs = 15
phase1_optimizer = optim.AdamW(spatial_proj_head.parameters(), lr=1e-3, weight_decay=0.01)
phase1_scheduler = optim.lr_scheduler.CosineAnnealingLR(phase1_optimizer, T_max=phase1_epochs, eta_min=1e-6)

phase1_history = {'loss': [], 'cosine_sim': []}
best_phase1_loss = float('inf')

print("=" * 70)
print("PHASE 1: SPATIAL KNOWLEDGE DISTILLATION")
print("=" * 70)
print(f"Objective: Align Prithvi patch tokens with RemoteCLIP patch tokens")
print(f"Teacher:   RemoteCLIP ViT-B/32 (frozen) — 512-dim, 7×7 grid")
print(f"Student:   Prithvi + Spatial Projection Head — 1024→512, 14×14 grid")
print(f"Loss:      Per-position cosine embedding loss")
print(f"Epochs:    {phase1_epochs}")
print(f"Optimizer: AdamW (lr=1e-3, wd=0.01)")
print("=" * 70)

In [None]:
# Phase 1 Training Loop
for epoch in range(phase1_epochs):
    spatial_proj_head.train()
    epoch_loss = 0.0
    epoch_cosine_sim = 0.0
    
    print(f"\n{'='*70}")
    print(f"EPOCH {epoch+1}/{phase1_epochs}")
    print('='*70)
    
    with tqdm(phase1_loader, desc='Distillation', ncols=100) as pbar:
        for batch_idx, (prithvi_imgs, remoteclip_imgs) in enumerate(pbar):
            prithvi_imgs = prithvi_imgs.to(device)
            remoteclip_imgs = remoteclip_imgs.to(device)
            
            # Zero-pad Prithvi input to 6 channels
            prithvi_6ch = torch.cat([prithvi_imgs, torch.zeros_like(prithvi_imgs)], dim=1)
            
            # Extract spatial features from both frozen models
            prithvi_patches = extract_prithvi_patch_tokens(prithvi_backbone, prithvi_6ch)     # [B, 1024, 14, 14]
            remoteclip_patches = extract_remoteclip_patch_tokens(remoteclip_model, remoteclip_imgs)  # [B, 512, 7, 7]
            
            # Interpolate RemoteCLIP 7×7 → 14×14
            teacher_targets = F.interpolate(
                remoteclip_patches, size=(14, 14), mode='bilinear', align_corners=False
            )  # [B, 512, 14, 14]
            
            # Project Prithvi features
            student_projected = spatial_proj_head(prithvi_patches)  # [B, 512, 14, 14]
            
            # Compute per-position cosine loss
            loss = spatial_cosine_loss(student_projected, teacher_targets)
            
            # Backward
            phase1_optimizer.zero_grad()
            loss.backward()
            phase1_optimizer.step()
            
            # Track metrics
            epoch_loss += loss.item()
            
            with torch.no_grad():
                s_norm = F.normalize(student_projected, p=2, dim=1)
                t_norm = F.normalize(teacher_targets, p=2, dim=1)
                cosine_sim = (s_norm * t_norm).sum(dim=1).mean().item()
                epoch_cosine_sim += cosine_sim
            
            pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'CosSim': f'{cosine_sim:.4f}'
            })
    
    phase1_scheduler.step()
    
    avg_loss = epoch_loss / len(phase1_loader)
    avg_cosine_sim = epoch_cosine_sim / len(phase1_loader)
    
    phase1_history['loss'].append(avg_loss)
    phase1_history['cosine_sim'].append(avg_cosine_sim)
    
    print(f"\n{'─'*70}")
    print(f"EPOCH {epoch+1} SUMMARY:")
    print(f"  Avg Loss:              {avg_loss:.4f}")
    print(f"  Avg Cosine Similarity: {avg_cosine_sim:.4f}")
    print(f"  Learning Rate:         {phase1_optimizer.param_groups[0]['lr']:.6f}")
    print(f"{'─'*70}")
    
    # Save best model
    if avg_loss < best_phase1_loss:
        best_phase1_loss = avg_loss
        torch.save({
            'epoch': epoch,
            'spatial_proj_head_state_dict': spatial_proj_head.state_dict(),
            'optimizer_state_dict': phase1_optimizer.state_dict(),
            'loss': best_phase1_loss,
            'cosine_sim': avg_cosine_sim
        }, 'best_spatial_projection_head.pth')
        print(f"Saved best spatial projection head (Loss: {best_phase1_loss:.4f})")

print("\n" + "=" * 70)
print(f"PHASE 1 COMPLETED! Best Loss: {best_phase1_loss:.4f}")
print("=" * 70)

## 10. Phase 1 Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curve
axes[0].plot(range(1, len(phase1_history['loss'])+1), phase1_history['loss'],
             marker='o', color='crimson', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss (1 - cos_sim)', fontsize=12)
axes[0].set_title('Phase 1: Spatial Distillation Loss', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)

# Cosine similarity curve
axes[1].plot(range(1, len(phase1_history['cosine_sim'])+1), phase1_history['cosine_sim'],
             marker='s', color='seagreen', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Cosine Similarity', fontsize=12)
axes[1].set_title('Phase 1: Spatial Alignment Quality', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].axhline(y=0.0, color='gray', linestyle='--', alpha=0.5, label='Random')
axes[1].legend()

plt.tight_layout()
plt.show()

## 11. Load Best Spatial Projection Head & Freeze

In [None]:
# Load best spatial projection head from Phase 1
checkpoint = torch.load('best_spatial_projection_head.pth')
spatial_proj_head.load_state_dict(checkpoint['spatial_proj_head_state_dict'])
spatial_proj_head.eval()

# Freeze for Phase 2
for param in spatial_proj_head.parameters():
    param.requires_grad = False

print(f"Loaded best spatial projection head from epoch {checkpoint['epoch']+1}")
print(f"  Final Loss:              {checkpoint['loss']:.4f}")
print(f"  Final Cosine Similarity: {checkpoint['cosine_sim']:.4f}")

## 12. Define Segmentation Decoder

A lightweight convolutional decoder that takes a spatial feature grid (14×14) and upsamples it to the target segmentation resolution (512×512).

**Architecture**: Progressive upsampling with Conv → BatchNorm → ReLU blocks.

The **same decoder architecture** is used for both models — only the input channel count differs:
- **Model A** (Baseline): 1024 input channels (raw Prithvi features)
- **Model B** (Enhanced): 512 input channels (projected features)

In [None]:
class SegmentationDecoder(nn.Module):
    """
    Lightweight decoder: upsamples from 14×14 feature grid → 512×512 segmentation map.
    
    Upsampling stages:
        14×14 → 28×28 → 56×56 → 112×112 → 224×224 → 512×512
    
    Each stage: Upsample(2×) → Conv3x3 → BN → ReLU
    Final: Conv1x1 → num_classes
    """
    def __init__(self, in_channels, num_classes=7):
        super(SegmentationDecoder, self).__init__()
        
        self.decoder = nn.Sequential(
            # 14×14 → 28×28
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            
            # 28×28 → 56×56
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            # 56×56 → 112×112
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            # 112×112 → 224×224
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            # 224×224 → 512×512
            nn.Upsample(size=(512, 512), mode='bilinear', align_corners=False),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
        )
        
        # Final classification head
        self.classifier = nn.Conv2d(32, num_classes, kernel_size=1)
    
    def forward(self, x):
        x = self.decoder(x)
        x = self.classifier(x)
        return x  # [B, num_classes, 512, 512]


# Verify decoder architecture
decoder_a = SegmentationDecoder(in_channels=1024, num_classes=7)
decoder_b = SegmentationDecoder(in_channels=512, num_classes=7)

print("=" * 60)
print("SEGMENTATION DECODER")
print("=" * 60)
print(f"\nFor Model A (Baseline):   in_channels=1024")
print(f"  Trainable params: {sum(p.numel() for p in decoder_a.parameters() if p.requires_grad):,}")
print(f"\nFor Model B (Enhanced):   in_channels=512")
print(f"  Trainable params: {sum(p.numel() for p in decoder_b.parameters() if p.requires_grad):,}")

# Shape check
dummy_a = torch.randn(1, 1024, 14, 14)
dummy_b = torch.randn(1, 512, 14, 14)
out_a = decoder_a(dummy_a)
out_b = decoder_b(dummy_b)
print(f"\nShape checks:")
print(f"  Model A: {dummy_a.shape} → {out_a.shape}")
print(f"  Model B: {dummy_b.shape} → {out_b.shape}")

# Clean up verification decoders
del decoder_a, decoder_b, dummy_a, dummy_b, out_a, out_b

## 13. Define Model A (Baseline) and Model B (Semantic-Enhanced)

### Model A (Baseline)
`Image → Prithvi (frozen) → patch tokens [B, 1024, 14, 14] → Decoder → [B, 7, 512, 512]`

### Model B (Semantic-Enhanced)
`Image → Prithvi (frozen) → patch tokens [B, 1024, 14, 14] → Spatial PH (frozen) → [B, 512, 14, 14] → Decoder → [B, 7, 512, 512]`

Only the decoder weights are trainable in both cases — ensuring a fair comparison.

In [None]:
class SegModelA_Baseline(nn.Module):
    """
    Baseline segmentation: Prithvi (frozen) → Decoder (trainable)
    Feature channel dim: 1024
    """
    def __init__(self, backbone, num_classes=7):
        super(SegModelA_Baseline, self).__init__()
        self.backbone = backbone
        self.decoder = SegmentationDecoder(in_channels=1024, num_classes=num_classes)
    
    def forward(self, x):
        # Zero-pad RGB → 6 channels
        if x.shape[1] == 3:
            x = torch.cat([x, torch.zeros_like(x)], dim=1)
        
        # Frozen backbone → patch tokens
        with torch.no_grad():
            features = self.backbone(x)
            patch_tokens = features[-1][:, 1:, :]  # [B, 196, 1024]
            B, N, C = patch_tokens.shape
            h = w = int(N ** 0.5)
            patch_tokens = patch_tokens.permute(0, 2, 1).reshape(B, C, h, w)  # [B, 1024, 14, 14]
        
        patch_tokens = patch_tokens.detach()  # Detach from no_grad context to avoid cuDNN stream issues
        
        # Trainable decoder
        logits = self.decoder(patch_tokens)  # [B, 7, 512, 512]
        return logits


class SegModelB_SemanticEnhanced(nn.Module):
    """
    Semantic-enhanced segmentation: Prithvi (frozen) → Spatial PH (frozen) → Decoder (trainable)
    Feature channel dim: 512
    """
    def __init__(self, backbone, spatial_proj_head, num_classes=7):
        super(SegModelB_SemanticEnhanced, self).__init__()
        self.backbone = backbone
        self.spatial_proj_head = spatial_proj_head
        self.decoder = SegmentationDecoder(in_channels=512, num_classes=num_classes)
    
    def forward(self, x):
        # Zero-pad RGB → 6 channels
        if x.shape[1] == 3:
            x = torch.cat([x, torch.zeros_like(x)], dim=1)
        
        # Frozen backbone → patch tokens
        with torch.no_grad():
            features = self.backbone(x)
            patch_tokens = features[-1][:, 1:, :]  # [B, 196, 1024]
            B, N, C = patch_tokens.shape
            h = w = int(N ** 0.5)
            patch_tokens = patch_tokens.permute(0, 2, 1).reshape(B, C, h, w)  # [B, 1024, 14, 14]
            
            # Frozen spatial projection
            projected = self.spatial_proj_head(patch_tokens)  # [B, 512, 14, 14]
        
        projected = projected.detach()  # Detach from no_grad context to avoid cuDNN stream issues
        
        # Trainable decoder
        logits = self.decoder(projected)  # [B, 7, 512, 512]
        return logits


# Instantiate both models
NUM_CLASSES = 7

seg_model_a = SegModelA_Baseline(prithvi_backbone, num_classes=NUM_CLASSES).to(device)
seg_model_b = SegModelB_SemanticEnhanced(prithvi_backbone, spatial_proj_head, num_classes=NUM_CLASSES).to(device)

print("=" * 70)
print("SEGMENTATION MODELS")
print("=" * 70)
print(f"\nModel A (Baseline):")
print(f"  Pipeline: Prithvi(frozen) → Decoder(1024→7)")
print(f"  Trainable params: {sum(p.numel() for p in seg_model_a.parameters() if p.requires_grad):,}")

print(f"\nModel B (Semantic-Enhanced):")
print(f"  Pipeline: Prithvi(frozen) → SpatialPH(frozen) → Decoder(512→7)")
print(f"  Trainable params: {sum(p.numel() for p in seg_model_b.parameters() if p.requires_grad):,}")
print("=" * 70)

## 14. Dataset & DataLoader for Phase 2: Segmentation

LoveDA images with their segmentation masks. Images are resized to 512×512 (both image and mask).

**Class mapping**: Class 0 (Background) is **ignored** during loss computation using `ignore_index=0`.

In [None]:
class LoveDA_Segmentation(Dataset):
    """
    LoveDA dataset for semantic segmentation.
    Returns (image_tensor, mask_tensor) pairs.
    
    Images: resized to img_size, normalized
    Masks:  resized to mask_size using nearest-neighbor (preserves class labels)
    """
    def __init__(self, root_dir, split='Train', img_size=224, mask_size=512, augment=False):
        self.root_dir = Path(root_dir)
        self.img_size = img_size
        self.mask_size = mask_size
        self.augment = augment
        
        self.image_paths = []
        self.mask_paths = []
        
        for domain in ['Rural', 'Urban']:
            img_dir = self.root_dir / split / domain / 'images_png'
            mask_dir = self.root_dir / split / domain / 'masks_png'
            
            for img_path in sorted(img_dir.glob('*.png')):
                mask_path = mask_dir / img_path.name
                if mask_path.exists():
                    self.image_paths.append(img_path)
                    self.mask_paths.append(mask_path)
        
        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
        
        print(f"[{split}] Found {len(self.image_paths)} image-mask pairs")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image and mask
        image = Image.open(self.image_paths[idx]).convert('RGB')
        mask = Image.open(self.mask_paths[idx])
        
        # Resize
        image = image.resize((self.img_size, self.img_size), Image.BILINEAR)
        mask = mask.resize((self.mask_size, self.mask_size), Image.NEAREST)
        
        # Random augmentations (applied consistently to image and mask)
        if self.augment:
            # Random horizontal flip
            if torch.rand(1).item() > 0.5:
                image = image.transpose(Image.FLIP_LEFT_RIGHT)
                mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
            # Random vertical flip
            if torch.rand(1).item() > 0.5:
                image = image.transpose(Image.FLIP_TOP_BOTTOM)
                mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
        
        # Convert to tensors
        image = transforms.ToTensor()(image)  # [3, img_size, img_size]
        image = self.normalize(image)
        
        mask = torch.from_numpy(np.array(mask)).long()  # [mask_size, mask_size]
        
        # Remap class 7 ("no data") to 0 (background) so it is ignored by CrossEntropyLoss
        mask[mask == 7] = 0
        
        return image, mask


# Create datasets
train_seg_dataset = LoveDA_Segmentation('LoveDA-dataset', split='Train', img_size=224, mask_size=512, augment=True)
val_seg_dataset = LoveDA_Segmentation('LoveDA-dataset', split='Val', img_size=224, mask_size=512, augment=False)

# Create dataloaders
batch_size_seg = 4
train_seg_loader = DataLoader(train_seg_dataset, batch_size=batch_size_seg, shuffle=True, num_workers=0, pin_memory=True)
val_seg_loader = DataLoader(val_seg_dataset, batch_size=batch_size_seg, shuffle=False, num_workers=0, pin_memory=True)

print(f"\nPhase 2 Segmentation Data:")
print(f"  Train: {len(train_seg_dataset)} samples, {len(train_seg_loader)} batches")
print(f"  Val:   {len(val_seg_dataset)} samples, {len(val_seg_loader)} batches")
print(f"  Image size: 224×224 (input to backbone)")
print(f"  Mask size:  512×512 (decoder output)")
print(f"  Batch size: {batch_size_seg}")

# Verify a sample
sample_img, sample_mask = train_seg_dataset[0]
print(f"\nSample shapes: image={sample_img.shape}, mask={sample_mask.shape}")
print(f"Mask unique values: {torch.unique(sample_mask).tolist()}")

## 15. Segmentation Training Function

Shared training loop for both models. Uses:
- **CrossEntropyLoss** with `ignore_index=0` (Background)
- **mIoU** (mean Intersection-over-Union) as the primary metric
- **Per-class IoU** tracking

In [None]:
def compute_miou(preds, labels, num_classes=7, ignore_index=0):
    """
    Compute per-class IoU and mean IoU (excluding background class 0).
    
    preds:  [B, H, W] — predicted class indices
    labels: [B, H, W] — ground truth class indices
    """
    ious = []
    for cls in range(1, num_classes):  # Skip class 0 (Background)
        pred_mask = (preds == cls)
        label_mask = (labels == cls)
        
        intersection = (pred_mask & label_mask).sum().item()
        union = (pred_mask | label_mask).sum().item()
        
        if union == 0:
            continue  # Class not present in this batch
        
        ious.append(intersection / union)
    
    if len(ious) == 0:
        return 0.0, {}
    
    miou = np.mean(ious)
    return miou, ious


def train_segmentation(model, train_loader, val_loader, num_epochs=25, lr=1e-3, model_name="Model"):
    """
    Train a segmentation model and track mIoU.
    Only decoder parameters are trained (backbone and projection head are frozen).
    """
    # Only train parameters that require grad (decoder only)
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore background
    optimizer = optim.AdamW(trainable_params, lr=lr, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
    
    history = {'train_loss': [], 'val_loss': [], 'val_miou': [], 'val_overall_acc': []}
    best_val_miou = 0.0
    best_state = None
    
    print(f"\n{'='*70}")
    print(f"Training {model_name}")
    print(f"{'='*70}")
    print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
    print(f"Epochs: {num_epochs}, LR: {lr}")
    print(f"{'='*70}")
    
    for epoch in range(num_epochs):
        # ── Training ──────────────────────────────────────────────────
        model.train()
        train_loss = 0.0
        
        with tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]', ncols=100) as pbar:
            for images, masks in pbar:
                images = images.to(device)
                masks = masks.to(device)  # [B, 512, 512]
                
                optimizer.zero_grad()
                logits = model(images)  # [B, 7, 512, 512]
                loss = criterion(logits, masks)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
        
        train_loss /= len(train_loader)
        
        # ── Validation ────────────────────────────────────────────────
        model.eval()
        val_loss = 0.0
        all_ious = []
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                
                logits = model(images)
                loss = criterion(logits, masks)
                val_loss += loss.item()
                
                # Predictions
                preds = logits.argmax(dim=1)  # [B, 512, 512]
                
                # mIoU
                miou_batch, _ = compute_miou(preds.cpu(), masks.cpu(), num_classes=7)
                all_ious.append(miou_batch)
                
                # Overall accuracy (excluding background)
                valid_mask = masks != 0
                val_correct += (preds[valid_mask] == masks[valid_mask]).sum().item()
                val_total += valid_mask.sum().item()
        
        val_loss /= len(val_loader)
        val_miou = np.mean(all_ious) if all_ious else 0.0
        val_acc = 100.0 * val_correct / val_total if val_total > 0 else 0.0
        
        scheduler.step()
        
        # Save history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_miou'].append(val_miou)
        history['val_overall_acc'].append(val_acc)
        
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | "
              f"Val mIoU: {val_miou:.4f} | "
              f"Val Acc: {val_acc:.2f}% | "
              f"LR: {optimizer.param_groups[0]['lr']:.6f}")
        
        # Track best
        if val_miou > best_val_miou:
            best_val_miou = val_miou
            best_state = {k: v.clone() for k, v in model.state_dict().items()}
            print(f"  >> New best mIoU: {best_val_miou:.4f}")
    
    # Load best model weights
    if best_state is not None:
        model.load_state_dict(best_state)
    
    print(f"\nBest Validation mIoU: {best_val_miou:.4f}")
    print('='*70)
    
    return history, best_val_miou

## 16. Train Model A (Baseline Segmentation)

In [None]:
history_seg_a, best_miou_a = train_segmentation(
    seg_model_a,
    train_seg_loader,
    val_seg_loader,
    num_epochs=25,
    lr=1e-3,
    model_name="Model A (Baseline Segmentation)"
)

## 17. Train Model B (Semantic-Enhanced Segmentation)

In [None]:
history_seg_b, best_miou_b = train_segmentation(
    seg_model_b,
    train_seg_loader,
    val_seg_loader,
    num_epochs=25,
    lr=1e-3,
    model_name="Model B (Semantic-Enhanced Segmentation)"
)

## 18. Compare Results: Model A vs Model B

In [None]:
print("\n" + "="*70)
print("SEGMENTATION COMPARISON")
print("="*70)

print(f"\nModel A (Baseline — Raw Prithvi Features):")
print(f"  Best Val mIoU:       {best_miou_a:.4f}")
print(f"  Best Val Accuracy:   {max(history_seg_a['val_overall_acc']):.2f}%")

print(f"\nModel B (Semantic-Enhanced — RemoteCLIP Aligned):")
print(f"  Best Val mIoU:       {best_miou_b:.4f}")
print(f"  Best Val Accuracy:   {max(history_seg_b['val_overall_acc']):.2f}%")

miou_improvement = best_miou_b - best_miou_a
acc_improvement = max(history_seg_b['val_overall_acc']) - max(history_seg_a['val_overall_acc'])

print(f"\nImprovement:")
print(f"  mIoU:     {miou_improvement:+.4f} ({miou_improvement/best_miou_a*100:+.2f}% relative)")
print(f"  Accuracy: {acc_improvement:+.2f}%")

if best_miou_b > best_miou_a:
    print(f"\nSemantic enhancement improved segmentation performance!")
else:
    print(f"\nBaseline performed better on segmentation. See analysis below.")
print("="*70)

## 19. Training Curves Comparison

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

epochs = range(1, len(history_seg_a['train_loss']) + 1)

# Training Loss
axes[0, 0].plot(epochs, history_seg_a['train_loss'], 'o-', label='Model A (Baseline)', linewidth=2, markersize=4)
axes[0, 0].plot(epochs, history_seg_b['train_loss'], 's-', label='Model B (Semantic)', linewidth=2, markersize=4)
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Loss', fontsize=12)
axes[0, 0].set_title('Training Loss', fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=10)
axes[0, 0].grid(True, alpha=0.3)

# Validation Loss
axes[0, 1].plot(epochs, history_seg_a['val_loss'], 'o-', label='Model A (Baseline)', linewidth=2, markersize=4)
axes[0, 1].plot(epochs, history_seg_b['val_loss'], 's-', label='Model B (Semantic)', linewidth=2, markersize=4)
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Loss', fontsize=12)
axes[0, 1].set_title('Validation Loss', fontsize=14, fontweight='bold')
axes[0, 1].legend(fontsize=10)
axes[0, 1].grid(True, alpha=0.3)

# Validation mIoU
axes[1, 0].plot(epochs, history_seg_a['val_miou'], 'o-', label='Model A (Baseline)', linewidth=2, markersize=4, color='orange')
axes[1, 0].plot(epochs, history_seg_b['val_miou'], 's-', label='Model B (Semantic)', linewidth=2, markersize=4, color='green')
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('mIoU', fontsize=12)
axes[1, 0].set_title('Validation mIoU', fontsize=14, fontweight='bold')
axes[1, 0].legend(fontsize=10)
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].axhline(y=best_miou_a, color='orange', linestyle='--', alpha=0.5)
axes[1, 0].axhline(y=best_miou_b, color='green', linestyle='--', alpha=0.5)

# Validation Overall Accuracy
axes[1, 1].plot(epochs, history_seg_a['val_overall_acc'], 'o-', label='Model A (Baseline)', linewidth=2, markersize=4, color='orange')
axes[1, 1].plot(epochs, history_seg_b['val_overall_acc'], 's-', label='Model B (Semantic)', linewidth=2, markersize=4, color='green')
axes[1, 1].set_xlabel('Epoch', fontsize=12)
axes[1, 1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1, 1].set_title('Validation Overall Accuracy', fontsize=14, fontweight='bold')
axes[1, 1].legend(fontsize=10)
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.suptitle('Phase 2: Segmentation Training Comparison', fontsize=16, fontweight='bold', y=1.002)
plt.show()

## 20. Per-Class IoU Evaluation

In [None]:
def evaluate_segmentation(model, val_loader, num_classes=7, model_name="Model"):
    """
    Full evaluation: per-class IoU, mIoU, and pixel accuracy.
    Uses confusion matrix for accurate computation over entire validation set.
    """
    model.eval()
    
    # Confusion matrix: [num_classes, num_classes]
    confusion = torch.zeros(num_classes, num_classes, dtype=torch.long)
    
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc=f'Evaluating {model_name}'):
            images = images.to(device)
            masks = masks.to(device)
            
            logits = model(images)
            preds = logits.argmax(dim=1).cpu()  # [B, 512, 512]
            masks_cpu = masks.cpu()
            
            # Update confusion matrix
            for cls_true in range(num_classes):
                for cls_pred in range(num_classes):
                    confusion[cls_true, cls_pred] += ((masks_cpu == cls_true) & (preds == cls_pred)).sum().item()
    
    # Per-class IoU (skip class 0 = Background)
    per_class_iou = {}
    for cls in range(1, num_classes):
        tp = confusion[cls, cls].item()
        fp = confusion[:, cls].sum().item() - tp
        fn = confusion[cls, :].sum().item() - tp
        
        if tp + fp + fn == 0:
            per_class_iou[cls] = float('nan')
        else:
            per_class_iou[cls] = tp / (tp + fp + fn)
    
    # mIoU (excluding NaN classes)
    valid_ious = [v for v in per_class_iou.values() if not np.isnan(v)]
    miou = np.mean(valid_ious) if valid_ious else 0.0
    
    # Overall pixel accuracy (excluding background)
    correct = sum(confusion[c, c].item() for c in range(1, num_classes))
    total = sum(confusion[c, :].sum().item() for c in range(1, num_classes))
    overall_acc = 100.0 * correct / total if total > 0 else 0.0
    
    # Print results
    print(f"\n{'='*70}")
    print(f"{model_name} — Segmentation Evaluation")
    print(f"{'='*70}")
    print(f"{'Class':<15} {'IoU':>8}")
    print(f"{'-'*25}")
    for cls_id, iou in per_class_iou.items():
        cls_name = LOVEDA_CLASSES[cls_id]
        if np.isnan(iou):
            print(f"{cls_name:<15} {'N/A':>8}")
        else:
            print(f"{cls_name:<15} {iou:>8.4f}")
    print(f"{'-'*25}")
    print(f"{'mIoU':<15} {miou:>8.4f}")
    print(f"{'Pixel Acc':<15} {overall_acc:>7.2f}%")
    print(f"{'='*70}")
    
    return per_class_iou, miou, overall_acc


# Evaluate both models
per_class_iou_a, miou_a_final, acc_a_final = evaluate_segmentation(
    seg_model_a, val_seg_loader, model_name="Model A (Baseline)")

per_class_iou_b, miou_b_final, acc_b_final = evaluate_segmentation(
    seg_model_b, val_seg_loader, model_name="Model B (Semantic-Enhanced)")

## 21. Per-Class IoU Comparison Chart

In [None]:
# Per-class IoU comparison bar chart
class_names = [LOVEDA_CLASSES[i] for i in range(1, 7)]
iou_a_vals = [per_class_iou_a.get(i, 0) for i in range(1, 7)]
iou_b_vals = [per_class_iou_b.get(i, 0) for i in range(1, 7)]

# Replace NaN with 0 for plotting
iou_a_vals = [0 if np.isnan(v) else v for v in iou_a_vals]
iou_b_vals = [0 if np.isnan(v) else v for v in iou_b_vals]

fig, ax = plt.subplots(figsize=(12, 6))

x = np.arange(len(class_names))
width = 0.35

bars1 = ax.bar(x - width/2, iou_a_vals, width, label='Model A (Baseline)', alpha=0.85, color='orange')
bars2 = ax.bar(x + width/2, iou_b_vals, width, label='Model B (Semantic)', alpha=0.85, color='green')

# Add value labels
for bar in bars1:
    h = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., h + 0.005, f'{h:.3f}',
            ha='center', va='bottom', fontsize=9)
for bar in bars2:
    h = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., h + 0.005, f'{h:.3f}',
            ha='center', va='bottom', fontsize=9)

ax.set_xlabel('Class', fontsize=12)
ax.set_ylabel('IoU', fontsize=12)
ax.set_title('Per-Class IoU Comparison (Segmentation)', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(class_names, fontsize=11)
ax.legend(fontsize=11)
ax.grid(True, axis='y', alpha=0.3)

# Add mIoU lines
ax.axhline(y=miou_a_final, color='orange', linestyle='--', alpha=0.6, label=f'mIoU A: {miou_a_final:.4f}')
ax.axhline(y=miou_b_final, color='green', linestyle='--', alpha=0.6, label=f'mIoU B: {miou_b_final:.4f}')
ax.legend(fontsize=10)

plt.tight_layout()
plt.show()

# Print improvement per class
print("\nPer-Class IoU Improvement:")
print(f"{'Class':<15} {'Model A':>8} {'Model B':>8} {'Delta':>8}")
print(f"{'-'*42}")
for i, name in enumerate(class_names):
    delta = iou_b_vals[i] - iou_a_vals[i]
    print(f"{name:<15} {iou_a_vals[i]:>8.4f} {iou_b_vals[i]:>8.4f} {delta:>+8.4f}")

## 22. Visual Results: Side-by-Side Predictions

Display sample validation images with:
1. Original image
2. Ground truth mask
3. Model A prediction
4. Model B prediction

In [None]:
def denormalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    """Reverse normalization for display."""
    mean = torch.tensor(mean).view(3, 1, 1)
    std = torch.tensor(std).view(3, 1, 1)
    return (tensor.cpu() * std + mean).clamp(0, 1)


# Select sample images from validation set
num_samples = 4
sample_indices = np.linspace(0, len(val_seg_dataset)-1, num_samples, dtype=int)

fig, axes = plt.subplots(num_samples, 4, figsize=(20, 5 * num_samples))

seg_model_a.eval()
seg_model_b.eval()

for row, idx in enumerate(sample_indices):
    image, gt_mask = val_seg_dataset[idx]
    image_input = image.unsqueeze(0).to(device)
    
    # Get predictions
    with torch.no_grad():
        pred_a = seg_model_a(image_input).argmax(dim=1)[0].cpu().numpy()
        pred_b = seg_model_b(image_input).argmax(dim=1)[0].cpu().numpy()
    
    gt_mask_np = gt_mask.numpy()
    
    # Original image (denormalized, resized to 512 for display alignment)
    img_display = denormalize(image).permute(1, 2, 0).numpy()
    
    # Column 0: Original Image
    axes[row, 0].imshow(img_display)
    axes[row, 0].set_title('Input Image' if row == 0 else '', fontsize=12, fontweight='bold')
    axes[row, 0].axis('off')
    
    # Column 1: Ground Truth
    axes[row, 1].imshow(colorize_mask(gt_mask_np))
    axes[row, 1].set_title('Ground Truth' if row == 0 else '', fontsize=12, fontweight='bold')
    axes[row, 1].axis('off')
    
    # Column 2: Model A prediction
    axes[row, 2].imshow(colorize_mask(pred_a))
    axes[row, 2].set_title('Model A (Baseline)' if row == 0 else '', fontsize=12, fontweight='bold')
    axes[row, 2].axis('off')
    
    # Column 3: Model B prediction
    axes[row, 3].imshow(colorize_mask(pred_b))
    axes[row, 3].set_title('Model B (Semantic)' if row == 0 else '', fontsize=12, fontweight='bold')
    axes[row, 3].axis('off')

# Add legend at bottom
legend_patches = []
for cls_id in range(1, 7):
    color = [c / 255.0 for c in LOVEDA_COLORS[cls_id]]
    patch = mpatches.Patch(color=color, label=f"{LOVEDA_CLASSES[cls_id]}")
    legend_patches.append(patch)

fig.legend(handles=legend_patches, loc='lower center', ncol=6,
           fontsize=11, frameon=True, fancybox=True, shadow=True,
           bbox_to_anchor=(0.5, -0.01))

plt.suptitle('Segmentation Predictions: Ground Truth vs Model A vs Model B',
             fontsize=16, fontweight='bold')
plt.tight_layout()
plt.subplots_adjust(bottom=0.05)
plt.show()

## 23. Summary & Cross-Task Comparison

In [None]:
print("\n" + "=" * 70)
print("EXPERIMENT SUMMARY")
print("=" * 70)

print("\n--- PHASE 1: Spatial Knowledge Distillation (Unsupervised) ---")
print(f"  Teacher: RemoteCLIP ViT-B/32 (frozen) — patch tokens [B, 512, 7×7]")
print(f"  Student: Prithvi EO v2 (frozen) + Spatial Projection Head (trainable)")
print(f"  Alignment: Per-position cosine embedding loss")
print(f"  Training Samples: {len(phase1_dataset):,}")
print(f"  Epochs: {phase1_epochs}")
print(f"  Best Loss: {best_phase1_loss:.4f}")
print(f"  Best Cosine Similarity: {phase1_history['cosine_sim'][-1]:.4f}")

print("\n--- PHASE 2: Semantic Segmentation on LoveDA ---")
print(f"  Classes: 7 (Background ignored in evaluation)")
print(f"  Training Samples: {len(train_seg_dataset):,}")
print(f"  Validation Samples: {len(val_seg_dataset):,}")
print(f"  Epochs: 25")

print(f"\n{'─'*70}")
print(f"{'RESULTS':^70}")
print(f"{'─'*70}")
print(f"  {'Metric':<25} {'Model A':>12} {'Model B':>12} {'Delta':>10}")
print(f"  {'-'*60}")
print(f"  {'mIoU':<25} {miou_a_final:>12.4f} {miou_b_final:>12.4f} {miou_b_final-miou_a_final:>+10.4f}")
print(f"  {'Pixel Accuracy':<25} {acc_a_final:>11.2f}% {acc_b_final:>11.2f}% {acc_b_final-acc_a_final:>+9.2f}%")

print(f"\n{'─'*70}")
print(f"{'CROSS-TASK COMPARISON':^70}")
print(f"{'─'*70}")
print(f"  {'Task':<30} {'Model A':>10} {'Model B':>10} {'Delta':>10}")
print(f"  {'-'*62}")
print(f"  {'Classification (Accuracy)':<30} {'35.78%':>10} {'40.38%':>10} {'+4.60%':>10}")
print(f"  {'Segmentation (mIoU)':<30} {miou_a_final:>10.4f} {miou_b_final:>10.4f} {miou_b_final-miou_a_final:>+10.4f}")
print(f"  {'Segmentation (Pixel Acc)':<30} {acc_a_final:>9.2f}% {acc_b_final:>9.2f}% {acc_b_final-acc_a_final:>+9.2f}%")

print(f"\n{'─'*70}")
print(f"KEY INSIGHTS:")
print(f"{'─'*70}")
if miou_b_final > miou_a_final:
    print(f"  * Semantic enhancement via spatial RemoteCLIP alignment improved")
    print(f"    segmentation by {miou_b_final-miou_a_final:+.4f} mIoU on the harder LoveDA task.")
    print(f"  * The spatial projection head successfully transfers semantic knowledge")
    print(f"    to per-pixel predictions, not just global image classification.")
    print(f"  * This validates that RemoteCLIP's semantic understanding benefits")
    print(f"    dense prediction tasks, confirming the generality of the approach.")
else:
    print(f"  * The baseline Prithvi model performed comparably or better on segmentation.")
    print(f"  * Spatial feature compression (1024→512) may discard spatial detail")
    print(f"    that is important for dense per-pixel prediction.")
    print(f"  * The classification improvement (+4.60%) did not fully transfer")
    print(f"    to the harder spatial reasoning task.")

print(f"\n" + "=" * 70)
print(f"EXPERIMENT COMPLETE!")
print("=" * 70)