# NB6 — Biological Context: Age, Neuropathology & Irisin

Age-stratified MES-GR associations, neuropathological burden analysis (Braak/CERAD), irisin-pathway scoring, GR x age interaction models, and threshold sensitivity analyses.

**Paper:** Zafar SA, Qin W. *Thymus-Derived Myeloid Education Signatures Predict Microglial Tolerance Positioning and Are Modulated by Glucocorticoid Stress-Axis Activity.* Neuroimmunomodulation (2026).

> **Note:** Update the path variables in section 0 to match your local directory structure before running. Raw data can be obtained from the public repositories listed in Supplementary Table S1.


In [None]:
from __future__ import annotations

import os, re, time, math, json, gc, glob, warnings, traceback
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from itertools import combinations

import numpy as np
import pandas as pd
warnings.filterwarnings("ignore")

import matplotlib
matplotlib.rcParams.update({
    "font.family": "Arial",
    "font.size": 8,
    "axes.titlesize": 9,
    "axes.labelsize": 8,
    "xtick.labelsize": 7,
    "ytick.labelsize": 7,
    "legend.fontsize": 7,
    "figure.dpi": 150,
    "savefig.dpi": 1200,
    "savefig.bbox": "tight",
    "savefig.pad_inches": 0.05,
    "axes.linewidth": 0.8,
    "xtick.major.width": 0.6,
    "ytick.major.width": 0.6,
    "xtick.major.size": 3,
    "ytick.major.size": 3,
    "pdf.fonttype": 42,       # editable text in PDF
    "ps.fonttype": 42,
})

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import FancyBboxPatch

try:
    import seaborn as sns
    sns.set_style("ticks")
    HAS_SNS = True
except ImportError:
    HAS_SNS = False
    print("[WARN] seaborn not installed; figures will use matplotlib only.")

import anndata as ad
import scanpy as sc
import scipy.sparse as sp
from scipy import stats as scipy_stats

try:
    import statsmodels.api as sm
    from statsmodels.stats.multitest import multipletests
    HAS_SM = True
except ImportError:
    HAS_SM = False
    print("[WARN] statsmodels not installed; partial correlations & OLS unavailable.")

# PATHS 

np.random.seed(42)

PROC_DIR  = Path(".") / "data" / "processed" / "aim2_microglia"  # <-- SET PATH
MANUS_DIR = Path(".") / "outputs" / "manuscript"  # <-- SET PATH

FIG_DIR = MANUS_DIR / "Figures"
TAB_DIR = MANUS_DIR / "Tables"
FIG_DIR.mkdir(parents=True, exist_ok=True)
TAB_DIR.mkdir(parents=True, exist_ok=True)

SCORED_GLOB = str(PROC_DIR / "*__microglia_scored.h5ad")

DPI = 1200
N_BOOT = 2000        # bootstrap iterations
N_PERM = 5000        # permutation test iterations
ALPHA  = 0.05        # significance threshold
MIN_DONORS_PER_GROUP = 8

# Run controls

RUN_PART_A = True
RUN_PART_B = True
RUN_PART_C = True
RUN_PART_D = False   # optional
RUN_PART_E = False   # optional
RUN_PART_F = True    # cross-cohort validation 
RUN_PART_G = True    # sensitivity analyses 

# Column conventions from NB5

MES_COLS = [
    "MES01_score", "MES02_score", "MES03_score", "MES04_score",
    "MES05_score", "MES06_score", "MES07_score", "MES08_score",
]
GR_COL = "GR_composite"

TOL_CANDS = [
    "tolerance_positioning", "tolerance", "Tolerance", "tolerance_score",
    "tolerance_pos", "positioning", "Tol_positioning",
]

# Gene sets

IRISIN_GENES   = ["FNDC5", "PPARGC1A", "PPRC1", "NRF1", "TFAM"]
IRISIN_ALIASES = {"PGC1A": "PPARGC1A"}

NT_SIGS = {
    "dopamine":        ["TH","DDC","SLC6A3","DRD1","DRD2","DRD3","DRD4","DRD5"],
    "serotonin":       ["TPH1","TPH2","SLC6A4","HTR1A","HTR1B","HTR2A","HTR2C"],
    "acetylcholine":   ["CHAT","ACHE","SLC18A3","CHRNA7","CHRNB2"],
    "norepinephrine":  ["DBH","SLC6A2","ADRA1A","ADRA2A","ADRB1","ADRB2"],
}

EXERCISE_MODULE = [
    "PPARGC1A","PPRC1","NRF1","TFAM",
    "SOD1","SOD2","CAT","GPX1",
    "VEGFA","VEGFB",
    "BDNF","IGF1",
]

# Color palettes

PALETTE_STATUS = {"Control": "#4393C3", "AD": "#D6604D",
                  "LowPath": "#4393C3", "HighPath": "#D6604D",
                  "Unknown": "#999999"}
PALETTE_AGE    = {"<65": "#66C2A5", "65-79": "#FC8D62", ">=80": "#8DA0CB", "NA": "#CCCCCC"}
PALETTE_MES    = sns.color_palette("husl", 8) if HAS_SNS else plt.cm.tab10(np.linspace(0, 1, 8))

# Utilities

def _now():
    return time.strftime("%Y-%m-%d %H:%M:%S")

def log(msg: str):
    print(f"[{_now()}] {msg}", flush=True)

def save_fig(path: Path, fig=None):
    path.parent.mkdir(parents=True, exist_ok=True)
    if fig is not None:
        fig.savefig(path, dpi=DPI, bbox_inches="tight")
        plt.close(fig)
    else:
        plt.savefig(path, dpi=DPI, bbox_inches="tight")
        plt.close()
    log(f"SAVED FIG {path}")

def save_xlsx(df: pd.DataFrame, path: Path, sheet_name="Sheet1"):
    if df is None or df.empty:
        log(f"SKIP TABLE (empty) {path}")
        return
    path.parent.mkdir(parents=True, exist_ok=True)
    with pd.ExcelWriter(path, engine="openpyxl") as w:
        df.to_excel(w, index=False, sheet_name=sheet_name)
    log(f"SAVED TABLE {path}")

def save_xlsx_multi(dfs: Dict[str, pd.DataFrame], path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    with pd.ExcelWriter(path, engine="openpyxl") as w:
        for name, df in dfs.items():
            if df is not None and not df.empty:
                df.to_excel(w, index=False, sheet_name=name[:31])
    log(f"SAVED TABLE (multi-sheet) {path}")

def safe_numeric(x) -> np.ndarray:
    return pd.to_numeric(pd.Series(x), errors="coerce").to_numpy()

def _align_finite(x, y, min_n=8):
    """Return aligned finite arrays + mask."""
    x = safe_numeric(x); y = safe_numeric(y)
    m = np.isfinite(x) & np.isfinite(y)
    if m.sum() < min_n:
        return None, None, 0
    return x[m], y[m], int(m.sum())

# STATISTICAL TOOLKIT 

def corr_with_pvalue(x, y, method="spearman", min_n=8):
    """
    Returns dict: r, p, n, ci_lo, ci_hi (bootstrap 95% CI).
    """
    xa, ya, n = _align_finite(x, y, min_n)
    out = {"r": np.nan, "p": np.nan, "n": n, "ci_lo": np.nan, "ci_hi": np.nan, "method": method}
    if xa is None:
        return out

    if method == "spearman":
        r, p = scipy_stats.spearmanr(xa, ya)
    elif method == "pearson":
        r, p = scipy_stats.pearsonr(xa, ya)
    elif method == "kendall":
        r, p = scipy_stats.kendalltau(xa, ya)
        r = r  # tau
    else:
        r, p = scipy_stats.spearmanr(xa, ya)

    out["r"] = float(r); out["p"] = float(p); out["n"] = n

    # Bootstrap CI
    rng = np.random.RandomState(42)
    boots = np.full(N_BOOT, np.nan)
    for i in range(N_BOOT):
        idx = rng.randint(0, n, n)
        if method == "spearman":
            boots[i] = scipy_stats.spearmanr(xa[idx], ya[idx])[0]
        elif method == "pearson":
            boots[i] = scipy_stats.pearsonr(xa[idx], ya[idx])[0]
        elif method == "kendall":
            boots[i] = scipy_stats.kendalltau(xa[idx], ya[idx])[0]
    finite_boots = boots[np.isfinite(boots)]
    if len(finite_boots) > 10:
        out["ci_lo"] = float(np.percentile(finite_boots, 2.5))
        out["ci_hi"] = float(np.percentile(finite_boots, 97.5))

    return out


def partial_corr(df: pd.DataFrame, x: str, y: str, covars: List[str],
                 method="spearman", min_n=10) -> dict:
    """
    Partial correlation controlling for covariates via OLS residuals.
    Falls back to raw correlation if statsmodels missing or covars empty.
    """
    cols = [x, y] + covars
    d = df[cols].dropna()
    if d.shape[0] < min_n or not HAS_SM or len(covars) == 0:
        return corr_with_pvalue(df[x], df[y], method=method, min_n=min_n)

    C = sm.add_constant(d[covars].values)
    resid_x = sm.OLS(d[x].values, C).fit().resid
    resid_y = sm.OLS(d[y].values, C).fit().resid

    out = corr_with_pvalue(resid_x, resid_y, method=method, min_n=min_n)
    out["partial"] = True
    out["covariates"] = ",".join(covars)
    return out


def cohen_d_with_ci(x, y, n_boot=N_BOOT):
    """
    Cohen's d = (mean(y)-mean(x))/pooled_sd with bootstrap 95% CI.
    Also returns Mann-Whitney U test p-value.
    """
    x = np.asarray(x, float); y = np.asarray(y, float)
    x = x[np.isfinite(x)]; y = y[np.isfinite(y)]
    out = {"d": np.nan, "ci_lo": np.nan, "ci_hi": np.nan,
           "mw_U": np.nan, "mw_p": np.nan,
           "mean_x": np.nan, "mean_y": np.nan, "sd_x": np.nan, "sd_y": np.nan,
           "n_x": len(x), "n_y": len(y)}

    if len(x) < 3 or len(y) < 3:
        return out

    out["mean_x"] = float(np.mean(x)); out["mean_y"] = float(np.mean(y))
    out["sd_x"] = float(np.std(x, ddof=1)); out["sd_y"] = float(np.std(y, ddof=1))

    vx = np.var(x, ddof=1); vy = np.var(y, ddof=1)
    sp2 = ((len(x)-1)*vx + (len(y)-1)*vy) / max((len(x)+len(y)-2), 1)
    if sp2 <= 0:
        return out
    out["d"] = float((np.mean(y) - np.mean(x)) / np.sqrt(sp2))

    # Mann-Whitney
    try:
        U, p = scipy_stats.mannwhitneyu(x, y, alternative="two-sided")
        out["mw_U"] = float(U); out["mw_p"] = float(p)
    except Exception:
        pass

    # Bootstrap CI on d
    rng = np.random.RandomState(42)
    boots = np.full(n_boot, np.nan)
    for i in range(n_boot):
        bx = x[rng.randint(0, len(x), len(x))]
        by = y[rng.randint(0, len(y), len(y))]
        bvx = np.var(bx, ddof=1); bvy = np.var(by, ddof=1)
        bsp2 = ((len(bx)-1)*bvx + (len(by)-1)*bvy) / max((len(bx)+len(by)-2), 1)
        if bsp2 > 0:
            boots[i] = (np.mean(by) - np.mean(bx)) / np.sqrt(bsp2)
    fb = boots[np.isfinite(boots)]
    if len(fb) > 10:
        out["ci_lo"] = float(np.percentile(fb, 2.5))
        out["ci_hi"] = float(np.percentile(fb, 97.5))

    return out


def permutation_test_corr(x, y, method="spearman", n_perm=N_PERM, min_n=8):
    """Permutation p-value for correlation."""
    xa, ya, n = _align_finite(x, y, min_n)
    if xa is None:
        return np.nan

    if method == "spearman":
        obs_r = scipy_stats.spearmanr(xa, ya)[0]
    else:
        obs_r = scipy_stats.pearsonr(xa, ya)[0]

    rng = np.random.RandomState(42)
    count = 0
    for _ in range(n_perm):
        perm_y = rng.permutation(ya)
        if method == "spearman":
            pr = scipy_stats.spearmanr(xa, perm_y)[0]
        else:
            pr = scipy_stats.pearsonr(xa, perm_y)[0]
        if abs(pr) >= abs(obs_r):
            count += 1
    return float((count + 1) / (n_perm + 1))


def bh_fdr(pvals) -> np.ndarray:
    """Benjamini-Hochberg FDR. Returns q-values. NaNs preserved."""
    p = np.asarray(pvals, float)
    q = np.full_like(p, np.nan)
    m = np.isfinite(p)
    if m.sum() == 0:
        return q
    if HAS_SM:
        _, qv, _, _ = multipletests(p[m], method="fdr_bh")
        q[m] = qv
    else:
        pv = p[m]; n = pv.size
        order = np.argsort(pv)
        ranked = pv[order]
        qv = ranked * n / (np.arange(n) + 1.0)
        qv = np.minimum.accumulate(qv[::-1])[::-1]
        q[m] = qv[np.argsort(order)]
    return q


def fit_interaction_ols(df: pd.DataFrame, y: str, x: str, z: str,
                        covars: Optional[List[str]] = None) -> Dict[str, Any]:
    """
    OLS: y ~ x + z + x*z [+ covars]. Returns betas, p-values, R², n.
    """
    covars = covars or []
    out = {"beta_x": np.nan, "beta_z": np.nan, "beta_xz": np.nan,
           "p_x": np.nan, "p_z": np.nan, "p_xz": np.nan,
           "R2": np.nan, "R2_adj": np.nan, "n": 0, "model": "failed"}

    cols = [y, x, z] + covars
    d = df[cols].dropna()
    out["n"] = int(d.shape[0])
    if out["n"] < max(10, len(cols) + 5):
        return out

    if not HAS_SM:
        # Fallback: numpy lstsq (no p-values)
        xz = d[x].values * d[z].values
        Xm = np.column_stack([np.ones(out["n"]), d[x].values, d[z].values, xz]
                             + [d[c].values for c in covars])
        try:
            b = np.linalg.lstsq(Xm, d[y].values, rcond=None)[0]
            out["beta_x"] = float(b[1]); out["beta_z"] = float(b[2]); out["beta_xz"] = float(b[3])
            out["model"] = "numpy_lstsq"
        except Exception:
            pass
        return out

    Xdf = pd.DataFrame({x: d[x].values, z: d[z].values, "xz": d[x].values * d[z].values})
    for c in covars:
        Xdf[c] = d[c].values
    Xdf = sm.add_constant(Xdf, has_constant="add")

    try:
        model = sm.OLS(d[y].values, Xdf).fit()
        out["beta_x"]  = float(model.params.get(x, np.nan))
        out["beta_z"]  = float(model.params.get(z, np.nan))
        out["beta_xz"] = float(model.params.get("xz", np.nan))
        out["p_x"]  = float(model.pvalues.get(x, np.nan))
        out["p_z"]  = float(model.pvalues.get(z, np.nan))
        out["p_xz"] = float(model.pvalues.get("xz", np.nan))
        out["R2"]     = float(model.rsquared)
        out["R2_adj"] = float(model.rsquared_adj)
        out["model"]  = "OLS_statsmodels"
    except Exception as e:
        out["model"] = f"failed: {e}"
    return out


def sig_stars(p):
    """Return significance stars for annotation."""
    if pd.isna(p) or not np.isfinite(p):
        return "ns"
    if p < 0.001:
        return "***"
    if p < 0.01:
        return "**"
    if p < 0.05:
        return "*"
    return "ns"

# Column detection helpers 

def pick_first_existing(cols, candidates):
    for c in candidates:
        if c in cols:
            return c
    return None

def find_col_by_keywords(obs_cols, keywords_any, keywords_not=None):
    keywords_not = keywords_not or []
    for c in obs_cols:
        s = str(c).lower()
        if any(k in s for k in keywords_any) and not any(k in s for k in keywords_not):
            return c
    return None

def detect_donor_col(adata):
    cols = list(map(str, adata.obs.columns))
    for cand in ["donor_id","Donor ID","donor","DonorID","patient_id","PatientID","subject_id","individual"]:
        if cand in cols:
            return cand
    return find_col_by_keywords(cols, ["donor"])

def detect_age_col(adata):
    cols = list(map(str, adata.obs.columns))
    for cand in ["Age","age","Age at Death","age_at_death","age_at_death_years","AgeAtDeath","age_years"]:
        if cand in cols:
            return cand
    return find_col_by_keywords(cols, ["age"], keywords_not=["stage","assay"])

def detect_diagnosis_col(adata):
    cols = list(map(str, adata.obs.columns))
    for cand in ["Diagnosis","diagnosis","dx","Dx","disease","Disease","clinical_diagnosis"]:
        if cand in cols:
            return cand
    return find_col_by_keywords(cols, ["diagnos","dement","alzheimer","adnc","case","control"])

def detect_severity_cols(adata):
    cols = list(map(str, adata.obs.columns))
    out = {
        "braak": pick_first_existing(cols, ["Braak","braak","braak_stage","Braak stage"]),
        "cerad": pick_first_existing(cols, ["CERAD","cerad","CERAD score","cerad_score"]),
        "adnc":  pick_first_existing(cols, ["ADNC","adnc","ADNC_level","adnc_level"]),
    }
    if out["braak"] is None: out["braak"] = find_col_by_keywords(cols, ["braak"])
    if out["cerad"] is None: out["cerad"] = find_col_by_keywords(cols, ["cerad"])
    if out["adnc"]  is None: out["adnc"]  = find_col_by_keywords(cols, ["adnc"])
    return out

def detect_tolerance_col(adata):
    cols = list(map(str, adata.obs.columns))
    c = pick_first_existing(cols, TOL_CANDS)
    if c is not None:
        return c
    return find_col_by_keywords(cols, ["toler","position"], keywords_not=["composition"])

def detect_sex_col(adata):
    cols = list(map(str, adata.obs.columns))
    for cand in ["Sex","sex","gender","Gender"]:
        if cand in cols:
            return cand
    return find_col_by_keywords(cols, ["sex","gender"])

def detect_pmi_col(adata):
    cols = list(map(str, adata.obs.columns))
    for cand in ["PMI","pmi","post_mortem_interval","PostMortemInterval"]:
        if cand in cols:
            return cand
    return find_col_by_keywords(cols, ["pmi","post.?mortem"])

def detect_batch_col(adata):
    cols = list(map(str, adata.obs.columns))
    for cand in ["batch","Batch","library_prep_batch","sequencing_batch"]:
        if cand in cols:
            return cand
    return find_col_by_keywords(cols, ["batch"], keywords_not=["size"])

# Braak / CERAD parsing (same robust approach)

_ROMAN = {"0":0, "I":1, "II":2, "III":3, "IV":4, "V":5, "VI":6}
_CERAD = {"ABSENT":0, "NONE":0, "SPARSE":1, "MODERATE":2, "FREQUENT":3}

def parse_braak(v):
    if pd.isna(v): return np.nan
    s = str(v).upper()
    m = re.findall(r"(VI|IV|V|III|II|I|0)", s)
    if not m: return np.nan
    nums = [_ROMAN.get(x, np.nan) for x in m]
    nums = [x for x in nums if np.isfinite(x)]
    return float(np.max(nums)) if nums else np.nan

def parse_cerad(v):
    if pd.isna(v): return np.nan
    s = str(v).strip().upper()
    vn = pd.to_numeric(s, errors="coerce")
    if np.isfinite(vn): return float(vn)
    for k, val in _CERAD.items():
        if k in s: return float(val)
    return np.nan

# Scoring helpers

def canonicalize_gene(g):
    return IRISIN_ALIASES.get(str(g).strip(), str(g).strip())

def score_geneset_simple(adata, genes, score_name):
    genes = [canonicalize_gene(g) for g in genes]
    genes_present = [g for g in genes if g in adata.var_names]
    if len(genes_present) < max(2, min(5, len(genes)//2)):
        adata.obs[score_name] = np.nan
        log(f"  score_geneset: {score_name} SKIPPED (present={len(genes_present)}/{len(genes)})")
        return genes_present
    sc.tl.score_genes(adata, gene_list=genes_present, score_name=score_name, use_raw=False)
    log(f"  score_geneset: {score_name} OK ({len(genes_present)} genes)")
    return genes_present

def score_geneset_zscore(adata, genes, score_name):
    """Alternative: mean of Z-scored expression. Reports alongside scanpy for robustness."""
    genes = [canonicalize_gene(g) for g in genes]
    genes_present = [g for g in genes if g in adata.var_names]
    if len(genes_present) < 2:
        adata.obs[score_name] = np.nan
        return
    idx = [list(adata.var_names).index(g) for g in genes_present]
    if sp.issparse(adata.X):
        mat = np.asarray(adata.X[:, idx].todense())
    else:
        mat = np.asarray(adata.X[:, idx])
    # Z-score per gene across cells
    mu = np.nanmean(mat, axis=0, keepdims=True)
    sd = np.nanstd(mat, axis=0, keepdims=True)
    sd[sd < 1e-10] = 1.0
    z = (mat - mu) / sd
    adata.obs[score_name] = np.nanmean(z, axis=1)
    log(f"  score_geneset_zscore: {score_name} OK ({len(genes_present)} genes)")

# Donor aggregation

def donor_aggregate(adata, donor_col, cols_to_mean, cols_to_keep_first=None):
    cols_to_keep_first = cols_to_keep_first or []
    obs = adata.obs.copy()
    obs[donor_col] = obs[donor_col].astype(str)
    mean_cols = [c for c in cols_to_mean if c in obs.columns]
    dfm = obs.groupby(donor_col)[mean_cols].mean(numeric_only=True)
    # Also compute n_cells per donor
    n_cells = obs.groupby(donor_col).size().rename("n_cells")
    out = dfm.join(n_cells).reset_index().rename(columns={donor_col: "donor_id"})
    for c in cols_to_keep_first:
        if c in obs.columns:
            tmp = obs.groupby(donor_col)[c].first().reset_index().rename(columns={donor_col: "donor_id"})
            out = out.merge(tmp, on="donor_id", how="left")
    return out

# Status assignment (Clinical or Neuropathological)

def assign_age_group(df, age_col):
    age = pd.to_numeric(df[age_col], errors="coerce")
    out = pd.Series("NA", index=df.index, dtype="object")
    out[age < 65] = "<65"
    out[(age >= 65) & (age < 80)] = "65-79"
    out[age >= 80] = ">=80"
    out[~np.isfinite(age)] = "NA"
    return pd.Categorical(out, categories=["<65","65-79",">=80","NA"])

def assign_status(df, dx_col, sev_cols, prefer_clinical=True,
                  braak_thresh=4, cerad_thresh=2):
    """
    Assign status with configurable thresholds for sensitivity analysis.
    Default: LowPath = Braak<=3 & CERAD<=1; HighPath = Braak>=4 | CERAD>=2.
    """
    # Clinical if available
    if prefer_clinical and dx_col and dx_col in df.columns:
        out = pd.Series("Unknown", index=df.index, dtype="object")
        s = df[dx_col].astype(str).str.lower()
        out[s.str.contains(r"control|no dementia|normal|non[- ]dement", na=False)] = "Control"
        out[s.str.contains(r"alzheimer|dement|ad\b", na=False)] = "AD"
        return pd.Categorical(out, categories=["Control","AD","Unknown"]), "Clinical"

    # Neuropathology strata
    braak_col = sev_cols.get("braak")
    cerad_col = sev_cols.get("cerad")

    b = df[braak_col].map(parse_braak).astype(float) if braak_col and braak_col in df.columns else pd.Series(np.nan, index=df.index, dtype=float)
    c = df[cerad_col].map(parse_cerad).astype(float) if cerad_col and cerad_col in df.columns else pd.Series(np.nan, index=df.index, dtype=float)

    b_vals = b.to_numpy(dtype=float)
    c_vals = c.to_numpy(dtype=float)

    low  = np.isfinite(b_vals) & np.isfinite(c_vals) & (b_vals < braak_thresh) & (c_vals < cerad_thresh)
    high = (np.isfinite(b_vals) & (b_vals >= braak_thresh)) | (np.isfinite(c_vals) & (c_vals >= cerad_thresh))

    out = pd.Series("Unknown", index=df.index, dtype="object")
    out[low]  = "LowPath"
    out[high] = "HighPath"

    return pd.Categorical(out, categories=["LowPath","HighPath","Unknown"]), "Neuropathology"

# Build covariate list

def build_covariate_list(df, sex_col, pmi_col, batch_col):
    """Return list of usable covariate column names (numeric-encoded)."""
    covars = []

    if sex_col and sex_col in df.columns:
        s = df[sex_col].astype(str).str.lower()
        df["_cov_sex_num"] = np.where(s.str.startswith("m"), 0.0,
                             np.where(s.str.startswith("f"), 1.0, np.nan))
        if df["_cov_sex_num"].notna().sum() > 10:
            covars.append("_cov_sex_num")

    if pmi_col and pmi_col in df.columns:
        df["_cov_pmi"] = pd.to_numeric(df[pmi_col], errors="coerce")
        if df["_cov_pmi"].notna().sum() > 10:
            covars.append("_cov_pmi")

    if batch_col and batch_col in df.columns:
        u = df[batch_col].dropna().unique()
        if 1 < len(u) < 20:
            df["_cov_batch_num"] = pd.Categorical(df[batch_col]).codes.astype(float)
            df.loc[df[batch_col].isna(), "_cov_batch_num"] = np.nan
            if df["_cov_batch_num"].notna().sum() > 10:
                covars.append("_cov_batch_num")

    return covars

# Build donor-level table

def make_donor_df(adata, cohort_name="SEA-AD",
                  braak_thresh=4, cerad_thresh=2):
    """Build donor-level DataFrame from a scored AnnData."""
    d_col = detect_donor_col(adata)
    a_col = detect_age_col(adata)
    dx    = detect_diagnosis_col(adata)
    sev   = detect_severity_cols(adata)
    tol   = detect_tolerance_col(adata)
    sx    = detect_sex_col(adata)
    pmi   = detect_pmi_col(adata)
    bat   = detect_batch_col(adata)

    meta = {"cohort": cohort_name,
            "donor_col": d_col, "age_col": a_col, "dx_col": dx,
            "tol_col": tol, "sex_col": sx, "pmi_col": pmi, "batch_col": bat,
            **{f"sev_{k}": v for k, v in sev.items()}}

    # Columns to aggregate
    mean_cols = [c for c in MES_COLS + [GR_COL, "IrisinScore", "IrisinScore_Z"]
                 if c in adata.obs.columns]
    if tol and tol in adata.obs.columns:
        mean_cols.append(tol)
    for k in NT_SIGS:
        if f"NT_{k}" in adata.obs.columns:
            mean_cols.append(f"NT_{k}")
    if "ExerciseModuleScore" in adata.obs.columns:
        mean_cols.append("ExerciseModuleScore")

    keep_first = []
    for c in [a_col, dx, sev.get("braak"), sev.get("cerad"), sev.get("adnc"), sx, pmi, bat]:
        if c and c in adata.obs.columns:
            keep_first.append(c)

    if d_col and d_col in adata.obs.columns:
        ddf = donor_aggregate(adata, d_col, mean_cols, keep_first)
    else:
        ddf = adata.obs.copy().reset_index().rename(columns={"index": "donor_id"})
        ddf["n_cells"] = 1

    # Age
    if a_col and a_col in ddf.columns:
        ddf["age_group"] = assign_age_group(ddf, a_col)
        ddf["age_years"] = pd.to_numeric(ddf[a_col], errors="coerce")
    else:
        ddf["age_group"] = pd.Categorical(["NA"]*len(ddf), categories=["<65","65-79",">=80","NA"])
        ddf["age_years"] = np.nan

    # Status
    status, label = assign_status(ddf, dx_col=dx, sev_cols=sev,
                                  prefer_clinical=True,
                                  braak_thresh=braak_thresh,
                                  cerad_thresh=cerad_thresh)
    ddf["status"] = status
    ddf["status_label"] = label

    # Numeric neuropath
    if sev.get("braak") and sev["braak"] in ddf.columns:
        ddf["Braak_num"] = ddf[sev["braak"]].map(parse_braak)
    else:
        ddf["Braak_num"] = np.nan
    if sev.get("cerad") and sev["cerad"] in ddf.columns:
        ddf["CERAD_num"] = ddf[sev["cerad"]].map(parse_cerad)
    else:
        ddf["CERAD_num"] = np.nan

    # Covariates
    covars = build_covariate_list(ddf, sx, pmi, bat)
    meta["covariates"] = covars

    ddf["cohort"] = cohort_name
    return ddf, meta

# Load all cohorts

def load_scored_cohorts():
    files = sorted(glob.glob(SCORED_GLOB))
    log(f"Found scored cohorts: {[Path(x).name for x in files]}")
    cohorts = {}
    for f in files:
        a = sc.read_h5ad(f)
        if not a.obs_names.is_unique: a.obs_names_make_unique()
        if not a.var_names.is_unique: a.var_names_make_unique()
        ds = Path(f).name.replace("__microglia_scored.h5ad", "")
        cohorts[ds] = a
    log(f"Loaded cohorts: {list(cohorts.keys())}")
    return cohorts

log("=" * 60)
log("NB6 START")
log("=" * 60)

cohorts = load_scored_cohorts()

if "SEA-AD" not in cohorts:
    raise RuntimeError("SEA-AD cohort not found. Check PROC_DIR / SCORED_GLOB.")

# Score gene modules on ALL cohorts before donor aggregation

for cname, adata in cohorts.items():
    log(f"Scoring modules on {cname} (n={adata.n_obs})")
    score_geneset_simple(adata, IRISIN_GENES, "IrisinScore")
    score_geneset_zscore(adata, IRISIN_GENES, "IrisinScore_Z")  # robustness
    if RUN_PART_D:
        for k, g in NT_SIGS.items():
            score_geneset_simple(adata, g, f"NT_{k}")
    if RUN_PART_E:
        score_geneset_simple(adata, EXERCISE_MODULE, "ExerciseModuleScore")

# Extract FNDC5 expression on ALL cohorts

for cname, adata in cohorts.items():
    if "FNDC5" in adata.var_names:
        idx = int(np.where(adata.var_names == "FNDC5")[0][0])
        if sp.issparse(adata.X):
            f = np.asarray(adata.X[:, idx].todense()).ravel()
        else:
            f = np.asarray(adata.X[:, idx]).ravel()
        adata.obs["FNDC5_expr_X"] = f
        log(f"  {cname}: FNDC5 mean={np.nanmean(f):.4f}, frac>0={np.nanmean(f>0):.4f}")
    else:
        adata.obs["FNDC5_expr_X"] = np.nan
        log(f"  {cname}: FNDC5 not found")

# Primary: SEA-AD donor table

sea = cohorts["SEA-AD"]
sea_d, sea_meta = make_donor_df(sea, "SEA-AD")
covars = sea_meta["covariates"]
tol_col = sea_meta["tol_col"]

log(f"SEA-AD donor table: {sea_d.shape[0]} donors")
log(f"Status: {sea_d['status_label'].iloc[0] if len(sea_d) else 'NA'}")
log(f"Status counts:\n{sea_d['status'].value_counts(dropna=False)}")
log(f"Age group counts:\n{sea_d['age_group'].value_counts(dropna=False)}")
log(f"Covariates available: {covars}")

# Add FNDC5 donor mean
if sea_meta["donor_col"] and sea_meta["donor_col"] in sea.obs.columns:
    tmp = donor_aggregate(sea, sea_meta["donor_col"], ["FNDC5_expr_X"])
    tmp = tmp.rename(columns={"FNDC5_expr_X": "FNDC5_donor_mean"})
    sea_d = sea_d.merge(tmp[["donor_id","FNDC5_donor_mean"]], on="donor_id", how="left")
else:
    sea_d["FNDC5_donor_mean"] = np.nan

# Scoring robustness check
r_scoring = corr_with_pvalue(sea_d["IrisinScore"], sea_d.get("IrisinScore_Z", pd.Series(dtype=float)),
                             method="spearman")
log(f"IrisinScore vs IrisinScore_Z (robustness): r={r_scoring['r']:.3f}, p={r_scoring['p']:.2e}")

# PART A: AGE STRATIFICATION (SEA-AD)

def partA_age(sea_d, covars):
    log("=" * 50)
    log("PART A: Age Stratification")

    age_groups_all = ["<65", "65-79", ">=80"]
    age_groups = [ag for ag in age_groups_all
                  if (sea_d["age_group"].astype(str)==ag).sum() >= MIN_DONORS_PER_GROUP]

    if len(age_groups) == 0:
        log("[A] No age groups with sufficient donors; skipping.")
        return pd.DataFrame(), pd.DataFrame()

    # A1: Stratified correlations with p-values + CIs
    rows = []
    for ag in age_groups:
        sub = sea_d[sea_d["age_group"].astype(str)==ag].copy()
        for mes in [c for c in MES_COLS if c in sub.columns]:
            # GR vs MES
            if GR_COL in sub.columns:
                raw = corr_with_pvalue(sub[GR_COL], sub[mes], method="spearman")
                pcr = partial_corr(sub, GR_COL, mes, covars, method="spearman")
                rows.append({
                    "age_group": ag, "target": "GR", "MES": mes.replace("_score",""),
                    "r_spearman": raw["r"], "p_spearman": raw["p"],
                    "ci_lo": raw["ci_lo"], "ci_hi": raw["ci_hi"],
                    "r_partial": pcr["r"], "p_partial": pcr["p"],
                    "N": raw["n"]
                })
            # Tolerance vs MES
            if tol_col and tol_col in sub.columns:
                raw = corr_with_pvalue(sub[tol_col], sub[mes], method="spearman")
                pcr = partial_corr(sub, tol_col, mes, covars, method="spearman")
                rows.append({
                    "age_group": ag, "target": "tolerance", "MES": mes.replace("_score",""),
                    "r_spearman": raw["r"], "p_spearman": raw["p"],
                    "ci_lo": raw["ci_lo"], "ci_hi": raw["ci_hi"],
                    "r_partial": pcr["r"], "p_partial": pcr["p"],
                    "N": raw["n"]
                })

    df_corr = pd.DataFrame(rows)
    if len(df_corr):
        df_corr["q_BH"] = bh_fdr(df_corr["p_spearman"].values)
        df_corr["q_partial_BH"] = bh_fdr(df_corr["p_partial"].values)

    # A2: Interaction tests MES ~ GR + age + GR*age [+ covars]
    inter_rows = []
    if "age_years" in sea_d.columns and GR_COL in sea_d.columns:
        for mes in [c for c in MES_COLS if c in sea_d.columns]:
            out = fit_interaction_ols(sea_d, y=mes, x=GR_COL, z="age_years", covars=covars)
            inter_rows.append({"MES": mes.replace("_score",""), **out})
    df_inter = pd.DataFrame(inter_rows)
    if len(df_inter) and "p_xz" in df_inter.columns:
        df_inter["q_xz_BH"] = bh_fdr(df_inter["p_xz"].values)

    # Figure 6A: Heatmap
    mes_names = [m for m in MES_COLS if m in sea_d.columns]
    ags_ok, mat = [], []
    for ag in age_groups:
        sub = sea_d[sea_d["age_group"].astype(str)==ag]
        if sub.shape[0] == 0:
            continue
        ags_ok.append(f"{ag}\n(n={sub.shape[0]})")
        mat.append([float(np.nanmean(pd.to_numeric(sub[m], errors="coerce"))) for m in mes_names])
    mat = np.array(mat, float)

    fig, ax = plt.subplots(figsize=(7.0, 2.2))
    if HAS_SNS:
        sns.heatmap(mat, ax=ax, annot=True, fmt=".3f", cmap="RdBu_r",
                    center=0, linewidths=0.5, linecolor="white",
                    xticklabels=[m.replace("_score","") for m in mes_names],
                    yticklabels=ags_ok, cbar_kws={"label": "Mean MES score", "shrink": 0.8})
    else:
        im = ax.imshow(mat, aspect="auto", cmap="RdBu_r")
        ax.set_yticks(np.arange(len(ags_ok))); ax.set_yticklabels(ags_ok)
        ax.set_xticks(np.arange(len(mes_names)))
        ax.set_xticklabels([m.replace("_score","") for m in mes_names], rotation=45, ha="right")
        plt.colorbar(im, ax=ax, label="Mean MES score", shrink=0.8)
    ax.set_title("SEA-AD: Age-stratified MES means (donor-level)")
    save_fig(FIG_DIR / "Main_Fig6A_AgeStrat_MES_Means.png", fig)

    return df_corr, df_inter


# PART B: NEUROPATHOLOGY / CLINICAL PROGRESSION

def partB_status(sea_d, covars):
    log("=" * 50)
    log("PART B: Status Progression")

    if "status" not in sea_d.columns:
        log("[B] No status column; skipping.")
        return pd.DataFrame()

    label = str(sea_d["status_label"].iloc[0]) if len(sea_d) else "NA"

    if label == "Clinical":
        g1, g2 = "Control", "AD"
        ylab = "Cohen's d (AD − Control)"
    else:
        g1, g2 = "LowPath", "HighPath"
        ylab = "Cohen's d (HighPath − LowPath)"

    sub = sea_d[sea_d["status"].astype(str).isin([g1, g2])].copy()
    n1 = int((sub["status"].astype(str)==g1).sum())
    n2 = int((sub["status"].astype(str)==g2).sum())
    if n1 < MIN_DONORS_PER_GROUP or n2 < MIN_DONORS_PER_GROUP:
        log(f"[B] Too few donors ({g1}={n1}, {g2}={n2}); skipping.")
        return pd.DataFrame()

    # Features to test
    features = [c for c in MES_COLS if c in sub.columns]
    extra = []
    if tol_col and tol_col in sub.columns: extra.append(tol_col)
    if GR_COL in sub.columns: extra.append(GR_COL)
    if "IrisinScore" in sub.columns: extra.append("IrisinScore")
    if "ExerciseModuleScore" in sub.columns: extra.append("ExerciseModuleScore")
    for k in NT_SIGS:
        c = f"NT_{k}"
        if c in sub.columns: extra.append(c)
    features += extra

    rows = []
    for feat in features:
        a = pd.to_numeric(sub.loc[sub["status"].astype(str)==g1, feat], errors="coerce").to_numpy()
        b = pd.to_numeric(sub.loc[sub["status"].astype(str)==g2, feat], errors="coerce").to_numpy()
        res = cohen_d_with_ci(a, b)
        rows.append({
            "feature": feat.replace("_score",""),
            "d": res["d"], "d_ci_lo": res["ci_lo"], "d_ci_hi": res["ci_hi"],
            "mw_U": res["mw_U"], "mw_p": res["mw_p"],
            "mean_g1": res["mean_x"], "mean_g2": res["mean_y"],
            "sd_g1": res["sd_x"], "sd_g2": res["sd_y"],
            "n_g1": res["n_x"], "n_g2": res["n_y"],
            "group1": g1, "group2": g2
        })

    df = pd.DataFrame(rows)
    if len(df):
        df["mw_q_BH"] = bh_fdr(df["mw_p"].values)
    df = df.sort_values("d", ascending=False)

    # Figure 6B: Cohen's d with CIs and significance
    fig, ax = plt.subplots(figsize=(9.0, 3.5))
    x_pos = np.arange(len(df))
    colors = [PALETTE_STATUS.get(g2, "#D6604D") if d > 0 else PALETTE_STATUS.get(g1, "#4393C3")
              for d in df["d"].values]

    bars = ax.bar(x_pos, df["d"].values, color=colors, edgecolor="black", linewidth=0.5, alpha=0.85)

    # Error bars from bootstrap CI
    yerr_lo = df["d"].values - df["d_ci_lo"].values
    yerr_hi = df["d_ci_hi"].values - df["d"].values
    ax.errorbar(x_pos, df["d"].values,
                yerr=[yerr_lo, yerr_hi],
                fmt="none", ecolor="black", elinewidth=0.8, capsize=2.5, capthick=0.8)

    # Significance stars
    for i, (_, row) in enumerate(df.iterrows()):
        star = sig_stars(row.get("mw_q_BH", row.get("mw_p")))
        if star != "ns":
            y_off = row["d_ci_hi"] + 0.03 if row["d"] > 0 else row["d_ci_lo"] - 0.06
            ax.text(i, y_off, star, ha="center", va="bottom" if row["d"]>0 else "top",
                    fontsize=7, fontweight="bold")

    ax.set_xticks(x_pos)
    ax.set_xticklabels(df["feature"].values, rotation=55, ha="right")
    ax.axhline(0, color="black", linewidth=0.8)
    ax.set_ylabel(ylab)
    ax.set_title(f"SEA-AD: {g2} vs {g1} (n={n1} vs {n2}) | bars=Cohen's d, whiskers=95% CI")
    if HAS_SNS: sns.despine(ax=ax)
    save_fig(FIG_DIR / f"Main_Fig6B_{'AD' if label=='Clinical' else 'Pathology'}_Effects_CohenD.png", fig)

    tab_name = "Supplementary_Table6_AD_Differential.xlsx" if label=="Clinical" else "Supplementary_Table6_Pathology_Differential.xlsx"
    save_xlsx(df, TAB_DIR / tab_name)

    return df


# PART C: FNDC5 / IRISIN DEEP DIVE

def partC_irisin(sea, sea_d, covars):
    log("=" * 50)
    log("PART C: FNDC5 / Irisin Deep Dive")

    rows = []

    # C1: IrisinScore vs MES/tolerance by age group with stats
    if "age_group" in sea_d.columns:
        for ag in ["<65", "65-79", ">=80"]:
            sub = sea_d[sea_d["age_group"].astype(str)==ag]
            if sub.shape[0] < MIN_DONORS_PER_GROUP:
                continue
            for mes in [c for c in MES_COLS if c in sub.columns]:
                raw = corr_with_pvalue(sub["IrisinScore"], sub[mes], method="spearman")
                rows.append({
                    "stratum": "age_group", "group": ag, "target": mes.replace("_score",""),
                    "assoc": "Irisin×MES",
                    "r": raw["r"], "p": raw["p"],
                    "ci_lo": raw["ci_lo"], "ci_hi": raw["ci_hi"],
                    "N": raw["n"]
                })
            if tol_col and tol_col in sub.columns:
                raw = corr_with_pvalue(sub["IrisinScore"], sub[tol_col], method="spearman")
                rows.append({
                    "stratum": "age_group", "group": ag, "target": tol_col,
                    "assoc": "Irisin×tolerance",
                    "r": raw["r"], "p": raw["p"],
                    "ci_lo": raw["ci_lo"], "ci_hi": raw["ci_hi"],
                    "N": raw["n"]
                })

    # C2: Overall correlations
    for target_name, target_col in [("GR", GR_COL), ("tolerance", tol_col)]:
        if target_col and target_col in sea_d.columns:
            raw = corr_with_pvalue(sea_d["IrisinScore"], sea_d[target_col], method="spearman")
            pcr = partial_corr(sea_d, "IrisinScore", target_col, covars, method="spearman")
            rows.append({
                "stratum": "all", "group": "all", "target": target_name,
                "assoc": f"Irisin×{target_name}",
                "r": raw["r"], "p": raw["p"],
                "ci_lo": raw["ci_lo"], "ci_hi": raw["ci_hi"],
                "r_partial": pcr["r"], "p_partial": pcr["p"],
                "N": raw["n"]
            })

    # C3: FNDC5_donor_mean vs IrisinScore
    raw = corr_with_pvalue(sea_d["FNDC5_donor_mean"], sea_d["IrisinScore"], method="spearman")
    rows.append({
        "stratum": "all", "group": "all", "target": "FNDC5_donor_mean",
        "assoc": "FNDC5_expr×IrisinScore",
        "r": raw["r"], "p": raw["p"],
        "ci_lo": raw["ci_lo"], "ci_hi": raw["ci_hi"],
        "N": raw["n"]
    })

    # C4: Permutation test for key irisin-tolerance link
    if tol_col and tol_col in sea_d.columns:
        perm_p = permutation_test_corr(sea_d["IrisinScore"], sea_d[tol_col],
                                       method="spearman", n_perm=N_PERM)
        rows.append({
            "stratum": "all", "group": "all", "target": "tolerance_permutation",
            "assoc": "Irisin×tolerance (permutation)",
            "r": np.nan, "p": perm_p,
            "ci_lo": np.nan, "ci_hi": np.nan,
            "N": int(sea_d.shape[0])
        })
        log(f"  Irisin×tolerance permutation p = {perm_p:.4f}")

    df_ir = pd.DataFrame(rows)
    if len(df_ir) and "p" in df_ir.columns:
        df_ir["q_BH"] = bh_fdr(df_ir["p"].values)

    # Figure 6C: IrisinScore vs tolerance by status
    label = str(sea_d["status_label"].iloc[0]) if len(sea_d) else "NA"
    g1 = "Control" if label=="Clinical" else "LowPath"
    g2 = "AD" if label=="Clinical" else "HighPath"
    has_groups = sea_d["status"].astype(str).isin([g1, g2]).any()

    if tol_col and tol_col in sea_d.columns and has_groups:
        sub = sea_d[sea_d["status"].astype(str).isin([g1, g2])].copy()
        fig, axes = plt.subplots(1, 2, figsize=(8.5, 3.5), sharey=True)

        for ax, grp, color in zip(axes, [g1, g2],
                                  [PALETTE_STATUS[g1], PALETTE_STATUS[g2]]):
            ss = sub[sub["status"].astype(str)==grp]
            x = pd.to_numeric(ss["IrisinScore"], errors="coerce")
            y = pd.to_numeric(ss[tol_col], errors="coerce")
            ax.scatter(x, y, s=22, alpha=0.7, color=color, edgecolors="black", linewidth=0.3)

            # Regression line
            m = np.isfinite(x) & np.isfinite(y)
            if m.sum() > 5:
                z = np.polyfit(x[m], y[m], 1)
                xline = np.linspace(np.nanmin(x[m]), np.nanmax(x[m]), 100)
                ax.plot(xline, np.polyval(z, xline), color="black", linewidth=1.0, linestyle="--")

            raw = corr_with_pvalue(x, y, method="spearman")
            star = sig_stars(raw["p"])
            ax.set_title(f"{grp} (n={ss.shape[0]})\nρ={raw['r']:.2f} [{raw['ci_lo']:.2f}, {raw['ci_hi']:.2f}] {star}",
                        fontsize=8)
            ax.set_xlabel("IrisinScore")
            ax.axhline(0, color="grey", linewidth=0.5, linestyle=":")
            if HAS_SNS: sns.despine(ax=ax)

        axes[0].set_ylabel(tol_col)
        fig.suptitle("SEA-AD: IrisinScore vs Tolerance (donor-level)", fontsize=9, y=1.02)
        fig.tight_layout()
        save_fig(FIG_DIR / f"Main_Fig6C_Irisin_vs_Tolerance_{g1}_vs_{g2}.png", fig)

    elif tol_col and tol_col in sea_d.columns:
        # Single panel
        fig, ax = plt.subplots(figsize=(5.5, 4.5))
        x = pd.to_numeric(sea_d["IrisinScore"], errors="coerce")
        y = pd.to_numeric(sea_d[tol_col], errors="coerce")
        ax.scatter(x, y, s=30, alpha=0.7, color="#4393C3", edgecolors="black", linewidth=0.3)
        m = np.isfinite(x) & np.isfinite(y)
        if m.sum() > 5:
            z = np.polyfit(x[m], y[m], 1)
            xline = np.linspace(np.nanmin(x[m]), np.nanmax(x[m]), 100)
            ax.plot(xline, np.polyval(z, xline), color="black", linewidth=1.0, linestyle="--")
        raw = corr_with_pvalue(x, y, method="spearman")
        star = sig_stars(raw["p"])
        ax.set_title(f"All donors (n={raw['n']})\nρ={raw['r']:.2f} [{raw['ci_lo']:.2f}, {raw['ci_hi']:.2f}] {star}")
        ax.set_xlabel("IrisinScore"); ax.set_ylabel(tol_col)
        if HAS_SNS: sns.despine(ax=ax)
        save_fig(FIG_DIR / "Main_Fig6C_Irisin_vs_Tolerance_All.png", fig)

    # Save donor table
    sea_d.to_csv(TAB_DIR / "NB6_SEA_AD_DonorTable.csv", index=False)
    log(f"SAVED {TAB_DIR / 'NB6_SEA_AD_DonorTable.csv'}")

    return df_ir


# PART D: NEUROTRANSMITTER SIGNATURES 

def partD_neurotransmitters(sea_d, covars):
    log("=" * 50)
    log("PART D: Neurotransmitter Signatures")

    nt_cols = [f"NT_{k}" for k in NT_SIGS if f"NT_{k}" in sea_d.columns]
    if not nt_cols:
        log("[D] No NT columns; skipping.")
        return pd.DataFrame()

    rows = []
    for nt in nt_cols:
        for mes in [c for c in MES_COLS if c in sea_d.columns]:
            raw = corr_with_pvalue(sea_d[nt], sea_d[mes], method="spearman")
            rows.append({"NT": nt, "MES": mes.replace("_score",""),
                         "r": raw["r"], "p": raw["p"],
                         "ci_lo": raw["ci_lo"], "ci_hi": raw["ci_hi"],
                         "N": raw["n"]})
    df = pd.DataFrame(rows)
    if len(df):
        df["q_BH"] = bh_fdr(df["p"].values)

    # Heatmap with significance
    piv_r = df.pivot_table(index="NT", columns="MES", values="r", aggfunc="mean")
    piv_q = df.pivot_table(index="NT", columns="MES", values="q_BH", aggfunc="mean")

    fig, ax = plt.subplots(figsize=(8.0, 2.8))
    if HAS_SNS:
        sns.heatmap(piv_r, ax=ax, annot=True, fmt=".2f", cmap="RdBu_r", center=0,
                    linewidths=0.5, linecolor="white",
                    cbar_kws={"label": "Spearman ρ", "shrink": 0.8})
        # Add significance markers
        for i, nt in enumerate(piv_r.index):
            for j, mes in enumerate(piv_r.columns):
                q = piv_q.loc[nt, mes] if nt in piv_q.index and mes in piv_q.columns else np.nan
                s = sig_stars(q)
                if s != "ns":
                    ax.text(j+0.5, i+0.82, s, ha="center", va="center", fontsize=6, color="black")
    else:
        im = ax.imshow(piv_r.values, aspect="auto", cmap="RdBu_r")
        ax.set_yticks(range(piv_r.shape[0])); ax.set_yticklabels(piv_r.index)
        ax.set_xticks(range(piv_r.shape[1])); ax.set_xticklabels(piv_r.columns, rotation=45, ha="right")
        plt.colorbar(im, ax=ax, label="Spearman ρ", shrink=0.8)
    ax.set_title("SEA-AD: Neurotransmitter signatures vs MES (donor-level)")
    save_fig(FIG_DIR / "Supp_Fig6D_NT_Sigs.png", fig)

    return df

# PART E: EXERCISE MODULE 

def partE_exercise(sea_d, covars):
    log("=" * 50)
    log("PART E: Exercise Module")

    if "ExerciseModuleScore" not in sea_d.columns:
        log("[E] ExerciseModuleScore missing; skipping.")
        return pd.DataFrame()

    rows = []
    for mes in [c for c in MES_COLS if c in sea_d.columns]:
        raw = corr_with_pvalue(sea_d["ExerciseModuleScore"], sea_d[mes], method="spearman")
        pcr = partial_corr(sea_d, "ExerciseModuleScore", mes, covars, method="spearman")
        rows.append({"MES": mes.replace("_score",""),
                     "r": raw["r"], "p": raw["p"],
                     "ci_lo": raw["ci_lo"], "ci_hi": raw["ci_hi"],
                     "r_partial": pcr["r"], "p_partial": pcr["p"],
                     "N": raw["n"]})
    df = pd.DataFrame(rows)
    if len(df):
        df["q_BH"] = bh_fdr(df["p"].values)
    df = df.sort_values("r", ascending=False)

    fig, ax = plt.subplots(figsize=(7.0, 3.0))
    colors = ["#66C2A5" if r > 0 else "#FC8D62" for r in df["r"].values]
    ax.bar(df["MES"], df["r"].values, color=colors, edgecolor="black", linewidth=0.5)
    yerr_lo = df["r"].values - df["ci_lo"].values
    yerr_hi = df["ci_hi"].values - df["r"].values
    ax.errorbar(np.arange(len(df)), df["r"].values,
                yerr=[yerr_lo, yerr_hi],
                fmt="none", ecolor="black", elinewidth=0.8, capsize=2.5)
    for i, (_, row) in enumerate(df.iterrows()):
        s = sig_stars(row.get("q_BH", row["p"]))
        if s != "ns":
            y = row["ci_hi"] + 0.02 if row["r"] > 0 else row["ci_lo"] - 0.03
            ax.text(i, y, s, ha="center", va="bottom" if row["r"]>0 else "top", fontsize=7, fontweight="bold")
    ax.axhline(0, color="black", linewidth=0.8)
    ax.set_ylabel("Spearman ρ"); ax.set_title("SEA-AD: Exercise module × MES (donor-level)")
    ax.set_xticklabels(df["MES"], rotation=45, ha="right")
    if HAS_SNS: sns.despine(ax=ax)
    save_fig(FIG_DIR / "Supp_Fig6E_Exercise_Module.png", fig)

    return df

# PART F: CROSS-COHORT VALIDATION

def partF_cross_cohort(cohorts, covars_sea):
    log("=" * 50)
    log("PART F: Cross-Cohort Validation")

    all_rows = []

    for cname, adata in cohorts.items():
        log(f"  Processing {cname} ...")
        ddf, meta = make_donor_df(adata, cname)
        local_covars = meta["covariates"]
        local_tol = meta["tol_col"]
        label = str(ddf["status_label"].iloc[0]) if len(ddf) else "NA"

        if label == "Clinical":
            g1, g2 = "Control", "AD"
        else:
            g1, g2 = "LowPath", "HighPath"

        n1 = (ddf["status"].astype(str)==g1).sum()
        n2 = (ddf["status"].astype(str)==g2).sum()

        # Effect sizes for MES + key features
        for feat in [c for c in MES_COLS if c in ddf.columns] + \
                     ([GR_COL] if GR_COL in ddf.columns else []) + \
                     (["IrisinScore"] if "IrisinScore" in ddf.columns else []):

            if n1 >= MIN_DONORS_PER_GROUP and n2 >= MIN_DONORS_PER_GROUP:
                a = pd.to_numeric(ddf.loc[ddf["status"].astype(str)==g1, feat], errors="coerce").to_numpy()
                b = pd.to_numeric(ddf.loc[ddf["status"].astype(str)==g2, feat], errors="coerce").to_numpy()
                res = cohen_d_with_ci(a, b)
            else:
                res = {"d": np.nan, "ci_lo": np.nan, "ci_hi": np.nan,
                       "mw_p": np.nan, "n_x": int(n1), "n_y": int(n2)}

            # Irisin-tolerance correlation
            ir_tol_r = np.nan
            if local_tol and local_tol in ddf.columns and "IrisinScore" in ddf.columns:
                ir_corr = corr_with_pvalue(ddf["IrisinScore"], ddf[local_tol], method="spearman")
                ir_tol_r = ir_corr["r"]

            all_rows.append({
                "cohort": cname, "feature": feat.replace("_score",""),
                "status_type": label, "g1": g1, "g2": g2,
                "d": res["d"], "d_ci_lo": res["ci_lo"], "d_ci_hi": res["ci_hi"],
                "mw_p": res.get("mw_p", np.nan),
                "n_g1": res.get("n_x", int(n1)), "n_g2": res.get("n_y", int(n2)),
                "n_donors": int(len(ddf)),
                "irisin_tol_r": ir_tol_r
            })

    df = pd.DataFrame(all_rows)
    if len(df):
        df["mw_q_BH"] = bh_fdr(df["mw_p"].values)

    # Forest plot: Cohen's d across cohorts for GR_composite
    gr_df = df[df["feature"]==GR_COL.replace("_score","")].copy()
    if len(gr_df) > 1:
        gr_df = gr_df.sort_values("d")
        fig, ax = plt.subplots(figsize=(6.0, max(2.5, 0.4*len(gr_df))))
        y_pos = np.arange(len(gr_df))

        for i, (_, row) in enumerate(gr_df.iterrows()):
            color = "#D6604D" if row["d"] > 0 else "#4393C3"
            ax.errorbar(row["d"], i,
                       xerr=[[row["d"]-row["d_ci_lo"]], [row["d_ci_hi"]-row["d"]]],
                       fmt="o", color=color, ecolor="black", elinewidth=0.8,
                       capsize=3, markersize=6)
            star = sig_stars(row.get("mw_q_BH", row["mw_p"]))
            ax.text(row["d_ci_hi"]+0.05, i, f'{star}  n={row["n_g1"]}+{row["n_g2"]}',
                   va="center", fontsize=7)

        ax.set_yticks(y_pos)
        ax.set_yticklabels(gr_df["cohort"].values)
        ax.axvline(0, color="black", linewidth=0.8, linestyle="--")
        ax.set_xlabel("Cohen's d (High − Low pathology)")
        ax.set_title("Cross-cohort: GR composite effect sizes")
        if HAS_SNS: sns.despine(ax=ax)
        save_fig(FIG_DIR / "Supp_Fig6F_CrossCohort_ForestPlot.png", fig)

    # Save
    save_xlsx(df, TAB_DIR / "Supplementary_Table6_CrossCohort.xlsx")
    df.to_csv(TAB_DIR / "NB6_CrossCohort_Summary.csv", index=False)
    log(f"SAVED {TAB_DIR / 'NB6_CrossCohort_Summary.csv'}")

    return df

# PART G: SENSITIVITY ANALYSES

def partG_sensitivity(sea, sea_d_orig, covars):
    log("=" * 50)
    log("PART G: Sensitivity Analyses")

    results = {}

    # G1: Different Braak/CERAD thresholds
    threshold_schemes = [
        {"name": "Braak≥3/CERAD≥1 (liberal)",  "braak": 3, "cerad": 1},
        {"name": "Braak≥4/CERAD≥2 (default)",   "braak": 4, "cerad": 2},
        {"name": "Braak≥5/CERAD≥2 (strict)",    "braak": 5, "cerad": 2},
        {"name": "Braak≥5/CERAD≥3 (very strict)","braak": 5, "cerad": 3},
    ]

    thresh_rows = []
    for scheme in threshold_schemes:
        ddf, meta = make_donor_df(sea, "SEA-AD",
                                  braak_thresh=scheme["braak"],
                                  cerad_thresh=scheme["cerad"])
        label = str(ddf["status_label"].iloc[0]) if len(ddf) else "NA"
        g1 = "Control" if label=="Clinical" else "LowPath"
        g2 = "AD" if label=="Clinical" else "HighPath"

        n1 = (ddf["status"].astype(str)==g1).sum()
        n2 = (ddf["status"].astype(str)==g2).sum()

        for feat in [c for c in MES_COLS if c in ddf.columns] + \
                     ([GR_COL] if GR_COL in ddf.columns else []):
            if n1 >= MIN_DONORS_PER_GROUP and n2 >= MIN_DONORS_PER_GROUP:
                a = pd.to_numeric(ddf.loc[ddf["status"].astype(str)==g1, feat], errors="coerce").to_numpy()
                b = pd.to_numeric(ddf.loc[ddf["status"].astype(str)==g2, feat], errors="coerce").to_numpy()
                res = cohen_d_with_ci(a, b, n_boot=500)  # fewer boots for speed
            else:
                res = {"d": np.nan, "ci_lo": np.nan, "ci_hi": np.nan, "mw_p": np.nan}

            thresh_rows.append({
                "scheme": scheme["name"],
                "braak_thresh": scheme["braak"], "cerad_thresh": scheme["cerad"],
                "feature": feat.replace("_score",""),
                "d": res["d"], "d_ci_lo": res["ci_lo"], "d_ci_hi": res["ci_hi"],
                "mw_p": res.get("mw_p", np.nan),
                "n_low": int(n1), "n_high": int(n2)
            })

    df_thresh = pd.DataFrame(thresh_rows)
    if len(df_thresh):
        df_thresh["mw_q_BH"] = bh_fdr(df_thresh["mw_p"].values)
    results["thresholds"] = df_thresh

    # G2: Leave-one-out donor stability
    loo_rows = []
    if GR_COL in sea_d_orig.columns and tol_col and tol_col in sea_d_orig.columns:
        base_r = corr_with_pvalue(sea_d_orig[GR_COL], sea_d_orig[tol_col], method="spearman")["r"]
        for i in range(len(sea_d_orig)):
            sub = sea_d_orig.drop(sea_d_orig.index[i])
            r_loo = corr_with_pvalue(sub[GR_COL], sub[tol_col], method="spearman")["r"]
            loo_rows.append({
                "dropped_donor": sea_d_orig.iloc[i].get("donor_id", i),
                "r_spearman_loo": r_loo,
                "delta_r": r_loo - base_r
            })
        df_loo = pd.DataFrame(loo_rows)
        results["loo_GR_tol"] = df_loo

        log(f"  LOO GR×tolerance: base r={base_r:.3f}, "
            f"range=[{df_loo['r_spearman_loo'].min():.3f}, {df_loo['r_spearman_loo'].max():.3f}]")

        # LOO plot
        fig, ax = plt.subplots(figsize=(6.0, 3.0))
        ax.hist(df_loo["r_spearman_loo"], bins=20, color="#4393C3", edgecolor="black", linewidth=0.5, alpha=0.8)
        ax.axvline(base_r, color="red", linewidth=1.2, linestyle="--", label=f"Full sample ρ={base_r:.3f}")
        ax.set_xlabel("Spearman ρ (LOO)")
        ax.set_ylabel("Count")
        ax.set_title("Leave-one-out stability: GR × Tolerance")
        ax.legend(fontsize=7)
        if HAS_SNS: sns.despine(ax=ax)
        save_fig(FIG_DIR / "Supp_Fig6G_LOO_Stability.png", fig)
    else:
        results["loo_GR_tol"] = pd.DataFrame()

    # G3: Scoring robustness (scanpy vs Z-score)
    rob_rows = []
    if "IrisinScore" in sea_d_orig.columns and "IrisinScore_Z" in sea_d_orig.columns:
        for mes in [c for c in MES_COLS if c in sea_d_orig.columns]:
            r1 = corr_with_pvalue(sea_d_orig["IrisinScore"], sea_d_orig[mes], method="spearman")
            r2 = corr_with_pvalue(sea_d_orig["IrisinScore_Z"], sea_d_orig[mes], method="spearman")
            rob_rows.append({
                "MES": mes.replace("_score",""),
                "r_scanpy": r1["r"], "p_scanpy": r1["p"],
                "r_zscore": r2["r"], "p_zscore": r2["p"],
                "delta_r": abs(r1["r"] - r2["r"]) if np.isfinite(r1["r"]) and np.isfinite(r2["r"]) else np.nan
            })
    results["scoring_robustness"] = pd.DataFrame(rob_rows)

    # Save all sensitivity results
    save_xlsx_multi({
        "Threshold_Sensitivity": df_thresh,
        "LOO_Stability": results.get("loo_GR_tol", pd.DataFrame()),
        "Scoring_Robustness": results.get("scoring_robustness", pd.DataFrame()),
    }, TAB_DIR / "Supplementary_Table6_Sensitivity.xlsx")

    return results


# MAIN EXECUTION

log("=" * 60)
log("EXECUTING ALL PARTS")
log("=" * 60)

summary_rows = []

def push_summary(label, df):
    sl = str(df["status_label"].iloc[0]) if "status_label" in df.columns and len(df) else "NA"
    if sl == "Clinical":
        g1, g2 = "Control", "AD"
    else:
        g1, g2 = "LowPath", "HighPath"
    n1 = int((df.get("status", pd.Series(dtype=object)).astype(str)==g1).sum())
    n2 = int((df.get("status", pd.Series(dtype=object)).astype(str)==g2).sum())
    summary_rows.append({
        "label": label, "N_donors": int(df.shape[0]), "status_label": sl,
        f"N_{g1}": n1, f"N_{g2}": n2,
        "covariates": ",".join(covars),
        "age_col": sea_meta.get("age_col"),
        "dx_col": sea_meta.get("dx_col"),
        "donor_col": sea_meta.get("donor_col"),
        "tol_col": tol_col,
        "braak_col": sea_meta.get("sev_braak"),
        "cerad_col": sea_meta.get("sev_cerad"),
    })

push_summary("SEA-AD donor table", sea_d)

# Part A
df_age_corr, df_age_inter = pd.DataFrame(), pd.DataFrame()
if RUN_PART_A:
    df_age_corr, df_age_inter = partA_age(sea_d, covars)
    save_xlsx(df_age_corr, TAB_DIR / "Supplementary_Table6_AgeStrat_Corr.xlsx")
    save_xlsx(df_age_inter, TAB_DIR / "Supplementary_Table6_Age_GR_MES_Interaction.xlsx")

# Part B
df_status = pd.DataFrame()
if RUN_PART_B:
    df_status = partB_status(sea_d, covars)

# Part C
df_irisin = pd.DataFrame()
if RUN_PART_C:
    df_irisin = partC_irisin(sea, sea_d, covars)
    save_xlsx(df_irisin, TAB_DIR / "Supplementary_Table6_Irisin.xlsx")

# Part D
df_nt = pd.DataFrame()
if RUN_PART_D:
    df_nt = partD_neurotransmitters(sea_d, covars)
    save_xlsx(df_nt, TAB_DIR / "Supplementary_Table6_NT_Signatures.xlsx")

# Part E
df_ex = pd.DataFrame()
if RUN_PART_E:
    df_ex = partE_exercise(sea_d, covars)
    save_xlsx(df_ex, TAB_DIR / "Supplementary_Table6_ExerciseModule.xlsx")

# Part F
df_cross = pd.DataFrame()
if RUN_PART_F:
    df_cross = partF_cross_cohort(cohorts, covars)

#  Part G
sensitivity_results = {}
if RUN_PART_G:
    sensitivity_results = partG_sensitivity(sea, sea_d, covars)

# Main summary table
df_main6 = pd.DataFrame(summary_rows)
save_xlsx(df_main6, TAB_DIR / "Main_Table6_Age_Path_Summary.xlsx")

log("=" * 60)
log("NB6 COMPLETED")
log("=" * 60)
log(f"Figures: {FIG_DIR}")
log(f"Tables:  {TAB_DIR}")
log("Key outputs:")
log("  Main_Fig6A — Age-stratified MES heatmap")
log("  Main_Fig6B — Cohen's d with CIs + significance")
log("  Main_Fig6C — Irisin × tolerance by status")
log("  Supp_Fig6F — Cross-cohort forest plot")
log("  Supp_Fig6G — LOO stability")
log("  All tables include: p-values, CIs, BH-FDR q-values, partial correlations")