# NB4 — Microglial Transfer, Spatial Validation & Stress-Axis Calibration

Projection of thymus-trained MES onto four independent microglial cohorts (SEA-AD, Olah, Tuddenham, MS GSE180759), Visium spatial validation, GR stress-axis calibration (GSE219208), and LPS tolerance model scoring.

**Paper:** Zafar SA, Qin W. *Thymus-Derived Myeloid Education Signatures Predict Microglial Tolerance Positioning and Are Modulated by Glucocorticoid Stress-Axis Activity.* Neuroimmunomodulation (2026).

> **Note:** Update the path variables in section 0 to match your local directory structure before running. Raw data can be obtained from the public repositories listed in Supplementary Table S1.


In [None]:
from __future__ import annotations

import os, re, warnings, gc, time, traceback
from pathlib import Path
import numpy as np
import pandas as pd

import scanpy as sc
import anndata as ad
import scipy.sparse as sp
import scipy.stats as stats
import scipy.io
import matplotlib
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")
np.random.seed(42)
sc.settings.verbosity = 2

try:
    import seaborn as sns
    sns.set_style("ticks")
    HAS_SNS = True
except ImportError:
    HAS_SNS = False

# 0) GLOBAL STYLE
matplotlib.rcParams.update({
    "font.family": "Arial",
    "font.size": 8,
    "axes.labelsize": 8,
    "axes.titlesize": 9,
    "xtick.labelsize": 7,
    "ytick.labelsize": 7,
    "legend.fontsize": 7,
    "figure.titlesize": 9,
    "axes.linewidth": 0.8,
    "xtick.major.width": 0.6,
    "ytick.major.width": 0.6,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
    "figure.dpi": 150,
    "savefig.dpi": 1200,
    "savefig.bbox": "tight",
    "savefig.pad_inches": 0.05,
})

FIG_DPI = 1200

# Color system
CMAP_CONT = "cividis"
CMAP_HEAT = "RdBu_r"
CMAP_DIV  = "RdBu_r"

C_BAR_MAIN = "#4A4A4A"
C_BAR_ACC  = "#D1495B"
C_BAR_ACC2 = "#2A9D8F"
C_GRID     = "#B0B0B0"
C_UP  = "#C0392B"
C_DN  = "#2E7D32"
C_NEU = "#7F8C8D"

def _beautify_axes(ax):
    ax.grid(axis="y", alpha=0.22, color=C_GRID, linewidth=0.6)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

# 0b) STATISTICAL PARAMETERS 
N_BOOT = 2000
N_PERM = 5000
SEED   = 42
ALPHA  = 0.05

# 1) PATHS
BASE_DIR = Path(".")  # <-- SET TO YOUR PROJECT ROOT
RAW_DIR  = BASE_DIR / "Raw Data"
PROC_DIR = BASE_DIR / "Process Data"

MANUSCRIPT_DIR = BASE_DIR / "outputs" / "manuscript"
FIG_DIR = MANUSCRIPT_DIR / "Figures"
TAB_DIR = MANUSCRIPT_DIR / "Tables"

AIM2_DIR = PROC_DIR / "aim2_microglia"
AIM3_DIR = PROC_DIR / "aim3_stress"
NB4_EXPORT_DIR = PROC_DIR / "nb4_exports"

for d in [FIG_DIR, TAB_DIR, AIM2_DIR, AIM3_DIR, NB4_EXPORT_DIR]:
    d.mkdir(parents=True, exist_ok=True)

assert RAW_DIR.exists(), f"RAW_DIR missing: {RAW_DIR}"

GSE233208_EXPORT = NB4_EXPORT_DIR / "GSE233208"
GSE233208_VIS_MTX = GSE233208_EXPORT / "visium_mtx"
GSE233208_SN_MTX  = GSE233208_EXPORT / "snrna_mtx"
GSE233208_EXPORT.mkdir(exist_ok=True)
GSE233208_VIS_MTX.mkdir(exist_ok=True)
GSE233208_SN_MTX.mkdir(exist_ok=True)

# 2) SAVE HELPERS
def save_fig(fig, fname: str, kind: str = "Supplementary"):
    assert kind in ("Main", "Supplementary")
    out = FIG_DIR / f"{kind}_{fname}.png"
    fig.savefig(out, dpi=FIG_DPI, bbox_inches="tight")
    plt.close(fig)
    print(f"[SAVED FIG] {out}")

def save_excel(sheets: dict, fname: str, kind: str = "Supplementary"):
    assert kind in ("Main", "Supplementary")
    out = TAB_DIR / f"{kind}_{fname}.xlsx"
    with pd.ExcelWriter(out, engine="openpyxl") as w:
        for sn, df in sheets.items():
            if df is None:
                df = pd.DataFrame()
            df.to_excel(w, index=False, sheet_name=str(sn)[:31])
    print(f"[SAVED TABLE] {out}")

def sanitize_for_write(adata: ad.AnnData) -> ad.AnnData:
    for df_name in ["obs", "var"]:
        df = getattr(adata, df_name)
        if df.index.name is not None:
            idx_name = df.index.name
            if idx_name in df.columns:
                if not pd.Series(df.index, index=df.index).equals(df[idx_name]):
                    df.rename(columns={idx_name: f"{idx_name}_col"}, inplace=True)
            df.index.name = None
        df.columns = df.columns.astype(str)
    adata.obs_names = adata.obs_names.astype(str)
    adata.var_names = adata.var_names.astype(str)
    return adata

def sig_stars(p):
    if pd.isna(p) or not np.isfinite(p): return "ns"
    if p < 0.001: return "***"
    if p < 0.01:  return "**"
    if p < 0.05:  return "*"
    return "ns"

# 3) CORE HELPERS 
def ensure_counts_layer(adata):
    if "counts" not in adata.layers:
        adata.layers["counts"] = adata.X.copy()
    if not sp.issparse(adata.layers["counts"]):
        adata.layers["counts"] = sp.csr_matrix(adata.layers["counts"])

def looks_log1p(adata, n=2000):
    if adata.n_obs == 0 or adata.n_vars == 0: return False
    X = adata.X
    if sp.issparse(X):
        v = X.data
        if v.size == 0: return False
        v = v[:min(v.size, n)]
    else:
        v = np.asarray(X).ravel()[:n]
    frac_nonint = np.mean(np.abs(v - np.round(v)) > 1e-6)
    return (np.nanmax(v) < 25) and (frac_nonint > 0.2)

def prep_for_scoring(adata):
    if adata.n_obs == 0 or adata.n_vars == 0: return
    ensure_counts_layer(adata)
    if looks_log1p(adata): return
    adata.X = adata.layers["counts"].copy()
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)

def compute_basic_qc(adata):
    if "mt" not in adata.var.columns:
        adata.var["mt"] = adata.var_names.astype(str).str.upper().str.startswith("MT-")
    sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True)

def safe_corr(x, y):
    x = np.asarray(x); y = np.asarray(y)
    m = np.isfinite(x) & np.isfinite(y)
    if m.sum() < 3: return np.nan
    return float(np.corrcoef(x[m], y[m])[0, 1])

def bh_fdr(pvals):
    p = np.asarray(pvals, float)
    m = np.isfinite(p)
    out = np.full_like(p, np.nan, dtype=float)
    if m.sum() == 0: return out
    pv = p[m]; n = pv.size
    order = np.argsort(pv)
    ranked = pv[order]
    q = ranked * n / (np.arange(1, n + 1))
    q = np.minimum.accumulate(q[::-1])[::-1]
    out_m = np.empty_like(pv)
    out_m[order] = np.clip(q, 0, 1)
    out[m] = out_m
    return out

def corr_with_ci_p(x, y, n_boot=N_BOOT, n_perm=N_PERM, seed=SEED):
    """Correlation with bootstrap CI + permutation p-value. UPGRADED: 2000/5000."""
    rng = np.random.default_rng(seed)
    x = np.asarray(x, float); y = np.asarray(y, float)
    m = np.isfinite(x) & np.isfinite(y)
    x = x[m]; y = y[m]; n = x.size
    if n < 10:
        return {"n": n, "r": np.nan, "ci_low": np.nan, "ci_high": np.nan, "p_perm": np.nan}
    r_obs = float(np.corrcoef(x, y)[0, 1])
    boot = np.full(n_boot, np.nan)
    for i in range(n_boot):
        idx = rng.integers(0, n, size=n)
        boot[i] = np.corrcoef(x[idx], y[idx])[0, 1]
    fb = boot[np.isfinite(boot)]
    ci_low, ci_high = (np.nanpercentile(fb, [2.5, 97.5]) if len(fb) > 10 else (np.nan, np.nan))
    perm_count = 0
    for _ in range(n_perm):
        yp = rng.permutation(y)
        if abs(np.corrcoef(x, yp)[0, 1]) >= abs(r_obs):
            perm_count += 1
    p = float((perm_count + 1) / (n_perm + 1))
    return {"n": n, "r": r_obs, "ci_low": float(ci_low), "ci_high": float(ci_high), "p_perm": p}

def cohens_d_with_ci(a, b, n_boot=N_BOOT, seed=SEED):
    """Cohen's d with bootstrap 95% CI. NEW."""
    rng = np.random.default_rng(seed)
    a = np.asarray(a, float); b = np.asarray(b, float)
    a = a[np.isfinite(a)]; b = b[np.isfinite(b)]
    if len(a) < 2 or len(b) < 2:
        return {"d": np.nan, "ci_lo": np.nan, "ci_hi": np.nan, "n_a": len(a), "n_b": len(b)}

    def _d(aa, bb):
        va = np.var(aa, ddof=1); vb = np.var(bb, ddof=1)
        sp2 = ((len(aa) - 1) * va + (len(bb) - 1) * vb) / max(len(aa) + len(bb) - 2, 1)
        if sp2 <= 1e-12: return np.nan
        return float((np.mean(aa) - np.mean(bb)) / np.sqrt(sp2))

    d_obs = _d(a, b)
    boots = np.full(n_boot, np.nan)
    for i in range(n_boot):
        aa = rng.choice(a, size=len(a), replace=True)
        bb = rng.choice(b, size=len(b), replace=True)
        boots[i] = _d(aa, bb)
    fb = boots[np.isfinite(boots)]
    ci_lo, ci_hi = (np.nanpercentile(fb, [2.5, 97.5]) if len(fb) > 10 else (np.nan, np.nan))
    return {"d": d_obs, "ci_lo": float(ci_lo), "ci_hi": float(ci_hi), "n_a": len(a), "n_b": len(b)}

def auc_with_ci(y01, scores, n_boot=N_BOOT, seed=SEED):
    """AUC with bootstrap CI. NEW."""
    rng = np.random.default_rng(seed)
    y = np.asarray(y01).astype(int)
    s = np.asarray(scores).astype(float)
    m = np.isfinite(s)
    y = y[m]; s = s[m]
    if len(np.unique(y)) < 2 or y.size < 6:
        return {"auc": np.nan, "ci_lo": np.nan, "ci_hi": np.nan, "n": y.size}

    def _auc(yy, ss):
        n1 = (yy == 1).sum(); n0 = (yy == 0).sum()
        if n1 == 0 or n0 == 0: return np.nan
        ranks = stats.rankdata(ss)
        return float((ranks[yy == 1].sum() - n1 * (n1 + 1) / 2) / (n1 * n0))

    auc_obs = _auc(y, s)
    boots = np.full(n_boot, np.nan)
    for i in range(n_boot):
        idx = rng.integers(0, y.size, y.size)
        boots[i] = _auc(y[idx], s[idx])
    fb = boots[np.isfinite(boots)]
    ci_lo, ci_hi = (np.nanpercentile(fb, [2.5, 97.5]) if len(fb) > 10 else (np.nan, np.nan))
    return {"auc": auc_obs, "ci_lo": float(ci_lo), "ci_hi": float(ci_hi), "n": y.size}

def build_confound_matrix(df, conf_cols):
    Xparts = []
    for c in conf_cols:
        if c not in df.columns: continue
        s = df[c]
        if pd.api.types.is_numeric_dtype(s):
            Xparts.append(pd.DataFrame({c: pd.to_numeric(s, errors="coerce")}))
        else:
            d = pd.get_dummies(s.astype(str), prefix=c, dummy_na=True)
            Xparts.append(d)
    if not Xparts: return None
    X = pd.concat(Xparts, axis=1).replace([np.inf, -np.inf], np.nan)
    keep = [c for c in X.columns if np.isfinite(X[c].values).sum() >= 3 and np.nanstd(X[c].values) > 1e-12]
    return X[keep].astype(float) if keep else None

def residualize(y, X):
    y = np.asarray(y, dtype=float)
    if X is None or X.shape[1] == 0: return y
    Xv = X.values.astype(float)
    m = np.isfinite(y) & np.all(np.isfinite(Xv), axis=1)
    if m.sum() < 5: return y * np.nan
    yy = y[m]
    XX = np.column_stack([np.ones((m.sum(), 1)), Xv[m]])
    beta, *_ = np.linalg.lstsq(XX, yy, rcond=None)
    resid = np.full_like(y, np.nan)
    resid[m] = yy - (XX @ beta)
    return resid

def pick_donor_col(df):
    candidates = [
        "donor_id", "donor", "Donor", "subject", "Subject", "individual", "Individual",
        "patient", "Patient", "participant", "Participant", "case_id", "Case",
        "participant_id", "donor_id_clean", "donorID", "DonorID", "sample_id",
        "Sample", "sample", "Donor ID", "orig.ident"
    ]
    cols = list(df.columns)
    for c in candidates:
        if c in cols: return c
    low = {c.lower(): c for c in cols}
    for key in ["donor", "subject", "patient", "individual", "participant", "case", "orig.ident"]:
        for lc, orig in low.items():
            if key in lc: return orig
    return None

def donor_aggregate(df_obs, donor_col, cols_needed):
    g = df_obs.groupby(donor_col, dropna=False)
    out = pd.DataFrame(index=g.size().index)
    out[donor_col] = out.index.astype(str)
    num_cols = [c for c in cols_needed if c in df_obs.columns and pd.api.types.is_numeric_dtype(df_obs[c])]
    if num_cols:
        out[num_cols] = g[num_cols].mean(numeric_only=True)
    cat_cols = [c for c in cols_needed if c in df_obs.columns and c not in num_cols]
    for c in cat_cols:
        def _mode(x):
            x = x.dropna()
            return x.astype(str).value_counts().index[0] if not x.empty else np.nan
        out[c] = g[c].apply(_mode)
    return out.reset_index(drop=True)

# 4) Gene symbol handling
MARKER_PROBE = ["P2RY12", "CX3CR1", "TMEM119", "CSF1R", "NR3C1", "FKBP5", "LGALS3", "AIF1", "LST1", "TYROBP"]

def _upper_map(var_names):
    return {str(v).upper(): str(v) for v in var_names}

def _overlap_symbols(adata, genes):
    m = _upper_map(adata.var_names)
    return sum(1 for g in genes if str(g).upper() in m)

def ensure_gene_symbols(adata, dataset_name):
    ov = _overlap_symbols(adata, MARKER_PROBE)
    if ov >= 3: return adata
    candidates = ["gene_symbol", "gene_symbols", "symbol", "symbols", "feature_name",
                  "name", "gene", "gene_name", "hgnc_symbol"]
    found = [c for c in candidates if c in adata.var.columns]
    if not found:
        print(f"[WARN] {dataset_name}: low marker overlap ({ov}) and no symbol columns.")
        return adata
    best_col, best_score = None, -1
    for c in found:
        s = adata.var[c].astype(str)
        non_ens = (~s.str.upper().str.startswith("ENSG")).mean()
        has_letters = s.str.contains(r"[A-Za-z]", regex=True).mean()
        score = float(non_ens + 0.5 * has_letters)
        if score > best_score:
            best_score, best_col = score, c
    new_names = adata.var[best_col].astype(str).values
    new_names = np.where((new_names == "nan") | (new_names == "") | pd.isna(new_names),
                         adata.var_names.values, new_names)
    adata.var["gene_symbol_used"] = new_names
    adata.var_names = pd.Index(new_names).astype(str)
    adata.var_names_make_unique()
    ov2 = _overlap_symbols(adata, MARKER_PROBE)
    print(f"[FIX] {dataset_name}: remapped using .var['{best_col}'] | overlap {ov} -> {ov2}")
    return adata

def _genes_present_case_insensitive(adata, genes):
    m = {g.upper(): g for g in adata.var_names.astype(str)}
    seen = set(); out = []
    for g in genes:
        gg = str(g).upper()
        if gg in m and m[gg] not in seen:
            out.append(m[gg]); seen.add(m[gg])
    return out

def score_geneset(adata, genes, score_name, min_genes=3):
    if adata.n_obs == 0 or adata.n_vars == 0:
        adata.obs[score_name] = np.nan; return
    present = _genes_present_case_insensitive(adata, genes)
    if len(present) < min_genes:
        adata.obs[score_name] = np.nan; return
    sc.tl.score_genes(adata, present, score_name=score_name, use_raw=False)

def subset_microglia_by_markers(adata, min_hits=2, extra=None):
    if adata.n_obs == 0: return adata
    markers = ["P2RY12", "CX3CR1", "TMEM119", "CSF1R"]
    if extra: markers += list(extra)
    present = _genes_present_case_insensitive(adata, markers)
    if len(present) < 2: return adata
    X = adata[:, present].X
    if sp.issparse(X): X = X.toarray()
    keep = (X > 0).sum(axis=1) >= min_hits
    return adata[keep].copy()

def add_common_cols(adata, dataset, cohort):
    adata.obs["dataset"] = dataset
    adata.obs["cohort"] = cohort
    adata.obs_names = adata.obs_names.astype(str)
    adata.var_names_make_unique()

def summarize_dataset_confounds(adata):
    compute_basic_qc(adata)
    conf = []
    for c in ["total_counts", "pct_counts_mt", "n_genes_by_counts"]:
        if c in adata.obs.columns: conf.append(c)
    for key in ["batch", "Batch", "platform", "Platform", "assay", "Assay", "region", "Region",
                "brain_region", "Brain region", "library", "Library", "Sex", "sex", "Diagnosis",
                "diagnosis", "Age at Death", "age", "PMI", "pmi", "Brain pH", "orig.ident"]:
        if key in adata.obs.columns: conf.append(key)
    return list(dict.fromkeys(conf))

# 5) LOAD MES MODULES (from NB3: Main_Table1.xlsx)
main_table1 = TAB_DIR / "Main_Table1.xlsx"
assert main_table1.exists(), f"Missing: {main_table1}"
df_weights = pd.read_excel(main_table1, sheet_name="GeneWeights")
mes_cols = [c for c in df_weights.columns if str(c).startswith("MES")]
assert len(mes_cols) == 8

TOP_N_MES = 50
mes_gene_sets = {}
for mes in mes_cols:
    sub = df_weights[["gene", mes]].dropna().sort_values(mes, ascending=False).head(TOP_N_MES)
    mes_gene_sets[mes] = sub["gene"].astype(str).tolist()

print("=" * 70)
print("NB4: UPGRADE — LOADING MES MODULES")
print("=" * 70)
print(f"Loaded {len(mes_cols)} MES modules with top {TOP_N_MES} genes each.")

MES_MODULES = mes_gene_sets
globals()["MES_MODULES"] = MES_MODULES
tab = df_weights.copy()
globals()["tab"] = tab

# 6) HK SENSITIVITY
def is_housekeeping(g):
    g = str(g).upper()
    if g in {"MALAT1"}: return True
    if g.startswith("MT-"): return True
    if g.startswith("RPL") or g.startswith("RPS"): return True
    return False

hk_report = []
mes_gene_sets_hk_stripped = {}
for mes, genes in mes_gene_sets.items():
    hk = [g for g in genes if is_housekeeping(g)]
    mes_gene_sets_hk_stripped[mes] = [g for g in genes if not is_housekeeping(g)]
    hk_report.append({"MES": mes, "topN": len(genes), "hk_n": len(hk),
                       "hk_frac": len(hk) / max(len(genes), 1),
                       "hk_examples": ", ".join(hk[:10])})
df_hk = pd.DataFrame(hk_report).sort_values("hk_frac", ascending=False)

fig, ax = plt.subplots(figsize=(6.0, 3.0))
colors = ["#D6604D" if f > 0.1 else "#FDAE6B" if f > 0.05 else "#66C2A5" for f in df_hk["hk_frac"]]
ax.bar(df_hk["MES"], df_hk["hk_frac"], color=colors, edgecolor="black", linewidth=0.3)
ax.axhline(0.1, color="grey", linewidth=0.5, linestyle=":")
ax.set_ylabel("Housekeeping fraction (top 50)")
ax.set_title("Supplementary Fig 2F  MES housekeeping load check")
_beautify_axes(ax)
save_fig(fig, "Fig2F_HK_Load", kind="Supplementary")

# 7) STRESS-AXIS CALIBRATION (GSE219208)
SIG = {
    "microglia_homeostatic": ["P2RY12", "CX3CR1", "TMEM119", "GPR34", "SALL1", "CSF1R", "OLFM3"],
    "microglia_activation": ["APOE", "SPP1", "LPL", "TREM2", "CST7", "CTSD", "TYROBP", "FCER1G", "LGALS3", "CD68"],
    "oxphos": ["NDUFA1", "NDUFB8", "NDUFS1", "COX4I1", "COX5A", "ATP5F1A", "ATP5F1B", "UQCRC1", "UQCRC2"],
    "glycolysis": ["HK1", "HK2", "PFKM", "ALDOA", "GAPDH", "ENO1", "PKM", "LDHA", "SLC2A1"],
}
GR_CORE_CURATED = ["NR3C1", "FKBP5", "TSC22D3", "DDIT4", "KLF9"]
EXPR_G = ["LGALS3", "FNDC5", "NR3C1"]

stress_hits = list(RAW_DIR.rglob("GSE219208_Non-Normalized_read_counts_combined_lanes_.csv"))
assert len(stress_hits) >= 1, "Missing GSE219208 counts CSV"
stress_csv = stress_hits[0]
df_stress_raw = pd.read_csv(stress_csv, index_col=0)

def parse_gse219208_prefix(col):
    base = str(col).split("_")[0].lower()
    washout = "wo" in base
    if base.startswith("ctl"):
        return {"sample": col, "prefix": base, "group": "Control", "drug": "ctl", "dose": None, "washout": washout}
    if base.startswith("cort"):
        dose = "high" if base.startswith("corth") else ("low" if base.startswith("cortl") else None)
        return {"sample": col, "prefix": base, "group": "Washout" if washout else "Treated", "drug": "cort", "dose": dose, "washout": washout}
    if base.startswith("dex"):
        dose = "high" if base.startswith("dexh") else ("low" if base.startswith("dexl") else None)
        return {"sample": col, "prefix": base, "group": "Washout" if washout else "Treated", "drug": "dex", "dose": dose, "washout": washout}
    if base.startswith("cv") or base.startswith("dv"):
        return {"sample": col, "prefix": base, "group": "Control", "drug": base[:2], "dose": None, "washout": washout}
    return {"sample": col, "prefix": base, "group": "Unknown", "drug": "unknown", "dose": None, "washout": washout}

df_meta_stress = pd.DataFrame([parse_gse219208_prefix(c) for c in df_stress_raw.columns])
print("GSE219208 label breakdown:", df_meta_stress["group"].value_counts().to_dict())

ctrl_cols    = df_meta_stress[df_meta_stress["group"] == "Control"]["sample"].tolist()
treated_cols = df_meta_stress[df_meta_stress["group"] == "Treated"]["sample"].tolist()
washout_cols = df_meta_stress[df_meta_stress["group"] == "Washout"]["sample"].tolist()
assert len(ctrl_cols) >= 2 and len(treated_cols) >= 2

# Drug-type breakdown
dex_treated = df_meta_stress[(df_meta_stress["group"] == "Treated") & (df_meta_stress["drug"] == "dex")]["sample"].tolist()
cort_treated = df_meta_stress[(df_meta_stress["group"] == "Treated") & (df_meta_stress["drug"] == "cort")]["sample"].tolist()

gene_sums = df_stress_raw.sum(axis=1)
df_filt = df_stress_raw.loc[gene_sums >= 10].copy()
df_cpm = df_filt.div(df_filt.sum(axis=0), axis=1) * 1e6
df_log = np.log2(df_cpm + 1.0)

de_rows = []
for gene in df_log.index:
    a = df_log.loc[gene, treated_cols].values.astype(float)
    b = df_log.loc[gene, ctrl_cols].values.astype(float)
    a = a[np.isfinite(a)]; b = b[np.isfinite(b)]
    if len(a) < 2 or len(b) < 2: continue
    log2fc = float(np.mean(a) - np.mean(b))
    try:
        _, p = stats.ttest_ind(a, b, equal_var=False)
        p = float(p)
    except:
        p = 1.0
    de_rows.append({"gene": gene, "log2FC": log2fc, "pval": p})

df_de = pd.DataFrame(de_rows).sort_values("pval")
df_de["padj"] = bh_fdr(df_de["pval"].values)

LOG2FC_T = 0.5
PADJ_T = 0.05 
sig_up = df_de[(df_de["log2FC"] > LOG2FC_T) & (df_de["padj"] < PADJ_T)].sort_values("log2FC", ascending=False)
sig_dn = df_de[(df_de["log2FC"] < -LOG2FC_T) & (df_de["padj"] < PADJ_T)].sort_values("log2FC", ascending=True)
if len(sig_up) < 10:
    sig_up = df_de[df_de["log2FC"] > 0.3].sort_values("log2FC", ascending=False)
if len(sig_dn) < 10:
    sig_dn = df_de[df_de["log2FC"] < -0.3].sort_values("log2FC", ascending=True)

TOP_N_SIG = 50
GR_UP = sig_up["gene"].head(TOP_N_SIG).astype(str).tolist()
GR_DN = sig_dn["gene"].head(TOP_N_SIG).astype(str).tolist()
if len(GR_UP) < 5:
    GR_UP = GR_CORE_CURATED.copy(); GR_DN = []

def sample_score_from_log(df_log_gxS, up_genes, dn_genes):
    up = [g for g in up_genes if g in df_log_gxS.index]
    dn = [g for g in dn_genes if g in df_log_gxS.index]
    if len(up) < 3: return pd.Series(index=df_log_gxS.columns, data=np.nan)
    s_up = df_log_gxS.loc[up].mean(axis=0)
    s_dn = df_log_gxS.loc[dn].mean(axis=0) if len(dn) >= 3 else 0.0
    return (s_up - s_dn).astype(float)

stress_score = sample_score_from_log(df_log, GR_UP, GR_DN)
df_meta_stress["GR_score"] = df_meta_stress["sample"].map(stress_score.to_dict())
df184 = df_meta_stress
globals()["df184"] = df184

# AUC + Cohen's d with CIs
y_auc = np.array([0] * len(ctrl_cols) + [1] * len(treated_cols))
s_auc = np.array(list(stress_score[ctrl_cols].values) + list(stress_score[treated_cols].values))
auc_res = auc_with_ci(y_auc, s_auc, n_boot=N_BOOT, seed=SEED)
auc_ct = auc_res["auc"]

d_res = cohens_d_with_ci(stress_score[treated_cols].values, stress_score[ctrl_cols].values,
                         n_boot=N_BOOT, seed=SEED)
d_treat_vs_ctrl = d_res["d"]

# Drug-type specific AUCs
drug_breakdown = []
for drug_name, drug_cols in [("dex", dex_treated), ("cort", cort_treated)]:
    if len(drug_cols) >= 2:
        ya = np.array([0] * len(ctrl_cols) + [1] * len(drug_cols))
        sa = np.array(list(stress_score[ctrl_cols].values) + list(stress_score[drug_cols].values))
        ar = auc_with_ci(ya, sa, n_boot=min(N_BOOT, 500), seed=SEED + 10)
        dr = cohens_d_with_ci(stress_score[drug_cols].values, stress_score[ctrl_cols].values,
                              n_boot=min(N_BOOT, 500), seed=SEED + 10)
        drug_breakdown.append({"drug": drug_name, "n_treated": len(drug_cols), "n_ctrl": len(ctrl_cols),
                               "auc": ar["auc"], "auc_ci_lo": ar["ci_lo"], "auc_ci_hi": ar["ci_hi"],
                               "d": dr["d"], "d_ci_lo": dr["ci_lo"], "d_ci_hi": dr["ci_hi"]})
df_drug_bd = pd.DataFrame(drug_breakdown)

# GR signature sensitivity: top 30/50/100
gr_sens_rows = []
for topN in [30, 50, 100]:
    up_n = sig_up["gene"].head(topN).astype(str).tolist()
    dn_n = sig_dn["gene"].head(topN).astype(str).tolist()
    if len(up_n) < 5: continue
    ss = sample_score_from_log(df_log, up_n, dn_n)
    ya = np.array([0] * len(ctrl_cols) + [1] * len(treated_cols))
    sa = np.array(list(ss[ctrl_cols].values) + list(ss[treated_cols].values))
    ar = auc_with_ci(ya, sa, n_boot=min(N_BOOT, 500), seed=SEED + topN)
    gr_sens_rows.append({"topN": topN, "n_up": len(up_n), "n_dn": len(dn_n),
                         "auc": ar["auc"], "auc_ci_lo": ar["ci_lo"], "auc_ci_hi": ar["ci_hi"]})
df_gr_sens = pd.DataFrame(gr_sens_rows)

print(f"\n=== GR SIGNATURE SUMMARY ===")
print(f"UP={len(GR_UP)} DN={len(GR_DN)} | AUC={auc_ct:.3f} [{auc_res['ci_lo']:.3f},{auc_res['ci_hi']:.3f}] | d={d_treat_vs_ctrl:.2f} [{d_res['ci_lo']:.2f},{d_res['ci_hi']:.2f}]")

# Main Fig 4A: volcano with key gene labels
df_plot = df_de.copy()
df_plot["-log10(padj)"] = -np.log10(df_plot["padj"].clip(lower=1e-50))
fig, ax = plt.subplots(figsize=(6.0, 5.0))
colors = np.array([C_NEU] * len(df_plot), dtype=object)
colors[(df_plot["padj"] < 0.05) & (df_plot["log2FC"] > 0.5)] = C_UP
colors[(df_plot["padj"] < 0.05) & (df_plot["log2FC"] < -0.5)] = C_DN
ax.scatter(df_plot["log2FC"], df_plot["-log10(padj)"], c=colors, s=8, alpha=0.65, linewidths=0)
ax.axhline(-np.log10(0.05), color="#333333", linestyle="--", linewidth=0.8, alpha=0.6)
ax.axvline(0.5, color="#333333", linestyle="--", linewidth=0.8, alpha=0.6)
ax.axvline(-0.5, color="#333333", linestyle="--", linewidth=0.8, alpha=0.6)

# Label key GR genes (NEW)
label_genes = set(GR_CORE_CURATED) | {"FKBP5", "NR3C1", "TSC22D3", "DDIT4", "KLF9", "PER1", "GILZ", "SGK1"}
for _, row in df_plot.iterrows():
    if str(row["gene"]).upper() in {g.upper() for g in label_genes}:
        if np.isfinite(row["-log10(padj)"]) and row["-log10(padj)"] > 1:
            ax.annotate(row["gene"], (row["log2FC"], row["-log10(padj)"]),
                       fontsize=6, fontweight="bold", alpha=0.9,
                       xytext=(4, 4), textcoords="offset points")
# Also label top 5 by significance
for _, row in df_plot.nlargest(5, "-log10(padj)").iterrows():
    if str(row["gene"]).upper() not in {g.upper() for g in label_genes}:
        ax.annotate(row["gene"], (row["log2FC"], row["-log10(padj)"]),
                   fontsize=5, alpha=0.7, xytext=(3, 3), textcoords="offset points")

ax.set_xlabel("log₂ Fold Change (Treated vs Control)")
ax.set_ylabel("-log₁₀(padj)")
ax.set_title(f"Main Fig 4A  GSE219208 perturbation DE (padj<0.05)")
_beautify_axes(ax)
save_fig(fig, "Fig4A", kind="Main")

# Main Fig 4B heatmap (z-scored) 
hm_genes = (GR_UP[:15] + GR_DN[:15]) if len(GR_DN) else GR_UP[:25]
hm_genes = [g for g in hm_genes if g in df_log.index]
if len(hm_genes) >= 6:
    df_hm = df_log.loc[hm_genes, ctrl_cols + treated_cols].copy()
    df_hm = df_hm.sub(df_hm.mean(axis=1), axis=0).div(df_hm.std(axis=1) + 1e-8, axis=0)
    fig, ax = plt.subplots(figsize=(8.0, 6.0))
    if HAS_SNS:
        sns.heatmap(df_hm, ax=ax, cmap=CMAP_HEAT, vmin=-2, vmax=2, linewidths=0.3,
                    cbar_kws={"label": "Z-score (log₂CPM)", "shrink": 0.8},
                    xticklabels=True, yticklabels=True)
        ax.set_xticklabels(ax.get_xticklabels(), rotation=90, fontsize=5)
        ax.set_yticklabels(ax.get_yticklabels(), fontsize=6)
    else:
        im = ax.imshow(df_hm.values, aspect="auto", cmap=CMAP_HEAT, vmin=-2, vmax=2)
        ax.set_yticks(range(len(df_hm.index))); ax.set_yticklabels(df_hm.index, fontsize=6)
        ax.set_xticks(range(len(df_hm.columns))); ax.set_xticklabels(df_hm.columns, rotation=90, fontsize=5)
        plt.colorbar(im, ax=ax, fraction=0.03, pad=0.02, label="Z-score")
    ax.set_title("Main Fig 4B  Perturbation-derived GR signature")
    save_fig(fig, "Fig4B", kind="Main")

# Main Fig 4C boxplot with significance annotations
fig, ax = plt.subplots(figsize=(6.4, 3.5))
order = ["Control", "Treated", "Washout"]
data = [df_meta_stress.loc[df_meta_stress["group"] == g, "GR_score"].dropna().values for g in order]
bp = ax.boxplot(data, labels=order, showfliers=False, patch_artist=True,
                boxprops=dict(edgecolor="#333333", linewidth=0.9),
                medianprops=dict(color="#111111", linewidth=1.0),
                whiskerprops=dict(color="#333333", linewidth=0.9),
                capprops=dict(color="#333333", linewidth=0.9))
fill_cols = [C_BAR_MAIN, C_BAR_ACC, C_BAR_ACC2]
for patch, col in zip(bp["boxes"], fill_cols):
    patch.set_facecolor(col); patch.set_alpha(0.35)
# Significance annotations (NEW)
if len(data[0]) >= 2 and len(data[1]) >= 2:
    _, p_ct = stats.mannwhitneyu(data[0], data[1], alternative="two-sided")
    s = sig_stars(p_ct)
    y_max = max(np.max(data[0]), np.max(data[1]))
    ax.plot([1, 2], [y_max + 0.05, y_max + 0.05], color="black", linewidth=0.8)
    ax.text(1.5, y_max + 0.06, f"{s} p={p_ct:.2e}", ha="center", fontsize=6)
if len(data[1]) >= 2 and len(data[2]) >= 2:
    _, p_tw = stats.mannwhitneyu(data[1], data[2], alternative="two-sided")
    s = sig_stars(p_tw)
    y_max2 = max(np.max(data[1]), np.max(data[2])) + 0.15
    ax.plot([2, 3], [y_max2, y_max2], color="black", linewidth=0.8)
    ax.text(2.5, y_max2 + 0.01, f"{s} p={p_tw:.2e}", ha="center", fontsize=6)

ax.set_ylabel("GR calibration score")
ax.set_title(f"Main Fig 4C  AUC={auc_ct:.2f} [{auc_res['ci_lo']:.2f},{auc_res['ci_hi']:.2f}]  d={d_treat_vs_ctrl:.2f} [{d_res['ci_lo']:.2f},{d_res['ci_hi']:.2f}]")
_beautify_axes(ax)
save_fig(fig, "Fig4C", kind="Main")

save_excel({
    "Calibration_summary": pd.DataFrame([{
        "stress_file": str(stress_csv), "n_control": len(ctrl_cols),
        "n_treated": len(treated_cols), "n_washout": len(washout_cols),
        "auc": auc_ct, "auc_ci_lo": auc_res["ci_lo"], "auc_ci_hi": auc_res["ci_hi"],
        "cohens_d": d_treat_vs_ctrl, "d_ci_lo": d_res["ci_lo"], "d_ci_hi": d_res["ci_hi"],
        "log2fc_threshold": LOG2FC_T, "padj_threshold": PADJ_T,
    }]),
    "Drug_breakdown": df_drug_bd,
    "GR_size_sensitivity": df_gr_sens,
    "Sample_metadata": df_meta_stress,
    "GR_UP_signature": pd.DataFrame({"rank": range(1, len(GR_UP) + 1), "gene": GR_UP}),
    "GR_DOWN_signature": pd.DataFrame({"rank": range(1, len(GR_DN) + 1), "gene": GR_DN}) if len(GR_DN) else pd.DataFrame(),
    "DE_top500": df_de.sort_values("padj").head(500)
}, fname="Table3_StressCalibration", kind="Supplementary")

# 8) LOAD MICROGLIA COHORTS
micro_adatas = []

# SEA-AD
seaad_path = RAW_DIR / "Brain" / "SEA-AD (AllenBrainMap)" / "SEA-AD_Microglia-and-Immune_multi-regional_final-nuclei_AAIC-pre-release.2025-07-24.h5ad"
assert seaad_path.exists()
adata_seaad = ensure_gene_symbols(sc.read_h5ad(seaad_path), "SEA-AD")
add_common_cols(adata_seaad, "SEA-AD", "microglia")
prep_for_scoring(adata_seaad)
ct_cols_s = [c for c in adata_seaad.obs.columns if ("cell" in c.lower() and "type" in c.lower())]
if ct_cols_s:
    vals = adata_seaad.obs[ct_cols_s[0]].astype(str).str.lower()
    keep = vals.str.contains("microglia|myeloid")
    adata_seaad = adata_seaad[keep].copy() if keep.any() else subset_microglia_by_markers(adata_seaad, 2, ["AIF1", "LST1", "TYROBP"])
else:
    adata_seaad = subset_microglia_by_markers(adata_seaad, 2, ["AIF1", "LST1", "TYROBP"])
micro_adatas.append(adata_seaad)
print(f"SEA-AD: {adata_seaad.n_obs:,} cells")

# Olah
olah_path = RAW_DIR / "Brain" / "Olah et al. live human microglia (CELLxGENE Census artifact; AnnData).h5ad"
assert olah_path.exists()
adata_olah = ensure_gene_symbols(sc.read_h5ad(olah_path), "Olah")
add_common_cols(adata_olah, "Olah", "microglia")
prep_for_scoring(adata_olah)
adata_olah = subset_microglia_by_markers(adata_olah, 2, ["AIF1", "LST1", "TYROBP"])
micro_adatas.append(adata_olah)
print(f"Olah: {adata_olah.n_obs:,} cells")

# Tuddenham
tud_dir = RAW_DIR / "Cross-disease living human microglia (Tuddenham  De Jager lab)"
assert tud_dir.exists()
h5s = sorted(tud_dir.rglob("*.h5"))[:20]
tud_adatas = []
for h5 in h5s:
    try:
        a = sc.read_10x_h5(h5)
        a = ensure_gene_symbols(a, "Tuddenham_GSE204702")
        add_common_cols(a, "Tuddenham_GSE204702", "microglia")
        a.obs["sample_id"] = h5.stem
        prep_for_scoring(a)
        a = subset_microglia_by_markers(a, 1, ["AIF1", "LST1", "TYROBP"])
        if a.n_obs > 0: tud_adatas.append(a)
    except Exception as e:
        print(f"[WARN] Tuddenham: {h5.name} ({type(e).__name__})")
if tud_adatas:
    adata_tud = ad.concat(tud_adatas, join="outer", merge="same")
    micro_adatas.append(adata_tud)
    print(f"Tuddenham: {adata_tud.n_obs:,} cells")

# MS (GSE180759)
ms_root = RAW_DIR / "MS lesion snRNA" / "Absinta et al. progressive MS lesion edge snRNA (GEO GSE180759)"
ms_counts = ms_root / "GSE180759_expression_matrix.csv"
ms_anno = ms_root / "GSE180759_annotation.txt"
assert ms_counts.exists() and ms_anno.exists()

df_ms_anno = pd.read_csv(ms_anno, sep="\t")
def guess_col(cols, contains_any):
    for c in cols:
        if any(k in c.lower() for k in contains_any): return c
    return None
bc_col = guess_col(df_ms_anno.columns, ["nucleus_barcode", "barcode", "cell_id", "cellid", "cell"]) or df_ms_anno.columns[0]
df_ms_anno[bc_col] = df_ms_anno[bc_col].astype(str)
ct_col = None
for key in ["cell_type", "celltype", "annotation", "class", "cell_class"]:
    if key in df_ms_anno.columns: ct_col = key; break
if ct_col is None:
    low = {c.lower(): c for c in df_ms_anno.columns}
    for key in ["cell_type", "celltype", "annotation", "class", "cell_class"]:
        if key in low: ct_col = low[key]; break

with ms_counts.open("r", encoding="utf-8", errors="replace") as f:
    header_line = f.readline().strip()
hdr = [h.strip() for h in header_line.split(",")]
def looks_like_barcode(x): return "-" in str(x) and len(str(x)) >= 10
header_barcodes = hdr[:] if looks_like_barcode(hdr[0]) else hdr[1:]
header_set = set(header_barcodes)

if ct_col is not None:
    ct = df_ms_anno[ct_col].astype(str).str.lower()
    keep_bcs_anno = df_ms_anno.loc[ct.str.contains("immune|lymphocytes", regex=True, na=False), bc_col].tolist()
    if not keep_bcs_anno: keep_bcs_anno = df_ms_anno[bc_col].astype(str).tolist()
else:
    keep_bcs_anno = df_ms_anno[bc_col].astype(str).tolist()
keep_bcs_anno = list(dict.fromkeys([b for b in keep_bcs_anno if pd.notna(b) and str(b).strip()]))
keep_bcs_mapped = [b for b in keep_bcs_anno if b in header_set]

MS_MARKERS = ["P2RY12", "CX3CR1", "TMEM119", "CSF1R", "AIF1", "TYROBP", "LST1", "SPI1", "FCER1G"]
REQ = set(GR_CORE_CURATED) | set(GR_UP) | set(GR_DN)
for v in SIG.values(): REQ |= set(v)
for mes in mes_cols: REQ |= set(mes_gene_sets[mes])
REQ |= set(EXPR_G) | set(MS_MARKERS)
REQ_UP = {g.upper() for g in REQ}

def stream_ms_required_to_sparse(ms_counts_path, header_barcodes, selected_barcodes, required_genes_upper, chunksize=2000):
    pos = {b: i for i, b in enumerate(header_barcodes)}
    sel = [b for b in selected_barcodes if b in pos]
    if len(sel) < 200: return None
    usecols = [0] + [1 + pos[b] for b in sel]
    data = []; rows = []; cols = []; kept_genes = []; row_idx = 0
    for chunk in pd.read_csv(ms_counts_path, sep=",", header=None, skiprows=1, usecols=usecols,
                              chunksize=chunksize, dtype={0: str}, keep_default_na=False):
        genes = chunk.iloc[:, 0].astype(str).str.strip().values
        X = chunk.iloc[:, 1:].apply(pd.to_numeric, errors="coerce").fillna(0.0).values.astype(np.float32)
        sums = X.sum(axis=1)
        keep = (sums > 0) | np.isin(np.char.upper(genes.astype(str)), list(required_genes_upper))
        if not np.any(keep): continue
        gk = genes[keep]; Xk = X[keep]
        nzr, nzc = np.nonzero(Xk)
        if nzr.size:
            data.append(Xk[nzr, nzc]); rows.append(nzr + row_idx); cols.append(nzc)
        kept_genes.extend(gk.tolist()); row_idx += Xk.shape[0]
    if row_idx == 0: return None
    data = np.concatenate(data) if data else np.array([], dtype=np.float32)
    rows = np.concatenate(rows) if rows else np.array([], dtype=np.int64)
    cols = np.concatenate(cols) if cols else np.array([], dtype=np.int64)
    return kept_genes, sel, sp.coo_matrix((data, (rows, cols)), shape=(row_idx, len(sel))).tocsr()

selected_bcs = keep_bcs_mapped if len(keep_bcs_mapped) >= 200 else header_barcodes[:20000]
res = stream_ms_required_to_sparse(ms_counts, header_barcodes, selected_bcs, REQ_UP, 2000)
assert res is not None
genes_kept, sel_bcs, X_gxC = res

adata_ms = ad.AnnData(X=X_gxC.T.tocsr(), obs=pd.DataFrame(index=pd.Index(sel_bcs).astype(str)),
                       var=pd.DataFrame(index=pd.Index(genes_kept).astype(str)))
adata_ms.var_names_make_unique()
adata_ms = ensure_gene_symbols(adata_ms, "MS_GSE180759")
add_common_cols(adata_ms, "MS_GSE180759", "microglia")
prep_for_scoring(adata_ms)
adata_ms = subset_microglia_by_markers(adata_ms, 1, ["AIF1", "LST1", "TYROBP", "FCER1G"])
assert adata_ms.n_obs > 500
micro_adatas.append(adata_ms)
print(f"MS: {adata_ms.n_obs:,} cells")

# 9) SCORE COHORTS
MAX_CELL_STATS = 50_000

all_results = []
qc_rows = []
expr_rows = []
stress_interaction_rows = []
hk_sens_rows = []
lgals3_rows = []
fndc5_rows = []

for a in micro_adatas:
    ds = str(a.obs["dataset"].iloc[0])
    print(f"\n--- Scoring: {ds} ---")
    prep_for_scoring(a); compute_basic_qc(a)

    score_geneset(a, SIG["microglia_homeostatic"], "score_homeostatic")
    score_geneset(a, SIG["microglia_activation"], "score_activation")
    score_geneset(a, SIG["oxphos"], "score_oxphos")
    score_geneset(a, SIG["glycolysis"], "score_glycolysis")
    a.obs["tolerance_positioning"] = a.obs["score_homeostatic"] - a.obs["score_activation"]

    for mes, genes in mes_gene_sets.items():
        score_geneset(a, genes, f"{mes}_score")
    for mes, genes in mes_gene_sets_hk_stripped.items():
        score_geneset(a, genes, f"{mes}_score_hkstrip", min_genes=3)

    score_geneset(a, GR_UP, "GR_UP", min_genes=3)
    score_geneset(a, GR_DN, "GR_DN", min_genes=3)
    upv = a.obs["GR_UP"].values; dnv = a.obs["GR_DN"].values
    if np.isfinite(upv).sum() > 100 and np.isfinite(dnv).sum() > 100:
        a.obs["GR_composite"] = upv - dnv
    elif np.isfinite(upv).sum() > 100:
        a.obs["GR_composite"] = upv
    else:
        score_geneset(a, GR_CORE_CURATED, "GR_composite", min_genes=2)

    v = a.obs["GR_composite"].values; vfin = v[np.isfinite(v)]
    if vfin.size >= 200:
        a.obs["stress_high"] = (a.obs["GR_composite"] >= float(np.median(vfin))).astype(int)
    else:
        a.obs["stress_high"] = np.nan

    for g in EXPR_G:
        present = _genes_present_case_insensitive(a, [g])
        if present:
            x = a[:, present[0]].X
            if sp.issparse(x): x = x.toarray()
            a.obs[f"expr_{g}"] = np.asarray(x).reshape(-1)
        else:
            a.obs[f"expr_{g}"] = np.nan

    conf_cols = summarize_dataset_confounds(a)
    donor_col = pick_donor_col(a.obs)

    qc_rows.append({"dataset": ds, "n_cells": int(a.n_obs), "n_genes": int(a.n_vars),
                     "donor_col": donor_col or "(none)",
                     "confounds": ", ".join(conf_cols) if conf_cols else "(none)"})

    df_obs = a.obs.copy()
    needed_cols = list(dict.fromkeys(
        conf_cols + ["tolerance_positioning", "GR_composite", "stress_high"] +
        [f"{m}_score" for m in mes_cols] + [f"{m}_score_hkstrip" for m in mes_cols] +
        [f"expr_{g}" for g in EXPR_G]))

    if donor_col and donor_col in df_obs.columns:
        df_unit = donor_aggregate(df_obs, donor_col, needed_cols)
        level = "donor"; sample_note = "donor"
    else:
        df_unit = df_obs[needed_cols].copy()
        level = "cell"; sample_note = "cell-full"
        if df_unit.shape[0] > MAX_CELL_STATS:
            idx = np.random.default_rng(SEED).choice(df_unit.index.values, size=MAX_CELL_STATS, replace=False)
            df_unit = df_unit.loc[idx].copy()
            sample_note = f"cell-sub={MAX_CELL_STATS}"

    X_conf = build_confound_matrix(df_unit, conf_cols)
    y = df_unit["tolerance_positioning"].values
    y_adj = residualize(y, X_conf) if X_conf is not None else y

    for mes in mes_cols:
        x = df_unit.get(f"{mes}_score", pd.Series(np.nan, index=df_unit.index)).values
        x_adj = residualize(x, X_conf) if X_conf is not None else x
        s_adj = corr_with_ci_p(x_adj, y_adj, n_boot=N_BOOT, n_perm=N_PERM, seed=SEED + 7)

        x2 = df_unit.get(f"{mes}_score_hkstrip", pd.Series(np.nan, index=df_unit.index)).values
        x2_adj = residualize(x2, X_conf) if X_conf is not None else x2
        s2_adj = corr_with_ci_p(x2_adj, y_adj, n_boot=min(N_BOOT, 500), n_perm=min(N_PERM, 500), seed=SEED + 17)

        all_results.append({
            "dataset": ds, "level": level, "N": int(s_adj["n"]), "sample_note": sample_note,
            "MES": mes,
            "r_adj": s_adj["r"], "ci_adj_low": s_adj["ci_low"], "ci_adj_high": s_adj["ci_high"],
            "p_adj_perm": s_adj["p_perm"],
            "confounds": ", ".join(conf_cols) if conf_cols else "(none)"})
        hk_sens_rows.append({
            "dataset": ds, "level": level, "N": int(s_adj["n"]), "MES": mes,
            "r_adj_original": s_adj["r"], "r_adj_hkstrip": s2_adj["r"],
            "delta": (s2_adj["r"] - s_adj["r"]) if (np.isfinite(s2_adj["r"]) and np.isfinite(s_adj["r"])) else np.nan})

    for gname, rows_store in [("LGALS3", lgals3_rows), ("FNDC5", fndc5_rows)]:
        if f"expr_{gname}" in df_unit.columns:
            x = df_unit[f"expr_{gname}"].values
            x_adj = residualize(x, X_conf) if X_conf is not None else x
            s = corr_with_ci_p(x_adj, y_adj, n_boot=N_BOOT, n_perm=N_PERM, seed=SEED + 123)
            rows_store.append({"dataset": ds, "level": level, "N": int(s["n"]),
                               "r_adj": s["r"], "ci_low": s["ci_low"], "ci_high": s["ci_high"],
                               "p_perm": s["p_perm"]})

    expr_rows.append({
        "dataset": ds, "level": level, "N": int(df_unit.shape[0]),
        "corr_LGALS3_vs_tolerance": safe_corr(df_unit.get("expr_LGALS3", np.nan), y),
        "corr_FNDC5_vs_tolerance": safe_corr(df_unit.get("expr_FNDC5", np.nan), y),
        "sample_note": sample_note})

    if "stress_high" in df_unit.columns and np.isfinite(df_unit["stress_high"].values).sum() >= 30:
        suball = df_unit[np.isfinite(df_unit["stress_high"].values)].copy()

        def mean_mes_corr(df_sub):
            yy = df_sub["tolerance_positioning"].values
            Xc = build_confound_matrix(df_sub, conf_cols)
            yy_adj = residualize(yy, Xc) if Xc is not None else yy
            rs = []
            for mes in mes_cols:
                xx = df_sub.get(f"{mes}_score", pd.Series(np.nan, index=df_sub.index)).values
                xx_adj = residualize(xx, Xc) if Xc is not None else xx
                rs.append(safe_corr(xx_adj, yy_adj))
            return float(np.nanmean(rs))

        low = suball[suball["stress_high"] == 0]
        high = suball[suball["stress_high"] == 1]
        obs_low = mean_mes_corr(low) if low.shape[0] >= 10 else np.nan
        obs_high = mean_mes_corr(high) if high.shape[0] >= 10 else np.nan
        obs_diff = (obs_high - obs_low) if (np.isfinite(obs_high) and np.isfinite(obs_low)) else np.nan

        p_diff = np.nan
        if low.shape[0] >= 10 and high.shape[0] >= 10:
            rng = np.random.default_rng(SEED + 99)
            labels = suball["stress_high"].values.astype(int)
            diffs = []
            for _ in range(N_PERM):
                lp = rng.permutation(labels)
                tmp = suball.copy(); tmp["stress_high_perm"] = lp
                lo = tmp[tmp["stress_high_perm"] == 0]; hi = tmp[tmp["stress_high_perm"] == 1]
                if lo.shape[0] < 10 or hi.shape[0] < 10: continue
                diffs.append(mean_mes_corr(hi) - mean_mes_corr(lo))
            diffs = np.asarray(diffs, float)
            if diffs.size:
                p_diff = float((np.sum(np.abs(diffs) >= abs(obs_diff)) + 1) / (diffs.size + 1))

        stress_interaction_rows.append({
            "dataset": ds, "level": level, "N": int(suball.shape[0]),
            "mean_r_low_stress": obs_low, "mean_r_high_stress": obs_high,
            "diff_high_minus_low": obs_diff, "p_perm": p_diff,
            "sample_note": sample_note})

df_res = pd.DataFrame(all_results)
df_qc = pd.DataFrame(qc_rows)
df_expr = pd.DataFrame(expr_rows)
df_hk_sens = pd.DataFrame(hk_sens_rows)
df_stress = pd.DataFrame(stress_interaction_rows)

# Global BH-FDR across ALL correlation tests
if len(df_res) and "p_adj_perm" in df_res.columns:
    df_res["q_global_BH"] = bh_fdr(df_res["p_adj_perm"].values)
    df_res["q_dataset_BH"] = df_res.groupby("dataset")["p_adj_perm"].transform(bh_fdr)

save_excel({
    "MES_vs_Tolerance_main": df_res.sort_values(["dataset", "MES"]),
    "Dataset_QC": df_qc,
    "Expression_checks": df_expr
}, fname="Table2", kind="Main")

sig_def = [{"signature": k, "genes": ", ".join(v)} for k, v in SIG.items()]
sig_def += [
    {"signature": "GR_UP derived", "genes": ", ".join(GR_UP[:25]) + ("..." if len(GR_UP) > 25 else "")},
    {"signature": "GR_DOWN derived", "genes": ", ".join(GR_DN[:25]) + ("..." if len(GR_DN) > 25 else "")},
    {"signature": "GR_CORE curated", "genes": ", ".join(GR_CORE_CURATED)}]
sig_def += [{"signature": mes, "genes": ", ".join(mes_gene_sets[mes][:25]) + "..."} for mes in mes_cols]

save_excel({
    "Signature_Definitions": pd.DataFrame(sig_def),
    "MES_vs_Tolerance_full": df_res.sort_values(["dataset", "level", "MES"])
}, fname="Table4", kind="Supplementary")

save_excel({"Stress_interaction": df_stress}, fname="Table6", kind="Supplementary")
save_excel({"HK_load_by_module": df_hk, "HK_strip_sensitivity": df_hk_sens}, fname="Table9_HK_Sensitivity", kind="Supplementary")

# FIGURES 2A-2D 
MAX_EMB_TOTAL = 120000
rng = np.random.default_rng(SEED)
micro_for_emb = []
for a in micro_adatas:
    take = min(a.n_obs, max(8000, int(MAX_EMB_TOTAL / max(len(micro_adatas), 1))))
    idx = rng.choice(a.n_obs, size=take, replace=False) if a.n_obs > take else np.arange(a.n_obs)
    micro_for_emb.append(a[idx].copy())
micro = ad.concat(micro_for_emb, join="outer", merge="same")
micro = sanitize_for_write(micro)
ensure_counts_layer(micro)
micro.X = micro.layers["counts"].copy()
if not looks_log1p(micro):
    sc.pp.normalize_total(micro, target_sum=1e4); sc.pp.log1p(micro)
sc.pp.highly_variable_genes(micro, n_top_genes=2000, flavor="seurat_v3")
micro2 = micro[:, micro.var_names[micro.var["highly_variable"]]].copy()
sc.tl.pca(micro2, n_comps=50, svd_solver="arpack")
sc.pp.neighbors(micro2, n_neighbors=15, n_pcs=30)
sc.tl.umap(micro2)

# Main Fig 2A
fig, ax = plt.subplots(figsize=(6.6, 5.2))
xy = micro2.obsm["X_umap"]
val = micro2.obs.get("tolerance_positioning", pd.Series(np.nan, index=micro2.obs_names)).values
sca = ax.scatter(xy[:, 0], xy[:, 1], c=val, s=2, alpha=0.75, linewidths=0, cmap=CMAP_CONT)
ax.set_xlabel("UMAP1"); ax.set_ylabel("UMAP2")
ax.set_title("Main Fig 2A  Microglia tolerance positioning")
_beautify_axes(ax)
plt.colorbar(sca, ax=ax, fraction=0.03, pad=0.02, label="Homeostatic − Activation")
save_fig(fig, "Fig2A", kind="Main")

# Main Fig 2B: bar + error bars (CI range)
mat_adj = df_res.pivot_table(index="dataset", columns="MES", values="r_adj")
mat_ci_lo = df_res.pivot_table(index="dataset", columns="MES", values="ci_adj_low")
mat_ci_hi = df_res.pivot_table(index="dataset", columns="MES", values="ci_adj_high")
df_bar = pd.DataFrame({
    "dataset": mat_adj.index,
    "mean_r_adj": np.nanmean(mat_adj.values, axis=1),
    "mean_ci_lo": np.nanmean(mat_ci_lo.values, axis=1),
    "mean_ci_hi": np.nanmean(mat_ci_hi.values, axis=1)
}).sort_values("mean_r_adj", ascending=False)

fig, ax = plt.subplots(figsize=(6.6, 3.4))
yerr_lo = df_bar["mean_r_adj"] - df_bar["mean_ci_lo"]
yerr_hi = df_bar["mean_ci_hi"] - df_bar["mean_r_adj"]
ax.bar(df_bar["dataset"], df_bar["mean_r_adj"], color=C_BAR_MAIN,
       yerr=[yerr_lo.values, yerr_hi.values], capsize=3, error_kw={"linewidth": 0.8})
ax.set_ylabel("Mean adjusted ρ (95% CI)")
ax.set_title("Main Fig 2B  Average MES–tolerance correlation (adjusted)")
ax.tick_params(axis="x", rotation=45)
_beautify_axes(ax)
save_fig(fig, "Fig2B", kind="Main")

# Main Fig 2C: annotated seaborn heatmap
fig, ax = plt.subplots(figsize=(8.5, 3.5))
if HAS_SNS:
    # Build significance matrix
    mat_q = df_res.pivot_table(index="dataset", columns="MES", values="q_global_BH") if "q_global_BH" in df_res.columns else None
    annot_text = mat_adj.copy()
    for ds in mat_adj.index:
        for mes in mat_adj.columns:
            r_val = mat_adj.loc[ds, mes]
            q_val = mat_q.loc[ds, mes] if mat_q is not None and ds in mat_q.index and mes in mat_q.columns else np.nan
            s = sig_stars(q_val)
            annot_text.loc[ds, mes] = f"{r_val:.2f}\n{s}" if np.isfinite(r_val) else ""

    sns.heatmap(mat_adj, ax=ax, cmap=CMAP_DIV, vmin=-0.5, vmax=0.5, linewidths=0.5,
                linecolor="white", annot=annot_text, fmt="", annot_kws={"fontsize": 6},
                cbar_kws={"label": "Adjusted ρ", "shrink": 0.8})
else:
    im = ax.imshow(mat_adj.values, aspect="auto", vmin=-0.5, vmax=0.5, cmap=CMAP_DIV)
    ax.set_yticks(range(mat_adj.shape[0])); ax.set_yticklabels(mat_adj.index)
    ax.set_xticks(range(mat_adj.shape[1])); ax.set_xticklabels(mat_adj.columns, rotation=90)
    plt.colorbar(im, ax=ax, fraction=0.03, pad=0.02, label="Adjusted ρ")
ax.set_title("Main Fig 2C  MES–tolerance correlations (covariate-adjusted, FDR-annotated)")
save_fig(fig, "Fig2C", kind="Main")

# Main Fig 2D: LGALS3 with CIs
df_lg = pd.DataFrame(lgals3_rows).dropna(subset=["r_adj"])
fig, ax = plt.subplots(figsize=(6.6, 3.2))
if len(df_lg):
    yerr_lo = df_lg["r_adj"] - df_lg["ci_low"]
    yerr_hi = df_lg["ci_high"] - df_lg["r_adj"]
    ax.bar(df_lg["dataset"], df_lg["r_adj"], color=C_BAR_ACC2,
           yerr=[yerr_lo.values, yerr_hi.values], capsize=3, error_kw={"linewidth": 0.8},
           edgecolor="black", linewidth=0.3)
    for i, (_, row) in enumerate(df_lg.iterrows()):
        s = sig_stars(row.get("p_perm"))
        if s != "ns":
            ax.text(i, row["ci_high"] + 0.01, s, ha="center", fontsize=7, fontweight="bold")
ax.set_ylabel("Adjusted ρ (95% CI)")
ax.set_title("Main Fig 2D  LGALS3 vs tolerance (adjusted)")
ax.tick_params(axis="x", rotation=45)
_beautify_axes(ax)
save_fig(fig, "Fig2D", kind="Main")

# Supp Fig 2E: FNDC5 with CIs
df_fn = pd.DataFrame(fndc5_rows).dropna(subset=["r_adj"])
if len(df_fn):
    fig, ax = plt.subplots(figsize=(6.6, 3.2))
    yerr_lo = df_fn["r_adj"] - df_fn["ci_low"]
    yerr_hi = df_fn["ci_high"] - df_fn["r_adj"]
    ax.bar(df_fn["dataset"], df_fn["r_adj"], color=C_BAR_ACC,
           yerr=[yerr_lo.values, yerr_hi.values], capsize=3, error_kw={"linewidth": 0.8},
           edgecolor="black", linewidth=0.3)
    for i, (_, row) in enumerate(df_fn.iterrows()):
        s = sig_stars(row.get("p_perm"))
        if s != "ns":
            ax.text(i, row["ci_high"] + 0.01, s, ha="center", fontsize=7, fontweight="bold")
    ax.set_ylabel("Adjusted ρ (95% CI)")
    ax.set_title("Supplementary Fig 2E  FNDC5 vs tolerance (exploratory)")
    ax.tick_params(axis="x", rotation=45)
    _beautify_axes(ax)
    save_fig(fig, "Fig2E_FNDC5", kind="Supplementary")

# Supp Fig 2G: HK robustness
fig, ax = plt.subplots(figsize=(7.2, 3.0))
tmp = df_hk_sens.copy(); tmp["abs_delta"] = tmp["delta"].abs()
summ = tmp.groupby("dataset")["abs_delta"].agg(["mean", "std"]).sort_values("mean", ascending=False)
ax.bar(summ.index, summ["mean"], color=C_BAR_MAIN, yerr=summ["std"], capsize=3,
       error_kw={"linewidth": 0.8}, edgecolor="black", linewidth=0.3)
ax.set_ylabel("Mean |Δρ| ± SD")
ax.set_title("Supplementary Fig 2G  MES robustness to housekeeping removal")
ax.tick_params(axis="x", rotation=45)
_beautify_axes(ax)
save_fig(fig, "Fig2G_HK_Robustness", kind="Supplementary")


# 10) SPATIAL VALIDATION
vis_hits = list(RAW_DIR.rglob("GSE220442"))
assert len(vis_hits) >= 1
vis_root = None
for p in vis_hits:
    cand = p / "counts_and_images"
    if cand.exists(): vis_root = cand; break
assert vis_root is not None

samples = sorted([d for d in vis_root.iterdir() if d.is_dir()])
rep_sample = samples[0].name if samples else None

spatial_rows = []
for sdir in samples:
    h5 = sdir / "filtered_feature_bc_matrix.h5"
    if not h5.exists(): continue
    a = sc.read_10x_h5(h5)
    a = ensure_gene_symbols(a, f"Visium_{sdir.name}")
    a.var_names_make_unique()
    a.layers["counts"] = a.X.copy()
    sc.pp.normalize_total(a, target_sum=1e4); sc.pp.log1p(a)

    score_geneset(a, SIG["microglia_homeostatic"], "score_homeostatic")
    score_geneset(a, SIG["microglia_activation"], "score_activation")
    a.obs["tolerance_positioning"] = a.obs["score_homeostatic"] - a.obs["score_activation"]
    for mes, genes in mes_gene_sets.items():
        score_geneset(a, genes, f"{mes}_score")

    # ALL 8 MES correlations per sample
    for mes in mes_cols:
        st = corr_with_ci_p(a.obs[f"{mes}_score"].values, a.obs["tolerance_positioning"].values,
                            n_boot=min(N_BOOT, 500), n_perm=min(N_PERM, 500), seed=SEED + 77)
        spatial_rows.append({
            "dataset": "GSE220442", "sample_id": sdir.name, "MES": mes,
            "n_spots": int(st["n"]), "r": float(st["r"]),
            "ci_low": float(st["ci_low"]), "ci_high": float(st["ci_high"]),
            "p_perm": float(st["p_perm"])})

    if sdir.name == rep_sample:
        fig, ax = plt.subplots(figsize=(6.0, 3.2))
        vv = a.obs["tolerance_positioning"].values
        vv = vv[np.isfinite(vv)]
        ax.hist(vv, bins=50, color=C_BAR_MAIN, alpha=0.85, edgecolor="black", linewidth=0.3)
        ax.set_title(f"Main Fig 3A  {sdir.name} tolerance distribution (n={len(vv)})")
        ax.set_xlabel("Tolerance positioning"); ax.set_ylabel("Spots")
        _beautify_axes(ax)
        save_fig(fig, "Fig3A", kind="Main")

df_spatial = pd.DataFrame(spatial_rows)
if len(df_spatial):
    df_spatial["q_BH"] = bh_fdr(df_spatial["p_perm"].values)

# Main Fig 3B: heatmap per sample × MES
if len(df_spatial):
    fig, ax = plt.subplots(figsize=(8, max(3, 0.4 * df_spatial["sample_id"].nunique())))
    piv = df_spatial.pivot_table(index="sample_id", columns="MES", values="r", aggfunc="mean")
    ordered = [f"MES0{i}" for i in range(1, 9)]
    cols = [c for c in ordered if c in piv.columns]
    piv = piv[cols]

    if HAS_SNS:
        piv_q = df_spatial.pivot_table(index="sample_id", columns="MES", values="q_BH", aggfunc="mean")
        annot = piv.copy()
        for ds in piv.index:
            for mes in piv.columns:
                r = piv.loc[ds, mes]
                q = piv_q.loc[ds, mes] if ds in piv_q.index and mes in piv_q.columns else np.nan
                annot.loc[ds, mes] = f"{r:.2f}{sig_stars(q)}" if np.isfinite(r) else ""
        sns.heatmap(piv, ax=ax, cmap=CMAP_DIV, vmin=-0.5, vmax=0.5, linewidths=0.5,
                    annot=annot, fmt="", annot_kws={"fontsize": 5},
                    cbar_kws={"label": "ρ", "shrink": 0.8})
    else:
        im = ax.imshow(piv.values, aspect="auto", cmap=CMAP_DIV, vmin=-0.5, vmax=0.5)
        ax.set_xticks(range(piv.shape[1])); ax.set_xticklabels(piv.columns, rotation=45)
        ax.set_yticks(range(piv.shape[0])); ax.set_yticklabels(piv.index)
        plt.colorbar(im, ax=ax, label="ρ", shrink=0.8)
    ax.set_title("Main Fig 3B  Visium spatial: all MES × sample (FDR-annotated)")
    save_fig(fig, "Fig3B", kind="Main")

# 11) CROSS-MODAL GSE233208
def load_mtx_export(mtx_dir, label):
    mtx = mtx_dir / "counts.mtx"; feats = mtx_dir / "features.tsv.gz"
    bcs = mtx_dir / "barcodes.tsv.gz"; meta = mtx_dir / "meta.csv.gz"
    if not all(p.exists() for p in [mtx, feats, bcs, meta]):
        raise FileNotFoundError(f"{label} missing files")
    X = scipy.io.mmread(str(mtx)).tocsr()
    feats_df = pd.read_csv(feats, sep="\t", header=None, compression="gzip")
    bcs_df = pd.read_csv(bcs, sep="\t", header=None, compression="gzip")
    meta_df = pd.read_csv(meta, compression="gzip", index_col=0).reindex(bcs_df[0].astype(str).values)
    adata = ad.AnnData(X=X.T.tocsr(), obs=meta_df.copy(),
                        var=pd.DataFrame(index=feats_df[1].astype(str).values if feats_df.shape[1] > 1 else feats_df[0].astype(str).values))
    adata.obs_names = bcs_df[0].astype(str).values; adata.var_names_make_unique()
    return adata

crossmodal_ran = False; df_concord = pd.DataFrame(); pair_key_used = "(not run)"
try:
    if all((GSE233208_VIS_MTX / "counts.mtx").exists() for _ in [1]) and (GSE233208_SN_MTX / "counts.mtx").exists():
        ad_vis = load_mtx_export(GSE233208_VIS_MTX, "GSE233208_visium")
        ad_sn = load_mtx_export(GSE233208_SN_MTX, "GSE233208_snrna")
        crossmodal_ran = True
        ad_vis = ensure_gene_symbols(ad_vis, "GSE233208_Visium")
        ensure_counts_layer(ad_vis); ad_vis.X = ad_vis.layers["counts"].copy()
        if not looks_log1p(ad_vis): sc.pp.normalize_total(ad_vis, target_sum=1e4); sc.pp.log1p(ad_vis)
        score_geneset(ad_vis, SIG["microglia_homeostatic"], "score_homeostatic", 3)
        score_geneset(ad_vis, SIG["microglia_activation"], "score_activation", 3)
        ad_vis.obs["tolerance_positioning"] = ad_vis.obs["score_homeostatic"] - ad_vis.obs["score_activation"]
        for mes, genes in mes_gene_sets.items(): score_geneset(ad_vis, genes, f"{mes}_score", 3)
        v = ad_vis.obs["score_homeostatic"].values
        thr = np.nanquantile(v[np.isfinite(v)], 0.70) if np.isfinite(v).sum() > 50 else np.nan
        ad_vis.obs["microglia_enriched_spot"] = (ad_vis.obs["score_homeostatic"] >= thr).astype(int) if np.isfinite(thr) else 1
        ad_sn = ensure_gene_symbols(ad_sn, "GSE233208_snRNA"); prep_for_scoring(ad_sn)
        ct_c = [c for c in ad_sn.obs.columns if ("cell" in c.lower() and "type" in c.lower()) or "annotation" in c.lower()]
        if ct_c:
            vals = ad_sn.obs[ct_c[0]].astype(str).str.lower()
            keep = vals.str.contains("microglia|myeloid")
            ad_sn = ad_sn[keep].copy() if keep.any() else subset_microglia_by_markers(ad_sn, 2, ["AIF1","LST1","TYROBP"])
        else:
            ad_sn = subset_microglia_by_markers(ad_sn, 2, ["AIF1","LST1","TYROBP"])
        prep_for_scoring(ad_sn)
        score_geneset(ad_sn, SIG["microglia_homeostatic"], "score_homeostatic", 3)
        score_geneset(ad_sn, SIG["microglia_activation"], "score_activation", 3)
        ad_sn.obs["tolerance_positioning"] = ad_sn.obs["score_homeostatic"] - ad_sn.obs["score_activation"]
        for mes, genes in mes_gene_sets.items(): score_geneset(ad_sn, genes, f"{mes}_score", 3)

        def find_pair_key(dfA, dfB):
            for k in ["sample_id","sample","donor_id","donor","subject","patient","case_id","region","library_id","orig.ident"]:
                if k in dfA.columns and k in dfB.columns: return k
            return None
        pair_key = find_pair_key(ad_vis.obs, ad_sn.obs) or "__global__"
        if pair_key == "__global__": ad_vis.obs[pair_key] = "all"; ad_sn.obs[pair_key] = "all"
        pair_key_used = pair_key
        cols_scores = ["tolerance_positioning"] + [f"{m}_score" for m in mes_cols]
        dfv = ad_vis.obs[ad_vis.obs["microglia_enriched_spot"].astype(int) == 1].groupby(pair_key)[cols_scores].mean().reset_index()
        dfn = ad_sn.obs.groupby(pair_key)[cols_scores].mean().reset_index()
        df_merge = pd.merge(dfv, dfn, on=pair_key, suffixes=("_visium", "_snrna"))
        concord = []
        for c in cols_scores:
            cv = f"{c}_visium"; cn = f"{c}_snrna"
            if cv in df_merge.columns and cn in df_merge.columns:
                concord.append({"score": c, "n_groups": int(df_merge.shape[0]),
                                "pearson_r": safe_corr(df_merge[cv].values, df_merge[cn].values)})
        df_concord = pd.DataFrame(concord)
        save_excel({"Concordance": df_concord, "Paired": df_merge,
                    "pair_key": pd.DataFrame([{"pair_key": pair_key_used}])},
                   fname="Table8_CrossModal_GSE233208", kind="Supplementary")

        fig = plt.figure(figsize=(6.8, 3.2))
        gs = fig.add_gridspec(1, 2, wspace=0.35)
        ax1 = fig.add_subplot(gs[0, 0])
        if "tolerance_positioning_visium" in df_merge.columns:
            x = df_merge["tolerance_positioning_visium"].values; y2 = df_merge["tolerance_positioning_snrna"].values
            ax1.scatter(x, y2, s=18, alpha=0.85, color=C_BAR_ACC2)
            ax1.set_xlabel("Visium tolerance"); ax1.set_ylabel("snRNA tolerance")
            ax1.set_title(f"r = {safe_corr(x, y2):.2f}")
        ax2 = fig.add_subplot(gs[0, 1])
        dfm = df_concord[df_concord["score"].str.startswith("MES")]
        if len(dfm): ax2.bar(dfm["score"], dfm["pearson_r"], color=C_BAR_MAIN)
        ax2.tick_params(axis="x", rotation=90); ax2.set_ylabel("Pearson r")
        ax2.set_title("MES concordance"); _beautify_axes(ax2)
        save_fig(fig, "Fig3C_CrossModal_GSE233208", kind="Supplementary")
except Exception as e:
    crossmodal_ran = False
    print(f"[INFO] Cross-modal not run: {type(e).__name__}: {e}")

# 12) FIGURES 4D + 4E
fig, ax = plt.subplots(figsize=(6.8, 3.6))
if len(df_stress):
    diff = df_stress.set_index("dataset")
    ax.bar(diff.index, diff["diff_high_minus_low"], color=C_BAR_ACC, edgecolor="black", linewidth=0.3)
    for i, (ds, row) in enumerate(diff.iterrows()):
        s = sig_stars(row.get("p_perm"))
        if s != "ns":
            y_pos = row["diff_high_minus_low"] + (0.01 if row["diff_high_minus_low"] >= 0 else -0.02)
            ax.text(i, y_pos, f"{s}\np={row['p_perm']:.3f}", ha="center", fontsize=6, fontweight="bold")
ax.set_ylabel("Δ mean adjusted ρ (high − low GR)")
ax.set_title("Main Fig 4D  GR moderation of MES–tolerance link")
ax.tick_params(axis="x", rotation=45)
_beautify_axes(ax)
save_fig(fig, "Fig4D", kind="Main")

data_gr = []; labels_gr = []
for a in micro_adatas:
    vv = a.obs.get("GR_composite", pd.Series([], dtype=float)).values
    vv = vv[np.isfinite(vv)]
    if vv.size >= 200: data_gr.append(vv); labels_gr.append(str(a.obs["dataset"].iloc[0]))
fig, ax = plt.subplots(figsize=(6.8, 3.6))
bp = ax.boxplot(data_gr, labels=labels_gr, showfliers=False, patch_artist=True,
                boxprops=dict(edgecolor="#333333", linewidth=0.9),
                medianprops=dict(color="#111111", linewidth=1.0),
                whiskerprops=dict(color="#333333", linewidth=0.9),
                capprops=dict(color="#333333", linewidth=0.9))
for patch in bp["boxes"]: patch.set_facecolor(C_BAR_MAIN); patch.set_alpha(0.30)
ax.set_ylabel("GR composite score")
ax.set_title("Supplementary Fig 4E  GR program across cohorts")
ax.tick_params(axis="x", rotation=45)
_beautify_axes(ax)
save_fig(fig, "Fig4E_GR_AcrossCohorts", kind="Supplementary")

# 13) INNATE MEMORY (GSE184241)
def read_gse184241_counts(path):
    path = Path(path)
    compression = "gzip" if str(path).endswith(".gz") else None
    df = pd.read_csv(path, sep=r"\s+", engine="python", compression=compression, index_col=0)
    df.index = df.index.astype(str).str.replace('"', '').str.strip()
    df.columns = df.columns.astype(str).str.replace('"', '').str.strip()
    idx_sample = df.index[:100].astype(str)
    hgnc_like = idx_sample.str.match(r"^[A-Za-z][A-Za-z0-9\-\.\+]*$").mean()
    numeric_like = idx_sample.str.fullmatch(r"\d+").mean()
    if hgnc_like >= 0.6 and numeric_like < 0.2:
        df = df.apply(pd.to_numeric, errors="coerce").fillna(0.0)
        return df.groupby(df.index, sort=False).sum(min_count=1)
    df2 = pd.read_csv(path, sep=r"\s+", engine="python", compression=compression)
    df2.columns = df2.columns.astype(str).str.replace('"', '').str.strip()
    gene_col = df2.columns[0]
    df2[gene_col] = df2[gene_col].astype(str).str.replace('"', '').str.strip()
    df2 = df2.set_index(gene_col)
    df2 = df2.apply(pd.to_numeric, errors="coerce").fillna(0.0)
    return df2.groupby(df2.index, sort=False).sum(min_count=1)

def infer_condition_fixed(sample_name):
    s_up = str(sample_name).upper()
    if "_LPS_" in s_up: return "LPS"
    if "_RPMI_" in s_up: return "RPMI"
    if "LPS" in s_up: return "LPS"
    if "RPMI" in s_up: return "RPMI"
    if "CTRL" in s_up or "CONTROL" in s_up: return "Control"
    return "Unknown"

gse184_hits = list(RAW_DIR.rglob("GSE184241_combined_raw_counts.txt.gz")) + list(RAW_DIR.rglob("GSE184241_combined_raw_counts.txt"))
df_innate_scores = pd.DataFrame()
innate_stats_rows = []

if gse184_hits:
    gse184_file = gse184_hits[0]
    expr184 = read_gse184241_counts(gse184_file)
    MES_GENES = sorted({g for genes in mes_gene_sets.values() for g in genes})
    overlap = sorted(set(MES_GENES) & set(expr184.index))
    print(f"[GSE184241] MES overlap: {len(overlap)}/{len(MES_GENES)}")
    assert len(overlap) > 0

    lib = expr184.sum(axis=0).replace(0, np.nan)
    df_logm = np.log2(expr184.div(lib, axis=1) * 1e6 + 1.0)

    def sample_score(df_log_gxS, genes_up, genes_dn=None):
        up = [g for g in genes_up if g in df_log_gxS.index]
        dn = [g for g in (genes_dn or []) if g in df_log_gxS.index]
        if len(up) < 3: return pd.Series(index=df_log_gxS.columns, data=np.nan)
        s_up = df_log_gxS.loc[up].mean(axis=0)
        s_dn = df_log_gxS.loc[dn].mean(axis=0) if len(dn) >= 3 else 0.0
        return (s_up - s_dn).astype(float)

    out = pd.DataFrame(index=df_logm.columns)
    out["sample"] = out.index.astype(str)
    out["condition"] = out["sample"].apply(infer_condition_fixed)
    out["homeostatic"] = sample_score(df_logm, SIG["microglia_homeostatic"])
    out["activation"] = sample_score(df_logm, SIG["microglia_activation"])
    out["tolerance_like"] = out["homeostatic"] - out["activation"]
    out["GR_score"] = sample_score(df_logm, GR_UP, GR_DN)
    for mes in mes_cols:
        out[f"{mes}_score"] = sample_score(df_logm, mes_gene_sets[mes])

    df_innate_scores = out.reset_index(drop=True)
    print("GSE184241 conditions:", df_innate_scores["condition"].value_counts().to_dict())

    # FORMAL LPS vs RPMI STATISTICAL TESTS (NEW)
    lps = df_innate_scores[df_innate_scores["condition"] == "LPS"]
    rpmi = df_innate_scores[df_innate_scores["condition"] == "RPMI"]
    if len(lps) >= 3 and len(rpmi) >= 3:
        for sc_col in ["tolerance_like", "GR_score"] + [f"{m}_score" for m in mes_cols]:
            a_vals = lps[sc_col].dropna().values
            b_vals = rpmi[sc_col].dropna().values
            if len(a_vals) >= 3 and len(b_vals) >= 3:
                _, p_mw = stats.mannwhitneyu(a_vals, b_vals, alternative="two-sided")
                dr = cohens_d_with_ci(a_vals, b_vals, n_boot=min(N_BOOT, 500), seed=SEED + 200)
                innate_stats_rows.append({
                    "score": sc_col, "n_LPS": len(a_vals), "n_RPMI": len(b_vals),
                    "mean_LPS": float(np.mean(a_vals)), "mean_RPMI": float(np.mean(b_vals)),
                    "cohens_d": dr["d"], "d_ci_lo": dr["ci_lo"], "d_ci_hi": dr["ci_hi"],
                    "p_MW": float(p_mw)})
        df_innate_stats = pd.DataFrame(innate_stats_rows)
        df_innate_stats["q_BH"] = bh_fdr(df_innate_stats["p_MW"].values)
    else:
        df_innate_stats = pd.DataFrame()

    save_excel({
        "GSE184241_sample_scores": df_innate_scores,
        "LPS_vs_RPMI_stats": df_innate_stats if len(innate_stats_rows) else pd.DataFrame(),
        "note": pd.DataFrame([{"file": str(gse184_file),
                                "note": "Sample-level scoring + formal LPS vs RPMI comparison (MW + Cohen's d + BH-FDR)"}])
    }, fname="Table7_InnateMemory_GSE184241", kind="Supplementary")

    # Supp Fig 4F: heatmap
    mes_score_cols = [f"{m}_score" for m in mes_cols]
    hm = df_innate_scores.set_index("sample")[mes_score_cols].copy()
    hmz = (hm - hm.mean(axis=0)) / (hm.std(axis=0) + 1e-8)
    cond = df_innate_scores.set_index("sample")["condition"].reindex(hmz.index).astype(str)
    order = np.argsort(cond.values, kind="stable")
    hmz_ord = hmz.iloc[order]

    fig, ax = plt.subplots(figsize=(7.2, 4.8))
    if HAS_SNS:
        row_colors = cond.iloc[order].map({"LPS": "#D6604D", "RPMI": "#4393C3", "Control": "#66C2A5", "Unknown": "#999999"})
        sns.heatmap(hmz_ord, ax=ax, cmap=CMAP_HEAT, vmin=-2, vmax=2, linewidths=0,
                    cbar_kws={"label": "Z-score", "shrink": 0.7},
                    xticklabels=True, yticklabels=False)
    else:
        im = ax.imshow(hmz_ord.values, aspect="auto", cmap=CMAP_HEAT, vmin=-2, vmax=2, interpolation="nearest")
        ax.set_xticks(range(hmz_ord.shape[1])); ax.set_xticklabels(hmz_ord.columns, rotation=90)
        plt.colorbar(im, ax=ax, label="Z-score", shrink=0.7)
    ax.set_title("Supplementary Fig 4F  GSE184241 innate memory MES profiles")
    ax.set_ylabel("Samples (ordered by condition)")
    save_fig(fig, "Fig4F_InnateMemory_MES_Profiles", kind="Supplementary")
else:
    print("[INFO] GSE184241 not found; skipping.")

# 14) SAVE SCORED h5ad
for a in micro_adatas:
    ds = str(a.obs["dataset"].iloc[0])
    out = AIM2_DIR / f"{ds}__microglia_scored.h5ad"
    sanitize_for_write(a).write_h5ad(out)
    print(f"[SAVED] {out}")

# 15) SUPPLEMENTARY TABLE 5 
df_crossmodal_summary = pd.DataFrame()
if crossmodal_ran and len(df_concord):
    df_crossmodal_summary = df_concord.copy()
    df_crossmodal_summary["dataset"] = "GSE233208"
    df_crossmodal_summary["pair_key"] = pair_key_used

save_excel({
    "GSE220442_spatial_allMES": df_spatial,
    "GSE233208_crossmodal": df_crossmodal_summary,
    "crossmodal_status": pd.DataFrame([{
        "ran": bool(crossmodal_ran),
        "folder": str(GSE233208_EXPORT),
        "note": "Cross-modal runs only if MTX exports exist."}])
}, fname="Table5", kind="Supplementary")

# 16) SUPPLEMENTARY TABLE INDEX 
index_rows = [
    {"label": "Main Table 2", "file": "Main_Table2.xlsx", "content": "MES–tolerance correlations + QC + global FDR"},
    {"label": "Supp Table 3", "file": "Supplementary_Table3_StressCalibration.xlsx",
     "content": "GSE219208 calibration + GR signatures + drug breakdown + sensitivity"},
    {"label": "Supp Table 4", "file": "Supplementary_Table4.xlsx", "content": "Signature definitions + full results"},
    {"label": "Supp Table 5", "file": "Supplementary_Table5.xlsx", "content": "Spatial (all 8 MES) + cross-modal"},
    {"label": "Supp Table 6", "file": "Supplementary_Table6.xlsx", "content": "Stress moderation (permutation)"},
    {"label": "Supp Table 7", "file": "Supplementary_Table7_InnateMemory_GSE184241.xlsx",
     "content": "Innate memory scores + LPS vs RPMI formal stats"},
    {"label": "Supp Table 9", "file": "Supplementary_Table9_HK_Sensitivity.xlsx", "content": "HK load + sensitivity"},
]
if crossmodal_ran:
    index_rows.append({"label": "Supp Table 8", "file": "Supplementary_Table8_CrossModal_GSE233208.xlsx",
                        "content": "Visium+snRNA concordance"})

save_excel({"Index": pd.DataFrame(index_rows)}, fname="Table_Index", kind="Supplementary")

print("\n" + "=" * 70)
print("NB4 COMPLETE ✅")
print("=" * 70)
print("  ✓ Bootstrap 400→2000, permutation 400→5000")
print("  ✓ Cohen's d + AUC with bootstrap 95% CIs")
print("  ✓ Key gene labels on volcano (FKBP5, NR3C1, TSC22D3...)")
print("  ✓ Seaborn heatmaps with r-values + significance stars")
print("  ✓ All bar plots with error bars / CIs")
print("  ✓ Boxplot with MW significance annotations")
print("  ✓ Drug-type breakdown (dex vs cort)")
print("  ✓ GR signature size sensitivity (top 30/50/100)")
print("  ✓ Spatial: all 8 MES per sample (not just best)")
print("  ✓ Innate memory: formal LPS vs RPMI stats (MW + d + FDR)")
print("  ✓ Global BH-FDR across all tests")
print("  ✓ padj threshold 0.05 (was 0.1)")
print("=" * 70)