# Tree-Over-Pavement Segmentation Experiments

This notebook trains and evaluates several models for detecting **tree canopy over pavement (TOP)** using high-resolution NAIP imagery.  

This script uses an existing dataset folder stored in [Google Drive](https://drive.google.com/drive/folders/1IohJ4p7b2hXmYYkassmejs-5YOzG1LSP?usp=drive_link). (To use this, make a shortcut to the folder in your own google drive by right clicking the folder, select organize, the click add shortcut.)

We train and compare four models:

1. **Teacher UNet (6-band):**  
   Uses RGB, NIR, and two auxiliary masks (tree + paved) to produce high-quality predictions.

2. **Student UNet (NAIP-only) with Knowledge Distillation:**  
   Uses only RGB/NIR but learns from the teacher’s predictions to recover missing context.

3. **Baseline UNet (NAIP-only):**  
   Same architecture as the student but trained without teacher knowledge distillation.

The notebook also handles:
- class-imbalance weighting
- LR scheduling + early stopping
- full test metrics (mIoU, F1, per-class precision/recall)
- visual comparisons across all models

This establishes a controlled framework for studying how well NAIP-only models can recover the complex TOP signal compared to a fully informed 6-band teacher.


## 1. Environment Setup & Imports
   Install dependencies, mount Google Drive, and import all required libraries.

In [5]:
!pip -q install rasterio torchinfo torchmetrics segmentation_models_pytorch scikit-learn

import os
import glob
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
import pandas as pd
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchinfo import summary
import segmentation_models_pytorch as smp

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix

from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## 2. Global Configuration and Hyperparameters
Here we define dataset paths, class definitions, NAIP channels, and global training hyperparameters used across teacher, student, baseline models.  
These parameters control optimizer settings, learning rates, patience for early stopping, knowledge distillation strength, and loss weighting for class imbalance.


In [6]:
# GLOBAL Variables

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Paths
dataset_root = "/content/drive/MyDrive/PhD/Research/Research/TOP/dataset" # REPLACE with the path to your copy of the dataset folder your google drive or your local dataset folder created by the preprocessing scripts

# Training sizes
EPOCHS_TEACHER = 100
EPOCHS_STUDENT = 100
EPOCHS_BASELINE = 100

# Optimizers
TEACHER_LR = 1e-3
STUDENT_LR = 1e-3
BASELINE_LR = 1e-3

WEIGHT_DECAY = 1e-4

# Early stopping / LR schedule controls
LR_MIN      = 1e-5     # minimum learning rate
LR_PATIENCE = 3        # epochs of no val improvement at LR_MIN before stop
IMPROVE_EPS = 1e-4     # minimal improvement threshold in mIoU

# Distillation parameters
KD_ALPHA = 0.8   # weight on CE
KD_T     = 2.0   # temperature for distillation

# Gradient clipping
GRAD_CLIP = 1.0

# Channels for NAIP-only model
NAIP_CHANNELS = (0, 1, 2, 3)

# Classes
N_CLASSES = 3


Using device: cuda


## 3. Data Augmentation and Dataset Loading
We apply lightweight geometric augmentations (random flips and rotations) to increase model robustness.  
Augmentation is used **only for the training dataset**, helping the UNets generalize without altering the validation or test distributions.


In [7]:
class RandomFlipRotate:
    def __call__(self, img_t, mask_t):
        # img_t: (C,H,W), mask_t: (H,W)

        # horiz flip
        if random.random() < 0.5:
            img_t = torch.flip(img_t, dims=[2])
            mask_t = torch.flip(mask_t, dims=[1])

        # vert flip
        if random.random() < 0.5:
            img_t = torch.flip(img_t, dims=[1])
            mask_t = torch.flip(mask_t, dims=[0])

        # random 90 deg rotation
        k = random.randint(0, 3)
        if k > 0:
            img_t = torch.rot90(img_t, k, dims=[1, 2])
            mask_t = torch.rot90(mask_t, k, dims=[0, 1])

        return img_t, mask_t


### Dataset Definition and Dataloaders
This section defines the PyTorch dataset class used to load 6-band input tiles and 3-class segmentation masks.  
We use the train/val/test splits created from the preprocessing scripts, and create DataLoaders for efficient batching during training and evaluation.


In [8]:
class NAIPSegDataset(Dataset):
    """
    6-band input (R,G,B,NIR, tree_binary, paved_binary)
    Label: single-band class map with values {0,1,2}.
    """
    def __init__(self, split_root, transform=None):
        self.img_dir = os.path.join(split_root, "images")
        self.lab_dir = os.path.join(split_root, "labels")
        self.img_files = sorted(glob.glob(os.path.join(self.img_dir, "*_stack.tif")))
        if len(self.img_files) == 0:
            raise RuntimeError(f"No *_stack.tif files found in {self.img_dir}")
        self.transform = transform

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_path = self.img_files[idx]
        base = os.path.basename(img_path).replace("_stack.tif", "")
        lab_path = os.path.join(self.lab_dir, f"{base}_label.tif")

        # 6-band image
        with rasterio.open(img_path) as src:
            img = src.read().astype(np.float32)  # (C,H,W)

        # normalize NAIP bands (0-3) to 0-1
        img[0:4, :, :] = img[0:4, :, :] / 255.0

        # label
        with rasterio.open(lab_path) as src:
            mask = src.read(1).astype(np.int64)  # (H,W)

        img_t = torch.from_numpy(img)
        mask_t = torch.from_numpy(mask)

        if self.transform is not None:
            img_t, mask_t = self.transform(img_t, mask_t)

        return img_t, mask_t


# Create datasets & dataloaders
train_ds = NAIPSegDataset(os.path.join(dataset_root, "train"), transform=RandomFlipRotate())
val_ds   = NAIPSegDataset(os.path.join(dataset_root, "val"),   transform=None)
test_ds  = NAIPSegDataset(os.path.join(dataset_root, "test"),  transform=None)


batch_size = 8
num_workers = 2

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=num_workers)
val_dl   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_dl  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=num_workers)

print("Train:", len(train_ds), "Val:", len(val_ds), "Test:", len(test_ds))

xb, yb = next(iter(train_dl))
print("Batch shapes:", xb.shape, yb.shape)  # (B,6,H,W) and (B,H,W)


Train: 1775 Val: 380 Test: 381
Batch shapes: torch.Size([8, 6, 224, 224]) torch.Size([8, 224, 224])


## 4. Class Weights for Imbalanced Labels
Due to the nature of the tree-over-pavement class, the dataset is highly imbalanced, meaning there are far less pixels belonging to that class.  
To mitigate this, we compute class weights and then slightly boost the TOP class’s weight in the model, since it is the target class.  
These weights are used by all UNet models during training.


In [9]:
def compute_class_weights(dataset, num_classes=N_CLASSES):
    """
    Compute inverse-frequency class weights from a dataset.
    Assumes labels are in {0, 1, 2}.
    """
    counts = torch.zeros(num_classes, dtype=torch.float64)

    for i in range(len(dataset)):
        _, mask = dataset[i]          # mask: (H,W) tensor
        mask = mask.reshape(-1)

        for c in range(num_classes):
            counts[c] += (mask == c).sum().item()

    total = counts.sum()

    # inverse-frequency weights
    weights = total / (counts + 1e-6)

    # normalize so mean weight ~ 1.0
    weights = weights / weights.mean()

    print("Class pixel counts:", counts.tolist())
    print("Class weights (normalized):", weights.tolist())
    return weights.float()

w = compute_class_weights(train_ds, N_CLASSES)  # [w_bg, w_top, w_tnp]
# Bias toward TOP class
w[0] = w[0] * 0.7    # BG
w[1] = w[1] * 1.6    # TOP
w[2] = w[2] * 1.1    # TNP (less boost than TOP)

w = w / w.mean()
CLASS_WEIGHTS = w.to(device)

Class pixel counts: [70264054.0, 15329107.0, 3469239.0]
Class weights (normalized): [0.11611187596120413, 0.5322222046319419, 2.351665919406854]


## 5. Core Metrics - Confusion Matrix, IoU, and Accuracy
Segmentation performance is measured via a confusion matrix over all classes.  
From this we compute per-class precision, recall, F1, IoU, as well as mean IoU and overall accuracy.  
These metrics are used throughout training and for final test-set evaluation.


In [10]:
def update_confusion_matrix(cm, preds, targets, num_classes):
    """
    cm: (C,C) numpy or torch on CPU
    preds, targets: (H,W) or (N,) integer tensors
    """
    preds = preds.view(-1)
    targets = targets.view(-1)
    k = (targets >= 0) & (targets < num_classes)
    inds = num_classes * targets[k] + preds[k]
    cm_flat = cm.view(-1)
    bincount = torch.bincount(inds, minlength=num_classes**2).to(cm_flat.dtype)
    cm_flat += bincount
    return cm

def iou_from_confusion(cm):
    """
    cm: (C,C)
    returns per-class IoU and mean IoU
    """
    cm = cm.float()
    tp = torch.diag(cm)
    fp = cm.sum(dim=0) - tp
    fn = cm.sum(dim=1) - tp
    denom = tp + fp + fn + 1e-6
    iou = tp / denom
    miou = iou.mean()
    return iou, miou

def accuracy_from_confusion(cm):
    cm = cm.float()
    correct = torch.diag(cm).sum()
    total = cm.sum()
    return correct / (total + 1e-6)


## 6. Model Training

### Teacher UNet Training (6 Band Input)
The teacher model receives all 6 input bands (RGB, NIR, tree mask, pavement mask).  
It is trained with weighted cross-entropy and acts as a supervision signal for the NAIP-only student model.  
We use ReduceLROnPlateau and an early-stopping rule that stops training once the learning rate has fully decayed and validation mIoU no longer improves.


In [11]:
def run_epoch_6band(model, dataloader, optimizer=None):
    is_train = optimizer is not None
    model.train(is_train)

    loss_fn = nn.CrossEntropyLoss(weight=CLASS_WEIGHTS)
    cm = torch.zeros((N_CLASSES, N_CLASSES), dtype=torch.int64, device="cpu")

    total_loss = 0.0
    batches = 0

    for xb, yb in dataloader:
        xb = xb.to(device)
        yb = yb.to(device).long()

        if is_train:
            optimizer.zero_grad()

        logits = model(xb)
        loss = loss_fn(logits, yb)

        if is_train:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRAD_CLIP)
            optimizer.step()

        preds = torch.argmax(logits.detach(), dim=1).cpu()
        targets = yb.detach().cpu()
        cm = update_confusion_matrix(cm, preds, targets, num_classes=N_CLASSES)

        total_loss += loss.item()
        batches += 1

    avg_loss = total_loss / max(batches, 1)
    per_class_iou, miou = iou_from_confusion(cm)
    acc = accuracy_from_confusion(cm)

    return avg_loss, acc.item(), miou.item(), per_class_iou.cpu().numpy()


def train_teacher(model, train_dl, val_dl):
    optimizer = torch.optim.AdamW(model.parameters(), lr=TEACHER_LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="max",
        factor=0.5,
        patience=2,
        min_lr=LR_MIN
    )

    history = {
        "train_loss": [],
        "train_miou": [],
        "val_loss": [],
        "val_miou": [],
        "lr": []
    }

    best_val_miou = -1.0
    best_state = None
    no_improve_epochs = 0

    for epoch in range(1, EPOCHS_TEACHER + 1):
        train_loss, train_acc, train_miou, _ = run_epoch_6band(model, train_dl, optimizer)
        val_loss, val_acc, val_miou, _ = run_epoch_6band(model, val_dl, optimizer=None)

        history["train_loss"].append(train_loss)
        history["train_miou"].append(train_miou)
        history["val_loss"].append(val_loss)
        history["val_miou"].append(val_miou)

        # Step LR scheduler on validation performance
        scheduler.step(val_miou)
        current_lr = optimizer.param_groups[0]["lr"]
        history["lr"].append(current_lr)

        # Check improvement
        if val_miou > best_val_miou + IMPROVE_EPS:
            best_val_miou = val_miou
            best_state = model.state_dict()
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1

        print(
            f"[Teacher] Epoch {epoch:02d} | "
            f"train_loss={train_loss:.4f}, train_mIoU={train_miou:.3f}, "
            f"val_loss={val_loss:.4f}, val_mIoU={val_miou:.3f}, "
            f"lr={current_lr:.2e}"
        )

        # Early stopping: LR hit minimum & no improvement for LR_PATIENCE epochs
        if current_lr <= LR_MIN + 1e-12 and no_improve_epochs >= LR_PATIENCE:
            print("Early stopping (Teacher): LR at min and val mIoU plateaued.")
            break

    if best_state is not None:
        model.load_state_dict(best_state)

    return model, history, best_val_miou




## NAIP-Only Epoch Logic
This function defines a shared epoch loop for models that only use the 4 NAIP bands (RGB + NIR).  
The same structure is reused for both the KD-trained student model and the NAIP-only baseline.


In [12]:
def run_epoch_naip_only(model, dataloader, optimizer=None):
    is_train = optimizer is not None
    model.train(is_train)

    loss_fn = nn.CrossEntropyLoss(weight=CLASS_WEIGHTS)
    cm = torch.zeros((N_CLASSES, N_CLASSES), dtype=torch.int64, device="cpu")

    total_loss = 0.0
    batches = 0

    for xb, yb in dataloader:
        xb = xb.to(device)
        yb = yb.to(device).long()

        xb_naip = xb[:, NAIP_CHANNELS, :, :]

        if is_train:
            optimizer.zero_grad()

        logits = model(xb_naip)
        loss = loss_fn(logits, yb)

        if is_train:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRAD_CLIP)
            optimizer.step()

        preds = torch.argmax(logits.detach(), dim=1).cpu()
        targets = yb.detach().cpu()
        cm = update_confusion_matrix(cm, preds, targets, num_classes=N_CLASSES)

        total_loss += loss.item()
        batches += 1

    avg_loss = total_loss / max(batches, 1)
    per_class_iou, miou = iou_from_confusion(cm)
    acc = accuracy_from_confusion(cm)

    return avg_loss, acc.item(), miou.item(), per_class_iou.cpu().numpy()


## Student UNet Training with Knowledge Distillation
The student receives only NAIP bands, so we guide it using results from the frozen teacher.  
Training combines:
- Weighted cross-entropy on ground-truth labels, and  
- KL-divergence to the teacher’s stats (temperature T, weight α).  

This helps compensate for weaker inputs by transferring structure learned by the teacher.


In [13]:
def train_student_with_distillation(student, teacher, train_dl, val_dl):
    # freeze teacher
    teacher.eval()
    for p in teacher.parameters():
        p.requires_grad = False

    optimizer = torch.optim.AdamW(student.parameters(), lr=STUDENT_LR, weight_decay=WEIGHT_DECAY)

    ce_loss_fn = nn.CrossEntropyLoss(weight=CLASS_WEIGHTS)
    kl_loss_fn = nn.KLDivLoss(reduction="mean")

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="max",
        factor=0.5,
        patience=2,
        min_lr=LR_MIN
    )

    history = {
        "train_loss": [],
        "train_miou": [],
        "val_loss": [],
        "val_miou": [],
        "lr": []
    }

    best_val_miou = -1.0
    best_state = None
    no_improve_epochs = 0

    for epoch in range(1, EPOCHS_STUDENT + 1):
        student.train()
        total_train_loss = 0.0
        batches = 0
        cm_train = torch.zeros((N_CLASSES, N_CLASSES), dtype=torch.int64)

        for xb, yb in train_dl:
            xb = xb.to(device)
            yb = yb.to(device).long()

            xb_full = xb
            xb_naip = xb[:, NAIP_CHANNELS, :, :]

            with torch.no_grad():
                teacher_logits = teacher(xb_full)

            optimizer.zero_grad()

            student_logits = student(xb_naip)

            loss_ce = ce_loss_fn(student_logits, yb)

            p_t = F.softmax(teacher_logits / KD_T, dim=1)
            log_p_s = F.log_softmax(student_logits / KD_T, dim=1)
            loss_kd = kl_loss_fn(log_p_s, p_t)

            loss = KD_ALPHA * loss_ce + (1 - KD_ALPHA) * loss_kd
            loss.backward()

            torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=GRAD_CLIP)
            optimizer.step()

            preds = torch.argmax(student_logits.detach(), dim=1).cpu()
            targets = yb.detach().cpu()
            cm_train = update_confusion_matrix(cm_train, preds, targets, N_CLASSES)

            total_train_loss += loss.item()
            batches += 1

        train_loss = total_train_loss / max(batches, 1)
        _, train_miou = iou_from_confusion(cm_train)

        val_loss, val_acc, val_miou, _ = run_epoch_naip_only(student, val_dl, optimizer=None)

        history["train_loss"].append(train_loss)
        history["train_miou"].append(train_miou.item())
        history["val_loss"].append(val_loss)
        history["val_miou"].append(val_miou)

        scheduler.step(val_miou)
        current_lr = optimizer.param_groups[0]["lr"]
        history["lr"].append(current_lr)

        if val_miou > best_val_miou + IMPROVE_EPS:
            best_val_miou = val_miou
            best_state = student.state_dict()
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1

        print(
            f"[Student] Epoch {epoch:02d} | "
            f"train_loss={train_loss:.3f}, train_mIoU={train_miou:.3f}, "
            f"val_loss={val_loss:.3f}, val_mIoU={val_miou:.3f}, "
            f"lr={current_lr:.2e}"
        )

        if current_lr <= LR_MIN + 1e-12 and no_improve_epochs >= LR_PATIENCE:
            print("Early stopping (Student): LR at min and val mIoU plateaued.")
            break

    if best_state is not None:
        student.load_state_dict(best_state)

    return student, history, best_val_miou


## Training the Teacher and Student Models
Here we instantiate the teacher (6-band) and student (4-band) UNets, then train them using the previously defined loops.  
The teacher is trained first and remains frozen while supervising the student.  
Both use the same architecture (ResNet-34 backbone) for a fair comparison.

In [None]:
teacher_unet_6ch = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None,
    in_channels=6,
    classes=N_CLASSES
).to(device)

teacher_unet_6ch, teacher_hist, teacher_best = train_teacher(
    teacher_unet_6ch, train_dl, val_dl
)

In [None]:
# Student: NAIP-only UNet (4 bands)
student_unet_naip = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None,
    in_channels=len(NAIP_CHANNELS),  # 4 bands
    classes=N_CLASSES,
).to(device)

student_unet_naip, student_hist, student_best = train_student_with_distillation(
    student_unet_naip, teacher_unet_6ch, train_dl, val_dl
)


### Baseline UNet (NAIP-Only, No Distillation)

To isolate the benefit of distillation, we train a NAIP-only UNet with the same architecture as the student but **without** any teacher supervision.  


In [None]:
def train_naip_baseline(model, train_dl, val_dl):
    """
    Baseline UNet, NAIP-only input (4 bands), no teacher / no KD.
    Uses: BASELINE_LR, WEIGHT_DECAY, EPOCHS_BASELINE, GRAD_CLIP, NAIP_CHANNELS.
    """

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=BASELINE_LR,
        weight_decay=WEIGHT_DECAY
    )

    ce_loss_fn = nn.CrossEntropyLoss(weight=CLASS_WEIGHTS)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="max",
        factor=0.5,
        patience=2,
        min_lr=LR_MIN
    )

    history = {
        "train_loss": [],
        "train_miou": [],
        "val_loss": [],
        "val_miou": [],
        "lr": []        # <<< NEW
    }

    best_val_miou = -1.0
    best_state = None
    no_improve_epochs = 0

    for epoch in range(1, EPOCHS_BASELINE + 1):
        # ---- TRAIN ----
        model.train()
        total_train_loss = 0.0
        batches = 0
        cm_train = torch.zeros((N_CLASSES, N_CLASSES), dtype=torch.int64)

        for xb, yb in train_dl:
            xb = xb.to(device)
            yb = yb.to(device).long()

            xb_naip = xb[:, NAIP_CHANNELS, :, :]

            optimizer.zero_grad()

            logits = model(xb_naip)
            loss = ce_loss_fn(logits, yb)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRAD_CLIP)
            optimizer.step()

            preds   = torch.argmax(logits.detach(), dim=1).cpu()
            targets = yb.detach().cpu()
            cm_train = update_confusion_matrix(cm_train, preds, targets, N_CLASSES)

            total_train_loss += loss.item()
            batches += 1

        train_loss = total_train_loss / max(batches, 1)
        _, train_miou = iou_from_confusion(cm_train)

        # ---- VALIDATION ----
        val_loss, val_acc, val_miou, _ = run_epoch_naip_only(model, val_dl, optimizer=None)

        history["train_loss"].append(train_loss)
        history["train_miou"].append(train_miou.item())
        history["val_loss"].append(val_loss)
        history["val_miou"].append(val_miou)

        scheduler.step(val_miou)
        current_lr = optimizer.param_groups[0]["lr"]
        history["lr"].append(current_lr)

        if val_miou > best_val_miou + IMPROVE_EPS:
            best_val_miou = val_miou
            best_state = model.state_dict()
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1

        print(
            f"[Baseline NAIP] Epoch {epoch:02d} | "
            f"train_loss={train_loss:.3f}, train_mIoU={train_miou:.3f}, "
            f"val_loss={val_loss:.3f}, val_mIoU={val_miou:.3f}, "
            f"lr={current_lr:.2e}"
        )

        if current_lr <= LR_MIN + 1e-12 and no_improve_epochs >= LR_PATIENCE:
            print("Early stopping (Baseline): LR at min and val mIoU plateaued.")
            break

    if best_state is not None:
        model.load_state_dict(best_state)

    return model, history, best_val_miou


In [None]:
# BASELINE: NAIP-only UNet model

baseline_unet_naip = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None,          # NO pretrained weights
    in_channels=len(NAIP_CHANNELS),# 4 NAIP bands
    classes=N_CLASSES
).to(device)

print("Baseline UNet initialized.")


baseline_unet_naip, baseline_hist, baseline_best_val_miou = train_naip_baseline(
    baseline_unet_naip,
    train_dl,
    val_dl
)

print("Best baseline val mIoU:", baseline_best_val_miou)


## 7. Training and Model Results

### Training Diagnostics: Curves for Loss, mIoU, and Learning Rate

To understand model behavior over time, we define a helper to plot:

- Training and validation loss per epoch
- Training and validation mIoU per epoch
- Learning rate schedule per epoch  

We then apply this to the teacher, student, and baseline histories to inspect convergence and over/under-fitting.


In [None]:
def plot_training_curves(name, history):
    """
    Plots train/val loss, train/val mIoU, and LR vs epoch.
    Expects keys: 'train_loss', 'val_loss', 'train_miou', 'val_miou', 'lr'.
    """
    epochs = np.arange(1, len(history["train_loss"]) + 1)

    # Loss + mIoU
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    # Loss
    axes[0].plot(epochs, history["train_loss"], label="Train loss")
    axes[0].plot(epochs, history["val_loss"], label="Val loss")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Loss")
    axes[0].set_title(f"{name} – Loss")
    axes[0].legend()

    # mIoU
    axes[1].plot(epochs, history["train_miou"], label="Train mIoU")
    axes[1].plot(epochs, history["val_miou"], label="Val mIoU")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("mIoU")
    axes[1].set_title(f"{name} – mIoU")
    axes[1].legend()

    plt.tight_layout()
    plt.show()

    # Learning rate
    if "lr" in history:
        plt.figure(figsize=(6, 4))
        plt.plot(epochs, history["lr"])
        plt.xlabel("Epoch")
        plt.ylabel("Learning rate")
        plt.title(f"{name} – LR schedule")
        plt.show()


# --- Call for each model ---
plot_training_curves("Teacher UNet (6-band)", teacher_hist)
plot_training_curves("Student UNet (KD, NAIP-only)", student_hist)
plot_training_curves("Baseline UNet (NAIP-only)", baseline_hist)


### Metrics from the Confusion Matrix

Beyond basic IoU and accuracy, we compute richer metrics from the confusion matrix:

- Per-class precision, recall, and F1  
- Per-class IoU  
- Overall accuracy, mean F1, and mean IoU  

These provide a more detailed view of how each model performs on background, TOP, and TNP classes.


In [None]:
def metrics_from_confusion(cm):
    """
    cm: (C,C) confusion matrix
    Returns:
        precision_per_class
        recall_per_class
        f1_per_class
        iou_per_class
        overall_accuracy
        mean_f1
        mean_iou
    """
    cm = cm.float()
    tp = torch.diag(cm)                     # (C,)
    fp = cm.sum(dim=0) - tp                 # predicted positive but wrong
    fn = cm.sum(dim=1) - tp                 # missed positive
    tn = cm.sum() - (tp + fp + fn)

    precision = tp / (tp + fp + 1e-6)
    recall    = tp / (tp + fn + 1e-6)
    f1        = 2 * precision * recall / (precision + recall + 1e-6)

    # IoU per class
    iou = tp / (tp + fp + fn + 1e-6)

    # mean metrics
    mean_f1  = f1.mean()
    mean_iou = iou.mean()

    # overall accuracy
    acc = (tp.sum() / (cm.sum() + 1e-6))

    return (
        precision.cpu().numpy(),
        recall.cpu().numpy(),
        f1.cpu().numpy(),
        iou.cpu().numpy(),
        acc.item(),
        mean_f1.item(),
        mean_iou.item()
    )


### Model Evaluation Helpers

Here we define two evaluation functions:

- `eval_teacher_6band_allmetrics` for the 6-band teacher, and  
- `eval_naip_4band_allmetrics` for the NAIP-only models.

Each runs through the test set, generates a confusion matrix, and computes a cross-entropy loss.


In [None]:
# ---------- EVAL FUNCTIONS ----------

def eval_teacher_6band_allmetrics(model, data_dl):
    model.eval()
    ce = nn.CrossEntropyLoss()

    total_loss = 0.0
    batches = 0
    cm = torch.zeros((N_CLASSES, N_CLASSES), dtype=torch.int64)

    with torch.no_grad():
        for xb, yb in data_dl:
            xb = xb.to(device)
            yb = yb.to(device).long()

            logits = model(xb)
            loss   = ce(logits, yb)

            preds   = torch.argmax(logits, dim=1).cpu()
            targets = yb.cpu()
            cm      = update_confusion_matrix(cm, preds, targets, N_CLASSES)

            total_loss += loss.item()
            batches += 1

    avg_loss = total_loss / max(batches, 1)

    (precision, recall, f1, iou, acc, mean_f1, mean_iou) = metrics_from_confusion(cm)

    return avg_loss, acc, mean_iou, mean_f1, precision, recall, f1, iou


def eval_naip_4band_allmetrics(model, data_dl):
    model.eval()
    ce = nn.CrossEntropyLoss()

    total_loss = 0.0
    batches = 0
    cm = torch.zeros((N_CLASSES, N_CLASSES), dtype=torch.int64)

    with torch.no_grad():
        for xb, yb in data_dl:
            xb = xb.to(device)
            yb = yb.to(device).long()

            xb_naip = xb[:, NAIP_CHANNELS, :, :]

            logits = model(xb_naip)
            loss   = ce(logits, yb)

            preds   = torch.argmax(logits, dim=1).cpu()
            targets = yb.cpu()
            cm      = update_confusion_matrix(cm, preds, targets, N_CLASSES)

            total_loss += loss.item()
            batches += 1

    avg_loss = total_loss / max(batches, 1)
    (precision, recall, f1, iou, acc, mean_f1, mean_iou) = metrics_from_confusion(cm)

    return avg_loss, acc, mean_iou, mean_f1, precision, recall, f1, iou

# ---------- RUN EVAL ON TEST SET ----------

# Teacher
t_loss, t_acc, t_miou, t_mf1, t_prec, t_rec, t_f1, t_iou = eval_teacher_6band_allmetrics(
    teacher_unet_6ch, test_dl
)

# Student (KD)
s_loss, s_acc, s_miou, s_mf1, s_prec, s_rec, s_f1, s_iou = eval_naip_4band_allmetrics(
    student_unet_naip, test_dl
)

# Baseline
b_loss, b_acc, b_miou, b_mf1, b_prec, b_rec, b_f1, b_iou = eval_naip_4band_allmetrics(
    baseline_unet_naip, test_dl
)


# ---------- BUILD TABLE ----------

rows = [
    {
        "Model": "Teacher UNet (6-band)",
        "Bands": "6",
        "Test Loss": t_loss,
        "Accuracy": t_acc,
        "mIoU": t_miou,
        "mF1": t_mf1,
        "Precision_BG": t_prec[0],
        "Recall_BG": t_rec[0],
        "F1_BG": t_f1[0],
        "Precision_TOP": t_prec[1],
        "Recall_TOP": t_rec[1],
        "F1_TOP": t_f1[1],
        "Precision_TNP": t_prec[2],
        "Recall_TNP": t_rec[2],
        "F1_TNP": t_f1[2],
    },
    {
        "Model": "Student UNet (KD, NAIP-only)",
        "Bands": "4",
        "Test Loss": s_loss,
        "Accuracy": s_acc,
        "mIoU": s_miou,
        "mF1": s_mf1,
        "Precision_BG": s_prec[0],
        "Recall_BG": s_rec[0],
        "F1_BG": s_f1[0],
        "Precision_TOP": s_prec[1],
        "Recall_TOP": s_rec[1],
        "F1_TOP": s_f1[1],
        "Precision_TNP": s_prec[2],
        "Recall_TNP": s_rec[2],
        "F1_TNP": s_f1[2],
    },
    {
        "Model": "Baseline UNet (NAIP-only)",
        "Bands": "4",
        "Test Loss": b_loss,
        "Accuracy": b_acc,
        "mIoU": b_miou,
        "mF1": b_mf1,
        "Precision_BG": b_prec[0],
        "Recall_BG": b_rec[0],
        "F1_BG": b_f1[0],
        "Precision_TOP": b_prec[1],
        "Recall_TOP": b_rec[1],
        "F1_TOP": b_f1[1],
        "Precision_TNP": b_prec[2],
        "Recall_TNP": b_rec[2],
        "F1_TNP": b_f1[2],
    },
]

df = pd.DataFrame(rows)
df_rounded = df.round(3)
df_rounded



### Qualitative Comparison: Student vs Baseline

Finally, we visualize a few random test tiles to compare:

- NAIP RGB input  
- Ground-truth labels  
- Student (KD NAIP-only) predictions  
- Baseline (NAIP-only) predictions  

These side-by-side views highlight where knowledge distillation helps, where both models struggle, and how their errors differ in practice.


In [None]:
# colormap: 0 = background (black), 1 = TOP (yellow), 2 = TNP (green)
label_colors = ["black", "yellow", "green"]
label_cmap = ListedColormap(label_colors)
label_norm = BoundaryNorm([-0.5, 0.5, 1.5, 2.5], label_cmap.N)

def unnormalize_rgb_from_6band(img_tensor, rgb_indices=(0,1,2)):
    """
    img_tensor: (6,H,W), where NAIP bands (0-3) are scaled 0-1.
    We'll turn them back into 0-255 for display.
    """
    img_np = img_tensor.cpu().numpy()
    rgb = img_np[list(rgb_indices), :, :]          # (3,H,W), 0-1
    rgb = np.clip(rgb * 255.0, 0, 255).astype(np.uint8)
    rgb = np.transpose(rgb, (1,2,0))              # (H,W,3)
    return rgb

In [None]:
def visualize_student_baseline(
    student_model,
    baseline_model,
    dataset,
    n_samples=5,
    rgb_indices=(0,1,2),
    naip_channels=NAIP_CHANNELS,
    title="Student vs Baseline"
):
    """
    For n_samples random tiles from `dataset`, show per row:
      [ NAIP RGB | Label | Student pred | Baseline pred ]
    """
    student_model.eval()
    baseline_model.eval()

    n_samples = min(n_samples, len(dataset))
    idxs = np.random.choice(len(dataset), size=n_samples, replace=False)

    fig, axes = plt.subplots(n_samples, 5, figsize=(20, 4 * n_samples))
    if n_samples == 1:
        axes = np.expand_dims(axes, axis=0)

    with torch.no_grad():
        for row, idx in enumerate(idxs):
            img_t, mask_t = dataset[idx]  # img_t: (6,H,W), mask_t: (H,W)

            # NAIP-only tensor for the CNNs
            inp_naip = img_t[naip_channels, :, :].unsqueeze(0).to(device)  # (1,4,H,W)

            # Student prediction
            logits_student = student_model(inp_naip)
            pred_student = torch.argmax(logits_student, dim=1).squeeze(0).cpu().numpy()

            # Baseline prediction
            logits_baseline = baseline_model(inp_naip)
            pred_baseline = torch.argmax(logits_baseline, dim=1).squeeze(0).cpu().numpy()

            # RGB + label
            rgb_img = unnormalize_rgb_from_6band(img_t, rgb_indices=rgb_indices)
            mask_np = mask_t.cpu().numpy()

            ax0, ax1, ax2, ax3, ax4 = axes[row]

            # Column 1: NAIP RGB
            ax0.imshow(rgb_img)
            ax0.set_title(f"NAIP RGB (tile {idx})")
            ax0.axis("off")

            # Column 2: Ground truth label
            ax1.imshow(mask_np, cmap=label_cmap, norm=label_norm)
            ax1.set_title("Label (0=bg, 1=TOP, 2=TNP)")
            ax1.axis("off")

            # Column 3: Student prediction
            ax2.imshow(pred_student, cmap=label_cmap, norm=label_norm)
            ax2.set_title("Student UNet (NAIP+KD)")
            ax2.axis("off")

            # Column 4: Baseline prediction
            ax3.imshow(pred_baseline, cmap=label_cmap, norm=label_norm)
            ax3.set_title("Baseline UNet (NAIP-only)")
            ax3.axis("off")

    plt.suptitle(title, y=0.99, fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.show()


# Call it on your test set:
visualize_student_baseline(
    student_unet_naip,
    test_ds,
    n_samples=5,
    rgb_indices=(0,1,2),
    naip_channels=NAIP_CHANNELS,
    title="Student vs Baseline – Test tiles"
)
