# In-Depth Analysis

In [86]:
%cd ~/Documents/DISSERTATION

/Users/t/Documents/DISSERTATION


## Rigorous paired significance testing and confidence intervals

In [30]:
# Rigorous paired significance testing and confidence intervals

import os
import sys
import json
import inspect
import numpy as np
import torch
import yaml
from collections import OrderedDict
from scipy import stats


PROJECT_ROOT = os.getcwd()
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

# Project modules
from datasets import create_data_loaders
from utils.metrics import dice_coefficient, iou_score
from models.unet import UNet
from models.attention_unet import AttentionUNet


def build_test_loader(cfg_path: str):
    """Read config; enforce a deterministic test_loader order and disable multiprocessing (num_workers=0)."""
    with open(cfg_path, "r") as f:
        cfg = yaml.safe_load(f)
    cfg = cfg or {}

    # — Disable multiprocessing to avoid failures when spawning child processes — #
    cfg.setdefault("data", {})
    cfg["data"]["shuffle"] = False
    cfg["data"]["num_workers"] = 0
    cfg["data"]["pin_memory"] = False       # Avoid pin_memory issues on Mac/CPU
    cfg["data"]["persistent_workers"] = False

    cfg.setdefault("training", {})
    cfg["training"]["num_workers"] = 0
    cfg["training"]["pin_memory"] = False
    cfg["training"]["persistent_workers"] = False

    _, _, test_loader, _ = create_data_loaders(cfg)
    return cfg, test_loader


def _filter_kwargs_for_ctor(model_cls, maybe_kwargs: dict):
    sig = inspect.signature(model_cls.__init__)
    return {k: v for k, v in (maybe_kwargs or {}).items() if k in sig.parameters}


def load_model(model_cls, cfg: dict, ckpt_path: str, device: str = "cpu"):
    ctor_kwargs = _filter_kwargs_for_ctor(model_cls, cfg.get("model", {}))
    model = model_cls(**ctor_kwargs).to(device)

    # Explicitly disable "weights-only" safe mode (PyTorch 2.6+ defaults to weights_only=True)
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)

    # Handle different checkpoint formats
    if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
        state = ckpt["model_state_dict"]
    elif isinstance(ckpt, dict) and "state_dict" in ckpt:
        state = ckpt["state_dict"]
    else:
        state = ckpt

    # Remove "module." prefix if the checkpoint was saved with DataParallel/DistributedDataParallel
    new_state = OrderedDict()
    for k, v in state.items():
        if k.startswith("module."):
            new_state[k[len("module."):]] = v
        else:
            new_state[k] = v

    model.load_state_dict(new_state, strict=True)
    model.eval()
    return model


def paired_metrics(unet_ckpt: str,
                   attn_ckpt: str,
                   unet_cfg: str,
                   attn_cfg: str,
                   threshold: float = 0.5,
                   device: str = "cpu"):
    """Under identical test-set ordering, compute per-sample Dice/IoU for both models and return paired arrays."""
    cfg_u, test_loader = build_test_loader(unet_cfg)
    cfg_a, _ = build_test_loader(attn_cfg)  # Only to obtain Attention U-Net ctor args; test order follows the first loader

    m_u = load_model(UNet, cfg_u, unet_ckpt, device)
    m_a = load_model(AttentionUNet, cfg_a, attn_ckpt, device)

    dice_u, dice_a, iou_u, iou_a = [], [], [], []

    with torch.no_grad():
        for batch in test_loader:
            x = batch["image"].to(device)
            y = batch["mask"].to(device)

            # Forward pass + sigmoid to probability
            pu = torch.sigmoid(m_u(x)).cpu().numpy()
            pa = torch.sigmoid(m_a(x)).cpu().numpy()
            y_ = y.cpu().numpy()

            B = x.size(0)
            for i in range(B):
                # Binarize using threshold
                dice_u.append(dice_coefficient(pu[i:i+1], y_[i:i+1], threshold=threshold))
                dice_a.append(dice_coefficient(pa[i:i+1], y_[i:i+1], threshold=threshold))
                iou_u.append(iou_score(pu[i:i+1], y_[i:i+1], threshold=threshold))
                iou_a.append(iou_score(pa[i:i+1], y_[i:i+1], threshold=threshold))

    return (np.array(dice_u), np.array(dice_a)), (np.array(iou_u), np.array(iou_a))


def bootstrap_CI(diff: np.ndarray, B: int = 10000, seed: int = 0):
    """Percentile bootstrap CI for the mean of paired differences."""
    rng = np.random.default_rng(seed)
    boot = rng.choice(diff, size=(B, diff.size), replace=True).mean(axis=1)
    return np.percentile(boot, [2.5, 97.5])


def permutation_pvalue(diff: np.ndarray, B: int = 20000, seed: int = 0):
    """Sign-flip permutation test under H0: mean(diff) = 0; returns right-tailed p-value."""
    rng = np.random.default_rng(seed)
    signs = rng.choice([-1, 1], size=(B, diff.size))
    perm_means = (signs * diff).mean(axis=1)
    p = ((perm_means >= diff.mean()).sum() + 1) / (B + 1)
    return p


if __name__ == "__main__":

    UNET_CFG = "configs/unet.yaml"
    ATTN_CFG = "configs/attention_unet.yaml"
    UNET_CKPT = "checkpoints/unet/best_model.pt"
    ATTN_CKPT = "checkpoints/attention_unet/best_model.pt"

    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    THRESH = 0.5

    # Compute per-sample metrics
    (dice_u, dice_a), (iou_u, iou_a) = paired_metrics(
        UNET_CKPT, ATTN_CKPT, UNET_CFG, ATTN_CFG, threshold=THRESH, device=DEVICE
    )

    # Differences
    d_dice = dice_a - dice_u
    d_iou = iou_a - iou_u

    # Statistical tests (focus on Dice)
    t_stat, p_t = stats.ttest_rel(dice_a, dice_u, alternative="two-sided")
    w_res = stats.wilcoxon(d_dice, zero_method="wilcox", alternative="greater")
    p_perm = permutation_pvalue(d_dice, B=20000, seed=0)

    # Confidence interval & effect size (paired Cohen's dz)
    ci_low, ci_high = bootstrap_CI(d_dice, B=10000, seed=0)
    dz = d_dice.mean() / (d_dice.std(ddof=1) / np.sqrt(d_dice.size))

    # Output
    print(f"[Dice] Δmean = {d_dice.mean():.6f}, 95% CI [{ci_low:.6f}, {ci_high:.6f}]")
    print(f"  paired t-test: t = {t_stat:.3f}, p = {p_t:.4f}")
    print(f"  Wilcoxon (H1: Attention > U-Net): p = {w_res.pvalue:.4f}")
    print(f"  Permutation p ≈ {p_perm:.4f}")
    print(f"  Effect size (Cohen's dz) = {dz:.3f}")

    print(f"[IoU ] Δmean = {d_iou.mean():.6f} (std = {d_iou.std(ddof=1):.6f}, n = {d_iou.size})")



Loaded 246 samples for train split
Loaded 246 samples for val split
Loaded 212 samples for test split
Loaded 246 samples for train split
Loaded 246 samples for val split
Loaded 212 samples for test split
[Dice] Δmean = 0.002237, 95% CI [0.001729, 0.002747]
  paired t-test: t = 8.553, p = 0.0000
  Wilcoxon (H1: Attention > U-Net): p = 0.0000
  Permutation p ≈ 0.0000
  Effect size (Cohen's dz) = 8.553
[IoU ] Δmean = 0.004199 (std = 0.006921, n = 212)


## Bucketed performance analysis

In [36]:
# Bucketed performance analysis (5 bins for lung area ratio + 3 bins for image size)
import os, sys, inspect
import numpy as np
import pandas as pd
import torch, yaml
from collections import OrderedDict
from scipy import stats

# === Add project root to sys.path (to avoid import issues in multiprocessing) ===
PROJECT_ROOT = os.getcwd()
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

# Internal project modules
from datasets import create_data_loaders
from utils.metrics import dice_coefficient, iou_score
from models.unet import UNet
from models.attention_unet import AttentionUNet

# ----------------- Utility functions -----------------
def load_cfg(p):
    with open(p, "r") as f:
        return yaml.safe_load(f) or {}

def _filter_kwargs_for_ctor(model_cls, maybe_kwargs):
    sig = inspect.signature(model_cls.__init__)
    return {k: v for k, v in (maybe_kwargs or {}).items() if k in sig.parameters}

def load_model(model_cls, cfg, ckpt_path, device="cpu"):
    ctor = _filter_kwargs_for_ctor(model_cls, cfg.get("model", {}))
    m = model_cls(**ctor).to(device)

    # Explicitly disable weights_only
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
    if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
        state = ckpt["model_state_dict"]
    elif isinstance(ckpt, dict) and "state_dict" in ckpt:
        state = ckpt["state_dict"]
    else:
        state = ckpt

    new_state = OrderedDict((k[len("module."):] if k.startswith("module.") else k, v)
                            for k, v in state.items())
    m.load_state_dict(new_state, strict=True)
    m.eval()
    return m

def build_test_loader(cfg):
    cfg.setdefault("data", {})
    cfg["data"]["shuffle"] = False
    cfg["data"]["num_workers"] = 0
    cfg["data"]["pin_memory"] = False
    cfg["data"]["persistent_workers"] = False
    _, _, test_loader, _ = create_data_loaders(cfg)
    return test_loader

# -------- Safely extract batch-level image names --------
def get_batch_names(batch, B, global_start_idx):
    """
    Returns a list of length B; attempts to extract from keys like name / image_name / path / etc.
    Falls back to index-based naming if unavailable.
    """
    candidate_keys = ["name", "image_name", "path", "image_path", "filename", "id"]
    for key in candidate_keys:
        if key in batch:
            v = batch[key]
            if isinstance(v, (list, tuple)):
                return [str(x) for x in v]
            if isinstance(v, np.ndarray):
                return [str(x) for x in v.tolist()]
            if torch.is_tensor(v):
                vv = v.detach().cpu().numpy().tolist()
                if isinstance(vv, (int, float, str)):
                    return [str(vv)] * B
                return [str(x) for x in vv]
            return [str(v)] * B
    return [f"{global_start_idx + i}" for i in range(B)]

# ----------------- Per-image evaluation -----------------
def run_per_image_metrics(unet_cfg, attn_cfg, unet_ckpt, attn_ckpt, threshold=0.5, device="cpu"):
    cu = load_cfg(unet_cfg); ca = load_cfg(attn_cfg)
    loader = build_test_loader(cu)  
    mu = load_model(UNet, cu, unet_ckpt, device)
    ma = load_model(AttentionUNet, ca, attn_ckpt, device)

    rows = []
    seen = 0
    with torch.no_grad():
        for batch in loader:
            x = batch["image"].to(device); y = batch["mask"].to(device)
            B = x.size(0)
            names = get_batch_names(batch, B, seen)
            pu = torch.sigmoid(mu(x)).cpu().numpy()
            pa = torch.sigmoid(ma(x)).cpu().numpy()
            y_ = y.cpu().numpy()
            for i in range(B):
                nm = str(names[i])
                rows.append({
                    "image_name": nm,
                    "dice_unet": float(dice_coefficient(pu[i:i+1], y_[i:i+1], threshold)),
                    "dice_attn": float(dice_coefficient(pa[i:i+1], y_[i:i+1], threshold)),
                    "iou_unet":  float(iou_score(pu[i:i+1], y_[i:i+1], threshold)),
                    "iou_attn":  float(iou_score(pa[i:i+1], y_[i:i+1], threshold)),
                })
            seen += B
    df = pd.DataFrame(rows)
    df["delta_dice"] = df["dice_attn"] - df["dice_unet"]
    df["delta_iou"]  = df["iou_attn"]  - df["iou_unet"]
    df["row_id"] = np.arange(len(df))
    return df

# ----------------- Attach metadata and bin -----------------
def attach_metadata(per_image_df, metadata_path, split_path):
    meta  = pd.read_csv(metadata_path)
    split = pd.read_csv(split_path)

    # Filter to test set only, unify keys to string
    test_names = split.loc[split["split"] == "test", "image_name"].copy().astype(str)
    meta["image_name"] = meta["image_name"].astype(str)
    per_image_df["image_name"] = per_image_df["image_name"].astype(str)

    meta_t = meta.merge(test_names.to_frame(), on="image_name", how="inner")

    if "positive_ratio" in meta_t.columns:
        meta_t = meta_t.rename(columns={"positive_ratio": "lung_area_ratio"})
    else:
        raise KeyError("Missing 'positive_ratio' column in metadata.csv (expected lung area ratio).")

    # Area and binning
    meta_t["area"] = meta_t["width"] * meta_t["height"]
    meta_t["lung_ratio_bin"] = pd.qcut(meta_t["lung_area_ratio"], q=5, labels=False, duplicates="drop")
    meta_t["area_bin"] = pd.qcut(meta_t["area"], q=3, labels=False, duplicates="drop")

    merged = per_image_df.merge(meta_t, on="image_name", how="left")

    # Fallback if merge failed (e.g. >20% NaN), align by row_id
    miss_ratio = merged["lung_area_ratio"].isna().mean()
    if miss_ratio > 0.2:
        meta_t_seq = meta_t.reset_index(drop=True).copy()
        meta_t_seq["row_id"] = np.arange(len(meta_t_seq))
        merged = per_image_df.merge(meta_t_seq.drop(columns=["image_name"]), on="row_id", how="left")

    return merged

# ----------------- Statistical summary -----------------
def bootstrap_ci_mean(x, B=10000, seed=42):
    rng = np.random.default_rng(seed)
    boot = rng.choice(x, size=(B, len(x)), replace=True).mean(axis=1)
    return np.percentile(boot, [2.5, 97.5])

def summarize_by_bucket(df, bucket_col, metric_delta_col="delta_dice"):
    out = []
    for b, g in df.groupby(bucket_col):
        x = g[metric_delta_col].dropna().values
        if len(x) == 0:
            continue
        ci = bootstrap_ci_mean(x)
        w = stats.wilcoxon(x, zero_method="wilcox", alternative="greater")  # H1: Attention > U-Net
        out.append({
            "bucket": int(b) if isinstance(b, (int, np.integer)) else str(b),
            "n": int(len(x)),
            "delta_mean": float(np.mean(x)),
            "delta_ci_low": float(ci[0]),
            "delta_ci_high": float(ci[1]),
            "wilcoxon_p": float(w.pvalue)
        })
    return pd.DataFrame(out).sort_values("bucket")

# ----------------- Entry point -----------------
def main():
    # Adjust paths as needed
    UNET_CFG  = "configs/unet.yaml"
    ATTN_CFG  = "configs/attention_unet.yaml"
    UNET_CKPT = "checkpoints/unet/best_model.pt"
    ATTN_CKPT = "checkpoints/attention_unet/best_model.pt"

    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    THRESH = 0.5

    # Per-image metrics
    df_img = run_per_image_metrics(UNET_CFG, ATTN_CFG, UNET_CKPT, ATTN_CKPT, threshold=THRESH, device=DEVICE)

    # Attach metadata
    with open(UNET_CFG, "r") as f:
        cfg = yaml.safe_load(f)
    meta_path  = cfg["data"]["metadata_path"]
    split_path = cfg["data"]["split_path"]
    df_all = attach_metadata(df_img, meta_path, split_path)

    # Output
    os.makedirs("outputs/bucket_eval", exist_ok=True)
    df_all.to_csv("outputs/bucket_eval/per_image_with_meta.csv", index=False)

    # Bucket summaries: lung area ratio (5 bins) + image size (3 bins)
    sum_lung = summarize_by_bucket(df_all, "lung_ratio_bin", "delta_dice")
    sum_area = summarize_by_bucket(df_all, "area_bin", "delta_dice")

    sum_lung.to_csv("outputs/bucket_eval/summary_by_lung_ratio.csv", index=False)
    sum_area.to_csv("outputs/bucket_eval/summary_by_area.csv", index=False)

    print("=== ΔDice by lung area ratio bin (0 = lowest) ===")
    print(sum_lung.to_string(index=False))
    print("\n=== ΔDice by area bin (0 = smallest) ===")
    print(sum_area.to_string(index=False))

if __name__ == "__main__":
    main()


Loaded 246 samples for train split
Loaded 246 samples for val split
Loaded 212 samples for test split
=== ΔDice by lung area ratio bin (0 = lowest) ===
 bucket  n  delta_mean  delta_ci_low  delta_ci_high  wilcoxon_p
      0 43    0.003597      0.002351       0.004777    0.000003
      1 42    0.002341      0.001279       0.003334    0.000007
      2 42    0.001969      0.000556       0.003444    0.000489
      3 42    0.001931      0.001006       0.002932    0.000541
      4 43    0.001340      0.000541       0.002170    0.000379

=== ΔDice by area bin (0 = smallest) ===
 bucket  n  delta_mean  delta_ci_low  delta_ci_high   wilcoxon_p
      0 71    0.001658      0.000818       0.002475 2.613209e-05
      1 70    0.002421      0.001511       0.003263 2.687844e-07
      2 71    0.002637      0.001713       0.003594 8.679081e-08


## Bucketed ΔDice bar plots with 95% CI & box plots

In [38]:
# Bucketed ΔDice bar plots with 95% CI & box plots
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

IN_DIR  = "outputs/bucket_eval"
OUT_DIR = "outputs/bucket_eval/figs"
os.makedirs(OUT_DIR, exist_ok=True)

# ---- Bin labels (Q1..Q5) ----
def pretty_bin_labels(series, prefix="Bin"):
    # series could be int or string; sort by appearance order
    vals = sorted(series.unique(), key=lambda x: int(x) if str(x).isdigit() else x)
    mapping = {}
    for i, v in enumerate(vals):
        # If bins are integers 0..k, map to Q1..Qk; otherwise use the string directly
        if str(v).isdigit():
            mapping[v] = f"Q{i+1}"
        else:
            mapping[v] = str(v)
    return series.map(mapping), [mapping[v] for v in vals]

# ---- 1) Bar + 95% CI: lung area ratio (5 bins) ----
def plot_bar_ci(csv_path, title, outfile):
    df = pd.read_csv(csv_path)
    if "bucket" not in df.columns:
        raise ValueError(f"Missing 'bucket' column in {csv_path}")
    xlabels_series, order_labels = pretty_bin_labels(df["bucket"])
    x = np.arange(len(df))
    y = df["delta_mean"].values
    yerr = np.vstack([y - df["delta_ci_low"].values,
                      df["delta_ci_high"].values - y])

    plt.figure(figsize=(6.0, 3.8), dpi=300)
    plt.bar(x, y)
    plt.errorbar(x, y, yerr=yerr, fmt="none", capsize=3, linewidth=1)
    plt.xticks(x, xlabels_series)
    plt.ylabel("ΔDice (Attention − U-Net)")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(outfile, bbox_inches="tight")
    plt.close()

# ---- 2) Bar + 95% CI: image area (3 bins) ----
def plot_bar_ci_area():
    plot_bar_ci(
        os.path.join(IN_DIR, "summary_by_area.csv"),
        "ΔDice by Image Area (Tertiles)",
        os.path.join(OUT_DIR, "bar_ci_area.png"),
    )

def plot_bar_ci_lung():
    plot_bar_ci(
        os.path.join(IN_DIR, "summary_by_lung_ratio.csv"),
        "ΔDice by Lung Area Ratio (Quintiles)",
        os.path.join(OUT_DIR, "bar_ci_lung.png"),
    )

# ---- 3) Box plot: per-bin ΔDice distribution (from per_image_with_meta.csv) ----
def plot_box_per_bucket(per_image_csv, bucket_col, title, outfile):
    df = pd.read_csv(per_image_csv)
    if bucket_col not in df.columns:
        raise ValueError(f"Missing column in {per_image_csv}: {bucket_col}")
    # Keep rows that have valid bucket assignment
    df = df[pd.notnull(df[bucket_col])]
    # Order & x-axis labels
    if np.issubdtype(df[bucket_col].dtype, np.number):
        order = sorted(df[bucket_col].unique())



Saved figures to: outputs/bucket_eval/figs


## Four-layer attention overlay panel figure

In [70]:
# Four-layer attention overlay panel figure
import os, sys, inspect
import numpy as np
import pandas as pd
import torch, yaml
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image
from collections import OrderedDict

# ===== Ensure project root is in sys.path =====
PROJECT_ROOT = os.getcwd()
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from datasets import create_data_loaders
from utils.metrics import dice_coefficient, iou_score
from models.unet import UNet
from models.attention_unet import AttentionUNet   

# ------------------------ Configuration ------------------------ #
UNET_CFG_DEFAULT  = "configs/unet.yaml"
ATTN_CFG_DEFAULT  = "configs/attention_unet.yaml"
UNET_CKPT_DEFAULT = "checkpoints/unet/best_model.pt"
ATTN_CKPT_DEFAULT = "checkpoints/attention_unet/best_model.pt"

# ------------------------ Utilities ------------------------ #
def load_cfg(p):
    with open(p, "r") as f:
        return yaml.safe_load(f) or {}

def _filter_kwargs_for_ctor(model_cls, maybe_kwargs):
    sig = inspect.signature(model_cls.__init__)
    return {k: v for k, v in (maybe_kwargs or {}).items() if k in sig.parameters}

def strict_load(model, ckpt_path, device="cpu"):
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
    sd = ckpt.get("model_state_dict", ckpt.get("state_dict", ckpt))
    clean = OrderedDict((k[7:], v) if k.startswith("module.") else (k, v) for k, v in sd.items())
    missing, unexpected = model.load_state_dict(clean, strict=False)
    if getattr(missing, "keys", None):  
        missing = list(missing.keys())
    if missing or unexpected:
        raise RuntimeError(
            f"[StrictLoadError] {ckpt_path}\n"
            f"  Missing: {missing}\n"
            f"  Unexpected: {unexpected}"
        )
    model.eval()
    return model

def load_model(model_cls, cfg, ckpt_path, device="cpu"):
    ctor = _filter_kwargs_for_ctor(model_cls, cfg.get("model", {}))
    m = model_cls(**ctor).to(device)
    m = strict_load(m, ckpt_path, device=device)
    return m

def build_test_loader(cfg):
    cfg.setdefault("data", {})
    cfg["data"]["shuffle"] = False
    cfg["data"]["num_workers"] = 0
    cfg["data"]["pin_memory"] = False
    cfg["data"]["persistent_workers"] = False
    _, _, test_loader, _ = create_data_loaders(cfg)
    return test_loader

# -------- Safely fetch batch names (avoid tensor truth-value ambiguity) --------
def get_batch_names(batch, B, global_start_idx):
    """
    Return a list of length B. Prefer keys: name / image_name / path / image_path / filename / id.
    If none exist or have mismatched types, fall back to sequential indices (aligned with split[test]).
    """
    candidate_keys = ["name", "image_name", "path", "image_path", "filename", "id"]
    for key in candidate_keys:
        if key in batch:
            v = batch[key]
            if isinstance(v, (list, tuple)):
                return [str(x) for x in v]
            if isinstance(v, np.ndarray):
                return [str(x) for x in v.tolist()]
            if torch.is_tensor(v):
                vv = v.detach().cpu().numpy().tolist()
                if isinstance(vv, (int, float, str)):
                    return [str(vv)] * B
                return [str(x) for x in vv]
            return [str(v)] * B
    return [f"{global_start_idx + i}" for i in range(B)]

# ------------------------ Hooks: capture attention maps (no model code changes) ------------------------ #
def _is_att_block(mod_name, mod):
    """Loosely identify attention blocks: name contains 'attention' and there exists a 'psi' submodule."""
    if "attention" in mod_name.lower():
        for n, m in mod.named_modules():
            if n.endswith("psi") or n.endswith("psi.0"):
                return True
    return False

def attach_attention_hooks(att_model):
    """
    Register forward hooks on the 'psi' submodules of attention blocks to capture their outputs.
    """
    handles = []
    att_model._captured_attn = []  # transient cache

    def _hook(_module, _inp, out):
        if isinstance(out, torch.Tensor):
            t = out.detach()
            if t.dim() == 3:  # [B,h,w] -> [B,1,h,w]
                t = t.unsqueeze(1)
            att_model._captured_attn.append(t)

    for name, mod in att_model.named_modules():
        if _is_att_block(name, mod):
            # Prefer hooking on 'psi'
            psi = None
            for n, m in mod.named_modules():
                if n.endswith("psi") or n.endswith("psi.0"):
                    psi = m
                    break
            target = psi if psi is not None else mod
            handles.append(target.register_forward_hook(_hook))

    return handles

def get_attention_maps(att_model, x_single):
    """
    Prefer reading in-model cache `self.attention_weights`; fallback to hooks otherwise.
    Returns list[Tensor], each ~ [1,1,h,w].
    """
    # Path 1: in-model cache exists
    maps = getattr(att_model, "attention_weights", None)
    if isinstance(maps, (list, tuple)) and len(maps) > 0:
        return maps

    # Path 2: capture via hooks
    handles = attach_attention_hooks(att_model)
    try:
        with torch.no_grad():
            _ = att_model(x_single)  # trigger forward
        captured = getattr(att_model, "_captured_attn", [])
        if not captured:
            raise RuntimeError(
                "Hook did not capture any attention maps. "
                "Check attention block naming or consider enabling in-model caching."
            )
        # Sort by resolution (small -> large), roughly “deep -> shallow”
        captured = sorted(captured, key=lambda t: (t.shape[-2], t.shape[-1]))
        return captured
    finally:
        for h in handles:
            h.remove()
        if hasattr(att_model, "_captured_attn"):
            delattr(att_model, "_captured_attn")

# ------------------------ Per-image metrics & selection ------------------------ #
def per_image_metrics_and_select(unet_cfg, attn_cfg, unet_ckpt, attn_ckpt, k=6, threshold=0.5, device="cpu"):
    cu = load_cfg(unet_cfg); ca = load_cfg(attn_cfg)
    loader = build_test_loader(cu)  
    mu = load_model(UNet, cu, unet_ckpt, device)
    ma = load_model(AttentionUNet, ca, attn_ckpt, device)

    rows = []
    cache = []  # list of dict per sample

    seen = 0
    with torch.no_grad():
        for batch in loader:
            x = batch["image"].to(device)
            y = batch["mask"].to(device)
            B = x.size(0)
            names = get_batch_names(batch, B, seen)

            pu = torch.sigmoid(mu(x))
            pa = torch.sigmoid(ma(x))  # forward also triggers attention hooks if attached

            for i in range(B):
                xi = x[i:i+1].cpu()
                yi = y[i:i+1].cpu()
                pui = pu[i:i+1].cpu()
                pai = pa[i:i+1].cpu()

                dice_u = float(dice_coefficient(pui.numpy(), yi.numpy(), threshold=threshold))
                dice_a = float(dice_coefficient(pai.numpy(), yi.numpy(), threshold=threshold))

                name = str(names[i])
                rows.append({
                    "image_name": name,
                    "dice_unet": dice_u,
                    "dice_attn": dice_a,
                    "delta_dice": dice_a - dice_u,
                })
                cache.append({
                    "image_name": name,
                    "image": xi,  # [1,1,H,W] or [1,C,H,W]
                    "mask":  yi,  # [1,1,H,W]
                })
            seen += B

    df = pd.DataFrame(rows).sort_values("delta_dice", ascending=False).reset_index(drop=True)
    topk = df.head(k)["image_name"].tolist()
    bottomk = df.tail(k)["image_name"].tolist()
    idx = {c["image_name"]: c for c in cache}
    return df, topk, bottomk, idx, mu, ma

# ------------------------ Visualization helpers ------------------------ #
def to_display_img(x_tensor):
    """[1,C,H,W] or [C,H,W] -> (H,W) numpy normalized to [0,1] for imshow."""
    x = x_tensor.squeeze().float().cpu().numpy()
    if x.ndim == 3:  # C,H,W
        x = x[0]     # single-channel
    x = x - x.min()
    if x.max() > 1e-8:
        x = x / x.max()
    return x

def prob_to_mask(p_tensor, thr=0.5):
    p = p_tensor.squeeze().cpu().numpy()
    if p.ndim == 3:
        p = p[0]
    return (p >= thr).astype(np.uint8)

def upsample_to(img_like, att_map):
    """Upsample attention map (1,1,h,w) -> (H,W) to align with the image; return [0,1] numpy."""
    am = att_map.squeeze().detach().cpu().numpy()
    if am.ndim == 3:
        am = am[0]
    am = (am - am.min()) / (am.max() - am.min() + 1e-8)
    H, W = img_like.shape[-2], img_like.shape[-1]
    am_img = Image.fromarray((am*255).astype("uint8")).resize((W, H), resample=Image.BILINEAR)
    return np.array(am_img) / 255.0

def render_panel(save_path, img, gt, pred_u, pred_a, attn_maps, title):
    """
    attn_maps: list of 4 numpy arrays in [0,1], shape (H,W), already upsampled.
    Panel columns: Image | GT | U-Net | AttU-Net | Att-1 | Att-2 | Att-3 | Att-4
    """
    cols = 8
    plt.figure(figsize=(cols*2.1, 2.6), dpi=250)

    # 1 Image
    plt.subplot(1, cols, 1); plt.imshow(img, cmap="gray"); plt.axis("off"); plt.title("Image")
    # 2 GT
    plt.subplot(1, cols, 2); plt.imshow(gt, cmap="gray"); plt.axis("off"); plt.title("GT")
    # 3 U-Net
    plt.subplot(1, cols, 3); plt.imshow(pred_u, cmap="gray"); plt.axis("off"); plt.title("U-Net")
    # 4 AttU-Net
    plt.subplot(1, cols, 4); plt.imshow(pred_a, cmap="gray"); plt.axis("off"); plt.title("AttU-Net")
    # 5–8 Attention overlays
    for i, am in enumerate(attn_maps, start=5):
        plt.subplot(1, cols, i)
        plt.imshow(img, cmap="gray")
        plt.imshow(am, alpha=0.45)
        plt.axis("off"); plt.title(f"Att-{i-4}")
    plt.suptitle(title, y=1.02, fontsize=11)
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

# ------------------------ Main: export panel figures & captions ------------------------ #
def export_panels(
    unet_cfg=UNET_CFG_DEFAULT,
    attn_cfg=ATTN_CFG_DEFAULT,
    unet_ckpt=UNET_CKPT_DEFAULT,
    attn_ckpt=ATTN_CKPT_DEFAULT,
    out_dir="outputs/attention_panels",
    k=6, threshold=0.5, device="cpu"
):
    os.makedirs(out_dir, exist_ok=True)

    df, topk, bottomk, cache_idx, mu, ma = per_image_metrics_and_select(
        unet_cfg, attn_cfg, unet_ckpt, attn_ckpt, k=k, threshold=threshold, device=device
    )
    # Save full per-image results
    df.to_csv(os.path.join(out_dir, "per_image_dice.csv"), index=False)

    selected = [("top", topk), ("bottom", bottomk)]
    cap_lines = []

    for tag, name_list in selected:
        tag_dir = os.path.join(out_dir, tag)
        os.makedirs(tag_dir, exist_ok=True)

        for name in name_list:
            item = cache_idx[name]
            x = item["image"].to(device)  # [1,1,H,W]
            y = item["mask"]
            with torch.no_grad():
                pu = torch.sigmoid(mu(x))
                pa = torch.sigmoid(ma(x))  

            # Read attention maps and upsample to the image size
            att_maps = get_attention_maps(ma, x)
            up_maps = [upsample_to(x, m) for m in att_maps]

            # Numpy for display
            img_np = to_display_img(x.cpu())
            gt_np  = prob_to_mask(y, 0.5)      # GT is a 0/1 mask; display directly
            un_np  = prob_to_mask(pu, threshold)
            at_np  = prob_to_mask(pa, threshold)

            # Title & export
            row = df[df["image_name"] == name].iloc[0]
            title = f"{name} | Dice_u={row['dice_unet']:.4f}  Dice_attn={row['dice_attn']:.4f}  Δ={row['delta_dice']:.4f}"
            save_path = os.path.join(tag_dir, f"{name}.png")
            render_panel(save_path, img_np, gt_np, un_np, at_np, up_maps, title)

            # Caption (.txt)
            cap = (
                f"{tag.upper()} | {name}\n"
                f"Dice (U-Net)={row['dice_unet']:.4f}, Dice (AttU-Net)={row['dice_attn']:.4f}, Δ={row['delta_dice']:.4f}\n"
                f"Notes: Inspect attention hotspots near thin/ambiguous boundaries (apices, costophrenic angles). "
                f"For failure cases, check non-lung distractors (cardiac shadow, diaphragm edges, rib artifacts)."
            )
            cap_lines.append(cap)

    with open(os.path.join(out_dir, "captions.txt"), "w") as f:
        f.write("\n\n".join(cap_lines))

    print(f"Saved panels & captions under: {out_dir}")

# ------------------------ Run ------------------------ #
if __name__ == "__main__":
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    THRESH = 0.5
    TOPK = 6

    export_panels(
        unet_cfg=UNET_CFG_DEFAULT,
        attn_cfg=ATTN_CFG_DEFAULT,
        unet_ckpt=UNET_CKPT_DEFAULT,
        attn_ckpt=ATTN_CKPT_DEFAULT,
        out_dir="outputs/attention_panels",
        k=TOPK, threshold=THRESH, device=DEVICE
    )


Loaded 246 samples for train split
Loaded 246 samples for val split
Loaded 212 samples for test split
Saved panels & captions under: outputs/attention_panels


In [76]:
import pandas as pd
from scipy.stats import wilcoxon
from statsmodels.stats.multitest import multipletests

df = pd.read_csv("/Users/t/Documents/DISSERTATION/outputs/bucket_eval/per_image_with_meta.csv")

area_map = {0:'Q1 (small)', 1:'Q2 (medium)', 2:'Q3 (large)'}
lung_map = {0:'Q1 (small lungs)',1:'Q2',2:'Q3',3:'Q4',4:'Q5 (large lungs)'}
df['area_label'] = df['area_bin'].map(area_map)
df['lung_label'] = df['lung_ratio_bin'].map(lung_map)

rows = []
for k, sub in df.groupby('area_label'):
    stat = wilcoxon(sub['delta_dice'], alternative='greater', zero_method='wilcox')
    rows.append(['Image area', k, len(sub), sub['delta_dice'].mean(), stat.pvalue])

for k, sub in df.groupby('lung_label'):
    stat = wilcoxon(sub['delta_dice'], alternative='greater', zero_method='wilcox')
    rows.append(['Lung area ratio', k, len(sub), sub['delta_dice'].mean(), stat.pvalue])

res = pd.DataFrame(rows, columns=['Stratum','Level','N','Mean ΔDice','Wilcoxon p (raw)'])

res['BH q (adj)'] = multipletests(res['Wilcoxon p (raw)'], method='fdr_bh')[1]
res['Bonferroni p (adj)'] = multipletests(res['Wilcoxon p (raw)'], method='bonferroni')[1]

res


Unnamed: 0,Stratum,Level,N,Mean ΔDice,Wilcoxon p (raw),BH q (adj),Bonferroni p (adj)
0,Image area,Q1 (small),71,0.001658,2.613209e-05,4.181134e-05,0.0002090567
1,Image area,Q2 (medium),70,0.002421,2.687844e-07,1.075137e-06,2.150275e-06
2,Image area,Q3 (large),71,0.002637,8.679081e-08,6.943265e-07,6.943265e-07
3,Lung area ratio,Q1 (small lungs),43,0.003597,2.603353e-06,6.942273e-06,2.082682e-05
4,Lung area ratio,Q2,42,0.002341,6.794283e-06,1.358857e-05,5.435426e-05
5,Lung area ratio,Q3,42,0.001969,0.0004894998,0.0005407854,0.003915998
6,Lung area ratio,Q4,42,0.001931,0.0005407854,0.0005407854,0.004326284
7,Lung area ratio,Q5 (large lungs),43,0.00134,0.0003794029,0.0005058705,0.003035223
