# Initiation

1. Add the project folder to your VS Code workspace.  
2. Download the dataset:
   - Go to [this page](https://lhncbc.nlm.nih.gov/LHC-research/LHC-projects/image-processing/malaria-datasheet.html)
   - Download **"NLM-Falciparum&Uninfected-Thin-193Patients"**
   - Unzip the downloaded archive into a convenient location (for example, `C:\Projects\cell images`)

3. Open a terminal in VS Code (PowerShell recommended) and run the following commands in order:

> **Note:** Replace the path in the first command with the location where you unzipped `cell images`.

```powershell
# 1) Go to the project data folder
cd "C:\[path to where you unzipped 'cell images']"

# 2) Create a Python 3.11 virtual environment
py -3.11 -m venv .venv

# 3) Activate the environment (PowerShell)
.\.venv\Scripts\Activate.ps1

# 4) Upgrade pip
python -m pip install --upgrade pip

# 5) Install PyTorch (CUDA 12.1 build)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# 6) Install the packages
pip install numpy pandas pillow scikit-learn matplotlib tqdm jupyter ipykernel fastapi uvicorn==0.30.* starlette python-multipart

# 7) Register the Jupyter kernel
python -m ipykernel install --user --name ddls-local --display-name "Python 3.11 (.venv) DDLS"
```

4. In VS Code, select the Jupyter kernel: **Python 3.11 (.venv) DDLS**.

5. Change all instances of "[path_placeholder]" to the path where you unzipped 'cell images'. 

# Environment and GPU Verification

This cell imports the core libraries (`sys`, `platform`, `torch`, `torchvision`) and prints the current Python executable path, Python version, and library versions for Torch and TorchVision.  
It then checks whether CUDA is available on the system.  
If a GPU is detected, the code prints the GPU device name and performs a small matrix multiplication on the GPU to confirm correct CUDA functionality.  
If no GPU is available, a message is printed advising that the correct virtual environment and GPU drivers should be verified.  
This cell ensures the runtime environment is properly configured before any computational steps.

In [None]:
import sys, platform, torch, torchvision
print("Python exe:", sys.executable)
print("Python ver:", platform.python_version())
print("Torch :", torch.__version__)
print("TV    :", torchvision.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Device:", torch.cuda.get_device_name(0))
    # tiny CUDA sanity op
    x = torch.randn(1024, 1024, device="cuda")
    y = torch.mm(x, x.t())
    print("CUDA matmul OK, shape:", y.shape)
else:
    print("No CUDA visible. Make sure this kernel is the .venv and drivers are up to date.")

# Step 1 — Environment Setup and Global Seeding

This cell defines the local directory structure for the project and establishes reproducibility through deterministic random seeds.

- **Directory setup:**  
  The project root, dataset directory, and results directory are specified using `pathlib.Path`.  
  The results directory is created if it does not already exist.

- **Reproducibility:**  
  The function `seed_everything()` sets fixed seeds for the `random`, `numpy`, and `torch` libraries.  
  It also configures PyTorch’s backend to deterministic mode, disabling automatic benchmarking to ensure consistent results across runs.  
  These settings may slightly reduce computational speed but guarantee that data splits, model initialisations, and training results remain reproducible.

- **Output:**  
  After setting the seed, the cell prints the global seed value and the absolute paths to the project root, dataset, and results directories.  
  This confirms that the working environment is correctly initialised before proceeding with data exploration or model training.

In [None]:
# Step 1 — Local environment & global seed
import os, random, numpy as np
from pathlib import Path

# ---- Local roots ----
PROJECT_ROOT = Path(r"[path_placeholder]")
DATA_DIR     = PROJECT_ROOT / "cell_images"
RESULTS_DIR  = PROJECT_ROOT / "results"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# ---- Reproducibility ----
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    try:
        import torch
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        # For reproducibility (set with care; can reduce speed a bit)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except Exception:
        pass
    print(f"Global seed set to {seed}")

seed_everything(42)

print(f"Project root: {PROJECT_ROOT}")
print(f"Dataset dir:  {DATA_DIR}")
print(f"Results dir:  {RESULTS_DIR}")

# Step 2 — Dataset Structure Check and Class Counts

This cell validates the expected on-disk layout and produces basic counts per class.

- **Expected layout:**  
  Two class subfolders under `DATA_DIR`: `Parasitized` and `Uninfected`.

- **Discovery and warnings:**  
  Lists present subfolders; prints a warning if any expected class folder is missing.

- **File indexing:**  
  Scans each class folder for image files with extensions `{.png, .jpg, .jpeg}`; builds an `(image_path, class)` index and tallies per-class counts.

- **Output:**  
  Prints counts for each class and the total number of images.  
  If no images are found, raises a `RuntimeError` to flag a misconfigured `DATA_DIR`.

This check verifies that the dataset is complete and correctly arranged before downstream processing.

In [None]:
# Step 2 — Sanity-check the dataset layout and counts

# Expected classes
expected_classes = {"Parasitized", "Uninfected"}

# Discover subfolders
present = {p.name for p in DATA_DIR.iterdir() if p.is_dir()}
print("Subfolders found:", present)

missing = expected_classes - present
if missing:
    print("WARNING: missing class folders:", missing)

# Collect files
def list_images(cls_dir: Path):
    exts = {".png", ".jpg", ".jpeg"}
    return sorted([p for p in cls_dir.iterdir() if p.suffix.lower() in exts and p.is_file()])

index = []
counts = {}

for cls in sorted(expected_classes):
    cls_dir = DATA_DIR / cls
    if not cls_dir.exists():
        counts[cls] = 0
        continue
    files = list_images(cls_dir)
    counts[cls] = len(files)
    index.extend([(str(p), cls) for p in files])

# Print counts
total = sum(counts.values())
print("\nImage counts:")
for k, v in counts.items():
    print(f"  {k:12s}: {v:6d}")
print(f"  {'TOTAL':12s}: {total:6d}")

# Quick assertion to catch a half-empty dataset by accident (soft)
if total == 0:
    raise RuntimeError("No images found. Check DATA_DIR.")

# Step 3 — Build a File Index with Basic Image Metadata

This cell scans all discovered image paths and creates a tabular index with minimal metadata.

- **Process:**  
  Iterates over `(path, class)` pairs; opens each image with Pillow; records filename, extension, width, and height.  
  Any unreadable file is caught; its path and error message are stored for logging.

- **Outputs:**  
  - `results/index_raw.csv`: one row per image with columns `path, class, filename, ext, width, height`.  
  - `results/bad_files.txt` (only if needed): paths of images that failed to open, with the exception text.

- **Why this step:**  
  Provides a single source of truth for downstream exploration, splitting, and preprocessing; surfaces corrupt or incompatible files early.

- **Notes:**  
  A short progress bar (`tqdm`) shows indexing status; a final summary prints the number of indexed images and whether any files were skipped.

In [None]:
# Step 3 — Build a full index with metadata

import pandas as pd
from PIL import Image
from tqdm import tqdm

rows = []
bad = []

for path_str, cls in tqdm(index, desc="Indexing images"):
    p = Path(path_str)
    try:
        with Image.open(p) as im:
            w, h = im.size
            ext = p.suffix.lower()
    except Exception as e:
        bad.append((path_str, str(e)))
        continue

    rows.append({
        "path": path_str,
        "class": cls,
        "filename": p.name,
        "ext": ext,
        "width": w,
        "height": h,
    })

df = pd.DataFrame(rows)
csv_path = RESULTS_DIR / "index_raw.csv"
df.to_csv(csv_path, index=False)
print(f"Indexed {len(df)} images → {csv_path}")

if bad:
    bad_path = RESULTS_DIR / "bad_files.txt"
    with open(bad_path, "w") as fh:
        for path_str, msg in bad:
            fh.write(f"{path_str}\t{msg}\n")
    print(f"Unreadable files: {len(bad)} (logged to {bad_path})")
else:
    print("All images opened successfully.")

# Step 4 — Image Size Summary and Outlier Inspection

This cell summarises image dimensions and lists extreme cases.

- **Statistics:**  
  Computes width and height percentiles (min, p10, median, p90, max) from `index_raw.csv`. This guides padding and resizing choices.

- **Outlier tables:**  
  Shows the ten smallest and ten largest images by width; repeats for height. IPython `display` renders compact tables with `filename, class, width, height`.

- **Why this matters:**  
  Size spread affects padding artefacts and interpolation; extreme aspect ratios or very small crops may require special handling or exclusion.

Review the extremes before finalising the canonical transform.

In [None]:
# Step 4 — Image size statistics and extremes

from IPython.display import display

def pct(s, q):
    return np.percentile(s.to_numpy(), q)

w = df["width"]
h = df["height"]

print("Width stats:")
print({
    "min": int(w.min()), "p10": int(pct(w,10)), "p50": int(pct(w,50)),
    "p90": int(pct(w,90)), "max": int(w.max())
})
print("Height stats:")
print({
    "min": int(h.min()), "p10": int(pct(h,10)), "p50": int(pct(h,50)),
    "p90": int(pct(h,90)), "max": int(h.max())
})

# Smallest widths
print("\nTop 10 smallest widths:")
display(df.sort_values("width").head(10)[["filename","class","width","height"]])

# Smallest heights
print("\nTop 10 smallest heights:")
display(df.sort_values("height").head(10)[["filename","class","width","height"]])

# Largest widths
print("\nTop 10 largest widths:")
display(df.sort_values("width", ascending=False).head(10)[["filename","class","width","height"]])

# Largest heights
print("\nTop 10 largest heights:")
display(df.sort_values("height", ascending=False).head(10)[["filename","class","width","height"]])

# Step 5 — Derive Group Keys to Prevent Split Leakage

This cell extracts a slide/session **group key** from each filename by removing the trailing pattern `_cell_<digits>.<ext>`. The key is used later to build **group-wise** train/val/test splits so that images from the same slide do not appear in multiple splits.

- **Group derivation:**  
  Applies a regex to `filename` to strip the per-cell suffix and produce `group`.

- **Group statistics:**  
  Reports the number of groups and the distribution of group sizes (min, p25, p50, p75, max).  
  Lists the largest groups overall, then the top groups per class to spot imbalances or anomalies.

- **Output:**  
  Saves the augmented index with a `group` column to `results/index_with_groups.csv`.

This step guards against information leakage by keeping slide-level batches intact during cross-validation and final splitting.

In [None]:
# Step 5 — Derive a leakage-safe group key per image

from IPython.display import display
import re

def derive_group_key(fname: str) -> str:
    # Remove trailing `_cell_<digits>.<ext>` (case-insensitive ext)
    return re.sub(r"_cell_\d+\.(png|jpg|jpeg)$", "", fname, flags=re.IGNORECASE)

df["group"] = df["filename"].apply(derive_group_key)

# Group stats
gsize = df.groupby("group").size().sort_values(ascending=False)
print(f"Total groups: {gsize.shape[0]}")
print({
    "min": int(gsize.min()), "p25": int(pct(gsize,25)), "p50": int(pct(gsize,50)),
    "p75": int(pct(gsize,75)), "max": int(gsize.max())
})

print("\nLargest groups (all classes mixed):")
display(gsize.head(10).to_frame(name="n_images"))

# Group-by-class top examples
print("\nTop groups by class (head=5 per class):")
for cls in sorted(df["class"].unique()):
    tmp = df[df["class"] == cls].groupby("group").size().sort_values(ascending=False).head(5)
    print(f"\nClass: {cls}")
    display(tmp.to_frame(name="n_images"))

# Save
csv_groups = RESULTS_DIR / "index_with_groups.csv"
df.to_csv(csv_groups, index=False)
print(f"Saved with groups → {csv_groups}")

# Step 6 — Probe border background colours on a stratified sample

Purpose:
- Verify that outer-border pixels are flat and dark; this supports the plan to pad-to-square with a constant black background.
- Check this **per class** to rule out class-specific acquisition artefacts.

Method:
- Randomly sample N images per class (default N=2000).
- For each image, extract a 2-pixel ring at the four borders; compute the modal RGB triplet and its fraction of the ring.
- Classify a mode as “near black” if all RGB components ≤ 10 (tolerant threshold for JPEG/PNG quantisation).

Outputs:
- A table with filename, class, modal RGB, and fraction represented.
- A per-class summary: count and proportion of images with near-black modes.

Interpretation:
- If ≥ ~90% show near-black modes per class, constant-black padding is appropriate.
- If many images have non-black modes, consider a fallback (e.g., reflect padding) for those specific cases.

In [None]:
# Step 6 — Probe border background colours on a stratified sample

from PIL import Image
from collections import Counter

# Parameters
N_PER_CLASS = 2000
RING_PX = 2
NEAR_BLACK_THR = 10  # per-channel threshold for "near black"

rng = 123

# Stratified sample (no deprecation warning)
sample_rows = (
    df.groupby("class", group_keys=False)
      .apply(lambda g: g.sample(min(N_PER_CLASS, len(g)), random_state=rng))
      .reset_index(drop=True)
)

def border_ring_mode_rgb(path: Path, ring: int = RING_PX, max_pixels: int = 20000):
    from collections import Counter
    with Image.open(path) as im:
        im = im.convert("RGB")
        arr = np.array(im)
    h, w, _ = arr.shape
    top = arr[0:ring, :, :]
    bottom = arr[h-ring:h, :, :]
    left = arr[:, 0:ring, :]
    right = arr[:, w-ring:w, :]
    ring_pixels = np.concatenate([
        top.reshape(-1,3), bottom.reshape(-1,3),
        left.reshape(-1,3), right.reshape(-1,3)
    ], axis=0)
    if ring_pixels.shape[0] > max_pixels:
        idx = np.random.choice(ring_pixels.shape[0], size=max_pixels, replace=False)
        ring_pixels = ring_pixels[idx]
    tuples = [tuple(int(x) for x in v) for v in ring_pixels]
    cnt = Counter(tuples)
    mode_rgb, mode_n = cnt.most_common(1)[0]
    frac = mode_n / len(tuples)
    return mode_rgb, frac

# Run probe
records = []
for _, r in sample_rows.iterrows():
    rgb, frac = border_ring_mode_rgb(Path(r["path"]), ring=RING_PX)
    near_black = all(c <= NEAR_BLACK_THR for c in rgb)
    records.append({
        "filename": r["filename"],
        "class": r["class"],
        "mode_rgb": rgb,
        "fraction": round(frac, 3),
        "near_black": near_black
    })

probe_df = pd.DataFrame(records)

# Display head and per-class summary
display(probe_df.head(10))

summary = (probe_df.groupby("class")
           .agg(n=("filename","count"),
                n_near_black=("near_black","sum"),
                prop_near_black=("near_black","mean")))
display(summary)

# Optional: histogram of modal fraction per class
import matplotlib.pyplot as plt

for cls, sub in probe_df.groupby("class"):
    plt.figure()
    plt.hist(sub["fraction"], bins=20)
    plt.title(f"Border mode fraction — {cls}")
    plt.xlabel("Fraction of ring pixels at modal RGB")
    plt.ylabel("Count")
    plt.show()

# Optional: list any non–near-black modes
non_nb = probe_df[~probe_df["near_black"]]
if not non_nb.empty:
    print("Non–near-black border modes found:")
    display(non_nb.sort_values("fraction", ascending=False).head(20))
else:
    print("All sampled images have near-black border modes.")

### Canonical preprocessing (single source of truth)

This cell defines `pad_resize_canonical`, which **all** later steps must use:

- Convert to RGB; inspect a 2-pixel border ring to get the modal RGB and its fraction.  
- If the modal colour is near-black (each channel ≤ 10) **and** covers ≥ 60% of the ring, pad with constant black; otherwise, use reflect padding.  
- Centre-pad each image to a square; bicubic resize to 128×128.  

This matches the approved plan and keeps preprocessing identical for training, statistics, evaluation, TTA, OOD, Grad-CAM, and the app. Import it in later cells rather than redefining new variants.

In [None]:
# Canonical preprocessing — single source of truth
# Pad to square using a near-black heuristic; then bicubic resize to 128×128.
# Use this everywhere: training, stats, evaluation, TTA, OOD, Grad-CAM, and the app.

from PIL import Image, ImageOps
import numpy as np
from collections import Counter

# Constants (match the approved plan)
NEAR_BLACK_THR = 10      # each RGB channel ≤ 10 → "near black"
MIN_MODE_FRAC  = 0.60    # modal ring colour must cover ≥60% of ring pixels

# Pillow compatibility (older/newer)
try:
    RESAMPLE_BICUBIC = Image.Resampling.BICUBIC  # Pillow ≥ 9.1
except Exception:
    RESAMPLE_BICUBIC = Image.BICUBIC             # older Pillow

def _ring_mode_from_array(arr: np.ndarray, ring: int = 2, max_pixels: int = 20000):
    """
    Modal RGB on the outer 'ring' (top/bottom/left/right strips).
    Uses deterministic sub-sampling by stride if too many pixels.
    """
    h, w, _ = arr.shape
    top    = arr[0:ring, :, :]
    bottom = arr[h-ring:h, :, :]
    left   = arr[:, 0:ring, :]
    right  = arr[:, w-ring:w, :]
    ring_pixels = np.concatenate(
        [top.reshape(-1,3), bottom.reshape(-1,3), left.reshape(-1,3), right.reshape(-1,3)],
        axis=0
    )
    if ring_pixels.shape[0] > max_pixels:
        step = int(np.ceil(ring_pixels.shape[0] / max_pixels))
        ring_pixels = ring_pixels[::step]
    tuples = [tuple(int(x) for x in v) for v in ring_pixels]
    cnt = Counter(tuples)
    mode_rgb, n = cnt.most_common(1)[0]
    frac = n / max(1, len(tuples))
    return mode_rgb, frac

def pad_resize_canonical(img: Image.Image, target_size: int = 128, ring: int = 2) -> Image.Image:
    """
    1) Convert to RGB.
    2) Inspect border ring to pick pad mode:
         - If modal ring colour is near-black (all channels ≤ NEAR_BLACK_THR)
           and covers ≥ MIN_MODE_FRAC of ring pixels → constant black pad.
         - Otherwise → reflect pad.
    3) Centre-pad to square; bicubic resize to (target_size, target_size).
    """
    img = img.convert("RGB")
    arr = np.array(img)
    mode_rgb, frac = _ring_mode_from_array(arr, ring=ring)
    near_black = (mode_rgb[0] <= NEAR_BLACK_THR and
                  mode_rgb[1] <= NEAR_BLACK_THR and
                  mode_rgb[2] <= NEAR_BLACK_THR)
    use_constant = (near_black and frac >= MIN_MODE_FRAC)

    w, h = img.size
    if w != h:
        side = max(w, h)
        pad_left   = (side - w) // 2
        pad_right  = side - w - pad_left
        pad_top    = (side - h) // 2
        pad_bottom = side - h - pad_top
        if use_constant:
            padded = ImageOps.expand(img, border=(pad_left, pad_top, pad_right, pad_bottom), fill=mode_rgb)
        else:
            # reflect via numpy for exact control
            pad_width = ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0))
            padded = Image.fromarray(np.pad(arr, pad_width, mode="reflect"))
    else:
        padded = img

    return padded.resize((target_size, target_size), resample=RESAMPLE_BICUBIC)

# Step 7 — Canonical Transform (pad-to-square + 128×128 resize) and Sample Validation

This cell implements and validates the preprocessing used in training and inference.

- **Inputs and guards:**  
  Reloads `index_with_groups.csv` if `df` is not present; prepares output paths.

- **Parameters:**  
  `TARGET_SIZE=128`; border ring width `RING_PX=2`; near-black threshold `NEAR_BLACK_THR=10` (per channel); minimum modal-colour coverage `MIN_MODE_FRAC=0.60`. Bicubic resampling is selected with a Pillow-version-safe fallback.

- **Border analysis:**  
  `ring_mode_rgb()` computes the modal RGB of the outer ring and its fraction.  
  `decide_pad_mode()` returns `"constant_black"` if the modal colour is near-black **and** covers ≥60% of ring pixels; otherwise it returns `"reflect"`.

- **Transform:**  
  `pad_to_square()` pads symmetrically to a square using either a constant dark fill or reflect; then resizes to 128×128 with bicubic.  
  `preprocess_one()` applies the full pipeline and records metadata (original size, pad mode, per-side pad pixels, modal colour and fraction).

- **Validation on a sample:**  
  Draws `SAMPLES_PER_CLASS=12` per class; runs the transform; stores a per-image report to `results/preprocess_sample_report.csv`.  
  Prints pad-mode counts and the most frequent border colours; saves a side-by-side montage (original | processed) to `results/preprocess_sample_montage.png`.

- **Purpose:**  
  Locks the exact preprocessing that will be used everywhere; provides a quick visual and numeric check that padding choice and resizing behave as intended.

In [None]:
# Step 7 — Canonical transform implementation + sample validation

from pathlib import Path
import os
import numpy as np
import pandas as pd
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
from collections import Counter
from tqdm import tqdm

# ---- Robustly re-establish paths / inputs ----
DATASET_DIR  = DATA_DIR
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

INDEX_WITH_GROUPS = RESULTS_DIR / "index_with_groups.csv"
if "df" not in globals():
    assert INDEX_WITH_GROUPS.exists(), f"Missing {INDEX_WITH_GROUPS}; please rerun earlier indexing steps."
    df = pd.read_csv(INDEX_WITH_GROUPS)

# ---- Parameters for transform ----
TARGET_SIZE = 128
RING_PX = 2
NEAR_BLACK_THR = 10     # per-channel max to consider "near-black"
MIN_MODE_FRAC = 0.60    # if the modal ring colour covers <60% of ring pixels, we treat as uncertain and fallback to reflect

# ---- Utilities ----
try:
    # Pillow ≥ 9/10
    from PIL import Image
    RESAMPLE_BICUBIC = Image.BICUBIC if hasattr(Image, "BICUBIC") else Image.Resampling.BICUBIC
except Exception:
    RESAMPLE_BICUBIC = Image.BICUBIC

def preprocess_one(path: Path, target_size=128):
    with Image.open(path) as im:
        im = im.convert("RGB")
        arr = np.array(im)
    mode_rgb, frac = _ring_mode_from_array(arr, ring=2)
    near_black = all(c <= NEAR_BLACK_THR for c in mode_rgb)
    pad_mode = "constant_black" if (near_black and frac >= MIN_MODE_FRAC) else "reflect"

    im_out = pad_resize_canonical(Image.fromarray(arr), target_size=target_size, ring=2)
    meta = {
        "orig_w": arr.shape[1], "orig_h": arr.shape[0],
        "pad_mode": pad_mode, "mode_rgb": mode_rgb, "mode_frac": round(frac, 3),
        "final_w": target_size, "final_h": target_size
    }
    return im_out, meta

# ---- Sample a small set per class and run the transform ----
SAMPLES_PER_CLASS = 12
rng = 42
sample_df = (
    df.groupby("class", group_keys=False)
      .apply(lambda g: g.sample(min(SAMPLES_PER_CLASS, len(g)), random_state=rng))
      .reset_index(drop=True)
)

records = []
thumbs = []  # (orig, processed) pairs for montage
for _, r in sample_df.iterrows():
    p = Path(r["path"])
    with Image.open(p) as im_orig:
        im_orig = im_orig.convert("RGB")
    im_proc, meta = preprocess_one(p)
    records.append({
        "filename": r["filename"],
        "class": r["class"],
        **meta
    })
    # Build side-by-side thumbnails for montage
    # Normalize thumbnail height to 128 for consistent grid
    thumbs.append((im_orig, im_proc))

# ---- Save report ----
report_df = pd.DataFrame(records)
report_csv = RESULTS_DIR / "preprocess_sample_report.csv"
report_df.to_csv(report_csv, index=False)
print(f"Saved sample report → {report_csv}")

# ---- Print quick summary ----
print("\nPad mode counts:")
print(report_df["pad_mode"].value_counts())

print("\nBorder mode colour (top 5):")
print(report_df["mode_rgb"].value_counts().head())

# ---- Make a montage: original | processed, 6 rows x 4 cols (pairs) ----
def make_montage(pairs, cols=4, out_path=RESULTS_DIR/"preprocess_sample_montage.png"):
    # Each pair contributes two panels; we stack as [orig, processed] horizontally within a pair.
    n = len(pairs)
    rows = int(np.ceil(n / cols))
    panel_w = 128 * 2   # orig + processed
    panel_h = 128
    canvas = Image.new("RGB", (panel_w * cols, panel_h * rows), (255,255,255))
    for i, (im_orig, im_proc) in enumerate(pairs):
        # Fit original to 128x128 preview (keep aspect ratio, pad white if needed)
        w, h = im_orig.size
        scale = min(128/w, 128/h)
        new_w, new_h = int(w*scale), int(h*scale)
        im_orig_small = im_orig.resize((new_w, new_h), resample=RESAMPLE_BICUBIC)
        panel = Image.new("RGB", (panel_w, panel_h), (255,255,255))
        # Centre original in left 128×128
        left_pad_x = (128 - new_w)//2
        left_pad_y = (128 - new_h)//2
        panel.paste(im_orig_small, (left_pad_x, left_pad_y))
        # Paste processed on the right (exact 128×128)
        panel.paste(im_proc, (128, 0))
        # Place panel in grid
        r = i // cols
        c = i % cols
        canvas.paste(panel, (c*panel_w, r*panel_h))
    canvas.save(out_path)
    return out_path

montage_path = make_montage(thumbs, cols=4)
print(f"Saved montage → {montage_path}")

# Step 8 — Group-wise, Stratified Train/Val/Test Splits

This cell creates leakage-safe splits while keeping class balance.

- **Inputs.** Loads `index_with_groups.csv`; checks required columns. Encodes labels: `Parasitized→1`, `Uninfected→0`. Uses `group` as the slide/session key.
- **Method.** Builds five folds with `StratifiedGroupKFold` so groups do not mix across folds and class proportions stay stable at the group level. Assigns each sample a fold id; picks one fold for validation and one for test using a fixed RNG seed; the remaining folds form the training set.
- **Leakage check.** Verifies that every group maps to exactly one split; raises an error if any group crosses splits.
- **Summaries.** Prints counts per split; a class-by-split crosstab; and overall split proportions.
- **Output.** Saves the manifest with split labels to `results/splits.csv`.

This split strategy keeps slide-level correlation confined within a single split; reported counts confirm balance before training.

In [None]:
# Step 8 — Build leakage-safe Train/Val/Test splits

# Imports
import numpy as np
import pandas as pd
from pathlib import Path

# Sklearn splitters
from sklearn.model_selection import StratifiedGroupKFold

# Paths
RESULTS = RESULTS_DIR

# Load the grouped index built earlier
df = pd.read_csv(RESULTS / "index_with_groups.csv")

# Basic checks
assert {"path","filename","class","group","width","height"}.issubset(df.columns), "Missing required columns."

# Encode labels
label_map = {"Parasitized": 1, "Uninfected": 0}
y = df["class"].map(label_map).values
groups = df["group"].values

# Build 5 stratified group folds
# (Stratifies on 'class' at group level; keeps groups intact.)
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)

# Assign each sample a fold id via the test indices of each split
fold_id = np.full(len(df), -1, dtype=int)
for k, (_, test_idx) in enumerate(sgkf.split(X=np.zeros(len(df)), y=y, groups=groups)):
    fold_id[test_idx] = k

assert (fold_id >= 0).all(), "Unassigned samples detected."

# Choose folds: one for val, one for test (reproducible)
rng = np.random.RandomState(1337)
folds = np.arange(5)
rng.shuffle(folds)
val_fold, test_fold = folds[:2]
train_folds = set(folds[2:])

def fold_to_split(fid):
    if fid == val_fold:
        return "val"
    if fid == test_fold:
        return "test"
    return "train"

df["split"] = [fold_to_split(fid) for fid in fold_id]

# Sanity: no group should cross splits
group_to_splits = df.groupby("group")["split"].nunique()
if (group_to_splits > 1).any():
    bad = group_to_splits[group_to_splits > 1]
    raise RuntimeError(f"Group leakage detected for {len(bad)} groups.")

# Summaries
print("Fold assignment:", {"val": int(val_fold), "test": int(test_fold), "train": sorted(list(train_folds))})

print("\nCounts per split:")
print(df["split"].value_counts())

print("\nClass counts per split:")
print(pd.crosstab(df["split"], df["class"]))

print("\nProportions per split (overall):")
print((df["split"].value_counts(normalize=True) * 100).round(2))

# Persist manifest
out_path = RESULTS / "splits.csv"
df.to_csv(out_path, index=False)
print(f"\nSaved manifest → {out_path}")

# Step 9 — Preprocessing Definition and Split-wise Visual Spot-check

This cell defines the operational preprocessing and verifies its behaviour on each split.

- **Preprocessing.**  
  `border_ring_mode_rgb()` estimates the modal border colour from a thin outer ring; `preprocess_image()` pads to a square using that colour, then resizes to 128×128 with bicubic interpolation.

- **Visual QA.**  
  For each of `train`, `val`, and `test`, it samples up to six images per class, applies the transform, and assembles a compact montage via `montage_grid()`.

- **Outputs.**  
  Saves `montage_split_train.png`, `montage_split_val.png`, and `montage_split_test.png` under `results/`.

**Why this matters.** Confirms that padding and resizing are consistent across splits; provides an immediate check for unintended artefacts or class-specific differences.

In [None]:
# Step 9 — Define preprocessing (pad-to-square + bicubic 128×128) and spot-check per split

import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image, ImageOps
from collections import Counter

RESULTS = RESULTS_DIR
df = pd.read_csv(RESULTS / "splits.csv")

# --- Canonical pad→resize: constant near-black if dominant; otherwise reflect ---
NEAR_BLACK_THR = 10
MIN_MODE_FRAC  = 0.60

def preprocess_image(path: Path, size: int = 128, ring: int = 2) -> Image.Image:
    with Image.open(path) as im:
        return pad_resize_canonical(im, target_size=size, ring=ring)

def montage_grid(imgs, ncols=3, cell_size=128, pad=4):
    """Create a simple montage for visual spot-checks."""
    n = len(imgs)
    nrows = int(np.ceil(n / ncols))
    W = ncols * cell_size + (ncols + 1) * pad
    H = nrows * cell_size + (nrows + 1) * pad
    canvas = Image.new("RGB", (W, H), (30, 30, 30))
    x = y = pad
    for i, im in enumerate(imgs):
        canvas.paste(im, (x, y))
        x += cell_size + pad
        if (i + 1) % ncols == 0:
            x = pad
            y += cell_size + pad
    return canvas

# --- Build small spot-check montages per split/class ---
rng = np.random.RandomState(123)
for split in ["train", "val", "test"]:
    imgs_to_show = []
    for cls in ["Parasitized", "Uninfected"]:
        sub = df[(df["split"] == split) & (df["class"] == cls)]
        if len(sub) == 0:
            continue
        sample = sub.sample(min(6, len(sub)), random_state=rng)
        for _, r in sample.iterrows():
            im = preprocess_image(Path(r["path"]), size=128, ring=2)
            imgs_to_show.append(im)
    if imgs_to_show:
        m = montage_grid(imgs_to_show, ncols=3, cell_size=128, pad=6)
        out_file = RESULTS / f"montage_split_{split}.png"
        m.save(out_file)
        print(f"Saved montage → {out_file}")

# Step 10 — Photometric Statistics on the Training Split

This cell quantifies basic intensity properties of the **training** images after the canonical preprocessing (pad-to-square using border-mode colour; bicubic resize to 128×128).

- **Scope.** Loads training paths from `splits.csv`; optionally stratified-samples up to `N_MAX=8000`. Each image is padded using the modal border RGB and resized to 128×128 before measurement.

- **Per-image features (0–255 scale).**  
  - Channel means and standard deviations: `rgb_mean_{r,g,b}`, `rgb_std_{r,g,b}`.  
  - Grayscale proxies: `gray_mean` and `gray_std` computed as the per-pixel average over RGB.  
  - Global extrema: `vmin`, `vmax`.

- **Outputs.**  
  - `results/train_photometric_stats.csv`: one row per image with all features.  
  - `results/train_photometric_summary.csv`: class-wise aggregates (means and p10/p90 for gray metrics; means for RGB means).  
  - Histograms saved to:
    - `hist_gray_mean_by_class.png` — grayscale mean by class,  
    - `hist_gray_std_by_class.png` — grayscale std by class,  
    - `hist_rgb_channel_means.png` — R/G/B means across the whole training set.

- **Use in the project.**  
  Guides normalisation choices; validates augmentation ranges for brightness/contrast; checks for class-specific photometric shifts.

- **Notes.**  
  Statistics are computed **after** padding and resizing; near-black borders may slightly affect means in very small crops. Random seed fixes the sampling for repeatability.

In [None]:
# Step 10 — Photometric statistics on the training split (after canonical preprocessing)

import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
from collections import Counter

RESULTS = RESULTS_DIR
splits_df = pd.read_csv(RESULTS / "splits.csv")

# --- Canonical pad→resize (reuse in stats to match training data exactly) ---
NEAR_BLACK_THR = 10
MIN_MODE_FRAC  = 0.60

def preprocess_image(path: Path, size: int = 128, ring: int = 2) -> Image.Image:
    with Image.open(path) as im:
        return pad_resize_canonical(im, target_size=size, ring=ring)

# --- Configuration ---
N_MAX = 8000   # set to None to process all train images
RNG = np.random.RandomState(7)

train_df = splits_df[splits_df["split"] == "train"].copy()

# Stratified sampling to N_MAX
if N_MAX is not None and len(train_df) > N_MAX:
    per_class = int(N_MAX // 2)
    parts = []
    for cls in ["Parasitized", "Uninfected"]:
        sub = train_df[train_df["class"] == cls]
        parts.append(sub.sample(min(per_class, len(sub)), random_state=RNG))
    train_df = pd.concat(parts, axis=0).sample(frac=1.0, random_state=RNG).reset_index(drop=True)

# Per-image statistics after canonical preprocessing
records = []
for _, r in train_df.iterrows():
    im = preprocess_image(Path(r["path"]), size=128, ring=2)
    arr = np.asarray(im).astype(np.float32)  # [H,W,3], 0..255
    ch = arr.reshape(-1, 3)
    gray = arr.mean(axis=2)
    records.append({
        "filename": r["filename"],
        "class": r["class"],
        "rgb_mean_r": float(ch[:,0].mean()),
        "rgb_mean_g": float(ch[:,1].mean()),
        "rgb_mean_b": float(ch[:,2].mean()),
        "rgb_std_r":  float(ch[:,0].std()),
        "rgb_std_g":  float(ch[:,1].std()),
        "rgb_std_b":  float(ch[:,2].std()),
        "gray_mean":  float(gray.mean()),
        "gray_std":   float(gray.std()),
        "vmin":       float(arr.min()),
        "vmax":       float(arr.max())
    })

stats_df = pd.DataFrame(records)
stats_csv = RESULTS / "train_photometric_stats.csv"
stats_df.to_csv(stats_csv, index=False)
print(f"Saved per-image stats → {stats_csv}")

# Summary per class
def pct(s, q): return float(np.percentile(s, q))

summary = (
    stats_df
    .groupby("class")
    .agg(gray_mean_mean=("gray_mean","mean"),
         gray_mean_p10 =("gray_mean", lambda s: pct(s,10)),
         gray_mean_p90 =("gray_mean", lambda s: pct(s,90)),
         gray_std_mean =("gray_std","mean"),
         gray_std_p10  =("gray_std",  lambda s: pct(s,10)),
         gray_std_p90  =("gray_std",  lambda s: pct(s,90)),
         r_mean=("rgb_mean_r","mean"),
         g_mean=("rgb_mean_g","mean"),
         b_mean=("rgb_mean_b","mean"))
    .round(2)
)
summary_csv = RESULTS / "train_photometric_summary.csv"
summary.to_csv(summary_csv)
print(f"Saved summary → {summary_csv}\n")
print(summary)

# Histograms — grayscale mean and std by class
plt.figure(figsize=(6,4))
for cls in ["Parasitized", "Uninfected"]:
    plt.hist(stats_df.loc[stats_df["class"]==cls, "gray_mean"], bins=40, alpha=0.6, label=cls)
plt.xlabel("Grayscale mean (0–255)"); plt.ylabel("Count"); plt.title("Train — grayscale mean by class")
plt.legend(); plt.tight_layout()
out = RESULTS / "hist_gray_mean_by_class.png"
plt.savefig(out, dpi=150); plt.close()
print(f"Saved {out}")

plt.figure(figsize=(6,4))
for cls in ["Parasitized", "Uninfected"]:
    plt.hist(stats_df.loc[stats_df["class"]==cls, "gray_std"], bins=40, alpha=0.6, label=cls)
plt.xlabel("Grayscale std (contrast proxy)"); plt.ylabel("Count"); plt.title("Train — grayscale std by class")
plt.legend(); plt.tight_layout()
out = RESULTS / "hist_gray_std_by_class.png"
plt.savefig(out, dpi=150); plt.close()
print(f"Saved {out}")

# Histograms — per-channel means (both classes together)
plt.figure(figsize=(6,4))
plt.hist(stats_df["rgb_mean_r"], bins=40, alpha=0.6, label="R")
plt.hist(stats_df["rgb_mean_g"], bins=40, alpha=0.6, label="G")
plt.hist(stats_df["rgb_mean_b"], bins=40, alpha=0.6, label="B")
plt.xlabel("Channel mean (0–255)"); plt.ylabel("Count"); plt.title("Train — channel means (R,G,B)")
plt.legend(); plt.tight_layout()
out = RESULTS / "hist_rgb_channel_means.png"
plt.savefig(out, dpi=150); plt.close()
print(f"Saved {out}")

### Step 10 extension — split photometrics

Run the same photometric analysis on **val** and **test** after the canonical preprocessing; save `{split}_photometric_stats.csv`, `{split}_photometric_summary.csv`, and three histograms per split (grey mean, grey std, RGB means).

In [None]:
# Step 10 — extension: Photometric statistics on the val and test splits

for SPLIT, N_MAX_SPLIT in [("val", 8000), ("test", None)]:
    sub_df = splits_df[splits_df["split"] == SPLIT].copy()

    # Stratified sampling to N_MAX_SPLIT
    if N_MAX_SPLIT is not None and len(sub_df) > N_MAX_SPLIT:
        per_class = int(N_MAX_SPLIT // 2)
        parts = []
        for cls in ["Parasitized", "Uninfected"]:
            g = sub_df[sub_df["class"] == cls]
            parts.append(g.sample(min(per_class, len(g)), random_state=RNG))
        sub_df = pd.concat(parts, axis=0).sample(frac=1.0, random_state=RNG).reset_index(drop=True)

    # Per-image statistics after canonical preprocessing
    records = []
    for _, r in sub_df.iterrows():
        im = preprocess_image(Path(r["path"]), size=128, ring=2)
        arr = np.asarray(im).astype(np.float32)  # [H,W,3], 0..255
        ch = arr.reshape(-1, 3)
        gray = arr.mean(axis=2)
        records.append({
            "filename": r["filename"],
            "class": r["class"],
            "rgb_mean_r": float(ch[:,0].mean()),
            "rgb_mean_g": float(ch[:,1].mean()),
            "rgb_mean_b": float(ch[:,2].mean()),
            "rgb_std_r":  float(ch[:,0].std()),
            "rgb_std_g":  float(ch[:,1].std()),
            "rgb_std_b":  float(ch[:,2].std()),
            "gray_mean":  float(gray.mean()),
            "gray_std":   float(gray.std()),
            "vmin":       float(arr.min()),
            "vmax":       float(arr.max())
        })

    stats_df = pd.DataFrame(records)
    stats_csv = RESULTS / f"{SPLIT}_photometric_stats.csv"
    stats_df.to_csv(stats_csv, index=False)
    print(f"[{SPLIT}] Saved per-image stats → {stats_csv}")

    # Summary per class
    summary = (
        stats_df
        .groupby("class")
        .agg(gray_mean_mean=("gray_mean","mean"),
             gray_mean_p10 =("gray_mean", lambda s: pct(s,10)),
             gray_mean_p90 =("gray_mean", lambda s: pct(s,90)),
             gray_std_mean =("gray_std","mean"),
             gray_std_p10  =("gray_std",  lambda s: pct(s,10)),
             gray_std_p90  =("gray_std",  lambda s: pct(s,90)),
             r_mean=("rgb_mean_r","mean"),
             g_mean=("rgb_mean_g","mean"),
             b_mean=("rgb_mean_b","mean"))
        .round(2)
    )
    summary_csv = RESULTS / f"{SPLIT}_photometric_summary.csv"
    summary.to_csv(summary_csv)
    print(f"[{SPLIT}] Saved summary → {summary_csv}\n")
    print(summary)

    # Histograms — grayscale mean and std by class
    plt.figure(figsize=(6,4))
    for cls in ["Parasitized", "Uninfected"]:
        plt.hist(stats_df.loc[stats_df["class"]==cls, "gray_mean"], bins=40, alpha=0.6, label=cls)
    plt.xlabel("Grayscale mean (0–255)"); plt.ylabel("Count"); plt.title(f"{SPLIT.capitalize()} — grayscale mean by class")
    plt.legend(); plt.tight_layout()
    out = RESULTS / f"{SPLIT}_hist_gray_mean_by_class.png"
    plt.savefig(out, dpi=150); plt.close()
    print(f"[{SPLIT}] Saved {out}")

    plt.figure(figsize=(6,4))
    for cls in ["Parasitized", "Uninfected"]:
        plt.hist(stats_df.loc[stats_df["class"]==cls, "gray_std"], bins=40, alpha=0.6, label=cls)
    plt.xlabel("Grayscale std (contrast proxy)"); plt.ylabel("Count"); plt.title(f"{SPLIT.capitalize()} — grayscale std by class")
    plt.legend(); plt.tight_layout()
    out = RESULTS / f"{SPLIT}_hist_gray_std_by_class.png"
    plt.savefig(out, dpi=150); plt.close()
    print(f"[{SPLIT}] Saved {out}")

    # Histograms — per-channel means (both classes together)
    plt.figure(figsize=(6,4))
    plt.hist(stats_df["rgb_mean_r"], bins=40, alpha=0.6, label="R")
    plt.hist(stats_df["rgb_mean_g"], bins=40, alpha=0.6, label="G")
    plt.hist(stats_df["rgb_mean_b"], bins=40, alpha=0.6, label="B")
    plt.xlabel("Channel mean (0–255)"); plt.ylabel("Count"); plt.title(f"{SPLIT.capitalize()} — channel means (R,G,B)")
    plt.legend(); plt.tight_layout()
    out = RESULTS / f"{SPLIT}_hist_rgb_channel_means.png"
    plt.savefig(out, dpi=150); plt.close()
    print(f"[{SPLIT}] Saved {out}")


# Step 11 — PyTorch Datasets and Dataloaders

This cell defines the training input pipeline and produces a quick visual check.

- **Preprocessing.**  
  `border_ring_mode_rgb()` estimates the modal border colour from a thin outer ring; `PadResizeBicubic` pads each image to a square with that colour, then resizes to 128×128 using bicubic interpolation.

- **Dataset class.**  
  `MalariaDataset` reads from a split-specific `DataFrame`, applies the pad+resize transform, and (optionally) mild augmentation: horizontal/vertical flips, ±10° rotation, and brightness/contrast jitter in ~[0.8, 1.2].  
  Labels are mapped to integers: `Parasitized→1`, `Uninfected→0`.  
  Images are returned as tensors in `[0,1]` (no normalisation here).

- **Splits and loaders.**  
  Builds `train/val/test` datasets from `splits.csv`, then `DataLoader`s with `batch_size=64`.  
  Training uses `shuffle=True` and `drop_last=True` (stable batch shapes); validation/test are sequential.

- **Preview.**  
  Grabs one training batch, makes a 32-image grid, and saves it to `results/loader_preview_train.png`. The preview reflects pad+resize and augmentations; tensors are unnormalised.

This establishes a consistent, reproducible input path for later model training.

In [None]:
# Step 11 — PyTorch datasets & dataloaders (canonical preprocessing; unnormalised tensors)

import os
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.utils as vutils
from PIL import Image, ImageOps
import numpy as np
import pandas as pd
from pathlib import Path
from collections import Counter
import matplotlib.pyplot as plt

RESULTS = RESULTS_DIR
splits_df = pd.read_csv(RESULTS / "splits.csv")

# --- Canonical pad→resize used by the dataset ---
NEAR_BLACK_THR = 10
MIN_MODE_FRAC  = 0.60

def pad_resize_canonical(img: Image.Image, target_size: int = 128, ring: int = 2) -> Image.Image:
    img = img.convert("RGB")
    arr = np.array(img)
    mode_rgb, frac = _ring_mode_from_array(arr, ring=ring)
    near_black = all(c <= NEAR_BLACK_THR for c in mode_rgb)
    use_constant = (near_black and frac >= MIN_MODE_FRAC)

    w, h = img.size
    if w != h:
        side = max(w, h)
        pl = (side - w) // 2; pr = side - w - pl
        pt = (side - h) // 2; pb = side - h - pt
        if use_constant:
            padded = ImageOps.expand(img, border=(pl, pt, pr, pb), fill=mode_rgb)
        else:
            pad_width = ((pt, pb), (pl, pr), (0, 0))
            padded = Image.fromarray(np.pad(arr, pad_width, mode="reflect"))
    else:
        padded = img

    return padded.resize((target_size, target_size), resample=Image.BICUBIC)

class PadResizeCanonical:
    """PIL → canonical pad→resize to size×size; used inside the dataset."""
    def __init__(self, size: int = 128, ring: int = 2):
        self.size = size
        self.ring = ring
    def __call__(self, img: Image.Image) -> Image.Image:
        return pad_resize_canonical(img, target_size=self.size, ring=self.ring)

class MalariaDataset(Dataset):
    """Single-cell crops with light geometric/photometric augments for training."""
    def __init__(self, frame: pd.DataFrame, split: str, augment: bool):
        self.df = frame.reset_index(drop=True)
        self.split = split
        self.augment = augment
        self.pad_resize = PadResizeCanonical(size=128, ring=2)

        # Augmentations (keep mild to preserve morphology)
        self.pil_aug = T.Compose([
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),
            T.RandomRotation(degrees=10, fill=0),
            T.ColorJitter(brightness=0.2, contrast=0.2)  # ≈ scales [0.8, 1.2]
        ])

        self.to_tensor = T.ToTensor()  # maps to [0,1] float32 (C,H,W)
        self.label_map = {"Parasitized": 1, "Uninfected": 0}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        with Image.open(row["path"]) as img:
            img = pad_resize_canonical(img, target_size=128, ring=2)
        if self.augment:
            img = self.pil_aug(img)
        x = self.to_tensor(img)
        y = torch.tensor(self.label_map[row["class"]], dtype=torch.long)
        return x, y

# Build splits
train_df = splits_df[splits_df["split"]=="train"]
val_df   = splits_df[splits_df["split"]=="val"]
test_df  = splits_df[splits_df["split"]=="test"]

train_set = MalariaDataset(train_df, split="train", augment=True)
val_set   = MalariaDataset(val_df,   split="val",   augment=False)
test_set  = MalariaDataset(test_df,  split="test",  augment=False)

BATCH_SIZE  = 64
NUM_WORKERS = 0
PIN_MEMORY  = False

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=True)
val_loader   = DataLoader(val_set,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
test_loader  = DataLoader(test_set,  batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

print(f"Train/Val/Test sizes: {len(train_set)} / {len(val_set)} / {len(test_set)}")

# Preview a batch from the train loader and save a grid
torch.manual_seed(123)
xb, yb = next(iter(train_loader))  # xb: [B,3,128,128]
grid = vutils.make_grid(xb[:32], nrow=8, padding=2)
out_path = RESULTS / "loader_preview_train.png"
plt.figure(figsize=(10,5))
plt.axis("off")
plt.title("Train preview (after canonical pad+resize; unnormalised)")
plt.imshow(np.transpose(grid.numpy(), (1,2,0)))
plt.savefig(out_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Saved preview grid → {out_path}")

# Step 12 — Add ImageNet Normalisation and Rebuild Dataloaders

This cell introduces ImageNet normalisation to match the MobileNetV2 pretraining stats and rebuilds the loaders.

- **Normalisation.**  
  Uses `IMAGENET_MEAN=[0.485, 0.456, 0.406]` and `IMAGENET_STD=[0.229, 0.224, 0.225]`. Applied after `ToTensor()` in the dataset; keeps transforms consistent with the pretrained backbone.

- **Preprocessing and augments.**  
  Same pad-to-square + bicubic resize as before; light PIL-level augmentation on the training set only (H/V flip, ±10° rotation with `fill=0`, brightness and contrast jitter ~[0.8, 1.2]). Validation and test are deterministic.

- **Dataset class change.**  
  `MalariaDataset(..., normalise=True)` adds `T.Normalize(mean, std)` to the pipeline; set `normalise=False` to skip it if needed.

- **Rebuilt loaders.**  
  Constructs `train/val/test` dataloaders with `batch_size=64`; `shuffle=True` and `drop_last=True` for training; sequential evaluation for val/test.

- **Preview and display note.**  
  Saves a normalised-batch preview to `results/loader_preview_train_normalised.png`. The code denormalises the grid before plotting; otherwise images would appear low contrast due to zero-mean scaling.

This step aligns input distributions with the pretrained model; it also preserves the earlier augmentation and padding behaviour.

In [None]:
# Step 12 — Add ImageNet normalisation and rebuild dataloaders (canonical preprocessing)

import os
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.utils as vutils
from PIL import Image, ImageOps
import numpy as np
import pandas as pd
from pathlib import Path
from collections import Counter
import matplotlib.pyplot as plt

RESULTS = RESULTS_DIR
splits_df = pd.read_csv(RESULTS / "splits.csv")

# ImageNet normalisation constants
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

# --- Canonical pad→resize used by the dataset ---
NEAR_BLACK_THR = 10
MIN_MODE_FRAC  = 0.60

def _ring_mode_from_array(arr: np.ndarray, ring: int = 2, max_pixels: int = 20000):
    h, w, _ = arr.shape
    top = arr[0:ring]; bottom = arr[h-ring:h]
    left = arr[:, 0:ring]; right = arr[:, w-ring:w]
    ring_pixels = np.concatenate([top.reshape(-1,3), bottom.reshape(-1,3),
                                  left.reshape(-1,3), right.reshape(-1,3)], axis=0)
    if ring_pixels.shape[0] > max_pixels:
        idx = np.random.choice(ring_pixels.shape[0], size=max_pixels, replace=False)
        ring_pixels = ring_pixels[idx]
    cnt = Counter([tuple(int(x) for x in v) for v in ring_pixels])
    mode_rgb, n = cnt.most_common(1)[0]
    return mode_rgb, n / max(1, ring_pixels.shape[0])

class PadResizeCanonical:
    def __init__(self, size: int = 128, ring: int = 2):
        self.size = size
        self.ring = ring
    def __call__(self, img: Image.Image) -> Image.Image:
        return pad_resize_canonical(img, target_size=self.size, ring=self.ring)

class MalariaDataset(Dataset):
    """Single-cell crops with optional augmentation and ImageNet normalisation."""
    def __init__(self, frame: pd.DataFrame, split: str, augment: bool, normalise: bool = True):
        self.df = frame.reset_index(drop=True)
        self.split = split
        self.augment = augment
        self.pad_resize = PadResizeCanonical(size=128, ring=2)

        self.pil_aug = T.Compose([
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),
            T.RandomRotation(degrees=10, fill=0),
            T.ColorJitter(brightness=0.2, contrast=0.2)
        ])

        self.to_tensor = T.ToTensor()
        self.normalise = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) if normalise else None
        self.label_map = {"Parasitized": 1, "Uninfected": 0}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        with Image.open(row["path"]) as img:
            img = pad_resize_canonical(img, target_size=128, ring=2)
        if self.augment:
            img = self.pil_aug(img)
        x = self.to_tensor(img)
        if self.normalise is not None:
            x = self.normalise(x)
        y = torch.tensor(self.label_map[row["class"]], dtype=torch.long)
        return x, y

# Rebuild loaders
train_df = splits_df[splits_df["split"]=="train"]
val_df   = splits_df[splits_df["split"]=="val"]
test_df  = splits_df[splits_df["split"]=="test"]

train_set = MalariaDataset(train_df, split="train", augment=True,  normalise=True)
val_set   = MalariaDataset(val_df,   split="val",   augment=False, normalise=True)
test_set  = MalariaDataset(test_df,  split="test",  augment=False, normalise=True)

BATCH_SIZE  = 64
NUM_WORKERS = 0
PIN_MEMORY  = False

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=True)
val_loader   = DataLoader(val_set,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
test_loader  = DataLoader(test_set,  batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

print(f"Rebuilt loaders — sizes: {len(train_set)} / {len(val_set)} / {len(test_set)}")

# Preview (denormalised for display)
xb, yb = next(iter(train_loader))
grid = vutils.make_grid(xb[:32].cpu(), nrow=8, padding=2)

mean = torch.tensor(IMAGENET_MEAN).view(3,1,1)
std  = torch.tensor(IMAGENET_STD).view(3,1,1)
grid_disp = grid * std + mean

plt.figure(figsize=(10,5)); plt.axis("off")
plt.title("Train preview (canonical pad+resize; ImageNet-normalised tensors shown denormalised)")
plt.imshow(np.transpose(grid_disp.numpy(), (1,2,0)))
out_path = RESULTS / "loader_preview_train_normalised.png"
plt.savefig(out_path, dpi=150, bbox_inches='tight'); plt.close()
print(f"Saved normalised preview grid → {out_path}")

# Step 13 — Model Definition, Loss, Optimiser, and Quick Eval

This cell instantiates the baseline classifier, sets up training primitives, and verifies the pipeline with a single optimisation step.

- **Backbone and head.**  
  Builds MobileNetV2 with ImageNet weights; replaces the classifier with a single-linear layer producing one **logit** (binary task). Moves the model to CPU/GPU; freezes all feature layers; keeps only the classifier trainable.

- **Loss with label smoothing.**  
  Uses `BCEWithLogitsLoss` on logits; applies manual label smoothing (`ε=0.05`) so targets become `t*(1−ε)+0.5ε`. This can stabilise early training and reduce overconfident predictions.

- **Optimiser.**  
  Adam with `lr=1e-3`, restricted to trainable parameters (the head).

- **Evaluation helper.**  
  `evaluate_auc()` switches to eval mode; gathers sigmoid probabilities; returns ROC–AUC and accuracy at a fixed 0.5 threshold.

- **Sanity pass.**  
  Runs a single forward/backward/step on one training batch to confirm the end-to-end path; prints the loss.

- **Baseline metrics and checkpoint.**  
  Reports pre-training validation metrics (AUC often ~0.5–0.6 at this stage) and saves the initial model state to `results/mobilenetv2_baseline_init.pt`.

This prepares a clean head-only warm-up in the next step while keeping the pretrained feature extractor frozen.

In [None]:
# Step 13 — Model, loss, optimiser, and eval utils
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from sklearn.metrics import roc_auc_score, accuracy_score
import numpy as np
from pathlib import Path
import time

RESULTS = RESULTS_DIR

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# --- Build MobileNetV2 and replace classifier ---
def build_mobilenetv2_single_logit(pretrained: bool = True):
    try:
        # torchvision >= 0.13 style
        weights = models.MobileNet_V2_Weights.DEFAULT if pretrained else None
        net = models.mobilenet_v2(weights=weights)
    except Exception:
        # Fallback (older API or offline)
        net = models.mobilenet_v2(pretrained=pretrained)
    in_features = net.classifier[1].in_features
    net.classifier[1] = nn.Linear(in_features, 1)  # single logit
    return net

model = build_mobilenetv2_single_logit(pretrained=True)
model = model.to(device)

# Freeze backbone (all except classifier)
for name, p in model.features.named_parameters():
    p.requires_grad = False
for p in model.classifier.parameters():
    p.requires_grad = True

# --- Loss with manual label smoothing for binary ---
LABEL_SMOOTH = 0.05
bce_logits = nn.BCEWithLogitsLoss(reduction='mean')

def bce_with_logits_smooth(logits, targets, eps=LABEL_SMOOTH):
    # targets: LongTensor {0,1} -> float
    t = targets.float()
    t_smooth = t*(1.0 - eps) + 0.5*eps
    return bce_logits(logits.view(-1), t_smooth)

# --- Optimiser ---
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

# --- Metrics ---
@torch.no_grad()
def evaluate_auc(loader, model):
    model.eval()
    probs = []
    ys = []
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        logits = model(xb).view(-1)
        p = torch.sigmoid(logits)
        probs.append(p.cpu().numpy())
        ys.append(yb.cpu().numpy())
    probs = np.concatenate(probs)
    ys = np.concatenate(ys)
    try:
        auc = roc_auc_score(ys, probs)
    except ValueError:
        auc = float('nan')
    acc = accuracy_score(ys, (probs >= 0.5).astype(np.int64))
    return {"auc": auc, "acc@0.5": acc}

# --- Sanity check: one mini-step ---
model.train()
xb, yb = next(iter(train_loader))
xb = xb.to(device); yb = yb.to(device)
logits = model(xb)
loss = bce_with_logits_smooth(logits, yb)
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"Sanity mini-step OK — loss={loss.item():.4f}")

# Pre-training validation AUC (expect ~0.5–0.6)
metrics = evaluate_auc(val_loader, model)
print("Pre-train val metrics:", metrics)

# Save initial (head-warm) weights
torch.save(model.state_dict(), RESULTS / "mobilenetv2_baseline_init.pt")
print("Saved initial model weights.")

# Step 14 — Head Warm-up (frozen backbone) with Early Stopping

This cell trains only the classifier head while the MobileNetV2 feature extractor stays frozen.

- **Device & seeding.**  
  Selects CUDA if available; fixes RNG seeds for Python, NumPy, and Torch; sets deterministic CuDNN (may slow training slightly).

- **Preconditions.**  
  Verifies that the model, dataloaders, and loss function were created in earlier steps.

- **Trainable parameters.**  
  Freezes all `model.features` parameters; leaves `model.classifier` trainable.

- **Optimiser.**  
  Adam with `lr=1e-4` and `weight_decay=1e-4`; scoped to trainable parameters only.

- **Mixed precision.**  
  Enables autocast + GradScaler on CUDA; reduces memory use; speeds up math; disabled on CPU.

- **Epoch routine (`train_one_epoch`).**  
  Loops over batches; runs forward under autocast; computes smoothed BCE-with-logits; scales and steps the optimiser; returns mean loss.

- **Validation metrics (`evaluate_auc_acc`).**  
  Switches to eval mode; collects sigmoid probabilities; reports ROC–AUC and accuracy at 0.5.

- **Early stopping on AUC.**  
  Trains up to 50 epochs; tracks the best validation AUC; stops after 5 epochs without improvement; keeps the best state dict in memory.

- **Outputs & artefacts.**  
  Prints timing and best val AUC; saves best head-only weights to  
  `results/mobilenetv2_head_warmup.pt`; writes a CSV log `results/trainlog_head.csv`; exports two PNGs with training loss and validation AUC curves:
  - `results/curves_head_warmup_loss.png`  
  - `results/curves_head_warmup_auc.png`

This produces a warmed-up classifier head that is ready for partial fine-tuning of the backbone.

In [None]:
# Step 14 — Warm-up the classifier head (frozen backbone) with early stopping

import os, time, math, random
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score, accuracy_score
import matplotlib.pyplot as plt

# ---- Paths & device ----
RESULTS = RESULTS_DIR
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---- Seed (again, safe) ----
def seed_everything(seed=42):
    import numpy as np, random, torch
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_everything(42)

# ---- Expect these from earlier steps ----
assert 'model' in globals(), "Model not found. Please run Step 13 first."
assert 'train_loader' in globals() and 'val_loader' in globals(), "Loaders not found. Please run Steps 12–13."
assert 'bce_with_logits_smooth' in globals(), "Loss fn not found. Please run Step 13."

# Ensure backbone frozen, head trainable
for p in model.features.parameters():
    p.requires_grad = False
for p in model.classifier.parameters():
    p.requires_grad = True

# Fresh optimiser for head-only training
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                       lr=1e-4, weight_decay=1e-4)

# Mixed precision if CUDA available
use_amp = (device.type == 'cuda')
scaler = torch.amp.GradScaler('cuda', enabled=use_amp)

def train_one_epoch(model, loader, optimizer, scaler, device):
    model.train()
    running = 0.0
    n = 0
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda', enabled=use_amp):
            logits = model(xb).view(-1)
            loss = bce_with_logits_smooth(logits, yb)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        bs = xb.size(0)
        running += loss.item() * bs
        n += bs
    return running / max(1, n)

@torch.no_grad()
def evaluate_auc_acc(model, loader, device):
    model.eval()
    probs = []
    ys = []
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        logits = model(xb).view(-1)
        p = torch.sigmoid(logits)
        probs.append(p.cpu().numpy())
        ys.append(yb.cpu().numpy())
    probs = np.concatenate(probs)
    ys = np.concatenate(ys)
    auc = roc_auc_score(ys, probs) if (len(np.unique(ys)) > 1) else float('nan')
    acc = accuracy_score(ys, (probs >= 0.5).astype(np.int64))
    return auc, acc

# ---- Training loop with early stopping on val AUC ----
MAX_EPOCHS = 50
PATIENCE = 5

best_auc = -1.0
best_state = None
no_improve = 0

history = []

start = time.time()
for epoch in range(1, MAX_EPOCHS+1):
    t0 = time.time()
    train_loss = train_one_epoch(model, train_loader, optimizer, scaler, device)
    val_auc, val_acc = evaluate_auc_acc(model, val_loader, device)

    history.append({"epoch": epoch, "train_loss": train_loss,
                    "val_auc": float(val_auc), "val_acc": float(val_acc)})
    print(f"[Head {epoch:02d}/{MAX_EPOCHS}] "
          f"train_loss={train_loss:.4f}  val_auc={val_auc:.4f}  val_acc@0.5={val_acc:.4f}  "
          f"({time.time()-t0:.1f}s)")

    # Early stopping on val AUC
    if val_auc > best_auc:
        best_auc = val_auc
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= PATIENCE:
            print(f"Early stopping triggered (no val AUC improvement in {PATIENCE} epochs).")
            break

elapsed = time.time() - start
print(f"Head warm-up finished in {elapsed/60:.1f} min. Best val AUC = {best_auc:.4f}")

# ---- Save best checkpoint, log, and curves ----
if best_state is not None:
    ckpt_path = RESULTS / "mobilenetv2_head_warmup.pt"
    torch.save(best_state, ckpt_path)
    print(f"Saved best head-only weights → {ckpt_path}")

log_df = pd.DataFrame(history)
log_csv = RESULTS / "trainlog_head.csv"
log_df.to_csv(log_csv, index=False)
print(f"Saved training log → {log_csv}")

# Curves
plt.figure(figsize=(6,4))
plt.plot(log_df["epoch"], log_df["train_loss"], label="train loss")
plt.xlabel("epoch"); plt.ylabel("loss"); plt.title("Head warm-up — training loss"); plt.legend()
plt.tight_layout(); plt.savefig(RESULTS / "curves_head_warmup_loss.png", dpi=150); plt.close()

plt.figure(figsize=(6,4))
plt.plot(log_df["epoch"], log_df["val_auc"], label="val AUC")
plt.xlabel("epoch"); plt.ylabel("AUC"); plt.ylim(0.4, 1.0)
plt.title("Head warm-up — validation AUC"); plt.legend()
plt.tight_layout(); plt.savefig(RESULTS / "curves_head_warmup_auc.png", dpi=150); plt.close()

print("Saved curves →",
      RESULTS / "curves_head_warmup_loss.png",
      "and",
      RESULTS / "curves_head_warmup_auc.png")

# Step 15 — Fine-tune the last *N* MobileNetV2 blocks (+ head) with early stopping

This cell adapts the tail of MobileNetV2 while keeping earlier layers frozen, using mixed precision and differential learning rates.

- **Preconditions & device.**  
  Requires the model, dataloaders, and smoothed BCE loss from earlier steps. Selects CUDA if available and enables AMP on GPU.

- **Warm-up checkpoint load.**  
  Restores the best head-only weights from Step 14 (`results/mobilenetv2_head_warmup.pt`) to start fine-tuning from a stable classifier.

- **Unfreezing policy.**  
  Freezes all feature blocks, then unfreezes the last `N_TAIL=12` blocks (`features[start_idx:]]`) plus the classifier:
  - `L = len(model.features)`; `start_idx = max(0, L - N_TAIL)`.
  - Prints indices and counts of trainable parameters for transparency.

- **Optimiser (differential LR).**  
  AdamW with `weight_decay=1e-4`, using a lower LR for the backbone tail and a higher LR for the head:
  - Tail (unfrozen feature blocks): `lr=1e-4`  
  - Head (classifier): `lr=5e-4`

- **Training loop (AMP).**  
  - `train_one_epoch_ft`: autocast + GradScaler; computes smoothed BCE-with-logits; aggregates mean loss.  
  - `eval_auc_acc`: collects calibrated probabilities (sigmoid) and computes ROC–AUC and accuracy@0.5.

- **Early stopping on validation AUC.**  
  Trains up to `MAX_EPOCHS=24`; stops after `PATIENCE=3` epochs without AUC improvement. The best-performing state dict is kept in memory.

- **Outputs & artefacts.**  
  - Best fine-tuned weights: `results/mobilenetv2_finetune_tail.pt`  
  - CSV training log: `results/trainlog_finetune_tail.csv`  
  - Curves:  
    - `results/curves_finetune_tail_loss.png` (training loss)  
    - `results/curves_finetune_tail_auc.png` (validation AUC)

- **What to adjust later.**  
  - `N_TAIL` to widen/narrow the adaptation window if validation performance or calibration needs change.  
  - Learning rates if convergence is unstable or too slow.

This produces a partially fine-tuned backbone consistent with the project plan: a lightweight model adapted to the dataset while limiting overfitting.

In [None]:
# Step 15 — Fine-tune the last N MobileNetV2 blocks (+ head) with early stopping on val AUC

import time, numpy as np, torch, torch.nn as nn, torch.optim as optim
from pathlib import Path
from sklearn.metrics import roc_auc_score, accuracy_score
import pandas as pd
import matplotlib.pyplot as plt

# --- Paths & device ---
RESULTS = RESULTS_DIR

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == 'cuda')

# --- Safety: expect these from earlier steps ---
assert 'model' in globals(), "Model not found. Please run Step 13 first."
assert 'train_loader' in globals() and 'val_loader' in globals(), "Loaders not found. Run Steps 11–12."
assert 'bce_with_logits_smooth' in globals(), "Loss fn not found. Run Step 13."

# --- Load best head-only checkpoint from Step 14 ---
ckpt_head = RESULTS / "mobilenetv2_head_warmup.pt"
assert ckpt_head.exists(), f"Missing {ckpt_head}; run Step 14 first."
model.load_state_dict(torch.load(ckpt_head, map_location="cpu"))
model = model.to(device)

# --- Freeze all features, then unfreeze the *tail* (last N blocks) + classifier ---
for p in model.features.parameters():
    p.requires_grad = False

N_TAIL = 12  # <-- try adjusting to a larger adaptation window
L = len(model.features)
start_idx = max(0, L - N_TAIL)
tail_indices = list(range(start_idx, L))
for i in tail_indices:
    # Some entries (e.g., final ConvBNReLU) are modules with parameters; this works for both IR blocks & final 1x1 conv
    for p in model.features[i].parameters():
        p.requires_grad = True

for p in model.classifier.parameters():
    p.requires_grad = True

# --- Introspection & transparency ---
def count_trainable(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

print(f"MobileNetV2 features length: {L}")
print(f"Unfreezing feature blocks (0-based): {tail_indices}")
print(f"Trainable params (tail): {count_trainable(model.features[start_idx:])}")
print(f"Trainable params (head): {count_trainable(model.classifier)}")

# --- Optimizer: smaller LR for tail, larger LR for head ---
backbone_params = [p for p in model.features.parameters() if p.requires_grad]
head_params     = list(model.classifier.parameters())
optimizer = optim.AdamW([
    {"params": backbone_params, "lr": 1e-4},
    {"params": head_params,     "lr": 5e-4},
], weight_decay=1e-4)

scaler = torch.amp.GradScaler('cuda', enabled=use_amp)

# --- Train / Eval helpers (AMP) ---
def train_one_epoch_ft(model, loader, optimizer, scaler, device):
    model.train()
    total, n = 0.0, 0
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True); yb = yb.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda', enabled=use_amp):
            logits = model(xb).view(-1)
            loss = bce_with_logits_smooth(logits, yb)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        bs = xb.size(0); total += loss.item() * bs; n += bs
    return total / max(1, n)

@torch.no_grad()
def eval_auc_acc(model, loader, device):
    model.eval()
    probs, ys = [], []
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True); yb = yb.to(device, non_blocking=True)
        p = torch.sigmoid(model(xb).view(-1))
        probs.append(p.cpu().numpy()); ys.append(yb.cpu().numpy())
    probs = np.concatenate(probs); ys = np.concatenate(ys)
    auc = roc_auc_score(ys, probs) if (len(np.unique(ys)) > 1) else float('nan')
    acc = accuracy_score(ys, (probs >= 0.5).astype(np.int64))
    return auc, acc

# --- Training loop with early stopping on val AUC ---
MAX_EPOCHS = 24
PATIENCE = 3
history = []

best_auc = -1.0
best_state = None
no_improve = 0
t_start = time.time()

for epoch in range(1, MAX_EPOCHS + 1):
    t0 = time.time()
    train_loss = train_one_epoch_ft(model, train_loader, optimizer, scaler, device)
    val_auc, val_acc = eval_auc_acc(model, val_loader, device)
    history.append({"epoch": epoch, "train_loss": float(train_loss),
                    "val_auc": float(val_auc), "val_acc": float(val_acc)})
    print(f"[FT {epoch:02d}/{MAX_EPOCHS}] train_loss={train_loss:.4f}  val_auc={val_auc:.4f}  "
          f"val_acc@0.5={val_acc:.4f}  ({time.time()-t0:.1f}s)")

    if val_auc > best_auc:
        best_auc = val_auc
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= PATIENCE:
            print(f"Early stopping (no val AUC improvement in {PATIENCE} epochs).")
            break

elapsed = time.time() - t_start
print(f"Fine-tune finished in {elapsed/60:.1f} min. Best val AUC = {best_auc:.4f}")

# --- Save best checkpoint, log, and curves ---
if best_state is not None:
    ckpt = RESULTS / "mobilenetv2_finetune_tail.pt"
    torch.save(best_state, ckpt)
    print(f"Saved fine-tuned weights → {ckpt}")

log_df = pd.DataFrame(history)
log_csv = RESULTS / "trainlog_finetune_tail.csv"
log_df.to_csv(log_csv, index=False)
print(f"Saved training log → {log_csv}")

plt.figure(figsize=(6,4))
plt.plot(log_df["epoch"], log_df["train_loss"], label="train loss")
plt.xlabel("epoch"); plt.ylabel("loss"); plt.title("Fine-tune (tail) — training loss")
plt.legend(); plt.tight_layout()
plt.savefig(RESULTS / "curves_finetune_tail_loss.png", dpi=150); plt.close()

plt.figure(figsize=(6,4))
plt.plot(log_df["epoch"], log_df["val_auc"], label="val AUC")
plt.xlabel("epoch"); plt.ylabel("AUC"); plt.ylim(0.9, 1.0)
plt.title("Fine-tune (tail) — validation AUC")
plt.legend(); plt.tight_layout()
plt.savefig(RESULTS / "curves_finetune_tail_auc.png", dpi=150); plt.close()

print("Saved curves →",
      RESULTS / "curves_finetune_tail_loss.png",
      "and",
      RESULTS / "curves_finetune_tail_auc.png")

# Step 16 — Calibration (temperature scaling), reliability, and operating point

This cell calibrates predicted probabilities on the **validation** split, evaluates calibration quality, draws a reliability diagram, and selects a clinically oriented decision threshold.

- **Model state & device.**  
  Ensures the fine-tuned MobileNetV2 weights (`mobilenetv2_finetune_tail.pt`) are loaded and the model is in `eval()` on CPU/GPU.

- **Logit collection (validation).**  
  `collect_logits_labels` runs the model over `val_loader` to gather raw logits and ground-truth labels; uncalibrated probabilities are `sigmoid(logits)`.

- **Temperature scaling (Platt-style for logits).**  
  - Defines `TemperatureScaler(logT)` so that calibrated logits = `logits / T` with `T = exp(logT) > 0`.  
  - Fits `T` by minimizing **negative log-likelihood** (binary cross-entropy) on validation using **LBFGS**.  
  - Produces calibrated probabilities `sigmoid(logits / T)`.  
  - Note: **AUC is unchanged** by temperature scaling (rank-preserving).

- **Calibration metrics (before/after).**  
  Computes on validation:
  - **NLL (log loss)** — cross-entropy on probabilities.  
  - **Brier score** — mean squared error of probabilities.  
  - **ECE (Expected Calibration Error)** — via 10 equal-width probability bins; also emits a per-bin table (`ece_bins_val.csv`).

- **Reliability diagram.**  
  Plots bin-mean predicted probability vs empirical positive rate for uncalibrated and calibrated models and saves to `reliability_diagram_val.png`.

- **Operating point selection (validation, calibrated).**  
  Scans candidate thresholds to **maximize sensitivity** subject to **specificity ≥ 0.95** (fallback selects the closest specificity, then best sensitivity).  
  Returns the chosen `threshold`, its `val_sensitivity`, and `val_specificity`.

- **Persistence of artefacts.**  
  - `temperature.txt` — fitted scalar `T`.  
  - `operating_point_val.json` — JSON payload with `temperature_T`, selected operating point (mode, target specificity, threshold, achieved sens/spec), reference calibration metrics, and file paths to artefacts.  
  - `ece_bins_val.csv`, `reliability_diagram_val.png`.

- **Why this matters.**  
  Temperature scaling improves **probability calibration** without affecting ranking (AUC). Selecting a threshold at high specificity supports conservative screening settings; the same `T` and threshold will be reused on the **test** set and in downstream analyses.

In [None]:
# Step 16 — Calibration (temperature scaling) + reliability diagram + operating point
import os, json, math, numpy as np, pandas as pd
import torch, torch.nn as nn
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.metrics import roc_auc_score, log_loss

# ---- Paths, device, and preconditions ----
RESULTS = RESULTS_DIR  # from earlier steps
device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")

assert 'model' in globals(), "Model not found – please run Steps 13–15."
assert 'val_loader' in globals(), "val_loader not found – please run Steps 11–12."

# If the current model isn't the fine-tuned one, (re)load the best FT checkpoint
ckpt_ft = RESULTS / "mobilenetv2_finetune_tail.pt"
if ckpt_ft.exists():
    try:
        state = torch.load(ckpt_ft, map_location="cpu", weights_only=True)
    except TypeError:  # older PyTorch
        state = torch.load(ckpt_ft, map_location="cpu")
    model.load_state_dict(state)
model = model.to(device).eval()

# ---- Collect validation logits & labels (uncalibrated) ----
@torch.no_grad()
def collect_logits_labels(loader, model, device):
    model.eval()
    logits_list, y_list = [], []
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        out = model(xb).view(-1)  # logits
        logits_list.append(out.detach().cpu())
        y_list.append(yb.detach().cpu())
    logits = torch.cat(logits_list).float()
    y      = torch.cat(y_list).long().numpy()
    return logits, y

val_logits, val_y = collect_logits_labels(val_loader, model, device)
val_p_uncal = torch.sigmoid(val_logits).numpy()

# ---- Temperature scaling (fit T on validation by NLL) ----
class TemperatureScaler(nn.Module):
    def __init__(self, init_T=1.0):
        super().__init__()
        # parameterized as logT for positivity
        self.logT = nn.Parameter(torch.tensor(math.log(init_T), dtype=torch.float32))
    def forward(self, logits):
        T = torch.exp(self.logT)
        return logits / T
    def T_value(self):
        return float(torch.exp(self.logT).detach().cpu().item())

def fit_temperature(logits, y, max_iter=100, lr=0.01):
    # logits: torch tensor on CPU, y: numpy array {0,1}
    y_t = torch.from_numpy(y.astype(np.float32))
    scaler = TemperatureScaler(init_T=1.0)
    scaler.train()
    opt = torch.optim.LBFGS(scaler.parameters(), lr=lr, max_iter=max_iter, line_search_fn="strong_wolfe")

    bce = nn.BCEWithLogitsLoss(reduction='mean')
    def closure():
        opt.zero_grad(set_to_none=True)
        logits_cal = scaler(logits)
        loss = bce(logits_cal.view(-1), y_t)
        loss.backward()
        return loss
    opt.step(closure)
    return scaler

scaler = fit_temperature(val_logits.clone(), val_y, max_iter=200, lr=0.01)
T_fitted = scaler.T_value()

with torch.no_grad():
    val_logits_cal = scaler(val_logits.clone()).view(-1)
val_p_cal = torch.sigmoid(val_logits_cal).numpy()

# ---- Metrics: AUC, NLL (log loss), Brier ----
def brier_score(probs, y):
    probs = np.clip(probs, 1e-6, 1-1e-6)
    return float(np.mean((probs - y)**2))

val_auc_ref = roc_auc_score(val_y, val_p_uncal)  # unchanged by temperature scaling
nll_uncal   = log_loss(val_y, np.clip(val_p_uncal, 1e-6, 1-1e-6))
nll_cal     = log_loss(val_y, np.clip(val_p_cal,   1e-6, 1-1e-6))
brier_uncal = brier_score(val_p_uncal, val_y)
brier_cal   = brier_score(val_p_cal,   val_y)

# ---- ECE & reliability (binary): prob vs empirical positive rate over p∈[0,1] ----
def bin_stats_prob_vs_posrate(probs, y, n_bins=10):
    probs = np.clip(probs.astype(np.float64), 1e-6, 1-1e-6)
    edges = np.linspace(0.0, 1.0, n_bins + 1)
    ids   = np.digitize(probs, edges[1:-1])  # 0..n_bins-1
    rows  = []
    for b in range(n_bins):
        m = (ids == b)
        n = int(np.sum(m))
        lo, hi = float(edges[b]), float(edges[b+1])
        if n == 0:
            rows.append(dict(bin=b, n=0, bin_lower=lo, bin_upper=hi,
                             avg_prob=np.nan, pos_rate=np.nan))
        else:
            rows.append(dict(
                bin=b, n=n, bin_lower=lo, bin_upper=hi,
                avg_prob=float(np.mean(probs[m])),
                pos_rate=float(np.mean(y[m]))
            ))
    return pd.DataFrame(rows)

def ece_binary_probrate(probs, y, n_bins=10):
    df = bin_stats_prob_vs_posrate(probs, y, n_bins=n_bins)
    n_tot = df["n"].sum()
    if n_tot == 0:
        return float("nan"), df
    gap = np.abs(df["pos_rate"] - df["avg_prob"])
    w   = df["n"] / n_tot
    ece = float(np.nansum(gap * w))
    return ece, df

N_BINS = 10
ece_uncal, df_uncal = ece_binary_probrate(val_p_uncal, val_y, n_bins=N_BINS)
ece_cal,   df_cal   = ece_binary_probrate(val_p_cal,   val_y, n_bins=N_BINS)

# Save bins table (uncal + cal)
bins_out = df_uncal.rename(columns={"avg_prob":"uncal_avg_prob","pos_rate":"uncal_pos_rate"})
bins_out[["cal_avg_prob","cal_pos_rate"]] = df_cal[["avg_prob","pos_rate"]]
bins_csv = RESULTS / "ece_bins_val.csv"
bins_out.to_csv(bins_csv, index=False)

# ---- Reliability diagram (probability vs positive rate) ----
def plot_reliability_binary(df_uncal, df_cal, out_path):
    plt.figure(figsize=(6,4))
    xs = np.linspace(0,1,101)
    plt.plot(xs, xs, linestyle="--", linewidth=1, label="perfectly calibrated")
    for label, df in [("uncalibrated", df_uncal), ("calibrated", df_cal)]:
        d = df.dropna(subset=["avg_prob","pos_rate"])
        plt.plot(d["avg_prob"], d["pos_rate"], marker="o", linewidth=1, label=label)
    plt.xlabel("predicted probability (bin mean)")
    plt.ylabel("empirical positive rate (bin mean)")
    plt.title("Reliability diagram — validation")
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close()

rel_png = RESULTS / "reliability_diagram_val.png"
plot_reliability_binary(df_uncal, df_cal, rel_png)

# ---- Operating point on validation: sensitivity at specificity ≥ 0.95 ----
def sens_spec_at_threshold(p, y, thr):
    y_pred = (p >= thr).astype(np.int64)
    tp = np.sum((y_pred==1) & (y==1))
    tn = np.sum((y_pred==0) & (y==0))
    fp = np.sum((y_pred==1) & (y==0))
    fn = np.sum((y_pred==0) & (y==1))
    sens = tp / (tp + fn) if (tp + fn) > 0 else np.nan
    spec = tn / (tn + fp) if (tn + fp) > 0 else np.nan
    return sens, spec

def pick_threshold_sens_at_spec(p, y, target_spec=0.95):
    # scan unique probabilities plus a small grid to be robust
    cand = np.unique(np.concatenate([p, np.linspace(0,1,2001)]))
    best = dict(threshold=None, sensitivity=-1.0, specificity=float("nan"))
    for thr in cand:
        sens, spec = sens_spec_at_threshold(p, y, thr)
        if np.isnan(spec): 
            continue
        if spec >= target_spec and sens > best["sensitivity"]:
            best = dict(threshold=float(thr), sensitivity=float(sens), specificity=float(spec))
    # If none meet the spec constraint, fall back to closest by spec then best sens
    if best["threshold"] is None:
        gaps = []
        for thr in cand:
            sens, spec = sens_spec_at_threshold(p, y, thr)
            if np.isnan(spec): 
                continue
            gaps.append((abs(spec - target_spec), -sens, thr, sens, spec))
        if gaps:
            gaps.sort()
            _, _, thr, sens, spec = gaps[0]
            best = dict(threshold=float(thr), sensitivity=float(sens), specificity=float(spec))
    return best

op = pick_threshold_sens_at_spec(val_p_cal, val_y, target_spec=0.95)

# ---- Persist temperature & print a compact summary ----
with open(RESULTS / "temperature.txt", "w") as f:
    f.write(f"{T_fitted:.6f}\n")

print(f"Fitted temperature T = {T_fitted:.4f}")
print(f"Validation AUC (reference): {val_auc_ref:.4f}")
print(f"NLL  (uncal → cal): {nll_uncal:.4f} → {nll_cal:.4f}")
print(f"Brier(uncal → cal): {brier_uncal:.4f} → {brier_cal:.4f}")
print(f"ECE  (uncal → cal): {ece_uncal:.4f} → {ece_cal:.4f}")
print(f"Reliability diagram saved → {rel_png}")
print(f"ECE bins CSV saved       → {bins_csv}")

print("\nOperating point (val, calibrated):")
print("  mode: sens_at_spec")
print("  target_specificity:", 0.95)
print(f"  threshold: {op['threshold']}")
print(f"  val_sensitivity: {op['sensitivity']}")
print(f"  val_specificity: {op['specificity']}")

# ---- Save calibrated operating point & threshold to JSON (robust to 'T' name clash) ----
from datetime import datetime
import json
from numbers import Number

def _to_float(x):
    """Best-effort conversion of tensors/ndarrays/scalars to plain float; return None if impossible."""
    try:
        if x is None:
            return None
        if isinstance(x, Number):
            return float(x)
        # torch tensor?
        if hasattr(x, "item") and callable(getattr(x, "item")):
            return float(x.item())
        # numpy scalar/array?
        try:
            import numpy as np
            if isinstance(x, (np.floating, np.integer)):
                return float(x)
            if isinstance(x, np.ndarray) and x.size == 1:
                return float(x.reshape(()))
        except Exception:
            pass
        return float(x)
    except Exception:
        return None

def _pick_first_numeric(names):
    """Return first variable among 'names' that resolves to a numeric float."""
    g = globals()
    l = locals()
    for nm in names:
        obj = l.get(nm, g.get(nm, None))
        val = _to_float(obj)
        if val is not None:
            return val
    return None

# 1) Temperature (try the safe alias first if you created it, then other likely names)
T_value = _pick_first_numeric(["T_fitted", "temperature_T", "T_value", "T_fit", "temp", "temperature", "T"])
if T_value is None:
    print("WARNING: Could not find a numeric temperature; defaulting to 1.0")
    T_value = 1.0

# 2) Operating-point fields (prefer dict 'op' if present)
op = globals().get("op", {})
op_mode = op.get("mode", globals().get("op_mode", "sens_at_spec"))

target_spec = _to_float(op.get("target_specificity"))
if target_spec is None:
    target_spec = _to_float(globals().get("target_specificity", 0.95))

thr = op.get("threshold", None)
if thr is None:
    thr = globals().get("op_threshold", globals().get("threshold", None))
thr = _to_float(thr)

val_sens = _to_float(op.get("val_sensitivity", globals().get("val_sensitivity", None)))
val_spec = _to_float(op.get("val_specificity", globals().get("val_specificity", None)))

# 3) Optional reference metrics (if you kept them around)
val_auc_ref = _to_float(globals().get("val_auc_ref", None))
nll_uncal   = _to_float(globals().get("nll_uncal", None))
nll_cal     = _to_float(globals().get("nll_cal", None))
brier_uncal = _to_float(globals().get("brier_uncal", None))
brier_cal   = _to_float(globals().get("brier_cal", None))
ece_uncal   = _to_float(globals().get("ece_uncal", None))
ece_cal     = _to_float(globals().get("ece_cal", None))

summary_json = {
    "temperature_T": round(T_value, 6),
    "operating_point": {
        "mode": op_mode,
        "target_specificity": round(target_spec, 6) if target_spec is not None else None,
        "threshold": round(thr, 6) if thr is not None else None,
        "val_sensitivity": round(val_sens, 6) if val_sens is not None else None,
        "val_specificity": round(val_spec, 6) if val_spec is not None else None,
    },
    "validation_reference": {
        "auc": val_auc_ref,
        "nll_uncal": nll_uncal, "nll_cal": nll_cal,
        "brier_uncal": brier_uncal, "brier_cal": brier_cal,
        "ece_uncal": ece_uncal, "ece_cal": ece_cal,
    },
    "artifacts": {
        "reliability_diagram": str(RESULTS_DIR / "reliability_diagram_val.png"),
        "ece_bins_csv": str(RESULTS_DIR / "ece_bins_val.csv"),
    },
    "timestamp": datetime.now().isoformat(timespec="seconds")
}

out_json = RESULTS_DIR / "operating_point_val.json"
with open(out_json, "w", encoding="utf-8") as f:
    json.dump(summary_json, f, indent=2)
print(f"Saved operating-point JSON → {out_json}")

# Step 17 — Final test evaluation (calibrated)

This cell evaluates the **final model on the test split**, applying the **temperature** and **decision threshold** obtained on validation, and exports a structured report and plots.

- **Inputs & preconditions**
  - Expects a MobileNetV2 model in memory and `test_loader` (from Steps 11–15).
  - Loads fine-tuned weights `mobilenetv2_finetune_tail.pt` and switches to `eval()` on the selected device (CPU/GPU).
  - Loads calibration artifacts from Step 16:
    - `temperature_T` (scalar for logit scaling).
    - Operating-point metadata from `operating_point_val.json`: `mode`, `target_specificity`, and **validation-derived threshold**.

- **Inference & calibration**
  - Collects **raw logits** and labels on the test set (`collect_logits_labels`).
  - Applies **temperature scaling**: `logits / T` (no re-fitting on test).
  - Converts both uncalibrated and calibrated logits to probabilities via `sigmoid`.

- **Discrimination (ranking)**
  - Computes **AUC** for both uncalibrated and calibrated probabilities (`roc_auc_score`).  
    *Note:* Temperature scaling does not change the ranking ideally; any difference reflects numerical or tie effects.

- **Calibration quality**
  - **NLL (log loss)**, **Brier score**, and **ECE (10-bin)** are computed for uncalibrated and calibrated probabilities.  
  - A **reliability diagram** is created for the test set by plotting bin mean predicted probability vs. empirical positive rate.

- **Operating point on test (fixed threshold)**
  - Uses the **validation-selected threshold** (high-specificity target) and applies it to **calibrated** test probabilities.
  - Reports confusion-matrix counts (TP, TN, FP, FN), **sensitivity**, and **specificity** at that threshold.

- **Plots & artifacts written**
  - `roc_test.png` — ROC curves (uncalibrated vs calibrated).
  - `reliability_diagram_test.png` — Test reliability diagram.
  - `ece_bins_test.csv` — Bin-level calibration table (uncalibrated + calibrated).
  - `test_metrics.json` — Structured summary with: dataset size, temperature, operating-point metrics, discrimination scores, calibration scores, and file paths to artifacts.  
    The console also prints a compact summary (AUCs, NLL/Brier/ECE, threshold, sensitivity/specificity, and confusion matrix).

- **Why this step**
  - Ensures **honest generalization** by reporting all final metrics on held-out test data, using **frozen calibration** and a **fixed clinical operating point** chosen on validation.
  - The JSON report and plots facilitate downstream analysis, reproducibility, and inclusion in results sections.

In [None]:
# Step 17 — Final test evaluation (calibrated)
import os, json, math, numpy as np, pandas as pd
from pathlib import Path
import torch, torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve, auc, log_loss

# ---- Paths & device ----
PROJECT_ROOT = Path(r"[path_placeholder]")
RESULTS_DIR  = PROJECT_ROOT / "results"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---- Preconditions: model, loaders, and weights from prior steps ----
assert 'model' in globals(), "Model not in memory — please run Steps 13–15."
assert 'test_loader' in globals(), "test_loader missing — please run Steps 11–12."

ckpt_ft = RESULTS_DIR / "mobilenetv2_finetune_tail.pt"
assert ckpt_ft.exists(), f"Missing {ckpt_ft} — run Step 15."

# Load best fine-tuned weights safely
try:
    state = torch.load(ckpt_ft, map_location="cpu", weights_only=True)
except TypeError:  # PyTorch < 2.4
    state = torch.load(ckpt_ft, map_location="cpu")
model.load_state_dict(state)
model = model.to(device).eval()

# ---- Load calibration artifacts (temperature + operating point) from Step 16 ----
op_json = RESULTS_DIR / "operating_point_val.json"
temp_txt = RESULTS_DIR / "temperature.txt"

if op_json.exists():
    with open(op_json, "r", encoding="utf-8") as f:
        op_payload = json.load(f)
    temperature_T = float(op_payload.get("temperature_T", 1.0))
    op_dict = op_payload.get("operating_point", {})
    op_mode = op_dict.get("mode", "sens_at_spec")
    target_specificity = float(op_dict.get("target_specificity", 0.95))
    op_threshold = float(op_dict.get("threshold"))
else:
    # Fallbacks if JSON is missing (shouldn't happen with Step 16 done)
    temperature_T = float(Path(temp_txt).read_text().strip()) if temp_txt.exists() else 1.0
    op_mode = "sens_at_spec"; target_specificity = 0.95
    raise RuntimeError("operating_point_val.json not found. Please re-run Step 16 to produce it.")

print(f"Loaded calibration: T={temperature_T:.4f}, mode={op_mode}, "
      f"target_spec={target_specificity:.2f}, threshold={op_threshold:.4f}")

# ---- Collect logits & labels on the TEST split ----
@torch.no_grad()
def collect_logits_labels(loader, model, device):
    model.eval()
    logits_list, y_list = [], []
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        out = model(xb).view(-1)  # logits
        logits_list.append(out.detach().cpu())
        y_list.append(yb.detach().cpu())
    logits = torch.cat(logits_list).float()
    y      = torch.cat(y_list).long().numpy()
    return logits, y

test_logits, test_y = collect_logits_labels(test_loader, model, device)

# ---- Apply temperature scaling ----
class TemperatureScaler(nn.Module):
    def __init__(self, T):
        super().__init__()
        # store as a detached tensor for clean, no-grad forward
        self.T = torch.tensor(float(T), dtype=torch.float32)
    def forward(self, logits):
        return logits / self.T

scaler = TemperatureScaler(temperature_T)
with torch.no_grad():
    test_logits_cal = scaler(test_logits.clone()).view(-1)

# ---- Convert to probabilities ----
test_p_uncal = torch.sigmoid(test_logits).numpy()
test_p_cal   = torch.sigmoid(test_logits_cal).numpy()

# ---- Metrics helpers ----
def brier_score(probs, y):
    probs = np.clip(probs, 1e-6, 1-1e-6)
    return float(np.mean((probs - y)**2))

def bin_stats_prob_vs_posrate(probs, y, n_bins=10):
    probs = np.clip(probs.astype(np.float64), 1e-6, 1-1e-6)
    edges = np.linspace(0.0, 1.0, n_bins + 1)
    ids   = np.digitize(probs, edges[1:-1])  # 0..n_bins-1
    rows  = []
    for b in range(n_bins):
        m = (ids == b)
        n = int(np.sum(m))
        lo, hi = float(edges[b]), float(edges[b+1])
        if n == 0:
            rows.append(dict(bin=b, n=0, bin_lower=lo, bin_upper=hi,
                             avg_prob=np.nan, pos_rate=np.nan))
        else:
            rows.append(dict(
                bin=b, n=n, bin_lower=lo, bin_upper=hi,
                avg_prob=float(np.mean(probs[m])),
                pos_rate=float(np.mean(y[m]))
            ))
    return pd.DataFrame(rows)

def ece_binary_probrate(probs, y, n_bins=10):
    df = bin_stats_prob_vs_posrate(probs, y, n_bins=n_bins)
    n_tot = df["n"].sum()
    if n_tot == 0:
        return float("nan"), df
    gap = np.abs(df["pos_rate"] - df["avg_prob"])
    w   = df["n"] / n_tot
    ece = float(np.nansum(gap * w))
    return ece, df

def sens_spec_counts(p, y, thr):
    y_hat = (p >= thr).astype(np.int64)
    tp = int(np.sum((y_hat==1) & (y==1)))
    tn = int(np.sum((y_hat==0) & (y==0)))
    fp = int(np.sum((y_hat==1) & (y==0)))
    fn = int(np.sum((y_hat==0) & (y==1)))
    sens = tp / (tp + fn) if (tp + fn) > 0 else float("nan")
    spec = tn / (tn + fp) if (tn + fp) > 0 else float("nan")
    return dict(tp=tp, tn=tn, fp=fp, fn=fn, sensitivity=float(sens), specificity=float(spec))

# ---- Discrimination (AUC) ----
auc_uncal = roc_auc_score(test_y, test_p_uncal)
auc_cal   = roc_auc_score(test_y, test_p_cal)

# ---- Calibration metrics on TEST (report calibrated; keep uncal for reference) ----
nll_uncal = log_loss(test_y, np.clip(test_p_uncal, 1e-6, 1-1e-6))
nll_cal   = log_loss(test_y, np.clip(test_p_cal,   1e-6, 1-1e-6))
brier_uncal = brier_score(test_p_uncal, test_y)
brier_cal   = brier_score(test_p_cal,   test_y)
ece_uncal, df_uncal = ece_binary_probrate(test_p_uncal, test_y, n_bins=10)
ece_cal,   df_cal   = ece_binary_probrate(test_p_cal,   test_y, n_bins=10)

# ---- Operating-point on TEST using *val-derived* threshold ----
cm = sens_spec_counts(test_p_cal, test_y, op_threshold)

# ---- ROC plot (test) ----
fpr_u, tpr_u, _ = roc_curve(test_y, test_p_uncal)
fpr_c, tpr_c, _ = roc_curve(test_y, test_p_cal)
plt.figure(figsize=(6,4))
plt.plot([0,1],[0,1],'--', linewidth=1, label="chance")
plt.plot(fpr_u, tpr_u, label=f"uncalibrated (AUC={auc_uncal:.3f})")
plt.plot(fpr_c, tpr_c, label=f"calibrated (AUC={auc_cal:.3f})")
plt.xlabel("1 - specificity (FPR)"); plt.ylabel("sensitivity (TPR)")
plt.title("ROC — test")
plt.legend(); plt.tight_layout()
roc_png = RESULTS_DIR / "roc_test.png"
plt.savefig(roc_png, dpi=150); plt.close()

# ---- Reliability diagram (test) ----
plt.figure(figsize=(6,4))
xs = np.linspace(0,1,101)
plt.plot(xs, xs, linestyle="--", linewidth=1, label="perfectly calibrated")
for label, df in [("uncalibrated", df_uncal), ("calibrated", df_cal)]:
    d = df.dropna(subset=["avg_prob","pos_rate"])
    plt.plot(d["avg_prob"], d["pos_rate"], marker="o", linewidth=1, label=label)
plt.xlabel("predicted probability (bin mean)")
plt.ylabel("empirical positive rate (bin mean)")
plt.title("Reliability diagram — test")
plt.legend(); plt.tight_layout()
rel_png = RESULTS_DIR / "reliability_diagram_test.png"
plt.savefig(rel_png, dpi=150); plt.close()

# Save ECE bins for test
bins_out = df_uncal.rename(columns={"avg_prob":"uncal_avg_prob","pos_rate":"uncal_pos_rate"})
bins_out[["cal_avg_prob","cal_pos_rate"]] = df_cal[["avg_prob","pos_rate"]]
bins_csv = RESULTS_DIR / "ece_bins_test.csv"
bins_out.to_csv(bins_csv, index=False)

# ---- Assemble and save JSON report ----
report = {
    "dataset": "test",
    "n_examples": int(len(test_y)),
    "temperature_T": round(float(temperature_T), 6),
    "operating_point": {
        "source": "validation",
        "mode": op_mode,
        "target_specificity": round(float(target_specificity), 6),
        "threshold": round(float(op_threshold), 6),
        "test_sensitivity": round(cm["sensitivity"], 6),
        "test_specificity": round(cm["specificity"], 6),
        "confusion_matrix": {k:int(v) for k,v in cm.items() if k in ("tp","tn","fp","fn")},
    },
    "discrimination": {
        "auc_uncalibrated": round(float(auc_uncal), 6),
        "auc_calibrated":   round(float(auc_cal), 6)
    },
    "calibration": {
        "ece_uncalibrated": round(float(ece_uncal), 6),
        "ece_calibrated":   round(float(ece_cal), 6),
        "nll_uncalibrated": round(float(nll_uncal), 6),
        "nll_calibrated":   round(float(nll_cal), 6),
        "brier_uncalibrated": round(float(brier_uncal), 6),
        "brier_calibrated":   round(float(brier_cal), 6),
        "n_bins": 10
    },
    "artifacts": {
        "roc_png": str(roc_png),
        "reliability_diagram_png": str(rel_png),
        "ece_bins_csv": str(bins_csv),
        "operating_point_val_json": str(op_json)
    }
}

out_json = RESULTS_DIR / "test_metrics.json"
with open(out_json, "w", encoding="utf-8") as f:
    json.dump(report, f, indent=2)

# ---- Console summary ----
print("\n=== Test summary (calibrated) ===")
print(f"AUC (cal): {report['discrimination']['auc_calibrated']:.4f} "
      f"(uncal: {report['discrimination']['auc_uncalibrated']:.4f})")
print(f"NLL  (cal): {report['calibration']['nll_calibrated']:.4f}  "
      f"Brier (cal): {report['calibration']['brier_calibrated']:.4f}  "
      f"ECE (cal): {report['calibration']['ece_calibrated']:.4f}")
print(f"Threshold (from val): {report['operating_point']['threshold']:.4f}")
print(f"Sensitivity@thr: {report['operating_point']['test_sensitivity']:.4f}  "
      f"Specificity@thr: {report['operating_point']['test_specificity']:.4f}")
print("Confusion matrix (TP, TN, FP, FN):",
      report['operating_point']['confusion_matrix'])
print(f"Saved ROC → {roc_png}")
print(f"Saved reliability diagram → {rel_png}")
print(f"Saved ECE bins → {bins_csv}")
print(f"Saved JSON report → {out_json}")

# Step 18 — Uncertainty via TTA + abstention

This cell estimates **predictive uncertainty** with **test-time augmentation (TTA)** and defines a simple **abstention rule** to trade coverage for accuracy. It reuses the calibrated model (temperature **T**) and **fixed decision threshold** from validation.

- **Prerequisites & artifacts**
  - Requires: `splits.csv` (manifests), `mobilenetv2_finetune_tail.pt` (fine-tuned weights), and `operating_point_val.json` (contains `temperature_T`, validation-derived `threshold`, and target specificity).
  - Loads the model on CPU/GPU, sets `eval()`, and applies **temperature scaling** at inference only (no retraining).

- **Preprocessing (consistent with training)**
  - Pads each image to square using the **border-ring modal color** (expected near-black), then **bicubic** resize to **128×128**.
  - Converts to tensor and applies **ImageNet normalization** (`mean=[0.485,0.456,0.406]`, `std=[0.229,0.224,0.225]`).

- **TTA configuration**
  - Uses **mild augmentations** aligned with training: horizontal/vertical flips (p=0.5), ±10° rotation (fill=0), brightness/contrast jitter (~[0.8, 1.2]).
  - For each image, samples **N=8** augmented variants, obtains logits, divides by **T** (from Step 16), applies `sigmoid`, and computes:
    - `mean_p`: mean probability across TTA samples.
    - `std_p`: standard deviation across TTA samples.
  - Runs TTA separately for **validation** and **test** subsets.

- **Baseline (no abstention)**
  - Computes the **baseline accuracy** by thresholding `mean_p` with the **fixed operating threshold** from validation (`OP_THR`).

- **Abstention rule (two-dimensional)**
  - Decision to **keep** a prediction:
    - Keep if `|mean_p − 0.5| ≥ δ` **AND** `std_p ≤ σ_thr`.
    - Otherwise **abstain** (defer/flag as uncertain).
  - Intuition:
    - `|mean_p − 0.5|` measures **confidence margin** away from ambiguity.
    - `std_p` captures **instability** across plausible views (data uncertainty).

- **Validation-driven selection of (δ, σ_thr)**
  - **Grid search** on validation for `δ ∈ [0.00, 0.30]` and `σ_thr ∈ [0.00, 0.10]`.
  - Objective: meet **target coverage** (default **0.90** kept) and **maximize accuracy among kept**.
  - If no pair hits coverage, selects the closest-by-coverage candidate, then highest accuracy (deterministic tie-break).

- **Frontier plot & artifacts**
  - Builds a **coverage–accuracy frontier** on validation and highlights the selected point.
  - Saves:
    - `coverage_accuracy_curve.png` — frontier with selected (δ, σ_thr).
    - `abstention_rule.json` — reusable rule:
      - `"rule": "abstain if |mean_p - 0.5| < delta OR std_p > sigma_thr"`
      - Selected `delta`, `sigma_thr`, `tta` details (N and aug), `calibration_T`, fixed `classification_threshold`, validation `target_specificity`, and target coverage.

- **Summaries on val & test**
  - Applies the chosen (δ, σ_thr) to **both validation and test**:
    - Reports **coverage** (fraction kept), **n_kept / n_abstained**, **accuracy among kept**, **baseline accuracy** (no abstention), and **accuracy gain**.
  - Saves:
    - `abstention_val_summary.json`
    - `abstention_test_summary.json`
  - Prints a concise console summary (rule, parameters, coverage, accuracy-kept, baseline).

- **Why this step**
  - Adds an **operational safety lever**: the system can **abstain** on uncertain cases, improving reliability at controlled throughput (coverage), without modifying the classifier or threshold learned on validation.

In [None]:
# Step 18 — Uncertainty via TTA + abstention
import json, math, time
from pathlib import Path
import numpy as np
import pandas as pd
import torch, torch.nn as nn
import torchvision.transforms as T
import matplotlib.pyplot as plt
from PIL import Image, ImageOps

# ---- Roots & prerequisites (consistent with earlier steps) ----
PROJECT_ROOT = Path(r"[path_placeholder]")
RESULTS_DIR  = PROJECT_ROOT / "results"
SPLITS_CSV   = RESULTS_DIR / "splits.csv"
OP_JSON      = RESULTS_DIR / "operating_point_val.json"
CKPT_FT      = RESULTS_DIR / "mobilenetv2_finetune_tail.pt"
assert SPLITS_CSV.exists() and OP_JSON.exists() and CKPT_FT.exists(), "Run Steps 15–17 first."

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- Load calibration artifacts (T and operating threshold from Step 16) ----
with open(OP_JSON, "r", encoding="utf-8") as f:
    op_payload = json.load(f)
T_cal      = float(op_payload["temperature_T"])
op_dict    = op_payload["operating_point"]
OP_THR     = float(op_dict["threshold"])            # same threshold as the rest of the pipeline
TARGET_SPEC = float(op_dict.get("target_specificity", 0.95))  # for metadata only

# ---- (Re)load best fine-tuned weights & eval mode ----
# Assumes 'model' from Step 15 exists; if not, rebuild MobileNetV2 as in Step 13 and load ft weights.
try:
    model
except NameError:
    from torchvision import models
    import torch.nn as nn
    weights = models.MobileNet_V2_Weights.DEFAULT
    model = models.mobilenet_v2(weights=weights)
    in_features = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(in_features, 1)

state = torch.load(CKPT_FT, map_location="cpu")
model.load_state_dict(state)
model = model.to(device).eval()

# ---- Mild TTA: same ranges as training ----
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]
to_tensor = T.ToTensor()
normalise = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
tta_aug   = T.Compose([
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.5),
    T.RandomRotation(degrees=10, fill=0),
    T.ColorJitter(brightness=0.2, contrast=0.2),  # ~[0.8,1.2]
])

def make_tensor_from_path(path: Path, augment: bool):
    with Image.open(path).convert("RGB") as img:
        img = pad_resize_canonical(img, target_size=128, ring=2)
    if augment:
        img = tta_aug(img)
    x = to_tensor(img)
    x = normalise(x)
    return x

# ---- Compute TTA (mean prob & std) for a split ----
def tta_split(df_split: pd.DataFrame, N=8, batch_device=device):
    """Returns dict with keys: y (0/1), mean_p, std_p"""
    ys, mean_ps, std_ps = [], [], []
    bce_sigmoid = torch.nn.Sigmoid()
    scaler_T = 1.0 / T_cal

    for _, row in df_split.iterrows():
        path = Path(row["path"])
        y    = 1 if row["class"] == "Parasitized" else 0

        # Build one TTA batch of size N for this image
        xb = torch.stack([ make_tensor_from_path(path, augment=True) for _ in range(N) ], dim=0)
        xb = xb.to(batch_device, non_blocking=True)

        with torch.no_grad():
            logits = model(xb).view(-1) * scaler_T   # temperature scaling on logits
            p = torch.sigmoid(logits).detach().cpu().numpy()

        ys.append(y)
        mean_ps.append(float(np.mean(p)))
        std_ps.append(float(np.std(p)))

    return {
        "y": np.array(ys, dtype=np.int64),
        "mean_p": np.array(mean_ps, dtype=np.float32),
        "std_p": np.array(std_ps, dtype=np.float32),
    }

# ---- Load splits ----
splits = pd.read_csv(SPLITS_CSV)
val_df  = splits[splits["split"]=="val"].reset_index(drop=True)
test_df = splits[splits["split"]=="test"].reset_index(drop=True)

print(f"TTA running… N=8  |  val={len(val_df)}  test={len(test_df)}")
t0 = time.time()
val_stats  = tta_split(val_df,  N=8)
test_stats = tta_split(test_df, N=8)
print(f"Done in {time.time()-t0:.1f}s")

# ---- Helper: evaluate abstention rule ----
def eval_rule(ys, mean_p, std_p, thr, delta, sigma_thr):
    keep = (np.abs(mean_p - 0.5) >= delta) & (std_p <= sigma_thr)
    cov  = float(np.mean(keep)) if len(keep) else 0.0
    if cov > 0:
        preds = (mean_p[keep] >= thr).astype(np.int64)
        acc   = float(np.mean(preds == ys[keep]))
    else:
        acc = float("nan")
    return cov, acc

# Baseline (no abstention) for reference
def baseline_acc(ys, mean_p, thr):
    return float(np.mean((mean_p >= thr).astype(np.int64) == ys))

val_baseline_acc  = baseline_acc(val_stats["y"],  val_stats["mean_p"],  OP_THR)
test_baseline_acc = baseline_acc(test_stats["y"], test_stats["mean_p"], OP_THR)

# ---- Grid search on validation for (delta, sigma_thr) at target coverage ----
TARGET_COVERAGE = 0.90   # "modest coverage reduction"; change to 0.85 if you want stronger abstention

delta_grid = np.linspace(0.00, 0.30, 31)   # |p-0.5| band
sigma_grid = np.linspace(0.00, 0.10, 21)   # std over N=8

grid_rows = []
best = None  # (acc, coverage, delta, sigma_thr)
for d in delta_grid:
    for s in sigma_grid:
        cov, acc = eval_rule(val_stats["y"], val_stats["mean_p"], val_stats["std_p"], OP_THR, d, s)
        grid_rows.append({"delta": float(d), "sigma_thr": float(s), "coverage": cov, "accuracy": acc})
        if cov >= TARGET_COVERAGE:
            if best is None or acc > best[0] or (math.isclose(acc, best[0]) and cov > best[1]):
                best = (acc, cov, d, s)

# Fallback if nothing meets coverage target: take the point with coverage closest *above* target if any,
# otherwise the closest overall, then highest accuracy.
if best is None:
    # distance to target coverage, then -accuracy (so we prefer higher accuracy)
    ranked = sorted(grid_rows, key=lambda r: (abs(r["coverage"]-TARGET_COVERAGE), -r["accuracy"]))
    top = ranked[0]
    best = (top["accuracy"], top["coverage"], top["delta"], top["sigma_thr"])

best_acc, best_cov, BEST_DELTA, BEST_SIGMA = best

# ---- Build coverage–accuracy curve (Pareto frontier on val) & plot ----
grid_df = pd.DataFrame(grid_rows).dropna()
# Pareto frontier: sort by coverage desc, keep points with strictly increasing accuracy
frontier = []
grid_sorted = grid_df.sort_values(["coverage","accuracy"], ascending=[True, False]).values.tolist()
# sweep from low→high coverage, keep running max accuracy
covs = []; accs = []
for cov in np.linspace(grid_df["coverage"].min(), 1.0, 100):
    subset = grid_df[grid_df["coverage"]>=cov]
    if len(subset):
        covs.append(cov)
        accs.append(subset["accuracy"].max())
plt.figure(figsize=(6,4))
plt.plot(covs, accs, linewidth=2, label="validation frontier")
plt.scatter(grid_df["coverage"], grid_df["accuracy"], s=8, alpha=0.25, label="grid")
plt.scatter([best_cov], [best_acc], s=60, marker="o", label=f"selected (δ={BEST_DELTA:.3f}, σ={BEST_SIGMA:.3f})")
plt.axvline(TARGET_COVERAGE, linestyle="--", linewidth=1, label=f"target coverage={TARGET_COVERAGE:.2f}")
plt.xlabel("coverage (kept fraction)"); plt.ylabel("accuracy among kept")
plt.title("Coverage vs accuracy — validation")
plt.legend()
cov_curve_png = RESULTS_DIR / "coverage_accuracy_curve.png"
plt.tight_layout(); plt.savefig(cov_curve_png, dpi=150); plt.close()

# ---- Save rule for reuse ----
rule_json = {
    "rule": "abstain if |mean_p - 0.5| < delta OR std_p > sigma_thr",
    "delta": round(float(BEST_DELTA), 6),
    "sigma_thr": round(float(BEST_SIGMA), 6),
    "tta": {"N": 8, "augs": "H/V flip, ±10° rot, brightness/contrast [0.8,1.2]"},
    "calibration_T": round(T_cal, 6),
    "classification_threshold": round(OP_THR, 6),
    "target_specificity_from_val": round(TARGET_SPEC, 6),
    "target_coverage": TARGET_COVERAGE
}
with open(RESULTS_DIR / "abstention_rule.json", "w", encoding="utf-8") as f:
    json.dump(rule_json, f, indent=2)

# ---- Summaries (val & test) using the selected rule ----
def summarize_split(name, stats, thr, delta, sigma_thr, baseline_acc):
    cov, acc = eval_rule(stats["y"], stats["mean_p"], stats["std_p"], thr, delta, sigma_thr)
    kept = ((np.abs(stats["mean_p"] - 0.5) >= delta) & (stats["std_p"] <= sigma_thr))
    n_all = int(stats["y"].shape[0]); n_kept = int(kept.sum()); n_abst = n_all - n_kept
    preds = (stats["mean_p"][kept] >= thr).astype(np.int64) if n_kept else np.array([], dtype=np.int64)
    acc_gain = acc - baseline_acc if n_kept else 0.0
    return {
        "split": name,
        "n_total": n_all,
        "coverage": round(cov, 6),
        "n_kept": n_kept,
        "n_abstained": n_abst,
        "accuracy_kept": round(acc, 6),
        "baseline_accuracy_no_abstention": round(baseline_acc, 6),
        "accuracy_gain_vs_baseline": round(acc_gain, 6),
        "delta": round(float(delta), 6),
        "sigma_thr": round(float(sigma_thr), 6),
        "threshold_used": round(float(thr), 6)
    }

val_summary  = summarize_split("val",  val_stats,  OP_THR, BEST_DELTA, BEST_SIGMA, val_baseline_acc)
test_summary = summarize_split("test", test_stats, OP_THR, BEST_DELTA, BEST_SIGMA, test_baseline_acc)

with open(RESULTS_DIR / "abstention_val_summary.json", "w", encoding="utf-8") as f:
    json.dump(val_summary, f, indent=2)
with open(RESULTS_DIR / "abstention_test_summary.json", "w", encoding="utf-8") as f:
    json.dump(test_summary, f, indent=2)

print("\n=== Abstention (selected rule) ===")
print("Rule: abstain if |mean_p - 0.5| < δ OR std_p > σ_thr")
print(f"δ={BEST_DELTA:.3f}, σ_thr={BEST_SIGMA:.3f}, target coverage={TARGET_COVERAGE:.2f}")
print(f"Validation: coverage={val_summary['coverage']:.3f}, acc_kept={val_summary['accuracy_kept']:.4f} "
      f"(baseline={val_summary['baseline_accuracy_no_abstention']:.4f})")
print(f"Test:       coverage={test_summary['coverage']:.3f}, acc_kept={test_summary['accuracy_kept']:.4f} "
      f"(baseline={test_summary['baseline_accuracy_no_abstention']:.4f})")
print(f"Saved rule → {RESULTS_DIR/'abstention_rule.json'}")
print(f"Saved curve → {cov_curve_png}")
print(f"Saved summaries → {RESULTS_DIR/'abstention_val_summary.json'}, {RESULTS_DIR/'abstention_test_summary.json'}")

# Step 19 — Out-of-distribution (OOD) flag via Maximum Softmax Probability (MSP)

This cell adds an OOD detector on top of the calibrated classifier using **MSP** (for binary models: `max(p, 1−p)` from calibrated logits). It selects an MSP threshold on the **validation** set to achieve a target **in-distribution TPR** (95%), then reports OOD performance on **test** using that fixed threshold.

- **Purpose**
  - Identify inputs that are likely **out-of-distribution** relative to training/validation data.
  - Provide an **ID-vs-OOD score** (MSP ∈ [0.5, 1.0]) and a **decision threshold** `τ` tuned on validation.

- **Prerequisites & setup**
  - Uses: `splits.csv`, `mobilenetv2_finetune_tail.pt` (fine-tuned weights), `operating_point_val.json` (for temperature `T`).
  - Reproducibility: `seed_everything(19)` (Python/NumPy/Torch/CUDA) and fixed seeds for OOD generators.
  - Device: CPU or CUDA; model is put in `eval()`.

- **Model & calibration**
  - Rebuilds MobileNetV2 with a **single-logit head** and loads fine-tuned weights.
  - Applies **temperature scaling** at inference (`logits / T_cal`) to ensure probabilities are aligned with previous calibration.

- **Preprocessing (consistent with training)**
  - **Pad-to-square** using the **border-ring modal color** (typically near-black), then **bicubic** resize to **128×128**.
  - Convert to tensor and apply **ImageNet normalization**.

- **Synthetic OOD construction (for evaluation only)**
  - Creates **corrupted variants** of in-distribution images to approximate OOD:
    - Heavy Gaussian blur (radius ~6–9).
    - Brightness extremes (0.2×, 1.8×).
    - Contrast extremes (0.2×, 2.5×).
    - With 30% probability, **compose two** corruptions.
  - Processes the same validation/test file list twice: once as **ID** (clean) and once as **synthetic OOD** (corrupted).

- **Scoring: MSP (binary)**
  - From calibrated logits: `p = sigmoid(logits / T)`.
  - **MSP = max(p, 1−p)**; higher ⇒ more **ID-like**.
  - Range: **[0.5, 1.0]** (0.5 = maximally uncertain).

- **Threshold selection (validation)**
  - Target **TPR(ID) = 0.95** → set `τ` to the **5th percentile** of validation **ID** MSP (`quantile = 1 − TPR_TARGET`).
  - At `τ`, report **TNR(OOD)** on the synthetic validation OOD set.
  - Save: `ood_threshold.json` (score type, `threshold`, achieved `val_tpr`, `val_tnr`, and `temperature_T`).
  - Plot: `ood_hist_msp_id_vs_ood.png` (MSP histograms for ID vs synthetic OOD with the vertical `τ` line).

- **Test-time evaluation (fixed τ from validation)**
  - Compute MSP for **test ID** and **test synthetic OOD**.
  - Metrics:
    - **AUROC (ID=1 vs OOD=0)** using MSP (ID should have **higher** scores).
    - **TNR@95%TPR** on test using the **validation-selected τ**.
  - Save: `ood_metrics_test.json` (counts, AUROC, TPR/TNR, threshold, temperature, notes).

- **Artifacts**
  - `results/ood_threshold.json` — Selected `τ` at 95% ID TPR on validation.
  - `results/ood_hist_msp_id_vs_ood.png` — Validation MSP histograms with `τ`.
  - `results/ood_metrics_test.json` — Test AUROC and TNR@95%TPR (with `τ` fixed from validation).

- **Operational notes**
  - **Assumption**: synthetic corruptions approximate OOD; real-world OOD can differ.
  - **Decision semantics**: MSP ≥ `τ` ⇒ **flag as ID**; MSP < `τ` ⇒ **flag as OOD**.
  - **Separation of concerns**: OOD decision is **independent** from the disease classification threshold; it only gates whether the classifier should be trusted.

In [None]:
# Step 19 — OOD flag via MSP
import os, json, math, time, random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torchvision import models, transforms as T
from torchvision.transforms import functional as TF
from PIL import Image, ImageOps, ImageFilter
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score

# ---------- Paths & constants (keep consistent with earlier steps) ----------
PROJECT_ROOT = Path(r"[path_placeholder]")
DATA_DIR     = PROJECT_ROOT / "cell_images"
RESULTS_DIR  = PROJECT_ROOT / "results"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

SPLITS_CSV   = RESULTS_DIR / "splits.csv"
CKPT_FT      = RESULTS_DIR / "mobilenetv2_finetune_tail.pt"
OP_JSON      = RESULTS_DIR / "operating_point_val.json"

assert SPLITS_CSV.exists(), f"Missing {SPLITS_CSV}"
assert CKPT_FT.exists(),    f"Missing {CKPT_FT}"
assert OP_JSON.exists(),    f"Missing {OP_JSON} (run Step 16)"

# ---------- Reproducibility ----------
def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed)
    try:
        import torch
        torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except Exception:
        pass

seed_everything(19)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------- Load T and model ----------
with open(OP_JSON, "r", encoding="utf-8") as f:
    op_payload = json.load(f)
T_cal = float(op_payload.get("temperature_T", 1.0))

# Build MobileNetV2 (single-logit head) and load fine-tuned weights
def build_mobilenetv2_single_logit(pretrained=False):
    try:
        weights = models.MobileNet_V2_Weights.DEFAULT if pretrained else None
        net = models.mobilenet_v2(weights=weights)
    except Exception:
        net = models.mobilenet_v2(pretrained=pretrained)
    in_features = net.classifier[1].in_features
    net.classifier[1] = nn.Linear(in_features, 1)
    return net

model = build_mobilenetv2_single_logit(pretrained=False)
try:
    state = torch.load(CKPT_FT, map_location="cpu", weights_only=True)  # PyTorch ≥2.4
except TypeError:
    state = torch.load(CKPT_FT, map_location="cpu")
model.load_state_dict(state)
model = model.to(device).eval()

# ---------- Canonical preprocessing (same as training/inference) ----------
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]
to_tensor = T.ToTensor()
normalize = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)

def preprocess_to_tensor(path: Path):
    with Image.open(path).convert("RGB") as img:
        img = pad_resize_canonical(img, target_size=128, ring=2)
    x = to_tensor(img)
    x = normalize(x)
    return x

# ---------- Synthetic OOD generator ----------
def make_synthetic_ood(img: Image.Image, rng: np.random.RandomState):
    """Apply one (or two) strong corruptions to a 128×128 PIL image."""
    ops = []
    # heavy blur
    def op_blur(i):
        r = float(rng.uniform(6.0, 9.0))
        return i.filter(ImageFilter.GaussianBlur(radius=r))
    # brightness extremes
    def op_bright_low(i):  return TF.adjust_brightness(i, 0.2)
    def op_bright_high(i): return TF.adjust_brightness(i, 1.8)
    # contrast extremes
    def op_contrast_low(i):  return TF.adjust_contrast(i, 0.2)
    def op_contrast_high(i): return TF.adjust_contrast(i, 2.5)

    choices = [op_blur, op_bright_low, op_bright_high, op_contrast_low, op_contrast_high]
    i1 = rng.choice(len(choices))
    img = choices[i1](img)
    # With 30% chance, compose a second corruption
    if rng.rand() < 0.30:
        i2 = rng.choice(len(choices))
        img = choices[i2](img)
    return img

def preprocess_to_tensor_ood(path: Path, rng: np.random.RandomState):
    with Image.open(path).convert("RGB") as img:
        img = pad_resize_canonical(img, target_size=128, ring=2)
        img = make_synthetic_ood(img, rng)
    x = to_tensor(img)
    x = normalize(x)
    return x

# ---------- Data (val/test) ----------
splits = pd.read_csv(SPLITS_CSV)
val_df  = splits[splits["split"]=="val"].reset_index(drop=True)
test_df = splits[splits["split"]=="test"].reset_index(drop=True)
print(f"val={len(val_df)}  test={len(test_df)}")

# ---------- Batched inference helpers ----------
@torch.no_grad()
def logits_for_paths(paths, batch=128):
    xs = []
    out = []
    for i, p in enumerate(paths, 1):
        xs.append(preprocess_to_tensor(Path(p)))
        if len(xs) == batch or i == len(paths):
            xb = torch.stack(xs, 0).to(device, non_blocking=True)
            logits = model(xb).view(-1)
            out.append(logits.detach().cpu())
            xs = []
    return torch.cat(out).float()

@torch.no_grad()
def logits_for_paths_ood(paths, seed=123, batch=128):
    rng = np.random.RandomState(seed)
    xs = []
    out = []
    for i, p in enumerate(paths, 1):
        xs.append(preprocess_to_tensor_ood(Path(p), rng))
        if len(xs) == batch or i == len(paths):
            xb = torch.stack(xs, 0).to(device, non_blocking=True)
            logits = model(xb).view(-1)
            out.append(logits.detach().cpu())
            xs = []
    return torch.cat(out).float()

def msp_from_logits(logits: torch.Tensor, T: float):
    """Return MSP for binary model given logits and temperature T."""
    logits_T = logits / float(T)
    p = torch.sigmoid(logits_T).numpy()
    msp = np.maximum(p, 1.0 - p)
    return msp

# ---------- Compute MSP on val (ID) and val-OOD, choose τ for TPR=0.95 ----------
print("Collecting logits… (validation)")
val_logits_id  = logits_for_paths(val_df["path"].tolist(), batch=256)
val_logits_ood = logits_for_paths_ood(val_df["path"].tolist(), seed=777, batch=256)

msp_val_id  = msp_from_logits(val_logits_id,  T_cal)
msp_val_ood = msp_from_logits(val_logits_ood, T_cal)

TPR_TARGET = 0.95
# Threshold τ so that 95% of ID have MSP ≥ τ  → τ is the 5th percentile of ID MSP
tau = float(np.quantile(msp_val_id, 1.0 - TPR_TARGET))

# Validation TNR at this τ
val_pred_id  = (msp_val_id  >= tau)
val_pred_ood = (msp_val_ood >= tau)  # predicted "ID" if True
val_tpr = float(np.mean(val_pred_id))                 # by construction ≈ 0.95
val_tnr = float(np.mean(~val_pred_ood))               # true negatives among OOD

# Save τ
thr_json = {
    "score": "MSP",
    "threshold": round(tau, 6),
    "tpr_target_id": TPR_TARGET,
    "val_tpr_id": round(val_tpr, 6),
    "val_tnr_ood_at_tpr": round(val_tnr, 6),
    "temperature_T": round(float(T_cal), 6)
}
with open(RESULTS_DIR / "ood_threshold.json", "w", encoding="utf-8") as f:
    json.dump(thr_json, f, indent=2)

# ---------- Histogram (val) ----------
plt.figure(figsize=(7,4.2))
bins = np.linspace(0.5, 1.0, 40)  # MSP ∈ [0.5, 1]
plt.hist(msp_val_id,  bins=bins, alpha=0.65, label="ID (val)")
plt.hist(msp_val_ood, bins=bins, alpha=0.65, label="synthetic OOD (val)")
plt.axvline(tau, linestyle="--", linewidth=1.5, label=f"τ @ 95% TPR (ID) = {tau:.3f}")
plt.xlabel("MSP"); plt.ylabel("count")
plt.title("MSP histogram — ID vs synthetic OOD (validation)")
plt.legend()
plt.tight_layout()
hist_png = RESULTS_DIR / "ood_hist_msp_id_vs_ood.png"
plt.savefig(hist_png, dpi=150); plt.close()
print(f"Saved histogram → {hist_png}")

# ---------- Test metrics using val-selected τ ----------
print("Collecting logits… (test and test-OOD)")
test_logits_id  = logits_for_paths(test_df["path"].tolist(), batch=256)
test_logits_ood = logits_for_paths_ood(test_df["path"].tolist(), seed=888, batch=256)

msp_test_id  = msp_from_logits(test_logits_id,  T_cal)
msp_test_ood = msp_from_logits(test_logits_ood, T_cal)

# AUROC (ID=1, OOD=0) with MSP (higher means more ID-like)
y_test = np.concatenate([np.ones_like(msp_test_id), np.zeros_like(msp_test_ood)])
s_test = np.concatenate([msp_test_id, msp_test_ood])
auroc = float(roc_auc_score(y_test, s_test))

# TNR@95%TPR on TEST using τ from validation
tpr_test = float(np.mean(msp_test_id >= tau))
tnr_test = float(np.mean(msp_test_ood <  tau))

metrics = {
    "temperature_T": round(float(T_cal), 6),
    "threshold": round(float(tau), 6),
    "n_test_id": int(len(msp_test_id)),
    "n_test_ood": int(len(msp_test_ood)),
    "auroc_id_vs_ood": round(auroc, 6),
    "tpr_id_at_tau": round(tpr_test, 6),
    "tnr_ood_at_95pct_tpr": round(tnr_test, 6),
    "notes": "ID positive class; score = MSP; τ chosen on validation as 5th percentile of ID MSP."
}
with open(RESULTS_DIR / "ood_metrics_test.json", "w", encoding="utf-8") as f:
    json.dump(metrics, f, indent=2)

# ---------- Console summary ----------
print("\n=== OOD via MSP ===")
print(f"T (temperature): {T_cal:.4f}")
print(f"τ (MSP) @ 95% TPR (selected on val): {tau:.4f}")
print(f"Validation:  TPR(ID)={val_tpr:.3f}  TNR(OOD)={val_tnr:.3f}")
print(f"Test AUROC (ID vs OOD): {auroc:.4f}")
print(f"Test:        TPR(ID)={tpr_test:.3f}  TNR(OOD)={tnr_test:.3f}")
print(f"Saved → {RESULTS_DIR/'ood_threshold.json'}, {hist_png}, {RESULTS_DIR/'ood_metrics_test.json'}")

# Step 20 — Photometric robustness sweeps (test set)

This cell measures how sensitive the **calibrated** classifier is to controlled **brightness** and **contrast** shifts on the **test set**, keeping the decision threshold fixed from validation. It sweeps multiplicative factors around nominal (1.0), reports **AUC** and **Accuracy@threshold**, and saves summary artifacts.

- **Inputs & prerequisites**
  - Uses paths and artifacts from prior steps: `splits.csv` (test manifest), `mobilenetv2_finetune_tail.pt` (fine-tuned weights), and `operating_point_val.json` (calibration temperature `T` and classification threshold `OP_THR`).
  - Runs on CPU or CUDA (`eval()` mode).

- **Calibration & decision rule**
  - Loads **temperature** `T_cal` and the **operating threshold** `OP_THR` selected on validation.
  - At inference, applies **temperature scaling** to logits (`logits / T_cal`) before `sigmoid`.
  - **Accuracy** is computed by thresholding calibrated probabilities with **the same `OP_THR` across all perturbations** to ensure a fair comparison.

- **Model & preprocessing (aligned with training)**
  - Rebuilds MobileNetV2 with a **single-logit** classifier; loads fine-tuned weights.
  - Preprocesses each image by **pad-to-square** using the image’s border-ring modal color (typically near-black), **bicubic** resize to **128×128**, then **ImageNet normalization**.

- **Photometric perturbations**
  - Two one-parameter families applied **after** pad/resize and **before** tensor conversion:
    - **Brightness** factors: `[0.8, 0.9, 1.0, 1.1, 1.2]` via `TF.adjust_brightness`.
    - **Contrast** factors: `[0.8, 0.9, 1.0, 1.1, 1.2]` via `TF.adjust_contrast`.
  - `1.0` is nominal; `<1.0` darkens/lowers contrast, `>1.0` brightens/raises contrast.

- **Evaluation protocol**
  - For each factor in each family:
    - Run **batched** inference on all test images with temperature scaling.
    - Compute:
      - **AUC** (threshold-free discrimination).
      - **Accuracy@OP_THR** (performance at the fixed operating point).
  - Aggregates all results into a single table and prints per-factor summaries.

- **Saved artifacts**
  - `results/robustness_metrics.csv` — Per-factor metrics with the used `OP_THR` and `T_cal`.
  - `results/robustness_brightness_curve.png` — AUC and Accuracy vs. **brightness** factor (vertical line at 1.0).
  - `results/robustness_contrast_curve.png` — AUC and Accuracy vs. **contrast** factor (vertical line at 1.0).

- **Interpretation notes**
  - **AUC** reflects ranking robustness; **Accuracy@threshold** shows stability at the chosen clinical/operational point.
  - Drops at factors far from 1.0 indicate **sensitivity** to illumination or contrast shifts; symmetric/asymmetric patterns may reveal bias to over/under-exposure.
  - Because the **threshold is fixed from validation**, changes isolate **input shift effects** rather than retuning.
  - Consider pairing with **data augmentation**, histogram equalization, or calibration refresh if large degradations are observed.

- **Extensibility**
  - Additional sweeps (e.g., gamma, saturation, hue, JPEG quality) can be added by following the same pattern and appending rows to the metrics table.

In [None]:
# Step 20 — Photometric robustness sweeps (test set)
import json, math
from pathlib import Path
import numpy as np
import pandas as pd
import torch, torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from torchvision import models, transforms as T
from torchvision.transforms import functional as TF

# ---- Roots & prerequisites (reuse your paths) ----
PROJECT_ROOT = Path(r"[path_placeholder]")
DATA_DIR     = PROJECT_ROOT / "cell_images"
RESULTS_DIR  = PROJECT_ROOT / "results"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

SPLITS_CSV   = RESULTS_DIR / "splits.csv"
CKPT_FT      = RESULTS_DIR / "mobilenetv2_finetune_tail.pt"
OP_JSON      = RESULTS_DIR / "operating_point_val.json"

assert SPLITS_CSV.exists(), "Missing splits.csv — run earlier steps."
assert CKPT_FT.exists(),    "Missing fine-tuned weights — run Step 15."
assert OP_JSON.exists(),    "Missing operating_point_val.json — run Step 16."

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---- Load calibration: temperature T and decision threshold (from Step 16) ----
with open(OP_JSON, "r", encoding="utf-8") as f:
    op_payload = json.load(f)
T_cal   = float(op_payload["temperature_T"])
op_dict = op_payload["operating_point"]
OP_THR  = float(op_dict["threshold"])
print(f"Loaded calibration → T={T_cal:.4f}, threshold={OP_THR:.4f}")

# ---- Build model & load best weights ----
def build_mobilenetv2_single_logit(pretrained=False):
    try:
        weights = models.MobileNet_V2_Weights.DEFAULT if pretrained else None
        net = models.mobilenet_v2(weights=weights)
    except Exception:
        net = models.mobilenet_v2(pretrained=pretrained)
    in_features = net.classifier[1].in_features
    net.classifier[1] = nn.Linear(in_features, 1)
    return net

model = build_mobilenetv2_single_logit(pretrained=False)
try:
    state = torch.load(CKPT_FT, map_location="cpu", weights_only=True)
except TypeError:
    state = torch.load(CKPT_FT, map_location="cpu")
model.load_state_dict(state)
model = model.to(device).eval()

# ---- Canonical preprocessing (pad-to-square using border mode + bicubic 128²) ----
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]
to_tensor = T.ToTensor()
normalize = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)

def preprocess_with_adjust(path: Path, mode: str, factor: float):
    with Image.open(path).convert("RGB") as img:
        img = pad_resize_canonical(img, target_size=128, ring=2)
        if mode == "brightness":
            img = TF.adjust_brightness(img, float(factor))
        elif mode == "contrast":
            img = TF.adjust_contrast(img, float(factor))
        else:
            raise ValueError("mode must be 'brightness' or 'contrast'")
    x = to_tensor(img)
    x = normalize(x)
    return x

# ---- Batched inference helper (applies T for calibrated probs) ----
@torch.no_grad()
def probs_for_paths(paths, mode: str, factor: float, batch=256):
    xs, out = [], []
    invT = 1.0 / float(T_cal)  # divide logits by T
    for i, p in enumerate(paths, 1):
        xs.append(preprocess_with_adjust(Path(p), mode=mode, factor=factor))
        if len(xs) == batch or i == len(paths):
            xb = torch.stack(xs, 0).to(device, non_blocking=True)
            logits = model(xb).view(-1) * invT
            out.append(torch.sigmoid(logits).cpu().numpy())
            xs = []
    return np.concatenate(out)

# ---- Load test paths & labels ----
splits = pd.read_csv(SPLITS_CSV)
test_df = splits[splits["split"]=="test"].reset_index(drop=True)
label_map = {"Parasitized": 1, "Uninfected": 0}
y_test = test_df["class"].map(label_map).to_numpy().astype(np.int64)
paths   = test_df["path"].tolist()
print(f"Test set size: {len(paths)}")

# ---- Metric helpers ----
from sklearn.metrics import roc_auc_score, accuracy_score

def metrics_at_factor(mode: str, factor: float):
    p = probs_for_paths(paths, mode=mode, factor=factor, batch=256)
    auc = float(roc_auc_score(y_test, p))
    acc = float(accuracy_score(y_test, (p >= OP_THR).astype(np.int64)))
    return auc, acc

# ---- Sweeps ----
BR_FACTORS = [0.8, 0.9, 1.0, 1.1, 1.2]
CT_FACTORS = [0.8, 0.9, 1.0, 1.1, 1.2]

rows = []

print("\nBrightness sweep:")
for f in BR_FACTORS:
    auc, acc = metrics_at_factor("brightness", f)
    rows.append({"type":"brightness", "factor":f, "auc":auc, "acc_at_threshold":acc, "n":int(len(y_test)),
                 "threshold_used": float(OP_THR), "temperature_T": float(T_cal)})
    print(f"  factor={f:.2f}  AUC={auc:.4f}  Acc@thr={acc:.4f}")

print("\nContrast sweep:")
for f in CT_FACTORS:
    auc, acc = metrics_at_factor("contrast", f)
    rows.append({"type":"contrast", "factor":f, "auc":auc, "acc_at_threshold":acc, "n":int(len(y_test)),
                 "threshold_used": float(OP_THR), "temperature_T": float(T_cal)})
    print(f"  factor={f:.2f}  AUC={auc:.4f}  Acc@thr={acc:.4f}")

robust_df = pd.DataFrame(rows)
csv_out = RESULTS_DIR / "robustness_metrics.csv"
robust_df.to_csv(csv_out, index=False)
print(f"\nSaved metrics → {csv_out}")

# ---- Curves (single-axes plots with both lines, vertical guide at 1.0) ----
def plot_curve(df, kind: str, out_path: Path, title: str):
    sub = df[df["type"] == kind].sort_values("factor")
    x = sub["factor"].to_numpy()
    y_auc = sub["auc"].to_numpy()
    y_acc = sub["acc_at_threshold"].to_numpy()
    plt.figure(figsize=(6,4))
    plt.plot(x, y_auc, marker="o", label="AUC")
    plt.plot(x, y_acc, marker="s", label=f"Accuracy@thr={OP_THR:.3f}")
    plt.axvline(1.0, linestyle="--", linewidth=1, label="nominal (1.0)")
    plt.xlabel(f"{kind} factor"); plt.ylabel("score")
    plt.ylim(0.0, 1.0)
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=150); plt.close()

plot_curve(robust_df, "brightness",
           RESULTS_DIR / "robustness_brightness_curve.png",
           "Photometric robustness — brightness (test)")
plot_curve(robust_df, "contrast",
           RESULTS_DIR / "robustness_contrast_curve.png",
           "Photometric robustness — contrast (test)")

print("Saved →",
      RESULTS_DIR / "robustness_brightness_curve.png",
      "and",
      RESULTS_DIR / "robustness_contrast_curve.png")

# Step 21 — Interpretability (Grad-CAM)

This cell generates **Grad-CAM** explanations for the fine-tuned MobileNetV2 on the **test split**, targeting the last convolutional block (`features[18]`). It saves (i) a **4×4 panel** with up to four examples each for **TP / FP / TN / FN**, and (ii) **per-case overlays**.

---

## Purpose
- Provide **visual evidence** of what the model attends to when predicting *Parasitized* vs *Uninfected*.
- Inspect **correct** (TP/TN) vs **error** (FP/FN) cases under the **same calibrated operating point** used in evaluation.

---

## Inputs & prerequisites
- Artifacts from earlier steps:
  - `results/splits.csv` (test manifest with file paths and labels).
  - `results/mobilenetv2_finetune_tail.pt` (fine-tuned weights).
  - `results/operating_point_val.json` (temperature `T_CAL` and decision threshold `OP_THR`).
- Rebuilds MobileNetV2 with a **single-logit** head and loads the fine-tuned weights.
- Runs on CPU or CUDA (`eval()` mode).

---

## Preprocessing (consistent with training/inference)
- **Pad-to-square** using the **border-ring modal color** (typically near-black), then **bicubic** resize to **128×128**.
- Convert to tensor and apply **ImageNet normalization**.
- Two products per image:
  - **PIL 128×128** for visualization overlays.
  - **Normalized tensor** for the forward/backward pass.

---

## Grad-CAM details
- **Target layer:** `model.features[-1]` (last conv block before pooling/classifier).
- **Hooks:** forward hook stores activations; backward hook stores gradients.
- **Signal for backprop:** the **temperature-scaled logit** (`logit / T_CAL`) used to compute the probability at the operating point. This keeps the explanation consistent with calibrated inference.
- **Weights:** global average of gradients over spatial dims → channel weights.
- **CAM construction:** ReLU of the weighted sum of feature maps; upsample to **128×128**; min-max normalize to **[0,1]**.
- **Overlay:** Blend the heatmap (e.g., `"jet"`) with the processed PIL image (α = 0.35).

---

## Case selection & panel assembly
- Computes **calibrated probability** `p = sigmoid(logit / T_CAL)` and **predicted class** via `p ≥ OP_THR`.
- Partitions test indices into **TP / FP / TN / FN** and samples up to **4** per category (deterministic seeds).
- For each sampled case:
  - Run Grad-CAM, create **overlay**, annotate with **true label**, **predicted label**, and **p**.
  - Save an individual PNG under `results/gradcam_cases/`.
- Assemble a **4×4 grid** (rows = TP/FP/TN/FN; up to 4 columns each). Empty categories are left blank.

---

## Outputs
- `results/gradcam_panel.png` — 4×4 panel of overlays.
- `results/gradcam_cases/*.png` — per-image overlays named with category, stem, and probability.

---

## How to read the overlays
- **Warmer regions** indicate areas contributing **positively** to the predicted class for that image (after calibration).
- **TP/TN:** Expect focus on **cell interiors/morphology** rather than borders or padding.
- **FP/FN:** Highlights on irrelevant background, staining artifacts, or borders may indicate **spurious cues**.
- Resolution (128²) and the chosen layer provide **coarse**, class-specific attributions; Grad-CAM is **qualitative**, not a proof of causality.

---

## Notes & limitations
- If a category has **< 4** instances, remaining slots are blank; if a category is **empty**, its row is blank.
- CAMs depend on the **chosen layer**; earlier layers can yield finer spatial detail at the cost of specificity.
- Color maps and α are fixed for consistency; modifying them changes visual salience but not the underlying CAM.

In [None]:
# Step 21 — Interpretability (Grad-CAM)
# Scope: last conv block of MobileNetV2 (features[18]).
# Outputs:
#   results/gradcam_panel.png    (4x4 grid: TP/FP/TN/FN, 4 each)
#   results/gradcam_cases/*.png  (per-case overlays)

import os, json, math, numpy as np, pandas as pd
from pathlib import Path
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import torch, torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms as T
from torchvision.transforms import functional as TF

# ---- Roots & artifacts from prior steps ----
PROJECT_ROOT = Path(r"[path_placeholder]")
DATA_DIR     = PROJECT_ROOT / "cell_images"
RESULTS_DIR  = PROJECT_ROOT / "results"
SPLITS_CSV   = RESULTS_DIR / "splits.csv"
CKPT_FT      = RESULTS_DIR / "mobilenetv2_finetune_tail.pt"
OP_JSON      = RESULTS_DIR / "operating_point_val.json"

assert SPLITS_CSV.exists() and CKPT_FT.exists() and OP_JSON.exists(), "Run Steps 15–17 first."
(RESULTS_DIR / "gradcam_cases").mkdir(parents=True, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---- Load temperature + operating threshold ----
with open(OP_JSON, "r", encoding="utf-8") as f:
    payload = json.load(f)
T_CAL   = float(payload["temperature_T"])
OP_THR  = float(payload["operating_point"]["threshold"])

# ---- Build model, load weights ----
def build_mobilenetv2_single_logit(pretrained=False):
    try:
        weights = models.MobileNet_V2_Weights.DEFAULT if pretrained else None
        net = models.mobilenet_v2(weights=weights)
    except Exception:
        net = models.mobilenet_v2(pretrained=pretrained)
    in_features = net.classifier[1].in_features
    net.classifier[1] = nn.Linear(in_features, 1)
    return net

model = build_mobilenetv2_single_logit(pretrained=False)
try:
    state = torch.load(CKPT_FT, map_location="cpu", weights_only=True)  # PyTorch ≥2.4
except TypeError:
    state = torch.load(CKPT_FT, map_location="cpu")
model.load_state_dict(state)
model = model.to(device).eval()

# ---- Preprocess (same as training/inference) ----
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]
to_tensor = T.ToTensor()
normalize = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)

def preprocess_both(path: Path):
    """Return (PIL_128x128 for overlay, tensor_normalised for model)."""
    with Image.open(path).convert("RGB") as img:
        img = pad_resize_canonical(img, target_size=128, ring=2)  # 128×128 PIL
    x = to_tensor(img)            # [0,1]
    x = normalize(x)              # ImageNet normalisation
    return img, x

# ---- Grad-CAM helper focused on features[18] ----
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.activ = None
        self.grads = None
        self.fh = target_layer.register_forward_hook(self._fwd_hook)
        # full backward hook for modern PyTorch; fallback to deprecated if unavailable
        try:
            self.bh = target_layer.register_full_backward_hook(self._bwd_hook)
        except Exception:
            self.bh = target_layer.register_backward_hook(self._bwd_hook)

    def _fwd_hook(self, module, inp, out):
        self.activ = out.detach()

    def _bwd_hook(self, module, grad_in, grad_out):
        self.grads = grad_out[0].detach()

    def __call__(self, x_tensor, use_temperature=True):
        """
        x_tensor: [1,3,128,128] normalised
        returns: (cam[128,128] np.float32 in [0,1], prob, logit_cal, pred_label[0/1])
        """
        self.model.zero_grad(set_to_none=True)
        logits = self.model(x_tensor).view(-1)  # [1]
        # Apply temperature on logits for the *prediction/probability* path
        if use_temperature:
            logits_for_pred = logits / float(T_CAL)
        else:
            logits_for_pred = logits
        prob = torch.sigmoid(logits_for_pred)[0].item()

        # For Grad-CAM, backprop from the calibrated logit (keeps consistency with OP)
        logits_for_pred.backward(retain_graph=True)

        A = self.activ           # [1,C,h,w]
        dA = self.grads          # [1,C,h,w]
        weights = dA.mean(dim=(2,3), keepdim=True)   # [1,C,1,1]
        cam = torch.relu((weights * A).sum(dim=1, keepdim=True))  # [1,1,h,w]
        cam = F.interpolate(cam, size=(128,128), mode="bilinear", align_corners=False)
        cam = cam[0,0]
        # normalise to [0,1]
        cam -= cam.min()
        denom = (cam.max() + 1e-8)
        cam = (cam / denom).clamp(0,1).cpu().numpy().astype(np.float32)
        pred = int(prob >= OP_THR)
        return cam, prob, logits_for_pred.item(), pred

    def close(self):
        self.fh.remove()
        self.bh.remove()

# Target: last conv block before GAP
target_layer = model.features[-1]   # features[18]
gcam = GradCAM(model, target_layer)

# ---- Collect test set predictions to pick TP/FP/TN/FN ----
splits = pd.read_csv(SPLITS_CSV)
test_df = splits[splits["split"]=="test"].reset_index(drop=True)
label_map = {"Parasitized": 1, "Uninfected": 0}

records = []
with torch.no_grad():
    for i, row in test_df.iterrows():
        path = Path(row["path"])
        y = label_map[row["class"]]
        _, x = preprocess_both(path)
        x = x.unsqueeze(0).to(device)
        logit = model(x).view(-1)
        p = torch.sigmoid(logit / float(T_CAL)).item()
        pred = int(p >= OP_THR)
        records.append({"idx": i, "path": str(path), "y": y, "p": p, "pred": pred})

pred_df = pd.DataFrame(records)

# Index by category
TP_idx = pred_df[(pred_df["pred"]==1) & (pred_df["y"]==1)].index.tolist()
FP_idx = pred_df[(pred_df["pred"]==1) & (pred_df["y"]==0)].index.tolist()
TN_idx = pred_df[(pred_df["pred"]==0) & (pred_df["y"]==0)].index.tolist()
FN_idx = pred_df[(pred_df["pred"]==0) & (pred_df["y"]==1)].index.tolist()

def sample_idxs(idxs, k=4, seed=123):
    if len(idxs) <= k:
        return idxs
    rng = np.random.RandomState(seed)
    return list(rng.choice(idxs, size=k, replace=False))

TP_s = sample_idxs(TP_idx, 4, 1)
FP_s = sample_idxs(FP_idx, 4, 2)
TN_s = sample_idxs(TN_idx, 4, 3)
FN_s = sample_idxs(FN_idx, 4, 4)

panel_order = [("TP", TP_s), ("FP", FP_s), ("TN", TN_s), ("FN", FN_s)]

# ---- Utilities to overlay CAM on the processed 128×128 image ----
def overlay_cam_on_image(pil_img_128, cam_128, alpha=0.35, cmap_name="jet"):
    """Return a PIL RGB image with heatmap overlay."""
    import matplotlib.cm as cm
    base = np.asarray(pil_img_128).astype(np.float32) / 255.0  # [H,W,3]
    cmap = cm.get_cmap(cmap_name)
    heat = cmap(cam_128)[:, :, :3]   # RGBA->RGB
    overlay = (1 - alpha) * base + alpha * heat
    overlay = np.clip(overlay, 0, 1)
    return Image.fromarray((overlay * 255).astype(np.uint8))

# ---- Generate per-case overlays & assemble panel ----
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
plt.subplots_adjust(wspace=0.02, hspace=0.15)

case_dir = RESULTS_DIR / "gradcam_cases"
case_dir.mkdir(exist_ok=True)

for r, (label, idxs) in enumerate(panel_order):
    if len(idxs) == 0:
        # If a category is missing, fill with blanks
        for c in range(4):
            axes[r, c].axis("off")
        continue
    for c, df_idx in enumerate(idxs[:4]):
        row = pred_df.loc[df_idx]
        path = Path(row["path"])
        true_y = int(row["y"])
        pil_img, x = preprocess_both(path)
        x = x.unsqueeze(0).to(device).requires_grad_(True)

        cam, prob, logit_cal, pred = gcam(x, use_temperature=True)
        over = overlay_cam_on_image(pil_img, cam, alpha=0.35, cmap_name="jet")

        # Save per-case
        base_name = f"{label}_{path.stem}_p{prob:.3f}.png"
        over.save(case_dir / base_name)

        # Title with class/pred/prob
        gt_txt   = "Parasitized" if true_y==1 else "Uninfected"
        pred_txt = "Parasitized" if pred==1 else "Uninfected"
        axes[r, c].imshow(over)
        axes[r, c].set_title(f"{label} | gt={gt_txt} | pred={pred_txt}\np={prob:.3f}", fontsize=9)
        axes[r, c].axis("off")

# Row headers if we sampled <4 in a row: fill remaining cells
for r in range(4):
    for c in range(4):
        if not hasattr(axes[r, c], 'has_data') or not axes[r, c].images:
            axes[r, c].axis("off")

# Add left y-labels for rows
for r, (label, _) in enumerate(panel_order):
    axes[r, 0].text(-0.10, 0.5, label, transform=axes[r, 0].transAxes,
                    fontsize=12, fontweight="bold", va="center", ha="right", rotation=90)

out_png = RESULTS_DIR / "gradcam_panel.png"
plt.tight_layout()
plt.savefig(out_png, dpi=150, bbox_inches="tight")
plt.close()
gcam.close()

print(f"Saved panel → {out_png}")
print(f"Saved per-case overlays → {case_dir}")