In [4]:
from __future__ import annotations

import os
from pathlib import Path

import numpy as np
import torch
import matplotlib.pyplot as plt

import sys
sys.path.append('/home/pirie03/projects/aip-medilab/pirie03/ProstateMicroSeg')

from src.models.monai_unet_2d import build_monai_unet_2d
from src.data.dataset_cases import MicroUSCaseDataset
from src.data.transforms_2d import center_crop_or_pad_2d

In [None]:
RUN_DIR = Path("/home/pirie03/projects/aip-medilab/pirie03/ProstateMicroSeg/runs/20260216_173816")
CKPT_PATH = RUN_DIR / "checkpoint_best.pt"

DATA_ROOT = Path("/home/pirie03/projects/aip-medilab/pirie03/ProstateMicroSeg/dataset/processed/Dataset120_MicroUSProstate")
SPLITS_DIR = Path("/home/pirie03/projects/aip-medilab/pirie03/ProstateMicroSeg/dataset/splits")
CASE_STATS_PATH = DATA_ROOT / "case_stats.json"

# Match your 2D training target size
TARGET_HW = (896, 1408)

# IMPORTANT:
# If you trained 2D with transpose_hw=True, set this True.
# If you trained with transpose_hw=False, set this False.
TRANSPOSE_HW = True

# How many slices to save: center +/- K
K_AROUND_CENTER = 1  # saves 3 slices total (center-1, center, center+1)


def _overlay(img: np.ndarray, mask: np.ndarray) -> np.ndarray:
    """Simple overlay (no fancy colors): brighten masked region."""
    img01 = (img - img.min()) / (img.max() - img.min() + 1e-8)
    out = img01.copy()
    out[mask > 0.5] = np.clip(out[mask > 0.5] * 0.6 + 0.4, 0, 1)
    return out


@torch.no_grad()
def main() -> None:
    if not CKPT_PATH.exists():
        raise FileNotFoundError(f"Missing checkpoint: {CKPT_PATH}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    ckpt = torch.load(CKPT_PATH, map_location=device)

    # Try to auto-detect model variant from checkpoint extra
    variant = "base"
    extra = ckpt.get("extra", {})
    if isinstance(extra, dict):
        args = extra.get("args", {})
        if isinstance(args, dict) and "model_variant" in args:
            variant = args["model_variant"]

    print("Model variant:", variant)

    model, _meta = build_monai_unet_2d(in_channels=1, out_channels=1, variant=variant)
    model.load_state_dict(ckpt["model_state_dict"])
    model = model.to(device)
    model.eval()

    # Load one validation case (full volume)
    val_ds = MicroUSCaseDataset(
        dataset_root=str(DATA_ROOT),
        splits_dir=str(SPLITS_DIR),
        split="val",
        use_case_stats=True,
        case_stats_path=str(CASE_STATS_PATH),
    )

    sample = val_ds[0]
    case_id = sample["case_id"]
    img3d = sample["image"]  # [1, Z, Y, X]
    lbl3d = sample["label"]  # [1, Z, Y, X]

    # convert to numpy for per-slice operations
    img3d_np = img3d.squeeze(0).numpy()  # [Z, Y, X]
    lbl3d_np = lbl3d.squeeze(0).numpy()  # [Z, Y, X]

    Z = img3d_np.shape[0]
    center = Z // 2

    out_dir = RUN_DIR / "viz_2d"
    out_dir.mkdir(parents=True, exist_ok=True)

    print(f"Case: {case_id} | volume shape: {tuple(img3d_np.shape)} | saving to: {out_dir}")

    # choose slice indices to visualize
    slice_ids = [s for s in range(center - K_AROUND_CENTER, center + K_AROUND_CENTER + 1) if 0 <= s < Z]

    for s in slice_ids:
        img2d = img3d_np[s]  # [Y, X]
        lbl2d = lbl3d_np[s]  # [Y, X]

        if TRANSPOSE_HW:
            img2d = img2d.T
            lbl2d = lbl2d.T

        # match 2D validation preprocessing: deterministic center crop/pad to TARGET_HW
        img2d = center_crop_or_pad_2d(img2d, TARGET_HW, pad_value=0.0)
        lbl2d = center_crop_or_pad_2d(lbl2d, TARGET_HW, pad_value=0.0)

        # to torch: [1,1,H,W]
        x = torch.from_numpy(img2d).unsqueeze(0).unsqueeze(0).float().to(device)
        y = torch.from_numpy(lbl2d).unsqueeze(0).unsqueeze(0).float().to(device)

        logits = model(x)  # [1,1,H,W]
        prob = torch.sigmoid(logits)[0, 0].detach().cpu().numpy()
        pred = (prob > 0.5).astype(np.float32)

        # Save images
        base = out_dir / f"{case_id}_slice{s:03d}"

        # image
        plt.figure()
        plt.imshow(img2d, cmap="gray")
        plt.axis("off")
        plt.title(f"{case_id} slice {s} image")
        plt.savefig(str(base) + "_img.png", bbox_inches="tight", pad_inches=0)
        plt.close()

        # label
        plt.figure()
        plt.imshow(lbl2d, cmap="gray")
        plt.axis("off")
        plt.title(f"{case_id} slice {s} label")
        plt.savefig(str(base) + "_lbl.png", bbox_inches="tight", pad_inches=0)
        plt.close()

        # pred
        plt.figure()
        plt.imshow(pred, cmap="gray")
        plt.axis("off")
        plt.title(f"{case_id} slice {s} pred@0.5")
        plt.savefig(str(base) + "_pred.png", bbox_inches="tight", pad_inches=0)
        plt.close()

        # overlay
        ov = _overlay(img2d.astype(np.float32), pred)
        plt.figure()
        plt.imshow(ov, cmap="gray")
        plt.axis("off")
        plt.title(f"{case_id} slice {s} overlay")
        plt.savefig(str(base) + "_overlay.png", bbox_inches="tight", pad_inches=0)
        plt.close()

        print(f"Saved: {base}_*.png", flush=True)

    print("Done.", flush=True)


if __name__ == "__main__":
    main()


In [None]:
import sys
sys.path.append("/home/pirie03/projects/aip-medilab/pirie03/ProstateMicroSeg")

from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt

from src.models.monai_unet_2d import build_monai_unet_2d
from src.data.dataset_cases import MicroUSCaseDataset
from src.data.transforms_2d import center_crop_or_pad_2d  # adjust import if your path differs

# -------------------------
# EDIT THESE
# -------------------------
RUN_DIR = Path("/home/pirie03/projects/aip-medilab/pirie03/ProstateMicroSeg/runs/20260216_173816")
CKPT_PATH = RUN_DIR / "checkpoint_best.pt"

DATA_ROOT = Path("/home/pirie03/projects/aip-medilab/pirie03/ProstateMicroSeg/dataset/processed/Dataset120_MicroUSProstate")
SPLITS_DIR = Path("/home/pirie03/projects/aip-medilab/pirie03/ProstateMicroSeg/dataset/splits")
CASE_STATS_PATH = DATA_ROOT / "case_stats.json"

# Match your 2D training target size
TARGET_HW = (896, 1408)

# If you trained 2D with transpose_hw=True, set this True.
TRANSPOSE_HW = False

# Picks 3 random slices + 1 center slice
N_RANDOM_SLICES = 3
SEED = 0

# Threshold for binarizing prediction
THR = 0.5


def _overlay(img: np.ndarray, mask: np.ndarray, alpha: float = 0.5) -> np.ndarray:
    """
    Overlay mask in yellow on grayscale image.
    """
    # Normalize image to [0,1]
    img01 = (img - img.min()) / (img.max() - img.min() + 1e-8)

    # Convert to RGB
    rgb = np.stack([img01, img01, img01], axis=-1)  # [H,W,3]

    # Yellow color
    yellow = np.array([1.0, 1.0, 0.0])

    # Apply overlay where mask > 0.5
    mask_bool = mask > 0.5
    rgb[mask_bool] = (
        (1 - alpha) * rgb[mask_bool] + alpha * yellow
    )

    return rgb


def _dice_binary(pred: np.ndarray, gt: np.ndarray, eps: float = 1e-8) -> float:
    """Dice for binary 2D masks (0/1)."""
    pred = (pred > 0.5).astype(np.uint8)
    gt = (gt > 0.5).astype(np.uint8)
    inter = int((pred & gt).sum())
    denom = int(pred.sum()) + int(gt.sum())
    return float((2.0 * inter + eps) / (denom + eps))


@torch.no_grad()
def main() -> None:
    if not CKPT_PATH.exists():
        raise FileNotFoundError(f"Missing checkpoint: {CKPT_PATH}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    ckpt = torch.load(CKPT_PATH, map_location=device)

    # Try to auto-detect model variant from checkpoint extra
    variant = "base"
    extra = ckpt.get("extra", {})
    if isinstance(extra, dict):
        args = extra.get("args", {})
        if isinstance(args, dict) and "model_variant" in args:
            variant = args["model_variant"]

    print("Model variant:", variant)

    model, _meta = build_monai_unet_2d(in_channels=1, out_channels=1, variant=variant)
    model.load_state_dict(ckpt["model_state_dict"])
    model = model.to(device)
    model.eval()

    # Load one validation case (full volume)
    val_ds = MicroUSCaseDataset(
        dataset_root=str(DATA_ROOT),
        splits_dir=str(SPLITS_DIR),
        split="val",
        use_case_stats=True,
        case_stats_path=str(CASE_STATS_PATH),
    )

    sample = val_ds[0]
    case_id = sample["case_id"]
    img3d = sample["image"]  # [1, Z, Y, X]
    lbl3d = sample["label"]  # [1, Z, Y, X]

    img3d_np = img3d.squeeze(0).numpy()  # [Z, Y, X]
    lbl3d_np = lbl3d.squeeze(0).numpy()  # [Z, Y, X]

    Z = img3d_np.shape[0]
    center = Z // 2

    out_dir = RUN_DIR / "viz_2d"
    out_dir.mkdir(parents=True, exist_ok=True)

    print(f"Case: {case_id} | volume shape: {tuple(img3d_np.shape)} | saving to: {out_dir}")

    # Choose slice indices: center + 3 random (unique, try to avoid center)
    rng = np.random.default_rng(SEED)
    all_ids = np.arange(Z)
    non_center = all_ids[all_ids != center]

    if len(non_center) >= N_RANDOM_SLICES:
        random_ids = rng.choice(non_center, size=N_RANDOM_SLICES, replace=False).tolist()
    else:
        # edge case: very short volume, just take what exists
        random_ids = non_center.tolist()

    slice_ids = sorted(set([center] + random_ids))
    print("Slices:", slice_ids)

    for s in slice_ids:
        img2d = img3d_np[s]  # [Y, X]
        lbl2d = lbl3d_np[s]  # [Y, X]

        if TRANSPOSE_HW:
            img2d = img2d.T
            lbl2d = lbl2d.T

        # Match 2D validation preprocessing: deterministic center crop/pad to TARGET_HW
        img2d = center_crop_or_pad_2d(img2d, TARGET_HW, pad_value=0.0)
        lbl2d = center_crop_or_pad_2d(lbl2d, TARGET_HW, pad_value=0.0)

        # to torch: [1,1,H,W]
        x = torch.from_numpy(img2d).unsqueeze(0).unsqueeze(0).float().to(device)

        logits = model(x)  # [1,1,H,W]
        prob = torch.sigmoid(logits)[0, 0].detach().cpu().numpy()
        pred = (prob > THR).astype(np.float32)

        dice = _dice_binary(pred, lbl2d)

        # 2x2 figure: label, prediction, image+label overlay, image+pred overlay
        fig, axes = plt.subplots(2, 2, figsize=(10, 8))
        fig.suptitle(f"{case_id} | slice {s} | Dice={dice:.3f} (thr={THR})", fontsize=12)

        axes[0, 0].imshow(lbl2d, cmap="gray")
        axes[0, 0].set_title("Label")
        axes[0, 0].axis("off")

        axes[0, 1].imshow(pred, cmap="gray")
        axes[0, 1].set_title("Prediction")
        axes[0, 1].axis("off")

        axes[1, 0].imshow(_overlay(img2d.astype(np.float32), lbl2d), cmap="gray")
        axes[1, 0].set_title("Image + Label overlay")
        axes[1, 0].axis("off")

        axes[1, 1].imshow(_overlay(img2d.astype(np.float32), pred), cmap="gray")
        axes[1, 1].set_title("Image + Pred overlay")
        axes[1, 1].axis("off")

        plt.tight_layout()

        out_path = out_dir / f"{case_id}_slice{s:03d}_dice{dice:.3f}.png"
        fig.savefig(out_path, dpi=200, bbox_inches="tight")
        plt.close(fig)

        print(f"Saved: {out_path}", flush=True)

    print("Done.", flush=True)


if __name__ == "__main__":
    main()


Device: cpu
Model variant: base
Case: microUS_46 | volume shape: (39, 962, 1372) | saving to: /home/pirie03/projects/aip-medilab/pirie03/ProstateMicroSeg/runs/20260216_173816/viz_2d
Slices: [19, 20, 24, 31]
Saved: /home/pirie03/projects/aip-medilab/pirie03/ProstateMicroSeg/runs/20260216_173816/viz_2d/microUS_46_slice019_dice0.871.png
Saved: /home/pirie03/projects/aip-medilab/pirie03/ProstateMicroSeg/runs/20260216_173816/viz_2d/microUS_46_slice020_dice0.850.png
Saved: /home/pirie03/projects/aip-medilab/pirie03/ProstateMicroSeg/runs/20260216_173816/viz_2d/microUS_46_slice024_dice0.940.png
Saved: /home/pirie03/projects/aip-medilab/pirie03/ProstateMicroSeg/runs/20260216_173816/viz_2d/microUS_46_slice031_dice0.798.png
Done.
