In [None]:
# ======================
# IMPORTS & CONFIGURATION
# ======================

from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader
import pandas as pd

# This notebook is in src/, so root is ".."
ROOT = Path("..")

# Local imports
from dataset import SatelliteSegDataset
from unet import UNet
from metrics import confusion_matrix, iou_from_cm

# Training Metrics Visualization Notebook

This notebook evaluates model performance on train/test splits:
- Compute confusion matrices
- Calculate per-class metrics (IoU, Dice, Precision, Recall, F1)
- Visualize confusion matrices and per-class performance
- Compare predictions with ground truth labels

In [None]:
# ======================
# CONFIGURATION
# ======================

# Model checkpoint path
CKPT_PATH = ROOT / "outputs" / "best_unet.pth"

# Dataset split to evaluate
SPLIT = "test"  # "train" or "test"

# DataLoader configuration
BATCH_SIZE = 32       # CPU-friendly batch size
NUM_WORKERS = 6       # Number of data loading workers (0 on Windows if issues)

# Model configuration
NUM_CLASSES = 10
IGNORE_INDEX = 255
IGNORE_LABELS = (0, 1)  # Classes to ignore in metrics

# Class names mapping
CLASS_NAMES = {
    0: "no_data",
    1: "clouds",
    2: "artificial",
    3: "cultivated",
    4: "broadleaf",
    5: "coniferous",
    6: "herbaceous",
    7: "natural_soil",
    8: "permanent_snow",
    9: "water",
}

## Dataset Setup

Load evaluation dataset and model.

In [None]:
# Create dataset and dataloader
ds = SatelliteSegDataset(
    images_dir=ROOT / "dataset" / SPLIT / "images",
    masks_dir=ROOT / "dataset" / SPLIT / "masks",
    ignore_labels=IGNORE_LABELS,
    ignore_index=IGNORE_INDEX
)

loader = DataLoader(
    ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS
)

print("======================")
print("Dataset configuration:")
print("======================")
print(f"Split: {SPLIT}")
print(f"Num samples: {len(ds)}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Checkpoint: {CKPT_PATH}")
print(f"✓ Images dir exists: {(ROOT / 'dataset' / SPLIT / 'images').exists()}")
print(f"✓ Masks dir exists: {(ROOT / 'dataset' / SPLIT / 'masks').exists()}")

## Model Loading

Load trained UNet checkpoint.

In [None]:
# ======================
# LOAD MODEL
# ======================

device = torch.device("cpu")
model = UNet(in_channels=4, num_classes=NUM_CLASSES, base_channels=32).to(device)
state = torch.load(CKPT_PATH, map_location=device)
model.load_state_dict(state)
model.eval()

print("✓ Model loaded successfully")

## Metrics Calculation

Compute IoU, Dice, Precision, Recall, F1 from confusion matrix.

In [None]:
def metrics_from_cm(cm: torch.Tensor) -> dict:
    """
    Compute metrics from confusion matrix.
    
    Parameters:
      - cm: (C, C) int64 confusion matrix
      
    Returns:
      - Dictionary with per-class and global metrics:
        - iou_per_class, miou: Intersection over Union
        - acc_global: Global pixel accuracy
        - precision, recall, f1: Per-class metrics
        - dice: Dice coefficient per class
        - support: Ground truth pixel count per class
    """
    cm = cm.to(torch.float32)

    tp = torch.diag(cm)                      # True positives
    fp = cm.sum(0) - tp                      # False positives
    fn = cm.sum(1) - tp                      # False negatives
    tn = cm.sum() - (tp + fp + fn)           # True negatives

    # -------- IoU --------
    denom_iou = tp + fp + fn
    iou = torch.where(denom_iou > 0, tp / denom_iou, torch.zeros_like(denom_iou))
    valid = denom_iou > 0
    miou = iou[valid].mean() if valid.any() else torch.tensor(0.0)

    # -------- Global Accuracy --------
    acc_global = tp.sum() / (cm.sum() + 1e-12)

    # -------- Precision, Recall, F1 --------
    precision = torch.where((tp + fp) > 0, tp / (tp + fp), torch.zeros_like(tp))
    recall = torch.where((tp + fn) > 0, tp / (tp + fn), torch.zeros_like(tp))
    f1 = torch.where((precision + recall) > 0, 2 * precision * recall / (precision + recall), torch.zeros_like(tp))

    # -------- Dice --------
    dice = torch.where((2*tp + fp + fn) > 0, (2*tp) / (2*tp + fp + fn), torch.zeros_like(tp))

    # -------- Support --------
    support = cm.sum(1)  # Ground truth pixels per class

    return {
        "iou": iou,
        "miou": miou,
        "acc_global": acc_global,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "dice": dice,
        "support": support
    }

## Compute Confusion Matrix

Run inference on entire dataset to build confusion matrix.

In [None]:
# ======================
# COMPUTE CONFUSION MATRIX
# ======================

cm_total = torch.zeros((NUM_CLASSES, NUM_CLASSES), dtype=torch.int64)

print(f"Computing confusion matrix on {SPLIT} split...")
with torch.no_grad():
    for batch_idx, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        logits = model(x)                           # (B, C, H, W)
        pred = torch.argmax(logits, dim=1)          # (B, H, W)

        cm_total += confusion_matrix(
            pred, y,
            num_classes=NUM_CLASSES,
            ignore_index=IGNORE_INDEX
        )
        
        if (batch_idx + 1) % 10 == 0:
            print(f"  Processed {(batch_idx + 1) * BATCH_SIZE} samples...")

print(f"✓ Confusion matrix computed: {cm_total.shape}")

## Per-Class Metrics Report

Display metrics for all classes in a table.

In [None]:
# ======================
# CALCULATE METRICS
# ======================

m = metrics_from_cm(cm_total)

# Build results dataframe
rows = []
for c in range(NUM_CLASSES):
    rows.append({
        "class": c,
        "name": CLASS_NAMES.get(c, f"class_{c}"),
        "support_pixels": int(m["support"][c].item()),
        "IoU": float(m["iou"][c].item()),
        "Dice": float(m["dice"][c].item()),
        "Precision": float(m["precision"][c].item()),
        "Recall": float(m["recall"][c].item()),
        "F1": float(m["f1"][c].item()),
    })

df = pd.DataFrame(rows)

# Print global metrics
print("======================")
print("GLOBAL METRICS")
print("======================")
print(f"Pixel Accuracy: {float(m['acc_global'].item()):.4f}")
print(f"Mean IoU (mIoU): {float(m['miou'].item()):.4f}")

print("\nPer-class metrics (sorted by support):")
print(df.sort_values("support_pixels", ascending=False).to_string(index=False))

## Filtered Metrics

Optionally analyze metrics for well-represented classes only.

In [None]:
# ======================
# FILTERED METRICS (exclude ignored classes)
# ======================

ignore_report = {0, 1}  # Classes to ignore in metrics

valid_classes = [c for c in range(NUM_CLASSES) if c not in ignore_report]
iou_vals = m["iou"][valid_classes]
support_vals = m["support"][valid_classes]

# Compute mIoU for valid classes only (those with pixels)
valid_mask = support_vals > 0
miou_no_ignored = iou_vals[valid_mask].mean() if valid_mask.any() else torch.tensor(0.0)

print("======================")
print("FILTERED METRICS")
print("======================")
print(f"mIoU (excluding classes {ignore_report}, only present classes):")
print(f"  {float(miou_no_ignored.item()):.4f}")

## Confusion Matrix Visualization

Display confusion matrices with different normalization views.

Confusion Matrix like slide

In [None]:
import matplotlib.pyplot as plt

def plot_confusion_matrix(cm, class_names, normalize=None, title="Confusion Matrix", max_classes=None):
    """
    Visualize confusion matrix with optional normalization.
    
    Parameters:
      - cm: (C, C) confusion matrix (torch.Tensor or np.ndarray)
      - class_names: List of class name strings
      - normalize: None | "recall" | "precision"
        - None: raw counts
        - "recall": normalize by ground truth totals (row-wise)
        - "precision": normalize by predicted totals (column-wise)
      - max_classes: Display top N classes by support (for readability)
    """
    if hasattr(cm, "cpu"):
        cm = cm.cpu().numpy()
    cm = cm.astype(np.float64)

    C = cm.shape[0]
    idx = list(range(C))
    
    # Filter to top classes by support if requested
    if max_classes is not None and max_classes < C:
        support = cm.sum(axis=1)
        idx = np.argsort(-support)[:max_classes].tolist()
        cm = cm[np.ix_(idx, idx)]
        class_names = [class_names[i] for i in idx]

    # Apply normalization
    if normalize == "recall":
        # Normalize by row (ground truth totals)
        denom = cm.sum(axis=1, keepdims=True)
        cm_show = np.divide(cm, denom, out=np.zeros_like(cm), where=denom != 0)
        fmt = "{:.2f}"
        subtitle = " (normalized by GT: recall view)"
    elif normalize == "precision":
        # Normalize by column (predicted totals)
        denom = cm.sum(axis=0, keepdims=True)
        cm_show = np.divide(cm, denom, out=np.zeros_like(cm), where=denom != 0)
        fmt = "{:.2f}"
        subtitle = " (normalized by Pred: precision view)"
    else:
        cm_show = cm
        fmt = "{:.0f}"
        subtitle = " (counts)"

    # Create figure
    fig_w = max(7, 0.7 * len(class_names))
    fig_h = max(6, 0.7 * len(class_names))
    plt.figure(figsize=(fig_w, fig_h))
    
    plt.imshow(cm_show, cmap="Blues")
    plt.title(title + subtitle, fontsize=12)
    plt.colorbar(fraction=0.046, pad=0.04)

    # Labels
    plt.xticks(range(len(class_names)), class_names, rotation=45, ha="right")
    plt.yticks(range(len(class_names)), class_names)
    plt.xlabel("Predicted")
    plt.ylabel("Ground Truth")

    # Annotate cells
    for i in range(cm_show.shape[0]):
        for j in range(cm_show.shape[1]):
            val = cm_show[i, j]
            text = fmt.format(val)
            plt.text(j, i, text, ha="center", va="center", fontsize=8)

    plt.tight_layout()
    plt.show()

## Confusion Matrix Plots

Three views of the confusion matrix: raw counts, recall-normalized, and precision-normalized.

In [None]:
# Build class name list in index order
class_names = [CLASS_NAMES.get(i, f"class_{i}") for i in range(NUM_CLASSES)]

# Plot three views
plot_confusion_matrix(cm_total, class_names, normalize=None, 
                     title=f"Confusion Matrix ({SPLIT})")
plot_confusion_matrix(cm_total, class_names, normalize="recall", 
                     title=f"Confusion Matrix ({SPLIT})")
plot_confusion_matrix(cm_total, class_names, normalize="precision", 
                     title=f"Confusion Matrix ({SPLIT})")

## Metrics Heatmap

Visualize per-class metrics (IoU, Dice, etc.) as a heatmap.

In [None]:
def plot_metrics_heatmap(df_metrics, class_col="name", 
                         metrics=("IoU", "Dice", "Precision", "Recall", "F1"),
                         title="Per-class metrics", sort_by="support_pixels"):
    """
    Visualize per-class metrics as a heatmap.
    
    Parameters:
      - df_metrics: DataFrame with class metrics
      - class_col: Column name for class labels
      - metrics: Tuple of metric columns to display
      - sort_by: Column to sort classes (default: support)
    """
    d = df_metrics.copy()
    if sort_by in d.columns:
        d = d.sort_values(sort_by, ascending=False)

    classes = d[class_col].tolist()
    M = d[list(metrics)].to_numpy(dtype=float)

    # Create heatmap
    fig_w = 1.2 * len(metrics) + 6
    fig_h = 0.45 * len(classes) + 2
    plt.figure(figsize=(fig_w, fig_h))
    
    plt.imshow(M, cmap="RdYlGn", aspect="auto", vmin=0, vmax=1)
    plt.title(title)
    plt.colorbar(fraction=0.046, pad=0.04)

    # Labels
    plt.xticks(range(len(metrics)), metrics, rotation=0)
    plt.yticks(range(len(classes)), classes)

    # Annotate cells
    for i in range(M.shape[0]):
        for j in range(M.shape[1]):
            plt.text(j, i, f"{M[i,j]:.2f}", ha="center", va="center", fontsize=9)

    plt.tight_layout()
    plt.show()

# Display metrics heatmap
plot_metrics_heatmap(df, title=f"Per-class metrics ({SPLIT})")

## Filtered Class Analysis

Analyze metrics for well-represented classes (with sufficient support).

In [None]:
# Filter classes by minimum pixel support
MIN_SUPPORT = 50_000  # Adjust this threshold as needed
df_filt = df[df["support_pixels"] >= MIN_SUPPORT]

print(f"\nClasses with support >= {MIN_SUPPORT} pixels:")
print(df_filt.sort_values("support_pixels", ascending=False).to_string(index=False))

# Visualize filtered metrics
if len(df_filt) > 0:
    plot_metrics_heatmap(df_filt, title=f"Per-class metrics (support ≥ {MIN_SUPPORT}, {SPLIT})")

## Save Results

Optionally save confusion matrix and metrics for later analysis.

In [None]:
# ======================
# SAVE RESULTS
# ======================

out_dir = ROOT / "outputs"
out_dir.mkdir(exist_ok=True)

# Save confusion matrix
cm_path = out_dir / f"confusion_matrix_{SPLIT}.pt"
torch.save(cm_total.cpu(), cm_path)
print(f"✓ Saved confusion matrix: {cm_path}")

# Save metrics report
csv_path = out_dir / f"metrics_{SPLIT}.csv"
df.to_csv(csv_path, index=False)
print(f"✓ Saved metrics report: {csv_path}")

## Load Previously Computed Results

Load confusion matrix from a previous run (if available).

In [None]:
# Example: Load a previously saved confusion matrix
# cm_loaded = torch.load("../outputs/confusion_matrix_test.pt")
print("✓ Cell ready to load saved confusion matrix (uncomment to use)")

## Epoch-by-Epoch Confusion Matrices

Load and analyze confusion matrices saved during training (if available).

In [None]:
# Example: Load confusion matrix from training epoch
# cm_epoch = torch.load("../outputs/val_cm_epoch_05.pt")
# print(f"Shape: {cm_epoch.shape}")

print("✓ Ready to load epoch confusion matrices (uncomment to use)")

In [None]:
# Example: Compute IoU metrics from loaded confusion matrix
# from metrics import iou_from_cm
# iou_per_class, miou = iou_from_cm(cm_epoch)
# print(f"mIoU: {float(miou):.4f}")
# for c, v in enumerate(iou_per_class):
#     print(f"  Class {c}: IoU = {float(v):.4f}")

print("✓ Ready to compute metrics from saved confusion matrix (uncomment to use)")

## Reference: Example Results

Example metrics from training runs (for reference):

**Training 2 (4 epochs):**
- mIoU: 0.4757
- Class 2 (artificial): 0.5818
- Class 3 (cultivated): 0.6628
- Class 4 (broadleaf): 0.7099
- Class 9 (water): 0.6942

**Training 1 (2 epochs):**
- mIoU: 0.4461
- Class 9 (water): 0.7682
- Class 4 (broadleaf): 0.6776

In [None]:
print("✓ Metrics visualization notebook complete")