In [15]:
# STED 0.375% analysis script
# Reads pore masks, measures segmentation quality per method, runs stats, and exports figures/CSVs.

import os
import re
import glob
import csv
import itertools
import numpy as np
import cv2
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from scipy.stats import kruskal, wilcoxon, binomtest
from scipy.stats import t as student_t  # for Bland–Altman CIs

# paths and basic config

# where the source TIFFs live
base_dir  = r"C:\Users\walsh\Downloads\STED Accuracy INTERNAL\Internal 0.375%"

# where all results should be written
results_root = r"C:\Users\walsh\Documents\GitHub\AGAROSE-HYDROGEL-TRENDS-USING-AI-ML"

gold_path = base_dir + r"\GOLD STANDARD.tif"

method_paths = [
    base_dir + r"\60%.tif",
    base_dir + r"\FREEHAND.tif",
    base_dir + r"\OVAL.tif",
    base_dir + r"\ILASTIK.tif",
    base_dir + r"\OTSU.tif",
    base_dir + r"\PLANKSTER.tif",
    base_dir + r"\PORED2.tif",
    base_dir + r"\SAMJ.tif",
    base_dir + r"\SEMI.tif",
    base_dir + r"\UNET.tif",
]

display_tag = "[STED 0.375%]"

# output subfolders (all under results_root now)
out_dir   = os.path.join(results_root, "Accuracy")
ov_dir    = os.path.join(out_dir, "Overlays_TIFF")
sens_dir  = os.path.join(out_dir, "Sensitivity")
fig_dir   = os.path.join(results_root, "Figures")
stats_dir = os.path.join(results_root, "Stats")
ba_dir    = os.path.join(fig_dir, "Bland-Altman")

# tiles are 64x64 on 516x516 images
patch_size     = 64
min_area_ratio = 0.0
sens_sizes     = [64, 48]

WRITE_BOOTSTRAP_CI = True
RNG_SEED_CI   = 42
RNG_SEED_PAIR = 7
RNG_SEED_BIAS = 11

os.makedirs(out_dir,  exist_ok=True)
os.makedirs(ov_dir,   exist_ok=True)
os.makedirs(sens_dir, exist_ok=True)
os.makedirs(fig_dir,  exist_ok=True)
os.makedirs(stats_dir,exist_ok=True)
os.makedirs(ba_dir,   exist_ok=True)

# image reading and binarization (pores are black -> value 1)

try:
    import tifffile as tiff
    _HAS_TIFF = True
except Exception:
    _HAS_TIFF = False

try:
    from PIL import Image
    _HAS_PIL = True
except Exception:
    _HAS_PIL = False

def _read_tiff_any(path):
    img = cv2.imread(path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_GRAYSCALE)
    if img is not None:
        return img
    if _HAS_TIFF and os.path.exists(path):
        try:
            return tiff.imread(path)
        except Exception:
            pass
    if _HAS_PIL and os.path.exists(path):
        try:
            with Image.open(path) as im:
                if "I;16" in im.mode:
                    im = im.convert("I;16")
                else:
                    im = im.convert("L")
                return np.array(im)
        except Exception:
            pass
    return None

def _to_gray(arr):
    if arr is None:
        return None
    if arr.ndim == 2:
        return arr
    if arr.ndim == 3 and arr.shape[-1] in (3, 4):
        a = arr
        if a.dtype != np.uint8:
            a_min = float(a.min())
            a_max = float(a.max())
            a = ((a - a_min) / (a_max - a_min + 1e-12) * 255.0).astype(np.uint8)
        if a.shape[-1] == 4:
            a = cv2.cvtColor(a, cv2.COLOR_BGRA2BGR)
        return cv2.cvtColor(a, cv2.COLOR_BGR2GRAY)
    if arr.ndim > 2:
        return _to_gray(arr[..., 0])
    return arr

def _ensure_uint(img):
    if img.dtype in (np.uint8, np.uint16):
        return img
    if np.issubdtype(img.dtype, np.floating):
        a_min = float(img.min())
        a_max = float(img.max())
        if a_max > a_min:
            scaled = (img - a_min) / (a_max - a_min)
        else:
            scaled = np.zeros_like(img, dtype=np.float32)
        return (scaled * 255.0 + 0.5).astype(np.uint8)
    if img.max() > 255:
        return img.astype(np.uint16)
    return img.astype(np.uint8)

def binarize_pores_black(img_gray):
    g = _ensure_uint(img_gray)
    u = np.unique(g)
    if u.size == 2:
        lo = int(u[0])
        pores = (g == lo)
        return pores.astype(np.uint8)
    if g.dtype == np.uint16:
        g8 = (g / 257).astype(np.uint8)
    else:
        g8 = g.astype(np.uint8)
    dummy, th = cv2.threshold(g8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    pores = (th == 0).astype(np.uint8)
    return pores

def read_mask(path):
    path_clean = path.strip()
    if not os.path.exists(path_clean):
        raise FileNotFoundError("file does not exist: " + path_clean)
    raw = _read_tiff_any(path_clean)
    if raw is None:
        raise FileNotFoundError("cannot read image (unsupported/corrupt TIFF): " + path_clean)
    gray = _to_gray(raw)
    if gray is None or gray.ndim != 2:
        raise ValueError("not a single-channel image: " + path_clean)
    return binarize_pores_black(gray)

gt_full = read_mask(gold_path)

# overlay rendering (TP: green edges, FP: red fill, FN: blue fill)

try:
    import tifffile as _tiff
    _HAS_TIFF_WRITE = True
except Exception:
    _HAS_TIFF_WRITE = False

try:
    from PIL import Image as _PILImage
    _HAS_PIL_WRITE = True
except Exception:
    _HAS_PIL_WRITE = False

def save_overlay_tif(gt, pr, save_path_tif):
    base_gray = (1 - gt) * 255
    overlay = cv2.cvtColor(base_gray.astype(np.uint8), cv2.COLOR_GRAY2BGR)

    tp = (gt == 1) & (pr == 1)
    fp = (gt == 0) & (pr == 1)
    fn = (gt == 1) & (pr == 0)

    fp_layer = overlay.copy()
    fp_layer[fp] = (0, 0, 255)
    overlay = cv2.addWeighted(fp_layer, 0.45, overlay, 0.55, 0)

    fn_layer = overlay.copy()
    fn_layer[fn] = (255, 0, 0)
    overlay = cv2.addWeighted(fn_layer, 0.45, overlay, 0.55, 0)

    edges = cv2.Canny((tp.astype(np.uint8) * 255), 50, 150)
    overlay[edges > 0] = (0, 255, 0)

    os.makedirs(os.path.dirname(save_path_tif), exist_ok=True)
    lower_path = save_path_tif.lower()
    if not (lower_path.endswith(".tif") or lower_path.endswith(".tiff")):
        save_path_tif = save_path_tif + ".tif"

    ok = False
    try:
        ok = cv2.imwrite(save_path_tif, overlay)
    except Exception:
        ok = False
    if (not ok) and _HAS_TIFF_WRITE:
        try:
            _tiff.imwrite(save_path_tif, overlay)
            ok = True
        except Exception:
            ok = False
    if (not ok) and _HAS_PIL_WRITE:
        try:
            _PILImage.fromarray(overlay[:, :, ::-1]).save(save_path_tif, format="TIFF")
            ok = True
        except Exception:
            ok = False
    if not ok:
        raise IOError("failed to write TIFF: " + save_path_tif)

# metric helpers

def _safe_div(n, d):
    if d:
        return float(n) / float(d)
    return 0.0

def _counts(gt, pred):
    gt_u8 = gt.astype(np.uint8)
    pred_u8 = pred.astype(np.uint8)
    tp = int(np.sum((gt_u8 == 1) & (pred_u8 == 1)))
    tn = int(np.sum((gt_u8 == 0) & (pred_u8 == 0)))
    fp = int(np.sum((gt_u8 == 0) & (pred_u8 == 1)))
    fn = int(np.sum((gt_u8 == 1) & (pred_u8 == 0)))
    return tp, fp, tn, fn

def compute_metrics(gt, pred):
    tp, fp, tn, fn = _counts(gt, pred)
    acc  = _safe_div(tp + tn, tp + tn + fp + fn)
    prec = _safe_div(tp, tp + fp)
    rec  = _safe_div(tp, tp + fn)
    spec = _safe_div(tn, tn + fp)
    ba   = 0.5 * (rec + spec)
    dice = _safe_div(2 * tp, 2 * tp + fp + fn)
    iou  = _safe_div(tp, tp + fp + fn)
    den  = np.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn))
    if den:
        mcc  = _safe_div(tp*tn - fp*fn, den)
    else:
        mcc  = 0.0
    result = {
        "accuracy": acc,
        "precision": prec,
        "recall": rec,
        "specificity": spec,
        "balanced_accuracy": ba,
        "f1_dice": dice,
        "iou_jaccard": iou,
        "mcc": mcc,
        "TP": tp,
        "FP": fp,
        "TN": tn,
        "FN": fn
    }
    return result

# break images into tiles and compute metrics per tile

def patch_metrics_list(gt, pred, size=patch_size, min_ratio=min_area_ratio):
    h, w = gt.shape
    out = []
    area_min = int((size * size) * min_ratio)
    for y in range(0, h, size):
        for x in range(0, w, size):
            y2 = min(y + size, h)
            x2 = min(x + size, w)
            g = gt[y:y2, x:x2]
            p = pred[y:y2, x:x2]
            if g.size < area_min:
                continue
            out.append(compute_metrics(g, p))
    return out

# loop over each method and save per-method metric CSV and overlay

os.makedirs(out_dir, exist_ok=True)
os.makedirs(ov_dir,  exist_ok=True)

method_names = []
method_rows  = {}

for pred_path in method_paths:
    method_name = os.path.splitext(os.path.basename(pred_path.strip()))[0]
    method_names.append(method_name)

    pr_full = read_mask(pred_path)

    if gt_full.shape != pr_full.shape:
        raise ValueError("shape mismatch for " + method_name + ": " + str(gt_full.shape) + " vs " + str(pr_full.shape))

    rows = patch_metrics_list(gt_full, pr_full, size=patch_size, min_ratio=min_area_ratio)
    if not rows:
        rows = [compute_metrics(gt_full, pr_full)]
    method_rows[method_name] = rows

    keys = list(rows[0].keys())
    count_keys = ["TP", "FP", "TN", "FN"]
    metric_keys = [k for k in keys if k not in count_keys]

    summary_metrics = []
    for k in metric_keys:
        vals = np.array([r[k] for r in rows], dtype=float)
        mean_v = float(np.mean(vals))
        if len(vals) > 1:
            sd_v   = float(np.std(vals, ddof=1))
        else:
            sd_v   = 0.0
        summary_metrics.append((k, mean_v, sd_v))

    totals = []
    for k in count_keys:
        vals = np.array([r[k] for r in rows], dtype=float)
        totals.append((k + "_total", float(np.sum(vals))))

    csv_path = os.path.join(out_dir, "Metric Results [" + method_name + "].csv")
    with open(csv_path, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["metric", "mean", "sd"])
        for k, mean_v, sd_v in sorted(summary_metrics):
            w.writerow([k, mean_v, sd_v])
        w.writerow([])
        w.writerow(["metric", "total"])
        for k, total_v in sorted(totals):
            w.writerow([k, total_v])

    ov_path = os.path.join(ov_dir, method_name + " " + display_tag + " overlay.tif")
    save_overlay_tif(gt_full, pr_full, ov_path)

    dmean_lookup = dict((k, m) for (k, m, s) in summary_metrics)
    dsd_lookup   = dict((k, s) for (k, m, s) in summary_metrics)
    dmean = dmean_lookup["f1_dice"]
    dsd   = dsd_lookup["f1_dice"]

    print(method_name + " " + display_tag + ": tiles=" + str(len(rows)) +
          " | Dice " + ("%.3f" % dmean) + " ± " + ("%.3f" % dsd) +
          " -> " + csv_path + " | " + ov_path)

# tile size sensitivity analysis (64 vs 48 etc)

os.makedirs(sens_dir, exist_ok=True)

def run_size(gt, pr, size):
    rows = patch_metrics_list(gt, pr, size=size, min_ratio=0.0)
    if not rows:
        rows = [compute_metrics(gt, pr)]
    return rows

for pred_path in method_paths:
    method_name = os.path.splitext(os.path.basename(pred_path.strip()))[0]
    pr_full = read_mask(pred_path)

    summary = []
    for s in sens_sizes:
        rows_s = run_size(gt_full, pr_full, s)
        n = len(rows_s)
        keys_interest = [
            "f1_dice",
            "iou_jaccard",
            "precision",
            "recall",
            "accuracy",
            "balanced_accuracy",
            "mcc",
            "specificity"
        ]
        for k in keys_interest:
            vals = np.array([r[k] for r in rows_s], dtype=float)
            mean_v = float(np.mean(vals))
            if n > 1:
                sd_v   = float(np.std(vals, ddof=1))
            else:
                sd_v   = 0.0
            summary.append((k, s, n, mean_v, sd_v))

    csv_path = os.path.join(sens_dir, "TileSize_Sensitivity [" + method_name + "].csv")
    with open(csv_path, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["metric","tile_size","n_tiles","mean","sd"])
        for row in summary:
            w.writerow(row)

    d0 = None
    d1 = None
    for (k, s2, n_tiles, m_val, sd_val) in summary:
        if k == "f1_dice" and s2 == sens_sizes[0]:
            d0 = m_val
        if k == "f1_dice" and s2 == sens_sizes[1]:
            d1 = m_val

    print("[SENS] " + method_name + ": Dice " +
          ("%.3f" % d0) + "->" + ("%.3f" % d1) +
          " -> " + csv_path)

# bootstrap confidence intervals per method

rng_ci = np.random.default_rng(RNG_SEED_CI)

def bootstrap_mean_ci(values, B=2000, alpha=0.05):
    v = np.asarray(values, dtype=float)
    n = len(v)
    if n == 1:
        single = float(v[0])
        return single, single, single
    means = np.empty(B, dtype=float)
    for b in range(B):
        idx = rng_ci.integers(0, n, size=n)
        means[b] = float(np.mean(v[idx]))
    overall_mean = float(np.mean(v))
    lo = float(np.quantile(means, alpha/2.0))
    hi = float(np.quantile(means, 1.0 - alpha/2.0))
    return overall_mean, lo, hi

if WRITE_BOOTSTRAP_CI:
    metric_keys_for_ci = [
        "f1_dice",
        "iou_jaccard",
        "precision",
        "recall",
        "specificity",
        "balanced_accuracy",
        "accuracy",
        "mcc"
    ]
    for method_name, rows_now in method_rows.items():
        out_rows = []
        for k in metric_keys_for_ci:
            vals = [r[k] for r in rows_now]
            mean_v, lo, hi = bootstrap_mean_ci(vals, B=2000, alpha=0.05)
            out_rows.append((k, mean_v, lo, hi, len(vals)))
        csv_path = os.path.join(out_dir, "Bootstrap_CI [" + method_name + "].csv")
        with open(csv_path, "w", newline="", encoding="utf-8") as f:
            w = csv.writer(f)
            w.writerow(["metric","mean","ci95_low","ci95_high","n_tiles"])
            for r in out_rows:
                w.writerow(r)
        print("[CI] " + method_name + ": wrote " + csv_path)

# paired bootstrap against the best Dice method, Holm corrected p-values

patch_bags = method_rows

mean_dice = {}
for m, rows_m in patch_bags.items():
    mean_dice[m] = np.mean([r["f1_dice"] for r in rows_m])

ref = max(mean_dice, key=mean_dice.get)
print("Reference (highest mean Dice): " + ref)

rng_pair = np.random.default_rng(RNG_SEED_PAIR)

def paired_bootstrap_diff_ci(a_vals, b_vals, B=5000, alpha=0.05):
    a = np.asarray(a_vals, dtype=float)
    b = np.asarray(b_vals, dtype=float)
    n = min(len(a), len(b))
    a = a[:n]
    b = b[:n]
    if n == 1:
        d0 = float(a[0] - b[0])
        return d0, d0, d0, 1.0, 1.0
    diffs = np.empty(B, dtype=float)
    for i in range(B):
        idx = rng_pair.integers(0, n, size=n)
        diffs[i] = float(np.mean(a[idx] - b[idx]))
    mean_diff = float(np.mean(a - b))
    lo = float(np.quantile(diffs, alpha/2.0))
    hi = float(np.quantile(diffs, 1.0 - alpha/2.0))
    p_one = float(np.mean(diffs <= 0))
    p_two = float(2.0 * min(p_one, 1.0 - p_one))
    if p_two > 1.0:
        p_two = 1.0
    return mean_diff, lo, hi, p_one, p_two

ref_vals = [r["f1_dice"] for r in patch_bags[ref]]
comp_rows = []

for m, rows_m in patch_bags.items():
    if m == ref:
        continue
    other_vals = [r["f1_dice"] for r in rows_m]
    md, lo, hi, p1, p2 = paired_bootstrap_diff_ci(ref_vals, other_vals, B=5000, alpha=0.05)
    comp_rows.append((ref, m, len(ref_vals), md, lo, hi, p1, p2))

p_sorted = sorted(comp_rows, key=lambda x: x[-1])
K = len(p_sorted)
holm = {}
i_counter = 1
for row in p_sorted:
    ref_mname = row[0]
    mname = row[1]
    p2_now = row[-1]
    holm_val = (K - i_counter + 1) * p2_now
    if holm_val > 1.0:
        holm_val = 1.0
    holm[mname] = holm_val
    i_counter = i_counter + 1

csv_path = os.path.join(out_dir, "Bootstrap_Comparisons_vs_[" + ref + "].csv")
with open(csv_path, "w", newline="", encoding="utf-8") as f:
    w = csv.writer(f)
    w.writerow([
        "ref",
        "method",
        "n_tiles",
        "diff_mean_Dice",
        "ci95_low",
        "ci95_high",
        "p_one_sided(ref>method)",
        "p_two_sided",
        "p_holm_two_sided"
    ])
    for row in p_sorted:
        ref_m = row[0]
        m = row[1]
        n = row[2]
        md = row[3]
        lo = row[4]
        hi = row[5]
        p1 = row[6]
        p2 = row[7]
        w.writerow([ref_m, m, n, md, lo, hi, p1, p2, holm[m]])

print("[COMPARE] reference=" + ref + " -> " + csv_path)

# porosity / pore fraction bias (does a method over-call pores?)

rng_bias = np.random.default_rng(RNG_SEED_BIAS)

def pore_fraction(mask):
    return float(np.mean(mask))

def patch_pore_fractions(gt, pred, size=patch_size, min_ratio=min_area_ratio):
    h, w = gt.shape
    g_list = []
    p_list = []
    area_min = int((size*size)*min_ratio)
    for y in range(0, h, size):
        for x in range(0, w, size):
            y2 = min(y+size, h)
            x2 = min(x+size, w)
            g = gt[y:y2, x:x2]
            p = pred[y:y2, x:x2]
            if g.size < area_min:
                continue
            g_list.append(pore_fraction(g))
            p_list.append(pore_fraction(p))
    if not g_list:
        g_list = [pore_fraction(gt)]
        p_list = [pore_fraction(pred)]
    return np.array(g_list, float), np.array(p_list, float)

def bootstrap_bias_ci(g_list, p_list, B=2000, alpha=0.05):
    d_vals = p_list - g_list
    n_vals = len(d_vals)
    if n_vals == 1:
        only = float(d_vals[0])
        return only, only, only
    means = np.empty(B, float)
    for b in range(B):
        idx = rng_bias.integers(0, n_vals, size=n_vals)
        means[b] = float(np.mean(d_vals[idx]))
    mean_bias = float(np.mean(d_vals))
    lo_b = float(np.quantile(means, alpha/2.0))
    hi_b = float(np.quantile(means, 1.0 - alpha/2.0))
    return mean_bias, lo_b, hi_b

for pred_path in method_paths:
    method_name = os.path.splitext(os.path.basename(pred_path.strip()))[0]
    pr_full = read_mask(pred_path)

    g_list, p_list = patch_pore_fractions(gt_full, pr_full, size=patch_size, min_ratio=min_area_ratio)
    bias = p_list - g_list
    b_mean = float(np.mean(bias))
    if len(bias) > 1:
        b_sd   = float(np.std(bias, ddof=1))
    else:
        b_sd   = 0.0

    if WRITE_BOOTSTRAP_CI:
        boot_mean, b_lo, b_hi = bootstrap_bias_ci(g_list, p_list, B=2000, alpha=0.05)
        cols = [
            "method",
            "n_tiles",
            "gold_pore_frac_mean",
            "pred_pore_frac_mean",
            "bias_mean(pred-gold)",
            "ci95_low",
            "ci95_high",
            "bias_sd"
        ]
        vals = [
            method_name,
            len(bias),
            float(np.mean(g_list)),
            float(np.mean(p_list)),
            boot_mean,
            b_lo,
            b_hi,
            b_sd
        ]
    else:
        cols = [
            "method",
            "n_tiles",
            "gold_pore_frac_mean",
            "pred_pore_frac_mean",
            "bias_mean(pred-gold)",
            "bias_sd"
        ]
        vals = [
            method_name,
            len(bias),
            float(np.mean(g_list)),
            float(np.mean(p_list)),
            b_mean,
            b_sd
        ]

    csv_path = os.path.join(out_dir, "Porosity_Bias [" + method_name + "].csv")
    with open(csv_path, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(cols)
        w.writerow(vals)

    print("[POROSITY] " + method_name + ": wrote " + csv_path)

# figure style and plotting setup

np.random.seed(2025)

mpl.rcParams.update({
    "font.size": 9,
    "axes.linewidth": 0.8,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "xtick.direction": "in",
    "ytick.direction": "in",
    "xtick.major.width": 0.8,
    "ytick.major.width": 0.8,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
    "savefig.dpi": 600,
})

TRADITIONAL = set(["FREEHAND", "OVAL"])
SEMI_AUTO   = set(["SEMI", "SAMJ", "ILASTIK", "60%"])
FULL_AUTO   = set(["PORED2", "UNET", "OTSU", "PLANKSTER"])

C_TRAD  = "#9ecae1"
C_SEMI  = "#d0b7ff"
C_AUTO  = "#f7b6b6"
C_OTHER = "#dddddd"

def method_color_and_group(name):
    u = name.upper()
    found_trad = any(k in u for k in TRADITIONAL)
    found_semi = any(k in u for k in SEMI_AUTO)
    found_full = any(k in u for k in FULL_AUTO)
    if found_trad:
        return C_TRAD, "Traditional"
    if found_semi:
        return C_SEMI, "Semi-automated"
    if found_full:
        return C_AUTO, "Fully automated"
    return C_OTHER, "Other"

legend_handles = [
    Patch(facecolor=C_TRAD, edgecolor='k', label='Traditional'),
    Patch(facecolor=C_SEMI, edgecolor='k', label='Semi-automated'),
    Patch(facecolor=C_AUTO, edgecolor='k', label='Fully automated')
]

NICE = {
    "f1_dice": "Dice (F1)",
    "iou_jaccard": "IoU (Jaccard)",
    "mcc": "Matthews CC",
    "precision": "Precision",
    "recall": "Recall (Sensitivity)",
    "specificity": "Specificity",
    "balanced_accuracy": "Balanced Accuracy",
    "accuracy": "Accuracy"
}

METRICS_ALL  = ["f1_dice","iou_jaccard","mcc","precision","recall","specificity","balanced_accuracy","accuracy"]
BOX_METRICS  = ["f1_dice","iou_jaccard","mcc"]
HEAT_METRICS = ["f1_dice","iou_jaccard","mcc","precision","recall","specificity","balanced_accuracy","accuracy"]

def sanitize(s):
    s2 = re.sub(r'[<>:"/\\|?*]+', "_", s)
    s2 = re.sub(r"\s+", " ", s2)
    return s2.strip()

def save_fig(fig, name_base, dirpath=fig_dir):
    tif_path = os.path.join(dirpath, sanitize(name_base) + ".tif")
    pdf_path = os.path.join(dirpath, sanitize(name_base) + ".pdf")
    fig.tight_layout()
    fig.savefig(tif_path, dpi=600)
    fig.savefig(pdf_path)
    plt.close(fig)
    print("Saved: " + tif_path)
    print("Saved: " + pdf_path)

# aggregate summaries for plotting and stats

summary_rows = []
for mname, rows_here in method_rows.items():
    for metric_name in METRICS_ALL:
        vals = np.array([r[metric_name] for r in rows_here], float)
        mean_val = float(np.mean(vals))
        if len(vals) > 1:
            sd_val   = float(np.std(vals, ddof=1))
        else:
            sd_val   = 0.0
        n_val    = int(len(vals))
        if n_val >= 2:
            half = 1.96 * sd_val / np.sqrt(n_val)
        else:
            half = 0.0
        summary_rows.append({
            "method": mname,
            "metric": metric_name,
            "mean": mean_val,
            "sd": sd_val,
            "n": n_val,
            "lo": mean_val - half if n_val >= 2 else np.nan,
            "hi": mean_val + half if n_val >= 2 else np.nan
        })

perrep_rows = []
for mname, rows_here in method_rows.items():
    counter_i = 1
    for r_dict in rows_here:
        row_out = {"method": mname, "replicate": str(counter_i)}
        for k_metric in METRICS_ALL:
            row_out[k_metric] = r_dict[k_metric]
        perrep_rows.append(row_out)
        counter_i = counter_i + 1

def table_for_metric(summary_rows_list, metric_name):
    return [r for r in summary_rows_list if r["metric"] == metric_name and not np.isnan(r["mean"])]

def ci_or_fallback_lo_hi(mean_val, lo_val, hi_val, sd_val, n_val):
    if (not np.isnan(lo_val)) and (not np.isnan(hi_val)):
        return lo_val, hi_val
    cond = (sd_val is not None) and (not np.isnan(sd_val)) and (n_val is not None and n_val >= 2)
    if cond:
        half = 1.96 * sd_val / np.sqrt(n_val)
        return mean_val - half, mean_val + half
    return np.nan, np.nan

# helpers for stats and significance

def build_series_by_replicate(metric_name):
    d = {}
    for row in perrep_rows:
        m = row["method"]
        r = row["replicate"]
        v = row[metric_name]
        if np.isnan(v) or r is None or r == "":
            continue
        if m not in d:
            d[m] = {}
        d[m][r] = float(v)
    return d

def holm_correction(pairs_pvals):
    ranked = sorted(pairs_pvals, key=lambda x: x[1])
    m_len = len(ranked)
    adj = {}
    index_i = 1
    for pair_and_p in ranked:
        pair_now = pair_and_p[0]
        p_now = pair_and_p[1]
        adj_val = p_now * (m_len - index_i + 1)
        if adj_val > 1.0:
            adj_val = 1.0
        adj[pair_now] = adj_val
        index_i = index_i + 1
    return adj

def p_to_stars(p_val):
    if p_val < 0.001:
        return "***"
    if p_val < 0.01:
        return "**"
    if p_val < 0.05:
        return "*"
    return ""

def wilcoxon_exact_or_pratt(a, b):
    a_arr = np.asarray(a, float)
    b_arr = np.asarray(b, float)
    d = a_arr - b_arr
    nz = (d != 0)
    n_eff = int(np.count_nonzero(nz))
    if n_eff < 2:
        return (np.nan, np.nan, n_eff, "NA")
    if np.all(nz) and n_eff <= 25:
        res = wilcoxon(a_arr, b_arr, alternative='two-sided',
                       zero_method='wilcox',
                       correction=False,
                       method='exact')
        return (float(res.statistic), float(res.pvalue), n_eff, "exact")
    res2 = wilcoxon(a_arr, b_arr, alternative='two-sided',
                    zero_method='pratt',
                    correction=True,
                    method='approx')
    return (float(res2.statistic), float(res2.pvalue), n_eff, "approx(pratt)")

def sign_test_two_sided(a, b):
    a_arr = np.asarray(a, float)
    b_arr = np.asarray(b, float)
    d = a_arr - b_arr
    pos = int(np.sum(d > 0))
    neg = int(np.sum(d < 0))
    n_val = pos + neg
    if n_val == 0:
        return (np.nan, 0, 0)
    p_val = binomtest(k=pos, n=n_val, p=0.5, alternative='two-sided').pvalue
    return (float(p_val), n_val, pos)

def hodges_lehmann(a, b):
    d = np.asarray(a, float) - np.asarray(b, float)
    return float(np.median(d))

def compute_ref_and_holm(metric_name):
    tbl = table_for_metric(summary_rows, metric_name)
    if not tbl:
        return None, {}, {}, {}, {}
    ref_local = max(tbl, key=lambda d: d["mean"])["method"]
    series = build_series_by_replicate(metric_name)
    methods_local = sorted(series.keys())
    pairs_list = list(itertools.combinations(methods_local, 2))

    wx_pairs = []
    sign_pairs = []
    wilcox_cache = {}
    sign_cache   = {}

    for a_name, b_name in pairs_list:
        ra = series.get(a_name, {})
        rb = series.get(b_name, {})
        common = sorted(set(ra.keys()) & set(rb.keys()))
        if len(common) < 2:
            continue
        a_vals = np.array([ra[r_now] for r_now in common], float)
        b_vals = np.array([rb[r_now] for r_now in common], float)

        W, p_wx, n_eff, mode = wilcoxon_exact_or_pratt(a_vals, b_vals)
        p_sign, n_sign, pos = sign_test_two_sided(a_vals, b_vals)
        HL = hodges_lehmann(a_vals, b_vals)
        med = float(np.median(a_vals - b_vals))

        if not np.isnan(p_wx):
            wx_pairs.append(((a_name,b_name), p_wx))
        if not np.isnan(p_sign):
            sign_pairs.append(((a_name,b_name), p_sign))

        wilcox_cache[(a_name,b_name)] = (n_eff, W, p_wx, mode, HL, med)
        sign_cache[(a_name,b_name)]   = (n_sign, pos, p_sign)

    p_holm_wx   = holm_correction(wx_pairs) if wx_pairs else {}
    p_holm_sign = holm_correction(sign_pairs) if sign_pairs else {}
    return ref_local, p_holm_wx, p_holm_sign, wilcox_cache, sign_cache

def lookup_pair(dct, a_name, b_name, default=np.nan):
    if (a_name,b_name) in dct:
        return dct[(a_name,b_name)]
    if (b_name,a_name) in dct:
        return dct[(b_name,a_name)]
    return default

# bar charts

import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

for metric_name in METRICS_ALL:
    tbl = table_for_metric(summary_rows, metric_name)
    if not tbl:
        continue

    ref_method, p_holm_wx, p_holm_sign, wilcox_cache, sign_cache = compute_ref_and_holm(metric_name)

    tbl.sort(key=lambda d: d["mean"], reverse=True)
    methods_plot = [t["method"] for t in tbl]
    means_plot   = np.array([t["mean"] for t in tbl], float)

    los = np.zeros_like(means_plot)
    his = np.zeros_like(means_plot)
    idx_i = 0
    for t in tbl:
        lo_val, hi_val = ci_or_fallback_lo_hi(t["mean"], t["lo"], t["hi"], t["sd"], t["n"])
        los[idx_i] = lo_val
        his[idx_i] = hi_val
        idx_i = idx_i + 1

    lower = np.where(np.isnan(los), 0.0, means_plot - los)
    upper = np.where(np.isnan(his), 0.0, his - means_plot)
    yerr  = np.vstack((lower, upper))

    bar_colors = []
    for m in methods_plot:
        col, group = method_color_and_group(m)
        bar_colors.append(col)

    fig, ax = plt.subplots(figsize=(7.5, 4.2))
    x_vals = np.arange(len(methods_plot))
    ax.bar(x_vals, means_plot, color=bar_colors, edgecolor="black", linewidth=0.8, zorder=2)
    ax.errorbar(x_vals, means_plot, yerr=yerr, fmt='none', ecolor='black', elinewidth=1.0, capsize=3, zorder=3)

    ax.set_xticks(x_vals)
    ax.set_xticklabels(methods_plot, rotation=45, ha='right')
    ax.set_title(display_tag + " [" + NICE.get(metric_name, metric_name) + "]")
    ax.set_ylabel(NICE.get(metric_name, metric_name) + " (mean ± 95% CI)")

    y_tops = means_plot + np.where(np.isnan(upper), 0.0, upper)
    ylim_top = float(np.nanmax(y_tops) + 0.06)
    ax.set_ylim(0, min(1.15, max(1.0, ylim_top)))
    ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1))
    ax.grid(axis='y', color=str(0.9), linestyle='-', linewidth=0.5, zorder=1)

    if ref_method is not None:
        for i_plot, m in enumerate(methods_plot):
            if m == ref_method:
                continue
            p_h = lookup_pair(p_holm_wx, ref_method, m, np.nan)
            if np.isnan(p_h):
                p_h = lookup_pair(p_holm_sign, ref_method, m, np.nan)
            cond_star = (not np.isnan(p_h)) and (p_h < 0.05)
            if cond_star:
                ax.text(
                    x_vals[i_plot],
                    y_tops[i_plot] + 0.02,
                    "*",
                    ha='center',
                    va='bottom',
                    fontsize=14,
                    fontweight='bold',
                    clip_on=False
                )

    leg = ax.legend(handles=legend_handles, loc='upper right', frameon=True)
    leg_frame = leg.get_frame()
    leg_frame.set_edgecolor('black')
    leg_frame.set_linewidth(0.8)
    leg_frame.set_alpha(1.0)
    leg_frame.set_facecolor('white')

    save_fig(fig, display_tag + " - Bar " + NICE.get(metric_name, metric_name))

# boxplots

def build_series(metric_name):
    all_methods = sorted(list({x["method"] for x in perrep_rows}))
    series_local = {}
    for m in all_methods:
        vals_here = [x[metric_name] for x in perrep_rows if x["method"] == m and not np.isnan(x[metric_name])]
        if len(vals_here) > 0:
            series_local[m] = np.array(vals_here, float)
    return series_local

for metric_name in BOX_METRICS:
    data_box = build_series(metric_name)
    if not data_box:
        continue

    order_box = sorted(data_box.keys(), key=lambda m: np.median(data_box[m]), reverse=True)
    vectors = [data_box[m] for m in order_box]

    box_colors = []
    for m in order_box:
        col, group = method_color_and_group(m)
        box_colors.append(col)

    fig, ax = plt.subplots(figsize=(7.5, 4.2))
    bp = ax.boxplot(
        vectors,
        vert=True,
        patch_artist=True,
        tick_labels=order_box,
        whis=1.5,
        widths=0.65,
        showmeans=False,
        manage_ticks=True
    )

    for patch_obj, col_val in zip(bp['boxes'], box_colors):
        patch_obj.set(facecolor=col_val, edgecolor="black", linewidth=0.8)
    for whisker_obj in bp['whiskers']:
        whisker_obj.set(color="black", linewidth=0.8)
    for cap_obj in bp['caps']:
        cap_obj.set(color="black", linewidth=0.8)
    for median_obj in bp['medians']:
        median_obj.set(color="black", linewidth=1.2)
    if 'fliers' in bp:
        for fl in bp['fliers']:
            fl.set(marker='o', markerfacecolor='none', markeredgecolor='black',
                   markersize=3, alpha=0.7)

    rng = np.random.default_rng(2025)
    index_box = 1
    for vals_now in vectors:
        if len(vals_now) > 0:
            jitter = (rng.random(size=len(vals_now)) - 0.5) * 0.25
            ax.plot(
                np.full(len(vals_now), index_box) + jitter,
                vals_now,
                'o',
                markerfacecolor='none',
                markeredgecolor='black',
                markersize=3,
                alpha=0.6,
                linewidth=0.8
            )
        index_box = index_box + 1

    y_top = 1.0
    for v_now in vectors:
        if len(v_now) > 0:
            vmax_now = np.max(v_now)
            if vmax_now > y_top:
                y_top = vmax_now
    ax.set_ylim(0, min(1.15, y_top + 0.05))

    ax.set_ylabel(NICE.get(metric_name, metric_name))
    ax.set_title(display_tag + " [" + NICE.get(metric_name, metric_name) + "]")
    ax.set_xticklabels(order_box, rotation=45, ha='right')
    ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1))
    ax.grid(axis='y', color=str(0.92), linestyle='-', linewidth=0.5)

    leg = ax.legend(handles=legend_handles, loc='upper right', frameon=True)
    leg_frame2 = leg.get_frame()
    leg_frame2.set_edgecolor('black')
    leg_frame2.set_linewidth(0.8)
    leg_frame2.set_alpha(1.0)
    leg_frame2.set_facecolor('white')

    save_fig(fig, display_tag + " - Boxplot " + NICE.get(metric_name, metric_name))

# heatmap of mean metric per method across metrics

methods_all = sorted(list({r["method"] for r in summary_rows}))
M = []
for m in methods_all:
    row_vals = []
    for metric_name in HEAT_METRICS:
        matches = [r for r in summary_rows if r["method"] == m and r["metric"] == metric_name]
        if matches and (not np.isnan(matches[0]["mean"])):
            row_vals.append(matches[0]["mean"])
        else:
            row_vals.append(np.nan)
    M.append(row_vals)

M = np.array(M, float)

primary_idx = [
    HEAT_METRICS.index("f1_dice"),
    HEAT_METRICS.index("iou_jaccard"),
    HEAT_METRICS.index("mcc")
]

order_idx = np.argsort(-np.nanmean(M[:, primary_idx], axis=1))
methods_sorted = [methods_all[i_now] for i_now in order_idx]
M_sorted = M[order_idx]

fig, ax = plt.subplots(figsize=(6.6, 0.45*len(methods_sorted) + 1.2))
im = ax.imshow(M_sorted, aspect="auto", cmap="Greys", vmin=0, vmax=1.0)

for i_row in range(M_sorted.shape[0]):
    for j_col in range(M_sorted.shape[1]):
        v_val = M_sorted[i_row, j_col]
        if np.isnan(v_val):
            continue
        ax.text(j_col, i_row, "%.2f" % v_val,
                ha='center', va='center', color='black')

ax.set_xticks(np.arange(len(HEAT_METRICS)))
ax.set_xticklabels([NICE.get(m, m) for m in HEAT_METRICS], rotation=45, ha='right')
ax.set_yticks(np.arange(len(methods_sorted)))
ax.set_yticklabels(methods_sorted)
ax.set_title(display_tag + " — Mean metrics (methods × metrics)")
cbar = fig.colorbar(im, ax=ax, fraction=0.03, pad=0.03)
cbar.set_label("Mean (0–1)")

save_fig(fig, display_tag + " - Heatmap mean metrics")

# write statistical test CSVs

def write_stats_csvs():
    kw_path  = os.path.join(stats_dir, "KruskalWallis [" + display_tag + "].csv")
    wx_path  = os.path.join(stats_dir, "Wilcoxon_Pairs [" + display_tag + "].csv")
    vr_path  = os.path.join(stats_dir, "Wilcoxon_vsRef [" + display_tag + "].csv")

    with open(kw_path, "w", newline="", encoding="utf-8") as fkw, \
         open(wx_path, "w", newline="", encoding="utf-8") as fwx, \
         open(vr_path, "w", newline="", encoding="utf-8") as fvr:

        kw_writer = csv.writer(fkw)
        kw_writer.writerow(["metric","k_groups","H","p_value"])

        wx_writer = csv.writer(fwx)
        wx_writer.writerow([
            "metric","method_A","method_B","n_common",
            "wilcoxon_n_eff","wilcoxon_W","wilcoxon_p_raw","wilcoxon_mode","wilcoxon_p_holm",
            "sign_n","sign_pos","sign_p_raw","sign_p_holm",
            "HL_estimate(A-B)","median_diff(A-B)"
        ])

        vr_writer = csv.writer(fvr)
        vr_writer.writerow([
            "metric","reference","method","n_common",
            "wilcoxon_n_eff","wilcoxon_W","wilcoxon_p_raw","wilcoxon_mode","wilcoxon_p_holm","stars",
            "sign_n","sign_pos","sign_p_raw","sign_p_holm",
            "HL_estimate(ref - method)","median_diff(ref - method)"
        ])

        for metric_name in METRICS_ALL:
            ser = build_series_by_replicate(metric_name)
            groups = []
            for d in ser.values():
                if len(d) >= 2:
                    groups.append(np.array(list(d.values()), float))
            if len(groups) >= 2:
                H, p_val = kruskal(*groups, nan_policy='omit')
                kw_writer.writerow([metric_name, len(groups), float(H), float(p_val)])

        for metric_name in METRICS_ALL:
            ser = build_series_by_replicate(metric_name)
            methods_local = sorted(ser.keys())
            pairs = list(itertools.combinations(methods_local, 2))

            wx_p_raw_pairs = []
            temp_rows = []

            for a_name, b_name in pairs:
                ra = ser[a_name]
                rb = ser[b_name]
                common = sorted(set(ra.keys()) & set(rb.keys()))
                if len(common) < 2:
                    continue
                a_vals = np.array([ra[r_now] for r_now in common], float)
                b_vals = np.array([rb[r_now] for r_now in common], float)

                W, p_wx, n_eff, mode = wilcoxon_exact_or_pratt(a_vals, b_vals)
                p_sign, n_sign, pos = sign_test_two_sided(a_vals, b_vals)
                HL = hodges_lehmann(a_vals, b_vals)
                med = float(np.median(a_vals - b_vals))

                if not np.isnan(p_wx):
                    wx_p_raw_pairs.append(((a_name,b_name), p_wx))
                temp_rows.append((
                    a_name,
                    b_name,
                    len(common),
                    n_eff,
                    W,
                    p_wx,
                    mode,
                    n_sign,
                    pos,
                    p_sign,
                    HL,
                    med
                ))

            holm_wx = {}
            if wx_p_raw_pairs:
                holm_wx = holm_correction(wx_p_raw_pairs)

            sign_p_raw_pairs = []
            for row_now in temp_rows:
                a_name = row_now[0]
                b_name = row_now[1]
                p_sign_now = row_now[9]
                if not np.isnan(p_sign_now):
                    sign_p_raw_pairs.append(((a_name,b_name), p_sign_now))

            holm_sign = {}
            if sign_p_raw_pairs:
                holm_sign = holm_correction(sign_p_raw_pairs)

            for row_now in temp_rows:
                a_name = row_now[0]
                b_name = row_now[1]
                n_common = row_now[2]
                n_eff = row_now[3]
                W = row_now[4]
                p_wx = row_now[5]
                mode = row_now[6]
                n_sign = row_now[7]
                pos = row_now[8]
                p_sign = row_now[9]
                HL = row_now[10]
                med = row_now[11]

                if (a_name,b_name) in holm_wx:
                    wx_holm_val = holm_wx[(a_name,b_name)]
                elif (b_name,a_name) in holm_wx:
                    wx_holm_val = holm_wx[(b_name,a_name)]
                else:
                    wx_holm_val = np.nan

                if (a_name,b_name) in holm_sign:
                    sign_holm_val = holm_sign[(a_name,b_name)]
                elif (b_name,a_name) in holm_sign:
                    sign_holm_val = holm_sign[(b_name,a_name)]
                else:
                    sign_holm_val = np.nan

                wx_writer.writerow([
                    metric_name,
                    a_name,
                    b_name,
                    n_common,
                    n_eff,
                    W,
                    p_wx,
                    mode,
                    wx_holm_val,
                    n_sign,
                    pos,
                    p_sign,
                    sign_holm_val,
                    HL,
                    med
                ])

            ref_m, p_holm_wx_m, p_holm_sign_m, wilcox_cache_m, sign_cache_m = compute_ref_and_holm(metric_name)
            if ref_m is not None:
                for m_local in methods_local:
                    if m_local == ref_m:
                        continue
                    pair_stats = lookup_pair(
                        wilcox_cache_m,
                        ref_m,
                        m_local,
                        (np.nan, np.nan, np.nan, np.nan, np.nan, np.nan)
                    )
                    n_eff_ref = pair_stats[0]
                    W_ref    = pair_stats[1]
                    p_wx_raw = pair_stats[2]
                    mode_ref = pair_stats[3]
                    HL_ref   = pair_stats[4]
                    med_ref  = pair_stats[5]

                    p_wx_holm = lookup_pair(p_holm_wx_m, ref_m, m_local, np.nan)

                    sstats = lookup_pair(
                        sign_cache_m,
                        ref_m,
                        m_local,
                        (np.nan, np.nan, np.nan)
                    )
                    n_sign_ref = sstats[0]
                    pos_ref    = sstats[1]
                    p_sign_raw = sstats[2]

                    p_sign_holm = lookup_pair(p_holm_sign_m, ref_m, m_local, np.nan)

                    p_for_star = p_wx_holm
                    if np.isnan(p_for_star):
                        p_for_star = p_sign_holm
                    if (not np.isnan(p_for_star)) and (p_for_star < 0.05):
                        stars_now = p_to_stars(p_for_star)
                    else:
                        stars_now = ""

                    if not np.isnan(n_eff_ref):
                        n_common_ref = int(n_eff_ref)
                    else:
                        n_common_ref = np.nan

                    vr_writer.writerow([
                        metric_name,
                        ref_m,
                        m_local,
                        n_common_ref,
                        n_eff_ref,
                        W_ref,
                        p_wx_raw,
                        mode_ref,
                        p_wx_holm,
                        stars_now,
                        n_sign_ref,
                        pos_ref,
                        p_sign_raw,
                        p_sign_holm,
                        HL_ref,
                        med_ref
                    ])

    print("Saved: " + kw_path)
    print("Saved: " + wx_path)
    print("Saved: " + vr_path)

write_stats_csvs()

# Bland–Altman plots

def bland_altman_plot(a_vals, b_vals, title, outfile_base, dirpath=ba_dir):
    a_vals_arr = np.asarray(a_vals, float)
    b_vals_arr = np.asarray(b_vals, float)
    means = (a_vals_arr + b_vals_arr) / 2.0
    diffs = a_vals_arr - b_vals_arr
    n = len(diffs)

    bias = float(np.mean(diffs))
    if n > 1:
        sd = float(np.std(diffs, ddof=1))
    else:
        sd = 0.0
    if n > 1:
        q = float(student_t.ppf(0.975, n-1))
    else:
        q = 1.96

    loa_hi = bias + q * sd
    loa_lo = bias - q * sd

    if n > 0:
        se_bias = sd / np.sqrt(n)
    else:
        se_bias = np.nan
    if n > 1:
        se_loa  = sd * np.sqrt((1.0/float(n)) + (q*q)/(2.0*(n-1)))
    else:
        se_loa  = np.nan

    if np.isfinite(se_bias):
        bias_ci_lo = bias - q*se_bias
        bias_ci_hi = bias + q*se_bias
    else:
        bias_ci_lo = np.nan
        bias_ci_hi = np.nan

    if np.isfinite(se_loa):
        loa_hi_ci_lo = loa_hi - q*se_loa
        loa_hi_ci_hi = loa_hi + q*se_loa
        loa_lo_ci_lo = loa_lo - q*se_loa
        loa_lo_ci_hi = loa_lo + q*se_loa
    else:
        loa_hi_ci_lo = np.nan
        loa_hi_ci_hi = np.nan
        loa_lo_ci_lo = np.nan
        loa_lo_ci_hi = np.nan

    fig, ax = plt.subplots(figsize=(5.0, 3.6))

    if np.isfinite(bias_ci_lo) and np.isfinite(bias_ci_hi):
        ax.axhspan(bias_ci_lo, bias_ci_hi, alpha=0.12, color='blue', linewidth=0, zorder=1)
    if np.isfinite(loa_hi_ci_lo) and np.isfinite(loa_hi_ci_hi):
        ax.axhspan(loa_hi_ci_lo, loa_hi_ci_hi, alpha=0.15, color='red', linewidth=0, zorder=1)
    if np.isfinite(loa_lo_ci_lo) and np.isfinite(loa_lo_ci_hi):
        ax.axhspan(loa_lo_ci_lo, loa_lo_ci_hi, alpha=0.15, color='red', linewidth=0, zorder=1)

    ax.plot(means, diffs, 'o',
            markerfacecolor='none',
            markeredgecolor='black',
            markersize=4,
            alpha=0.85,
            zorder=2)

    ax.axhline(bias,  linestyle='-',  color='blue', linewidth=1.1, label='Bias', zorder=3)
    ax.axhline(loa_hi, linestyle='--', color='red',  linewidth=1.0, label='LoA',  zorder=3)
    ax.axhline(loa_lo, linestyle='--', color='red',  linewidth=1.0,               zorder=3)

    ax.set_xlabel("Mean of pair")
    ax.set_ylabel("Difference (A − B)")
    ax.set_title(title)
    ax.grid(axis='both', color=str(0.92), linestyle='-', linewidth=0.5)

    handles = [
        Line2D([0], [0], color='blue', lw=1.2, label='Bias (±95% CI)'),
        Line2D([0], [0], color='red',  lw=1.0, linestyle='--', label='LoA (±95% CI)')
    ]
    leg = ax.legend(handles=handles, loc='lower right', frameon=True)
    leg_frame3 = leg.get_frame()
    leg_frame3.set_edgecolor('black')
    leg_frame3.set_linewidth(0.8)

    ann_lines = []

    ann_line_bias = "Bias=" + ("%.3f" % bias)
    if np.isfinite(bias_ci_lo):
        ann_line_bias = ann_line_bias + " (95% CI " + ("%.3f" % bias_ci_lo) + " to " + ("%.3f" % bias_ci_hi) + ")"
    ann_lines.append(ann_line_bias)

    ann_line_hi = "LoA+=" + ("%.3f" % loa_hi)
    if np.isfinite(loa_hi_ci_lo):
        ann_line_hi = ann_line_hi + " (95% CI " + ("%.3f" % loa_hi_ci_lo) + " to " + ("%.3f" % loa_hi_ci_hi) + ")"
    ann_lines.append(ann_line_hi)

    ann_line_lo = "LoA−=" + ("%.3f" % loa_lo)
    if np.isfinite(loa_lo_ci_lo):
        ann_line_lo = ann_line_lo + " (95% CI " + ("%.3f" % loa_lo_ci_lo) + " to " + ("%.3f" % loa_lo_ci_hi) + ")"
    ann_lines.append(ann_line_lo)

    ax.text(
        0.02,
        0.98,
        "\n".join(ann_lines),
        transform=ax.transAxes,
        ha='left',
        va='top',
        fontsize=8,
        bbox=dict(boxstyle="round,pad=0.25", facecolor="white", edgecolor="black", linewidth=0.6)
    )

    save_fig(fig, outfile_base, dirpath=dirpath)

ENABLE_BLAND_ALTMAN = True
BA_PAIRS_VS_REF_ONLY = True
BA_MIN_COMMON = 2

if ENABLE_BLAND_ALTMAN:
    for metric_name in METRICS_ALL:
        ser = build_series_by_replicate(metric_name)
        methods_local = sorted(ser.keys())
        if len(methods_local) < 2:
            continue
        tbl = table_for_metric(summary_rows, metric_name)
        if not tbl:
            continue
        ref_ba = max(tbl, key=lambda d: d["mean"])["method"]
        for m in methods_local:
            if m == ref_ba:
                continue
            ra = ser.get(ref_ba, {})
            rb = ser.get(m, {})
            common = sorted(set(ra.keys()) & set(rb.keys()))
            if len(common) < BA_MIN_COMMON:
                continue
            a_vals = [ra[r_now] for r_now in common]
            b_vals = [rb[r_now] for r_now in common]

            title_str = display_tag + " — Bland–Altman [" + NICE.get(metric_name, metric_name) + "]\n" + "A=" + ref_ba + " vs B=" + m + " (n=" + str(len(common)) + ")"
            base_str  = display_tag + " - Bland-Altman [" + NICE.get(metric_name, metric_name) + "] A=" + ref_ba + " vs B=" + m

            bland_altman_plot(a_vals, b_vals, title_str, base_str)

    print("Bland–Altman: EXPORTED to " + ba_dir)
else:
    print("Bland–Altman: SKIPPED (set ENABLE_BLAND_ALTMAN=True to export)")

60% [STED 0.375%]: tiles=81 | Dice 0.650 ± 0.172 -> C:\Users\walsh\Documents\GitHub\AGAROSE-HYDROGEL-TRENDS-USING-AI-ML\Accuracy\Metric Results [60%].csv | C:\Users\walsh\Documents\GitHub\AGAROSE-HYDROGEL-TRENDS-USING-AI-ML\Accuracy\Overlays_TIFF\60% [STED 0.375%] overlay.tif
FREEHAND [STED 0.375%]: tiles=81 | Dice 0.167 ± 0.161 -> C:\Users\walsh\Documents\GitHub\AGAROSE-HYDROGEL-TRENDS-USING-AI-ML\Accuracy\Metric Results [FREEHAND].csv | C:\Users\walsh\Documents\GitHub\AGAROSE-HYDROGEL-TRENDS-USING-AI-ML\Accuracy\Overlays_TIFF\FREEHAND [STED 0.375%] overlay.tif
OVAL [STED 0.375%]: tiles=81 | Dice 0.138 ± 0.177 -> C:\Users\walsh\Documents\GitHub\AGAROSE-HYDROGEL-TRENDS-USING-AI-ML\Accuracy\Metric Results [OVAL].csv | C:\Users\walsh\Documents\GitHub\AGAROSE-HYDROGEL-TRENDS-USING-AI-ML\Accuracy\Overlays_TIFF\OVAL [STED 0.375%] overlay.tif
ILASTIK [STED 0.375%]: tiles=81 | Dice 0.541 ± 0.258 -> C:\Users\walsh\Documents\GitHub\AGAROSE-HYDROGEL-TRENDS-USING-AI-ML\Accuracy\Metric Results [IL