In [None]:
from pathlib import Path
import os
import pandas as pd
from utils.constants import DATA_DIR

from pathlib import Path
import json, h5py, numpy as np
import torch
import matplotlib.pyplot as plt
from scipy import ndimage as ndi

from models.unet import UNet
from utils.train_utils import get_device
from utils.vis_utils import load_patient_metadata

In [None]:
CKPT_PATH = Path("")
META_PATH = Path("")


# Model Instantiation and Loading

In [None]:
DEVICE = torch.device("cpu")

meta = json.loads(META_PATH.read_text())

if isinstance(meta.get("input_channels_used"), list) and meta["input_channels_used"]:
    in_ch = len(meta["input_channels_used"])
else:
    in_ch = int(meta.get("in_channels", 37))

model = UNet(
    in_channels=in_ch,
    out_channels=meta.get("out_channels", 1),
    depth=meta.get("depth", 5),
    bilinear=meta.get("bilinear", True),
    dropout_p=meta.get("dropout_p", 0.1),
)

state = torch.load(CKPT_PATH, map_location="cpu")
try:
    model.load_state_dict(state, strict=True)
except RuntimeError as e:
    print(f"[warn] non-strict load due to: {e}")
    model.load_state_dict(state, strict=False)

model.eval().to(DEVICE)

print(f"Model ready on {DEVICE}: UNet(depth={meta.get('depth',5)}, in_ch={in_ch})")


In [None]:
df = load_patient_metadata()
df
patient_meta = df

# Select Test Cases Near the Median number of nervesegments

In [None]:

from pathlib import Path
import h5py, numpy as np, pandas as pd

H5_PATH = Path("data.h5")  # update if needed

with h5py.File(H5_PATH, "r") as hf:
    case_ids = [cid.decode("utf-8") for cid in hf["volumes"]["case_ids"][:]]
    test_idx = hf["volumes"]["test_idx"][:]

test_cases = {case_ids[i] for i in test_idx}

df_test = patient_meta.copy()
df_test["TMA_CASE"] = df_test["TMA_CASE"].astype(str)
df_test = df_test[df_test["TMA_CASE"].isin(test_cases)].copy()
df_test = df_test[df_test["nervesegments"].notna()]

assert len(df_test) >= 3, "Fewer than 3 test rows available."

med = float(np.median(df_test["nervesegments"].values))
df_test["dist_to_median"] = (df_test["nervesegments"] - med).abs()

df_test_sorted = df_test.sort_values(
    by=["dist_to_median", "nervesegments", "TMA_CASE"],
    ascending=[True, True, True],
)

selected_df = df_test_sorted.head(3).drop(columns=["dist_to_median"]).reset_index(drop=True)
selected_cases = selected_df["TMA_CASE"].tolist()

print(f"Test rows: {len(df_test)} | median(nervesegments) = {med:.3f}")
print("Selected CASE IDs:", selected_cases)

selected_df


# Select Test Cases with High Nerve Segment Counts

In [None]:
from pathlib import Path
import h5py, numpy as np, pandas as pd

H5_PATH = Path("data.h5")  

with h5py.File(H5_PATH, "r") as hf:
    case_ids = [cid.decode("utf-8") for cid in hf["volumes"]["case_ids"][:]]
    test_idx = hf["volumes"]["test_idx"][:]
test_cases = {case_ids[i] for i in test_idx}

df_test = df.copy()
df_test["TMA_CASE"] = df_test["TMA_CASE"].astype(str)
df_test = df_test[df_test["TMA_CASE"].isin(test_cases)].copy()
df_test["nervesegments"] = pd.to_numeric(df_test["nervesegments"], errors="coerce")
df_test = df_test.dropna(subset=["nervesegments"])

assert len(df_test) >= 3, "Fewer than 3 test rows available."

p90 = float(df_test["nervesegments"].quantile(0.90))
candidates = df_test[df_test["nervesegments"] >= p90].copy()

sort_keys = ["nervesegments", "TMA_CASE"]
candidates = candidates.sort_values(by=sort_keys, ascending=[False, True])

if len(candidates) < 3:
    candidates = df_test.sort_values(by=sort_keys, ascending=[False, True]).head(3)
else:
    candidates = candidates.head(3)

selected_df = candidates.reset_index(drop=True)
selected_cases = selected_df["TMA_CASE"].tolist()

print(f"Test rows: {len(df_test)} | 90th pct = {p90:.3f}")
print("Selected CASE IDs (high-nerve):", selected_cases)

selected_df


# Extract and Prepare Raw Volumes for Selected Test Cases

In [None]:

from pathlib import Path
import h5py
import numpy as np

H5_PATH = Path("data.h5")  

assert 'selected_cases' in globals(), "selected_cases not defined. Run the previous cell first."
assert len(selected_cases) >= 3, "Need at least 3 case IDs."

with h5py.File(H5_PATH, "r") as hf:
    case_ids_raw = hf["volumes"]["case_ids"][:]
    try:
        case_ids = [cid.decode("utf-8") for cid in case_ids_raw]
    except AttributeError:
        case_ids = [str(cid) for cid in case_ids_raw]
    case_to_idx = {cid: i for i, cid in enumerate(case_ids)}

    data_ds = hf["volumes"]["data"]  # shape: (N, H, W, C)

    raw_volumes = []
    for cid in selected_cases[:3]:
        idx = case_to_idx.get(cid)
        assert idx is not None, f"Case ID not found in HDF5: {cid}"
        vol = data_ds[idx]            # (H, W, C) NumPy array (copied from HDF5)
        raw_volumes.append(vol)

# sanity check
for cid, vol in zip(selected_cases[:3], raw_volumes):
    print(f"{cid}: shape={vol.shape}, dtype={vol.dtype}")



# Prepare Peripherin and Model Input Crops from Raw Volumes


In [None]:
import numpy as np

PERIPHERIN_CHANNEL = 31
DROP_CHANNELS = [14, 31]

assert 'raw_volumes' in globals(), "`raw_volumes` not found. Run the previous cell first."
assert 'selected_cases' in globals(), "`selected_cases` not found."
assert len(raw_volumes) >= 3, "Need at least 3 volumes in `raw_volumes`."

prph_crops = []       # list of (H-2, W-2) peripherin arrays
input_crops = []      # list of (H-2, W-2, C-2) arrays with ch 14 and 31 removed

for vol in raw_volumes[:3]:
    assert vol.ndim == 3, f"Expected (H, W, C), got shape {vol.shape}"
    H, W, C = vol.shape
    assert PERIPHERIN_CHANNEL < C, f"PERIPHERIN_CHANNEL={PERIPHERIN_CHANNEL} out of range for C={C}"
    assert all(ch < C for ch in DROP_CHANNELS), f"DROP_CHANNELS out of range for C={C}"

    vol_c = vol[1:H-1, 1:W-1, :]

    prph_crops.append(vol_c[:, :, PERIPHERIN_CHANNEL].astype(np.float32).copy())

    keep = [i for i in range(C) if i not in DROP_CHANNELS]
    input_crops.append(vol_c[:, :, keep].astype(np.float32).copy())

for cid, prph_arr, inp_arr in zip(selected_cases[:3], prph_crops, input_crops):
    print(f"{cid}: peripherin {prph_arr.shape}, inputs {inp_arr.shape} (dropped {DROP_CHANNELS})")



# Run Model Inference and Generate Logit Maps for Selected Crops

In [None]:

import numpy as np, torch, h5py
from pathlib import Path

assert 'model' in globals(), "Run your model-instantiation cell first."
assert 'input_crops' in globals() and len(input_crops) >= 3, "Need 3 crops in `input_crops`."

H5_PATH = Path("data.h5")
DROP_CHANNELS = [14, 31]
TOTAL_CHANNELS = 38
keep = [i for i in range(TOTAL_CHANNELS) if i not in DROP_CHANNELS]  # ascending kept-channel order

# Load normalization stats for the kept channels
with h5py.File(H5_PATH, "r") as hf:
    means = hf["statistics"]["means"][:][keep].astype(np.float32)
    stds  = hf["statistics"]["stds"][:][keep].astype(np.float32)

logit_maps = []
model.eval()
device = next(model.parameters()).device

for arr in input_crops[:3]:           # each arr: (H, W, C_kept)
    x = (arr - means.reshape(1,1,-1)) / (stds.reshape(1,1,-1) + 1e-8)
    xb = torch.from_numpy(np.transpose(x.astype(np.float32), (2,0,1))).unsqueeze(0).to(device)
    print("predicting:")
    with torch.inference_mode():
        logits = model(xb)[0, 0].cpu().numpy().astype(np.float32)  # (H, W)
    print("done:")
    logit_maps.append(logits)

print("Logit map shapes:", [m.shape for m in logit_maps])


# Visualize Peripherin Channel and Model Predictions for Selected Test Cases

In [None]:

from pathlib import Path
import json, numpy as np
import matplotlib.pyplot as plt

assert 'META_PATH' in globals(), "META_PATH not set (from model cell)."
assert 'selected_cases' in globals() and len(selected_cases) >= 3, "Need 3 case IDs in `selected_cases`."
assert 'prph_crops' in globals() and len(prph_crops) >= 3, "`prph_crops` missing or too small."
assert 'logit_maps' in globals() and len(logit_maps) >= 3, "`logit_maps` missing or too small."

thr = 0.99995
print(f"Using best validation threshold: {thr:.3f}")

def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))

run_through = list(zip(selected_cases[:3], prph_crops[:3], logit_maps[:3]))
chosen = run_through[2]
for cid, prph, logits in run_through[2:]:
    probs = sigmoid(logits)

    pred_bin = (probs >= thr).astype(np.uint8)

    vmax = np.percentile(prph, 99.5)

    fig, axes = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True)
    ax0, ax1 = axes

    ax0.imshow(prph, cmap="gray", vmin=0, vmax=vmax)
    ax0.set_title(f"Peripherin (raw)")
    ax0.axis("off")

    ax1.imshow(pred_bin, vmin=0, vmax=1)
    ax1.set_title(f"Model ≥ {thr:.2f} (binary)")
    ax1.axis("off")

    plt.show()


In [None]:
chosen

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_prph_and_predictions_grid(
    prph, logits, *,
    thresholds=None,
    intervals=None,
    case_id=None,
    prph_binary=False,
    prph_eps=None,
    origin="upper"
):
    """
    prph:   (H, W) float array  — raw peripherin channel
    logits: (H, W) float array  — model logits (before sigmoid)

    Provide exactly one of:
      - thresholds: list/tuple of 3 floats, each used as [t, 1.0]
      - intervals:  list of 3 (low, high) tuples, with 0<=low<high<=1

    prph_binary: if True, display PRPH as (prph > eps); else show raw grayscale.
    prph_eps:    optional epsilon for binarization (default: dtype epsilon).
    origin:      passed to imshow; 'upper' means y increases downward.
    """
    assert (thresholds is None) ^ (intervals is None), "Provide thresholds OR intervals (not both)."

    probs = 1.0 / (1.0 + np.exp(-logits))

    if intervals is None:
        assert len(thresholds) == 3, "Need exactly 3 thresholds."
        intervals = [(float(t), 1.0) for t in thresholds]
    else:
        assert len(intervals) == 3, "Need exactly 3 intervals."
        intervals = [(float(lo), float(hi)) for (lo, hi) in intervals]

    masks, labels = [], []
    for (lo, hi) in intervals:
        assert 0.0 <= lo < hi <= 1.0, f"Bad interval: {(lo,hi)}"
        m = (probs >= lo) & (probs <= hi)
        masks.append(m.astype(np.uint8))
        labels.append(f"[{lo:.6f}, {hi:.6f}]")

    fig, axes = plt.subplots(2, 2, figsize=(10, 10), constrained_layout=True)
    ax00, ax01 = axes[0]
    ax10, ax11 = axes[1]

    if prph_binary:
        eps = (np.finfo(prph.dtype).eps if prph_eps is None else float(prph_eps))
        prph_disp = (prph > eps).astype(np.uint8)
        ax00.imshow(prph_disp, cmap="gray", vmin=0, vmax=1, interpolation="nearest", origin=origin)
        title_left = "Peripherin (binary)"
    else:
        vmax = float(np.percentile(prph, 99.5))
        ax00.imshow(prph, cmap="gray", vmin=0, vmax=vmax, origin=origin)
        title_left = "Peripherin (raw)"
    ax00.set_title(f"{title_left}" if case_id else title_left)
    ax00.axis("off")

    for ax, m, lab in zip([ax01, ax10, ax11], masks, labels):
        ax.imshow(m, vmin=0, vmax=1, interpolation="nearest", origin=origin)
        ax.set_title(f"Prediction ∈ {lab}")
        ax.axis("off")

    plt.show()


case_id, prph, logits = chosen
plot_prph_and_predictions_grid(prph, logits, thresholds=[0.9999, 0.99999, 0.999999], case_id=case_id, prph_binary=False, prph_eps=1)
plot_prph_and_predictions_grid(prph, logits, thresholds=[0.9999, 0.99999, 0.999999], case_id=case_id, prph_binary=True, prph_eps=1)


In [None]:
import numpy as np

def crop_pair(prph, logits, *, crop=None, center_size=None):
    """
    Crop PRPH and logits identically.

    Use ONE of:
      - crop=(y0, y1, x0, x1)         # half-open box [y0:y1, x0:x1]
      - center_size=(cy, cx, h, w)    # center coords + box size

    Returns:
      prph_c, logits_c  (cropped arrays)
    """
    assert (crop is None) ^ (center_size is None), "Provide crop OR center_size (not both)."
    H, W = prph.shape
    assert logits.shape == prph.shape, "prph/logits shape mismatch."

    if crop is not None:
        y0, y1, x0, x1 = map(int, crop)
    else:
        cy, cx, h, w = map(int, center_size)
        y0 = cy - h // 2
        y1 = y0 + h
        x0 = cx - w // 2
        x1 = x0 + w

    
    y0 = max(0, min(y0, H - 1))
    y1 = max(y0 + 1, min(y1, H))
    x0 = max(0, min(x0, W - 1))
    x1 = max(x0 + 1, min(x1, W))

    return prph[y0:y1, x0:x1], logits[y0:y1, x0:x1]


prph_c, logits_c = crop_pair(prph, logits, crop=(600, 800, 150, 350))
plot_prph_and_predictions_grid(prph_c, logits_c, thresholds=[0.3, 0.99999, 0.999999], case_id="CASE123", prph_binary=False, prph_eps=1)
plot_prph_and_predictions_grid(prph_c, logits_c, thresholds=[0.9999, 0.99999, 0.999999], case_id="CASE123", prph_binary=True, prph_eps=1)
