# 03 · Uncertainty Visualisation — Monte Carlo Dropout

**Project:** SAM2 Lung Nodule Segmentation  
**Date:** April 2025 (Phase 4)

This notebook visualises the MC Dropout uncertainty maps and demonstrates their clinical value:

1. MC Dropout inference on a CT slice
2. Prediction spread across N stochastic passes
3. Uncertainty heatmap (variance) overlay
4. Uncertainty vs. prediction error correlation
5. Reliability diagram (calibration)
6. Threshold sensitivity analysis

In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path("..").resolve()
sys.path.insert(0, str(PROJECT_ROOT))

import warnings

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.colors import Normalize

warnings.filterwarnings("ignore")

from data.dataset import build_dataset
from evaluation.uncertainty_calibration import (
    CalibrationAnalyzer,
    brier_score,
    entropy_auc,
    expected_calibration_error,
    reliability_diagram,
)
from models.mc_dropout import (
    compute_uncertainty_stats,
    entropy_from_samples,
    mc_dropout_mode,
    mc_predict,
)
from models.registry import get_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

## 1 · Load Model and Data

In [None]:
# Load model (random weights if no checkpoint)
CKPT = PROJECT_ROOT / "runs" / "sam2_lung_seg_v1" / "checkpoints" / "best_model.pt"

model = get_model(
    "sam2_lung_seg",
    embed_dim=256,
    num_heads=8,
    attn_dropout=0.10,
    proj_dropout=0.10,
    encoder_frozen=False,
)

if CKPT.exists():
    ckpt = torch.load(CKPT, map_location=device)
    model.load_state_dict(ckpt["model_state_dict"])
    epoch = ckpt.get("epoch", "?")
    vdice = ckpt.get("metrics", {}).get("val_dice", "?")
    print(f"Loaded checkpoint: epoch={epoch}  val_dice={vdice}")
else:
    print(
        "No checkpoint found — using random weights (uncertainty patterns will be noise)"
    )

model = model.to(device).eval()

# Load a single slice
ds = build_dataset("SYNTHETIC", split="test", mode="slice", augment=False)
sample = ds[0]
img_t = sample["image"].unsqueeze(0).to(device)  # (1, 1, H, W)
msk_t = sample["mask"].unsqueeze(0).to(device)  # (1, 1, H, W)

H, W = img_t.shape[-2], img_t.shape[-1]
print(f"Image shape: {tuple(img_t.shape)}  |  Mask pixels: {msk_t.sum().item():.0f}")

## 2 · Prediction Spread across MC Samples

We run N=20 stochastic forward passes and visualise the spread of predictions.
High variance at slice boundaries indicates model uncertainty.

In [None]:
N_SAMPLES = 20

all_preds = []
with mc_dropout_mode(model):
    for _ in range(N_SAMPLES):
        with torch.no_grad():
            logit = model(img_t)
        all_preds.append(torch.sigmoid(logit).cpu().squeeze().numpy())

all_preds_arr = np.stack(all_preds, axis=0)  # (N, H, W)
mean_pred = all_preds_arr.mean(axis=0)
std_pred = all_preds_arr.std(axis=0)

img_np = img_t.cpu().squeeze().numpy()
msk_np = msk_t.cpu().squeeze().numpy()

# Show 6 individual predictions
fig, axes = plt.subplots(2, 6, figsize=(18, 6))
fig.suptitle(
    f"MC Dropout — Individual Predictions (N={N_SAMPLES}, showing 6)", fontweight="bold"
)

show_idx = np.linspace(0, N_SAMPLES - 1, 6, dtype=int)
for col, idx in enumerate(show_idx):
    axes[0, col].imshow(img_np, cmap="gray")
    axes[0, col].imshow(all_preds_arr[idx], cmap="hot", alpha=0.6, vmin=0, vmax=1)
    axes[0, col].set_title(f"Pass #{idx+1}", fontsize=9)
    axes[0, col].axis("off")

    axes[1, col].imshow(all_preds_arr[idx], cmap="gray", vmin=0, vmax=1)
    axes[1, col].set_title(f"P={all_preds_arr[idx].mean():.3f}", fontsize=9)
    axes[1, col].axis("off")

axes[0, 0].set_ylabel("Overlay", fontsize=9)
axes[1, 0].set_ylabel("Prob map", fontsize=9)
plt.tight_layout()
plt.savefig("mc_individual_passes.png", dpi=120, bbox_inches="tight")
plt.show()

## 3 · Mean Prediction and Uncertainty Heatmap

In [None]:
# Also get official mc_predict values
mean_t, var_t = mc_predict(model, img_t, n_samples=25, mc_batch_size=5, sigmoid=True)
mean_np = mean_t.cpu().squeeze().numpy()
var_np = var_t.cpu().squeeze().numpy()
entropy_np = entropy_from_samples(all_preds_arr)  # shape (H, W)

fig, axes = plt.subplots(1, 5, figsize=(20, 4))
fig.suptitle(
    "MC Dropout Uncertainty Decomposition (N=25)", fontsize=13, fontweight="bold"
)

titles = [
    "CT Slice",
    "Ground Truth",
    "Mean Prediction",
    "Variance (Uncertainty)",
    "Entropy",
]
maps = [img_np, msk_np, mean_np, var_np, entropy_np]
cmaps = ["gray", "gray", "gray", "hot", "plasma"]

for ax, title, data, cmap in zip(axes, titles, maps, cmaps):
    im = ax.imshow(data, cmap=cmap, vmin=0, vmax=(1 if cmap == "gray" else None))
    ax.set_title(title, fontsize=10, fontweight="bold")
    ax.axis("off")
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.tight_layout()
plt.savefig("uncertainty_heatmap.png", dpi=150, bbox_inches="tight")
plt.show()

unc_stats = compute_uncertainty_stats(var_t, (mean_t >= 0.5).float())
print("Uncertainty Statistics")
for k, v in unc_stats.items():
    print(f"  {k:<25}: {v:.6f}")

## 4 · Uncertainty vs. Prediction Error Correlation

A well-calibrated model should concentrate high uncertainty near prediction mistakes.

In [None]:
binary_pred = (mean_np >= 0.5).astype(float)
errors = (binary_pred != msk_np).astype(float)  # 1 where wrong

# Scatter: uncertainty per pixel vs. error
unc_flat = var_np.flatten()
err_flat = errors.flatten()

# Sub-sample for speed
idx_sub = np.random.choice(len(unc_flat), size=min(5000, len(unc_flat)), replace=False)
unc_sub = unc_flat[idx_sub]
err_sub = err_flat[idx_sub]

fig, axes = plt.subplots(1, 3, figsize=(15, 4))
fig.suptitle("Uncertainty vs. Error Analysis", fontsize=13, fontweight="bold")

# Box plot
axes[0].boxplot(
    [unc_flat[err_flat == 0], unc_flat[err_flat == 1]],
    labels=["Correct", "Error"],
    patch_artist=True,
    boxprops=dict(facecolor="#4C72B0", alpha=0.7),
)
axes[0].set_ylabel("MC Variance")
axes[0].set_title("Uncertainty by Correctness")
axes[0].grid(True, axis="y", alpha=0.3)

# Histogram split
axes[1].hist(
    unc_flat[err_flat == 0],
    bins=50,
    alpha=0.7,
    density=True,
    label="Correct",
    color="#4C72B0",
)
axes[1].hist(
    unc_flat[err_flat == 1],
    bins=50,
    alpha=0.7,
    density=True,
    label="Error",
    color="#C44E52",
)
axes[1].set_xlabel("MC Variance")
axes[1].set_ylabel("Density")
axes[1].set_title("Variance Distribution per Class")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Error map vs uncertainty map side by side
axes[2].imshow(errors, cmap="Reds", vmin=0, vmax=1, alpha=0.8)
unc_norm = (var_np - var_np.min()) / (var_np.max() - var_np.min() + 1e-9)
axes[2].contour(unc_norm, levels=[0.5], colors=["blue"], linewidths=1.5)
axes[2].set_title("Errors (red) + Unc contour (blue)", fontsize=9)
axes[2].axis("off")

plt.tight_layout()
plt.savefig("uncertainty_error_correlation.png", dpi=120, bbox_inches="tight")
plt.show()

# AUROC of uncertainty as error detector
auc = entropy_auc(unc_flat, err_flat)
print(f"Uncertainty AUROC (error detection): {auc:.4f}")
print(f"(0.5=random, 1.0=perfect — target ≥ 0.70)")

## 5 · Calibration — Reliability Diagram

ECE measures how well the model's confidence matches its actual accuracy per bin.

In [None]:
probs_flat = mean_np.flatten()
targets_flat = msk_np.flatten()
correct_flat = ((probs_flat >= 0.5) == (targets_flat > 0.5)).astype(float)

ece, bin_accs, bin_confs, bin_freqs = expected_calibration_error(
    probs_flat, correct_flat, n_bins=15
)
bs = brier_score(probs_flat, targets_flat)

print(f"Expected Calibration Error (ECE): {ece:.4f}")
print(f"Brier Score                     : {bs:.4f}")
print(f"(Target ECE < 0.03 for clinical deployment)")

reliability_diagram(
    bin_accs,
    bin_confs,
    bin_freqs,
    ece,
    title="SAM2 Lung Nodule — Calibration (single slice)",
    save_path="reliability_diagram.png",
)
plt.show()

## 6 · Threshold Sensitivity Analysis

Different probability thresholds yield different Dice / precision / recall trade-offs.

In [None]:
from evaluation.dice_metric import compute_dice, compute_precision_recall

thresholds = np.linspace(0.1, 0.9, 40)
dices, precs, recs = [], [], []

for t in thresholds:
    d = compute_dice(mean_t, msk_t, threshold=float(t)).item()
    p, r = compute_precision_recall(mean_t, msk_t, threshold=float(t))
    dices.append(d)
    precs.append(p.item())
    recs.append(r.item())

best_t = thresholds[np.argmax(dices)]

fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(thresholds, dices, "#4C72B0", lw=2, label="Dice")
ax.plot(thresholds, precs, "#55A868", lw=2, ls="--", label="Precision")
ax.plot(thresholds, recs, "#C44E52", lw=2, ls=":", label="Recall")
ax.axvline(0.5, color="gray", ls=":", lw=1, label="Default (0.5)")
ax.axvline(best_t, color="#DD8452", ls="--", lw=2, label=f"Best Dice t={best_t:.2f}")
ax.set_xlabel("Threshold")
ax.set_ylabel("Score")
ax.set_title("Threshold Sensitivity Analysis")
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)
ax.set_xlim(0.1, 0.9)
plt.tight_layout()
plt.savefig("threshold_sensitivity.png", dpi=120, bbox_inches="tight")
plt.show()

print(f"Best threshold : {best_t:.2f}  →  Dice={max(dices):.4f}")
print(f"Default (0.50) →  Dice={dices[np.argmin(np.abs(thresholds-0.5))]:.4f}")