<a href="https://colab.research.google.com/github/aalvarez359/3dunet-vs-medam-brain-tumor/blob/main/metrics/unet_4mod_viz_metrics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install nibabel segmentation-models-pytorch-3d torch torchvision torchaudio tqdm

Collecting segmentation-models-pytorch-3d
  Downloading segmentation_models_pytorch_3d-1.0.2-py3-none-any.whl.metadata (724 bytes)
Collecting pretrainedmodels==0.7.4 (from segmentation-models-pytorch-3d)
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting efficientnet-pytorch==0.7.1 (from segmentation-models-pytorch-3d)
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting timm==0.9.7 (from segmentation-models-pytorch-3d)
  Downloading timm-0.9.7-py3-none-any.whl.metadata (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting timm-3d==1.0.1 (from segmentation-models-pytorch-3d)
  Downloading timm_3d-1.0.1-py3-none-any.whl.metadata (588 bytes)
Col

New

In [None]:
#!/usr/bin/env python3
"""
Colab-ready script (no CLI):
- Set the CONFIG values below, then just "Run" the cell.
- Plots training curves from training_history.json in ckpt_dir.
- Optionally runs sliding-window inference on one case to save a slice PNG.
"""

# ======================
# ====== CONFIG ========
# ======================
CKPT_DIR   = "/content/drive/MyDrive/BrainTumor_Checkpoints"  # folder with training_history.json and *.pt
OUT_DIR    = "./plots"                                        # where to save PNGs

# Graphs (always produced)
MAKE_GRAPHS = True

# Slice viz (set to True only if you want the qualitative slice)
MAKE_SLICE  = True  # change to False if you only want the curves

# Which checkpoint to use for the slice (exact filename in CKPT_DIR)
CKPT_NAME   = "best_epoch65_dice0.7330.pt"  # or "last.pt"

# Paths to one case (only needed if MAKE_SLICE=True)
IMAGE_NII   = "/content/drive/MyDrive/Task01_BrainTumour_extracted/Task01_BrainTumour/imagesTr/BRATS_137.nii.gz"  # (H,W,D,4) for 4-mod MSD
LABEL_NII   = "/content/drive/MyDrive/Task01_BrainTumour_extracted/Task01_BrainTumour/labelsTr/BRATS_137.nii.gz"  # (H,W,D), optional

# Slice display options
SLICE_AXIS  = "H"     # "D", "H", or "W"
SLICE_INDEX = None     # None = auto-pick slice with max tumor; otherwise explicit index

# Sliding window overlap (for inference)
OVERLAP     = 0.75

# If your training used a different patch, this will be overridden by ckpt.config.patch_size if present.
PATCH_SIZE  = (128, 224, 224)

# Decoder channels used in your training model (kept slimmer for memory)
DECODER_CHANNELS = (192, 128, 64, 32, 16)

# ======================
# ===== END CONFIG =====
# ======================

import os, json
from typing import Any, Dict, List, Optional
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F

try:
    import nibabel as nib
except Exception:
    nib = None

try:
    import segmentation_models_pytorch_3d as smp3d
except Exception:
    smp3d = None


# ============================
# ===== HISTORY UTILITIES ====
# ============================

def _load_history_json(hist_path: str):
    with open(hist_path, "r") as f:
        data = json.load(f)
    if isinstance(data, dict) and "history" in data:
        return data["history"]
    return data

def _extract_from_list_of_dicts(hist_list: List[Dict[str, Any]]):
    epochs, mean_dice, per_class = [], [], []
    for i, e in enumerate(hist_list, start=1):
        ep = int(e.get("epoch", i))
        md = (
            e.get("val_meanDice")
            or e.get("val_meanDice(no-bg)")
            or e.get("val_meanDice_no_bg")
            or e.get("meanDice")
            or e.get("val_mean_dice")
        )
        pc = (
            e.get("per_class_dice")
            or e.get("val_per_class_dice")
            or e.get("dice_per_class")
            or e.get("val_dice_per_class")
            or e.get("per_class")
        )
        epochs.append(ep)
        mean_dice.append(float(md) if md is not None else np.nan)

        # normalize per-class to [bg, edema, non_enh, enh]
        if isinstance(pc, dict):
            order = ["bg", "edema", "non_enh", "enh"]
            if all(k in pc for k in order):
                per_class.append([float(pc[k]) for k in order])
            elif all(str(k) in pc for k in range(4)):
                per_class.append([float(pc[str(k)]) for k in range(4)])
            else:
                try:
                    items = sorted(pc.items(), key=lambda x: x[0])
                    per_class.append([float(v) for _, v in items[:4]])
                except Exception:
                    per_class.append([np.nan]*4)
        elif isinstance(pc, (list, tuple)) and len(pc) >= 3:
            if len(pc) >= 4:
                per_class.append([float(pc[0]), float(pc[1]), float(pc[2]), float(pc[3])])
            else:
                # assume [edema, non_enh, enh]
                per_class.append([np.nan, float(pc[0]), float(pc[1]), float(pc[2])])
        else:
            per_class.append([np.nan]*4)
    return np.array(epochs), np.array(mean_dice), np.array(per_class)

def _extract_from_dict_of_lists(hist_dict: Dict[str, List[Any]]):
    # matches your trainer keys: epoch, val_mean_dice, dice_edema, dice_non_enh, dice_enh
    ep    = np.array(hist_dict.get("epoch", []), dtype=int)
    md    = np.array(hist_dict.get("val_mean_dice", []), dtype=float)
    edema = np.array(hist_dict["dice_edema"],    dtype=float) if "dice_edema"    in hist_dict else None
    non_enh = np.array(hist_dict["dice_non_enh"], dtype=float) if "dice_non_enh" in hist_dict else None
    enh   = np.array(hist_dict["dice_enh"],      dtype=float) if "dice_enh"      in hist_dict else None

    lengths = [
        len(ep),
        len(md),
        len(edema)   if edema   is not None else 0,
        len(non_enh) if non_enh is not None else 0,
        len(enh)     if enh     is not None else 0,
    ]
    max_len = max(lengths) if lengths else 0

    def pad(a):
        if a is None:
            return np.full((max_len,), np.nan, dtype=float)
        a = np.asarray(a, dtype=float)
        if len(a) < max_len:
            out = np.full((max_len,), np.nan, dtype=float)
            out[:len(a)] = a
            return out
        return a

    ep    = pad(ep)
    md    = pad(md)
    edema = pad(edema)
    non_enh = pad(non_enh)
    enh   = pad(enh)

    per_class = np.stack([np.full_like(md, np.nan), edema, non_enh, enh], axis=1)
    return ep, md, per_class

def load_and_extract_history(ckpt_dir: str):
    hist_path = os.path.join(ckpt_dir, "training_history.json")
    if not os.path.exists(hist_path):
        raise FileNotFoundError(f"training_history.json not found in {ckpt_dir}")
    data = _load_history_json(hist_path)
    if isinstance(data, list):
        return _extract_from_list_of_dicts(data)
    elif isinstance(data, dict):
        return _extract_from_dict_of_lists(data)
    else:
        raise ValueError("Unsupported history JSON format.")

def plot_mean_dice(epochs, mean_dice, out_path):
    plt.figure()
    plt.plot(epochs, mean_dice, linewidth=2)
    plt.xlabel("Epoch")
    plt.ylabel("Validation Mean Dice (no-bg)")
    plt.title("Mean Dice vs Epoch")
    plt.grid(True, linestyle="--", alpha=0.4)
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.savefig(out_path, bbox_inches="tight", dpi=150)
    plt.close()

def plot_per_class(epochs, per_class, out_path):
    if per_class.ndim != 2 or per_class.shape[1] < 3:
        print("Per-class dice not found; skipping per-class plot.")
        return
    labels = ["bg","edema","non-enh","enh"]
    plt.figure()
    for i in range(min(4, per_class.shape[1])):
        plt.plot(epochs, per_class[:, i], linewidth=2, label=labels[i])
    plt.xlabel("Epoch")
    plt.ylabel("Dice per Class")
    plt.title("Per-class Dice vs Epoch")
    plt.grid(True, linestyle="--", alpha=0.4)
    plt.legend()
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.savefig(out_path, bbox_inches="tight", dpi=150)
    plt.close()


# ============================
# ===== INFERENCE HELPERS ====
# ============================

def _to_dhw(a: np.ndarray) -> np.ndarray:
    """Ensure volume is (D,H,W). Handles (H,W,D) or (D,H,W)."""
    if a.ndim != 3:
        raise ValueError("Expected 3D volume")
    if a.shape[-1] < a.shape[0] and a.shape[-1] < a.shape[1]:
        a = np.moveaxis(a, -1, 0)  # (D,H,W)
    return a

@torch.no_grad()
def sliding_window_inference(volume_cdhw: np.ndarray,
                             model: torch.nn.Module,
                             window: tuple,
                             overlap: float,
                             device: str = "cuda"):
    """
    Generic sliding-window inference for (C,D,H,W) volumes.
    Uses model.n_classes if present, else 4.
    """
    model.eval()
    C, D, H, W = volume_cdhw.shape
    wd, wh, ww = window
    sd = max(1, int(wd * (1 - overlap)))
    sh = max(1, int(wh * (1 - overlap)))
    sw = max(1, int(ww * (1 - overlap)))

    n_classes = int(getattr(model, "n_classes", 4))
    out_prob = torch.zeros((n_classes, D, H, W), dtype=torch.float32, device=device)
    out_norm = torch.zeros((1, D, H, W), dtype=torch.float32, device=device)

    for z in range(0, max(D - wd + 1, 1), sd):
        z0 = min(z, D - wd)
        for y in range(0, max(H - wh + 1, 1), sh):
            y0 = min(y, H - wh)
            for x in range(0, max(W - ww + 1, 1), sw):
                x0 = min(x, W - ww)
                patch = volume_cdhw[:, z0:z0+wd, y0:y0+wh, x0:x0+ww]
                pt = torch.from_numpy(patch).unsqueeze(0).to(device)  # 1,C,D,H,W
                logits = model(pt)
                probs = torch.softmax(logits, dim=1)[0]
                out_prob[:, z0:z0+wd, y0:y0+wh, x0:x0+ww] += probs
                out_norm[:, z0:z0+wd, y0:y0+wh, x0:x0+ww] += 1.0

    out_prob /= (out_norm + 1e-8)
    pred = torch.argmax(out_prob, dim=0)  # D H W
    return pred.cpu().numpy()

def _build_model_from_ckpt(ckpt: Dict[str, Any], device: str = "cuda"):
    """
    Rebuild the SMP 3D U-Net from checkpoint config (if present),
    otherwise fall back to 4-mod MSD defaults.
    """
    if smp3d is None:
        raise SystemExit("segmentation-models-pytorch-3d not installed. Run: pip install segmentation-models-pytorch-3d")

    cfg = ckpt.get("config", {}) if isinstance(ckpt, dict) else {}
    encoder     = cfg.get("encoder", "efficientnet-b7")
    in_channels = int(cfg.get("in_channels", 4))
    classes     = int(cfg.get("classes", 4))
    patch_size  = tuple(cfg.get("patch_size", PATCH_SIZE))

    model = smp3d.Unet(
        encoder_name=encoder,
        encoder_weights=None,
        in_channels=in_channels,
        classes=classes,
        decoder_channels=DECODER_CHANNELS,
    )
    model.to(device)

    # Attach helpers
    model.n_classes = classes
    model.patch_size = patch_size
    model.in_channels_manual = in_channels
    return model

def _load_image_chwd(image_nii_path: str, in_channels: int) -> np.ndarray:
    """
    Load a NIfTI volume and return (C,D,H,W) with per-channel z-score normalization.
    Supports:
      - 4-mod MSD/BRATS: (H,W,D,4) when in_channels == 4
      - single-modality: (H,W,D)  when in_channels == 1
    """
    if nib is None:
        raise SystemExit("nibabel not installed. Run: pip install nibabel")

    img = nib.load(image_nii_path).get_fdata(dtype=np.float32)

    # 4-mod case: (H,W,D,4)
    if img.ndim == 4 and img.shape[-1] == in_channels:
        chans = []
        for c in range(in_channels):
            v = img[..., c]
            mask = (v != 0)
            m = v[mask].mean() if mask.any() else v.mean()
            s = v[mask].std()  if mask.any() else v.std()
            s = s if s > 1e-8 else 1.0
            chans.append(((v - m) / s).astype(np.float32))
        vol = np.stack(chans, axis=3)        # H W D C
        vol = np.transpose(vol, (3, 2, 0, 1))# C D H W
        return vol

    # single-modality: (H,W,D)
    if img.ndim == 3 and in_channels == 1:
        v = img
        mask = (v != 0)
        m = v[mask].mean() if mask.any() else v.mean()
        s = v[mask].std()  if mask.any() else v.std()
        s = s if s > 1e-8 else 1.0
        v_norm = ((v - m) / s).astype(np.float32)
        vol = v_norm[None, ...]              # (1,H,W,D)
        vol = np.transpose(vol, (0, 3, 1, 2))# C D H W
        return vol

    raise ValueError(
        f"Unexpected image shape {img.shape} for in_channels={in_channels}. "
        f"Expected (H,W,D,{in_channels}) or (H,W,D) if in_channels==1."
    )

def _best_slice_index(mask_dhw: np.ndarray, axis: int) -> int:
    """
    Choose the slice index with the largest number of True voxels along the given axis.
    If mask is empty, returns the middle slice.
    """
    if mask_dhw is None or mask_dhw.ndim != 3:
        return 0
    if axis == 0:   # D
        counts = mask_dhw.reshape(mask_dhw.shape[0], -1).sum(axis=1)
    elif axis == 1: # H
        counts = mask_dhw.transpose(1,0,2).reshape(mask_dhw.shape[1], -1).sum(axis=1)
    else:           # W
        counts = mask_dhw.transpose(2,0,1).reshape(mask_dhw.shape[2], -1).sum(axis=1)
    if counts.max() == 0:
        L = mask_dhw.shape[axis]
        return L // 2
    return int(np.argmax(counts))

def _overlay_multiclass(ax, base2d: np.ndarray, label2d: Optional[np.ndarray], title: str):
    """
    Show base2d in grayscale, then overlay tumor classes:
      1 (edema) -> yellow, 2 (non-enh) -> green, 3 (enh) -> red.
    """
    img = base2d.astype(np.float32)
    if np.isfinite(img).all():
        p2, p98 = np.percentile(img, [2, 98])
        if p98 > p2:
            ax.imshow(img, cmap="gray", vmin=p2, vmax=p98)
        else:
            ax.imshow(img, cmap="gray")
    else:
        ax.imshow(img, cmap="gray")

    if label2d is not None:
        H, W = label2d.shape
        class_colors = {
            1: (1.0, 1.0, 0.0, 0.40),  # edema: yellow
            2: (0.0, 1.0, 0.0, 0.40),  # non-enh: green
            3: (1.0, 0.0, 0.0, 0.40),  # enh: red
        }
        for cls, (r, g, b, a) in class_colors.items():
            mask = (label2d == cls)
            if np.any(mask):
                overlay = np.zeros((H, W, 4), dtype=np.float32)
                overlay[..., 0] = r
                overlay[..., 1] = g
                overlay[..., 2] = b
                overlay[..., 3] = a * mask.astype(np.float32)
                ax.imshow(overlay, interpolation="none")

    ax.set_title(title)
    ax.axis("off")

def visualize_slice_png(pred_dhw: np.ndarray,
                        label_dhw: Optional[np.ndarray],
                        image_dhw: Optional[np.ndarray],
                        slice_axis: str,
                        slice_index: Optional[int],
                        out_path: str):
    """
    - If slice_index is None, auto-pick slice with most tumor (label if available, else prediction).
    - Shows:
       [0] image (gray)
       [1] label overlay (if provided)
       [2] prediction overlay
    """
    axis_map = {"D": 0, "H": 1, "W": 2}
    ax = axis_map.get(slice_axis.upper(), 1)

    # Choose slice index
    if slice_index is None:
        if label_dhw is not None:
            idx = _best_slice_index((label_dhw > 0), ax)
        else:
            idx = _best_slice_index((pred_dhw > 0), ax)
    else:
        idx = int(slice_index)

    def slc(a):
        if a is None:
            return None
        if ax == 0:
            return a[idx, :, :]
        elif ax == 1:
            return a[:, idx, :]
        else:
            return a[:, :, idx]

    img2d  = slc(image_dhw)
    lab2d  = slc(label_dhw)
    pred2d = slc(pred_dhw)

    # If no image channel provided, synthesize zeros
    if img2d is None:
        base2d = np.zeros_like(lab2d if lab2d is not None else pred2d, dtype=np.float32)
    else:
        base2d = img2d

    fig, axs = plt.subplots(1, 3, figsize=(14, 4))

    # Image
    axs[0].imshow(base2d, cmap="gray")
    axs[0].set_title("Image")
    axs[0].axis("off")

    # Label overlay
    _overlay_multiclass(axs[1], base2d, lab2d, "Label (overlay)")

    # Prediction overlay
    _overlay_multiclass(axs[2], base2d, pred2d, "Prediction (overlay)")

    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
    plt.savefig(out_path, bbox_inches="tight", dpi=150)
    plt.close(fig)
    print("Saved:", out_path)


# ============================
# ========== MAIN ============
# ============================

if __name__ == "__main__":
    os.makedirs(OUT_DIR, exist_ok=True)

    # 1) Graphs
    if MAKE_GRAPHS:
        epochs, mean_dice, per_class = load_and_extract_history(CKPT_DIR)
        out_mean = os.path.join(OUT_DIR, "mean_dice_vs_epoch.png")
        out_pcls = os.path.join(OUT_DIR, "per_class_dice_vs_epoch.png")
        plot_mean_dice(epochs, mean_dice, out_mean)
        plot_per_class(epochs, per_class, out_pcls)
        print("Saved:", out_mean)
        print("Saved:", out_pcls)

    # 2) One qualitative slice (optional)
    if MAKE_SLICE:
        ckpt_path = os.path.join(CKPT_DIR, CKPT_NAME)
        if not os.path.exists(ckpt_path):
            raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

        device = "cuda" if torch.cuda.is_available() else "cpu"

        # Load checkpoint & rebuild model from its config
        ckpt = torch.load(ckpt_path, map_location="cpu")
        model = _build_model_from_ckpt(ckpt, device=device)

        # Load weights
        state = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
        model.load_state_dict(state, strict=True)

        # Prepare image volume (C,D,H,W)
        in_channels = int(getattr(model, "in_channels_manual", 4))
        vol_cdhw = _load_image_chwd(IMAGE_NII, in_channels=in_channels)

        # Patch size: prefer what was saved in config, else global PATCH_SIZE
        patch_size = tuple(getattr(model, "patch_size", PATCH_SIZE))

        pred_dhw = sliding_window_inference(vol_cdhw, model, patch_size, OVERLAP, device=device)

        # Optional label and an image channel for context
        lab_dhw = None
        img_dhw = None

        if LABEL_NII and os.path.exists(LABEL_NII) and nib is not None:
            lab = nib.load(LABEL_NII).get_fdata().astype(np.int64)  # (H,W,D)
            lab_dhw = _to_dhw(lab)

        if IMAGE_NII and os.path.exists(IMAGE_NII) and nib is not None:
            raw = nib.load(IMAGE_NII).get_fdata()
            if raw.ndim == 4:
                raw = raw[..., 0]  # show channel 0
            img_dhw = _to_dhw(raw)

        out_slice = os.path.join(OUT_DIR, "qualitative_slice.png")
        visualize_slice_png(pred_dhw, lab_dhw, img_dhw, SLICE_AXIS, SLICE_INDEX, out_slice)

    print("Done.")


Saved: ./plots/mean_dice_vs_epoch.png
Saved: ./plots/per_class_dice_vs_epoch.png
Saved: ./plots/qualitative_slice.png
Done.
