# NB5 — Correlation, Meta-Analysis & Robustness

Donor-level Spearman correlations, DerSimonian-Laird random-effects meta-analysis, leave-one-dataset-out (LODO) sensitivity, housekeeping gene robustness, and GR-stratified differential expression.

**Paper:** Zafar SA, Qin W. *Thymus-Derived Myeloid Education Signatures Predict Microglial Tolerance Positioning and Are Modulated by Glucocorticoid Stress-Axis Activity.* Neuroimmunomodulation (2026).

> **Note:** Update the path variables in section 0 to match your local directory structure before running. Raw data can be obtained from the public repositories listed in Supplementary Table S1.


In [None]:
from __future__ import annotations
import os, re, time, math, json, gc, glob, warnings, traceback
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional, Any
from itertools import combinations

import numpy as np
import pandas as pd
warnings.filterwarnings("ignore")

import matplotlib
matplotlib.rcParams.update({
    "font.family": "Arial",
    "font.size": 8,
    "axes.titlesize": 9,
    "axes.labelsize": 8,
    "xtick.labelsize": 7,
    "ytick.labelsize": 7,
    "legend.fontsize": 7,
    "figure.dpi": 150,
    "savefig.dpi": 1200,
    "savefig.bbox": "tight",
    "savefig.pad_inches": 0.05,
    "axes.linewidth": 0.8,
    "xtick.major.width": 0.6,
    "ytick.major.width": 0.6,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
})

import matplotlib.pyplot as plt

try:
    import seaborn as sns
    sns.set_style("ticks")
    HAS_SNS = True
except ImportError:
    HAS_SNS = False

from adjustText import adjust_text

import anndata as ad
import scanpy as sc
import scipy.sparse as sp
from scipy import stats as scipy_stats

try:
    import statsmodels.api as sm
    from statsmodels.stats.multitest import multipletests
    HAS_SM = True
except ImportError:
    HAS_SM = False

# User config
np.random.seed(42)

PROC_DIR  = Path(".") / "data" / "processed" / "aim2_microglia"  # <-- SET PATH
MANUS_DIR = Path(".") / "outputs" / "manuscript"  # <-- SET PATH
FIG_DIR   = MANUS_DIR / "Figures"
TAB_DIR   = MANUS_DIR / "Tables"
FIG_DIR.mkdir(parents=True, exist_ok=True)
TAB_DIR.mkdir(parents=True, exist_ok=True)

NB4_TABLE_DIR = TAB_DIR
SCORED_GLOB   = str(PROC_DIR / "*__microglia_scored.h5ad")

MAX_CELLS_FOR_CELL_DE        = 50000
MAX_CELLS_FOR_CELL_RECOMPUTE = 50000
DPI      = 1200
N_BOOT   = 2000
N_PERM   = 5000
ALPHA    = 0.05

RUN_DIAGNOSTICS_ONLY = False
DO_ENRICHMENT        = False

# Color palettes
PAL_DATASETS = sns.color_palette("Set2", 8) if HAS_SNS else plt.cm.Set2(np.linspace(0, 1, 8))
PAL_GR       = {"GR_high": "#D6604D", "GR_low": "#4393C3"}


# Small utilities
def _now():  return time.strftime("%Y-%m-%d %H:%M:%S")

def _rss_gb():
    try:
        import psutil
        return psutil.Process(os.getpid()).memory_info().rss / (1024**3)
    except Exception:
        return float("nan")

def _cpu_pct():
    try:
        import psutil
        return psutil.Process(os.getpid()).cpu_percent(interval=None)
    except Exception:
        return float("nan")

def log(msg: str):
    print(f"[{_now()}] {msg}", flush=True)

def fmt_td(seconds: float) -> str:
    seconds = max(0.0, float(seconds))
    m, s = divmod(int(seconds), 60)
    h, m = divmod(m, 60)
    return f"{h}:{m:02d}:{s:02d}" if h > 0 else f"{m}:{s:02d}"


@dataclass
class Progress:
    label: str
    total: int
    start_t: float = None
    done: int = 0

    def __post_init__(self):
        self.start_t = time.time() if self.start_t is None else self.start_t

    def tick(self, inc=1):
        self.done = min(self.done + int(inc), self.total)

    def line(self, extra: str = "") -> str:
        elapsed = time.time() - self.start_t
        frac = (self.done / self.total) if self.total > 0 else 1.0
        eta  = elapsed / max(frac, 1e-9) - elapsed if self.total > 0 else 0.0
        return (f"PROG {self.label}: {self.done}/{self.total} ({frac*100:.1f}%) | "
                f"elapsed={fmt_td(elapsed)} | ETA={fmt_td(eta)} | "
                f"cpu={_cpu_pct():.1f}% | rss={_rss_gb():.2f}GB"
                + (f" | {extra}" if extra else ""))


class Heartbeat:
    def __init__(self, label: str, every_s: int = 30):
        self.label = label
        self.every_s = int(every_s)
        self._t0 = None
        self._running = False
        self._last_msg = ""
        self._thread = None

    def start(self, msg: str = "start"):
        import threading
        self._t0 = time.time()
        self._running = True
        log(f"HB {self.label}: {msg}")
        def _loop():
            while self._running:
                elapsed = time.time() - self._t0
                log(f"HB {self.label}: running | elapsed={fmt_td(elapsed)} | "
                    f"cpu={_cpu_pct():.1f}% | rss={_rss_gb():.2f}GB"
                    + (f" | {self._last_msg}" if self._last_msg else ""))
                time.sleep(self.every_s)
        self._thread = threading.Thread(target=_loop, daemon=True)
        self._thread.start()

    def update(self, msg: str):
        self._last_msg = str(msg)

    def stop(self, msg: str = "done"):
        self._running = False
        elapsed = time.time() - (self._t0 or time.time())
        log(f"HB {self.label}: {msg} | elapsed={fmt_td(elapsed)}")


def save_fig(path: Path, fig=None):
    path.parent.mkdir(parents=True, exist_ok=True)
    if fig is not None:
        fig.savefig(path, dpi=DPI, bbox_inches="tight")
        plt.close(fig)
    else:
        plt.savefig(path, dpi=DPI, bbox_inches="tight")
        plt.close()
    log(f"SAVED FIG {path}")


def save_xlsx(df: pd.DataFrame, path: Path, sheet_name="Sheet1"):
    if df is None: df = pd.DataFrame()
    path.parent.mkdir(parents=True, exist_ok=True)
    with pd.ExcelWriter(path, engine="openpyxl") as w:
        df.to_excel(w, index=False, sheet_name=sheet_name)
    log(f"SAVED TABLE {path}")


def save_xlsx_multi(dfs: Dict[str, pd.DataFrame], path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    with pd.ExcelWriter(path, engine="openpyxl") as w:
        for name, df in dfs.items():
            if df is not None and not df.empty:
                df.to_excel(w, index=False, sheet_name=name[:31])
    log(f"SAVED TABLE (multi-sheet) {path}")


def sig_stars(p):
    if pd.isna(p) or not np.isfinite(p): return "ns"
    if p < 0.001: return "***"
    if p < 0.01:  return "**"
    if p < 0.05:  return "*"
    return "ns"


# Statistical toolkit 

def safe_numeric(x) -> np.ndarray:
    return pd.to_numeric(pd.Series(x), errors="coerce").to_numpy()

def _align_finite(x, y, min_n=5):
    x = safe_numeric(x); y = safe_numeric(y)
    m = np.isfinite(x) & np.isfinite(y)
    if m.sum() < min_n: return None, None, 0
    return x[m], y[m], int(m.sum())


def corr_with_pvalue(x, y, method="spearman", min_n=5, n_boot=N_BOOT):
    """Correlation with p-value + bootstrap 95% CI."""
    xa, ya, n = _align_finite(x, y, min_n)
    out = {"r": np.nan, "p": np.nan, "n": n,
           "ci_lo": np.nan, "ci_hi": np.nan, "method": method}
    if xa is None: return out
    if method == "spearman":
        r, p = scipy_stats.spearmanr(xa, ya)
    elif method == "pearson":
        r, p = scipy_stats.pearsonr(xa, ya)
    else:
        r, p = scipy_stats.spearmanr(xa, ya)
    out["r"] = float(r); out["p"] = float(p); out["n"] = n
    rng = np.random.RandomState(42)
    boots = np.full(n_boot, np.nan)
    for i in range(n_boot):
        idx = rng.randint(0, n, n)
        if method == "spearman":
            boots[i] = scipy_stats.spearmanr(xa[idx], ya[idx])[0]
        else:
            boots[i] = scipy_stats.pearsonr(xa[idx], ya[idx])[0]
    fb = boots[np.isfinite(boots)]
    if len(fb) > 10:
        out["ci_lo"] = float(np.percentile(fb, 2.5))
        out["ci_hi"] = float(np.percentile(fb, 97.5))
    return out


def partial_corr(df, x, y, covars, method="spearman", min_n=10):
    """Partial correlation controlling for covariates."""
    cols = [x, y] + covars
    d = df[[c for c in cols if c in df.columns]].dropna()
    actual_covars = [c for c in covars if c in d.columns]
    if d.shape[0] < min_n or not HAS_SM or len(actual_covars) == 0:
        return corr_with_pvalue(df.get(x, []), df.get(y, []), method=method, min_n=min_n)
    C = sm.add_constant(d[actual_covars].values)
    resid_x = sm.OLS(d[x].values, C).fit().resid
    resid_y = sm.OLS(d[y].values, C).fit().resid
    out = corr_with_pvalue(resid_x, resid_y, method=method, min_n=min_n)
    out["partial"] = True
    out["covariates"] = ",".join(actual_covars)
    return out


def bh_fdr(pvals) -> np.ndarray:
    p = np.asarray(pvals, float)
    q = np.full_like(p, np.nan)
    m = np.isfinite(p)
    if m.sum() == 0: return q
    if HAS_SM:
        _, qv, _, _ = multipletests(p[m], method="fdr_bh")
        q[m] = qv
    else:
        pv = p[m]; nn = pv.size
        order  = np.argsort(pv)
        ranked = pv[order]
        qv = ranked * nn / (np.arange(nn) + 1.0)
        qv = np.minimum.accumulate(qv[::-1])[::-1]
        q[m] = qv[np.argsort(order)]
    return q


# Column detection helpers

MES_COLS = [
    "MES01_score", "MES02_score", "MES03_score", "MES04_score",
    "MES05_score", "MES06_score", "MES07_score", "MES08_score",
]
GR_COL = "GR_composite"


def detect_donor_col(adata):
    cols = list(map(str, adata.obs.columns))
    for cand in ["donor_id","Donor ID","donor","DonorID","patient_id","PatientID","subject_id","individual"]:
        if cand in cols: return cand
    low = {c: c.lower() for c in cols}
    for c in cols:
        s = low[c]
        if "donor" in s and ("id" in s or s.endswith("donor")): return c
    for c in cols:
        s = low[c]
        if ("patient" in s or "subject" in s or "individual" in s) and "id" in s: return c
    return None


def detect_mes_cols(obs):
    cols = list(map(str, obs.columns))
    if all(c in cols for c in MES_COLS): return MES_COLS[:]
    pat = re.compile(r"^MES0?[1-8](_score)?$", re.IGNORECASE)
    return [c for c in cols if pat.match(c)]


def detect_gr_col(obs):
    cols = list(map(str, obs.columns))
    if GR_COL in cols: return GR_COL
    low = {c: c.lower() for c in cols}
    for c in cols:
        s = low[c]
        if "gr" in s and ("composite" in s or "score" in s): return c
    return None


def detect_sex_col(obs):
    cols = list(map(str, obs.columns))
    for cand in ["Sex","sex","gender","Gender"]:
        if cand in cols: return cand
    return None


def detect_pmi_col(obs):
    cols = list(map(str, obs.columns))
    for cand in ["PMI","pmi","post_mortem_interval","PostMortemInterval"]:
        if cand in cols: return cand
    return None


def detect_batch_col(obs):
    cols = list(map(str, obs.columns))
    for cand in ["batch","Batch","library_prep_batch","sequencing_batch"]:
        if cand in cols: return cand
    return None


def canonical_mes_label(col: str) -> str:
    c = str(col).replace("_score", "").replace("Score", "").strip()
    m = re.match(r"^(MES0?[1-8])$", c, flags=re.IGNORECASE)
    if m:
        x = m.group(1).upper()
        if len(x) == 4: x = x[:3] + "0" + x[3:]
        return x
    return c


# Covariate builder for donor-level DataFrames

def build_covariates(donor_df, obs, donor_col):
    """Detect + encode covariates into donor_df.  Returns covariate column names."""
    covars = []
    sx  = detect_sex_col(obs)
    pmi = detect_pmi_col(obs)
    bat = detect_batch_col(obs)

    if sx and sx in obs.columns:
        sex_map = obs.groupby(obs[donor_col].astype(str))[sx].first()
        donor_df["_cov_sex_num"] = donor_df.index.map(
            lambda d: 0.0 if str(sex_map.get(d, "")).lower().startswith("m")
                      else (1.0 if str(sex_map.get(d, "")).lower().startswith("f") else np.nan))
        if donor_df["_cov_sex_num"].notna().sum() > 5:
            covars.append("_cov_sex_num")

    if pmi and pmi in obs.columns:
        pmi_map = obs.groupby(obs[donor_col].astype(str))[pmi].mean(numeric_only=True)
        donor_df["_cov_pmi"] = donor_df.index.map(lambda d: pmi_map.get(d, np.nan))
        donor_df["_cov_pmi"] = pd.to_numeric(donor_df["_cov_pmi"], errors="coerce")
        if donor_df["_cov_pmi"].notna().sum() > 5:
            covars.append("_cov_pmi")

    if bat and bat in obs.columns:
        bat_map = obs.groupby(obs[donor_col].astype(str))[bat].first()
        u = bat_map.dropna().unique()
        if 1 < len(u) < 20:
            donor_df["_cov_batch_num"] = donor_df.index.map(
                lambda d: pd.Categorical([bat_map.get(d, np.nan)]).codes[0])
            donor_df["_cov_batch_num"] = donor_df["_cov_batch_num"].replace(-1, np.nan).astype(float)
            if donor_df["_cov_batch_num"].notna().sum() > 5:
                covars.append("_cov_batch_num")
    return covars


# Data integrity utilities

def ensure_log1p_X(adata, prefer_layer_counts="counts"):
    def _maxX(a):
        return float(a.X.max()) if not sp.issparse(a.X) else float(a.X.max())
    try:
        mx = _maxX(adata)
    except Exception:
        mx = None
    if mx is not None and mx > 30:
        if prefer_layer_counts in adata.layers:
            X = adata.layers[prefer_layer_counts]
            tmp = ad.AnnData(X=X, obs=adata.obs.copy(), var=adata.var.copy())
            sc.pp.normalize_total(tmp, target_sum=1e4)
            sc.pp.log1p(tmp)
            adata.X = tmp.X
        else:
            sc.pp.normalize_total(adata, target_sum=1e4)
            sc.pp.log1p(adata)
    else:
        if not sp.issparse(adata.X):
            adata.X = np.asarray(adata.X)


def stratified_subsample_by_donor_positions(obs, donor_col, max_cells=50000, seed=42):
    rng = np.random.default_rng(seed)
    donors = obs[donor_col].astype(str).values
    n = donors.shape[0]
    if n <= max_cells: return np.arange(n, dtype=int)
    uniq = pd.unique(donors)
    per = max(10, int(np.floor(max_cells / max(len(uniq), 1))))
    pos_all = []
    total = 0
    for d in uniq:
        pos = np.flatnonzero(donors == d)
        if pos.size == 0: continue
        take = min(per, pos.size)
        choice = rng.choice(pos, size=take, replace=False)
        pos_all.append(choice)
        total += choice.size
        if total >= max_cells: break
    pos = np.concatenate(pos_all) if pos_all else rng.choice(np.arange(n), size=max_cells, replace=False)
    if pos.size > max_cells:
        pos = rng.choice(pos, size=max_cells, replace=False)
    return pos.astype(int)


def add_gr_group(adata, gr_col, out_col="GR_group"):
    if gr_col not in adata.obs.columns: return
    x = pd.to_numeric(adata.obs[gr_col], errors="coerce").values
    finite = np.isfinite(x)
    if finite.sum() == 0:
        adata.obs[out_col] = pd.Categorical(["NA"]*adata.n_obs, categories=["GR_low","GR_high","NA"])
        return
    med = float(np.nanmedian(x[finite]))
    grp = np.where(x >= med, "GR_high", "GR_low")
    grp = np.where(finite, grp, "NA")
    adata.obs[out_col] = pd.Categorical(grp, categories=["GR_low","GR_high","NA"])


# Scanpy DE helper (PATCH 4 + pseudobulk-friendly)

def de_rank_genes_groups(adata, group_col, group1="GR_high", group0="GR_low",
                         method="wilcoxon", n_genes=None, min_n_per_group=10):
    if group_col not in adata.obs.columns: return pd.DataFrame()
    vc = adata.obs[group_col].astype(str).value_counts(dropna=False)
    if int(vc.get(group1, 0)) < min_n_per_group or int(vc.get(group0, 0)) < min_n_per_group:
        return pd.DataFrame()
    mask = adata.obs[group_col].astype(str).isin([group0, group1]).values
    if int(mask.sum()) < max(2 * min_n_per_group, 10):
        return pd.DataFrame()
    aa = adata[mask, :].copy()
    ensure_log1p_X(aa, prefer_layer_counts="counts")
    sc.tl.rank_genes_groups(aa, groupby=group_col, groups=[group1],
                            reference=group0, method=method,
                            n_genes=n_genes if n_genes else aa.n_vars,
                            use_raw=False)
    res = aa.uns.get("rank_genes_groups", None)
    if res is None: return pd.DataFrame()

    def _extract(x):
        if x is None: return None
        try:
            if isinstance(x, dict): return np.array(x.get(group1, np.nan))
        except: pass
        try:
            if hasattr(x, "dtype") and getattr(x.dtype, "names", None):
                if group1 in x.dtype.names: return np.array(x[group1])
        except: pass
        try:
            arr = np.array(x)
            if arr.ndim == 2:
                if arr.shape[1] == 1: return arr[:, 0]
                if arr.shape[0] == 1: return arr[0, :]
        except: pass
        try: return np.array(x)
        except: return None

    genes = _extract(res.get("names"))
    if genes is None: return pd.DataFrame()
    df = pd.DataFrame({
        "gene":  pd.Series(genes).astype(str),
        "logFC": pd.to_numeric(pd.Series(_extract(res.get("logfoldchanges"))), errors="coerce"),
        "score": pd.to_numeric(pd.Series(_extract(res.get("scores"))), errors="coerce"),
        "pval":  pd.to_numeric(pd.Series(_extract(res.get("pvals"))), errors="coerce"),
        "padj":  pd.to_numeric(pd.Series(_extract(res.get("pvals_adj"))), errors="coerce"),
    })
    return df.dropna(subset=["gene"]).drop_duplicates(subset=["gene"])


# Diagnostics

def diagnostics():
    rows = []
    rows.append(("PROC_DIR exists", PROC_DIR.exists()))
    rows.append(("PROC_DIR", str(PROC_DIR)))
    rows.append(("MANUS_DIR exists", MANUS_DIR.exists()))
    rows.append(("SCORED_GLOB", SCORED_GLOB))
    files = sorted(glob.glob(SCORED_GLOB))
    rows.append(("scored matches", len(files)))
    if files:
        rows.append(("first scored", Path(files[0]).name))
        rows.append(("last scored",  Path(files[-1]).name))
        try:
            a = sc.read_h5ad(files[0], backed=None)
            if not a.obs_names.is_unique: a.obs_names_make_unique()
            if not a.var_names.is_unique: a.var_names_make_unique()
            rows.append(("detected GR col",    detect_gr_col(a.obs) or "NONE"))
            rows.append(("detected MES cols",  ",".join(detect_mes_cols(a.obs)) or "NONE"))
            rows.append(("detected donor col", detect_donor_col(a) or "NONE"))
            rows.append(("detected sex col",   detect_sex_col(a.obs) or "NONE"))
            rows.append(("detected PMI col",   detect_pmi_col(a.obs) or "NONE"))
            rows.append(("detected batch col", detect_batch_col(a.obs) or "NONE"))
        except Exception as e:
            rows.append(("peek error", repr(e)))
    df = pd.DataFrame(rows, columns=["check", "value"])
    save_xlsx(df, TAB_DIR / "NB5_DIAG.xlsx")
    return df


# Load scored cohorts (PATCH 1: unique names)

def load_nb4_tables():
    tables = {}
    for k, p in [("Main_Table5",          NB4_TABLE_DIR/"Main_Table5.xlsx"),
                 ("Main_Table1",          NB4_TABLE_DIR/"Main_Table1.xlsx"),
                 ("Supplementary_Table1", NB4_TABLE_DIR/"Supplementary_Table1.xlsx"),
                 ("Supplementary_Table2", NB4_TABLE_DIR/"Supplementary_Table2.xlsx"),
                 ("Supplementary_Table3", NB4_TABLE_DIR/"Supplementary_Table3.xlsx"),
                 ("Supplementary_Table5", NB4_TABLE_DIR/"Supplementary_Table5.xlsx")]:
        if p.exists():
            try: tables[k] = pd.read_excel(p)
            except: pass
    return tables


def load_scored_cohorts():
    files = sorted(glob.glob(SCORED_GLOB))
    log(f"[NB5] Found scored cohorts: {[Path(x).name for x in files]}")
    pr = Progress("Load cohorts", total=len(files))
    hb = Heartbeat("load_scored_cohorts", every_s=15)
    hb.start("start")
    cohorts = {}
    for f in files:
        a = sc.read_h5ad(f)
        if not a.obs_names.is_unique: a.obs_names_make_unique()
        if not a.var_names.is_unique: a.var_names_make_unique()
        ds = Path(f).name.replace("__microglia_scored.h5ad", "")
        cohorts[ds] = a
        pr.tick(1)
        log(pr.line(extra=f"{ds} | shape={a.n_obs}x{a.n_vars}"))
    hb.stop("done")
    return cohorts


# PART A: Per-dataset corr + random-effects meta + LODO

def random_effects_meta(r_list, n_list):
    """Random-effects meta with pooled CI + Q-test p-value."""
    items = [(r, n) for r, n in zip(r_list, n_list)
             if pd.notna(r) and pd.notna(n) and float(n) > 3]
    out = {"pooled_r": np.nan, "pooled_ci_lo": np.nan, "pooled_ci_hi": np.nan,
           "I2": np.nan, "tau2": np.nan, "Q": np.nan, "Q_p": np.nan, "k": 0}
    if len(items) == 0: return out
    r = np.array([x[0] for x in items], float)
    n = np.array([x[1] for x in items], float)
    z = np.arctanh(np.clip(r, -0.999999, 0.999999))
    v = 1.0 / (n - 3.0)
    w = 1.0 / v
    z_fixed = np.sum(w * z) / np.sum(w)
    Q = float(np.sum(w * (z - z_fixed) ** 2))
    df = max(0, len(z) - 1)
    C  = np.sum(w) - (np.sum(w**2) / np.sum(w))
    tau2 = max(0.0, (Q - df) / max(C, 1e-12))
    w_re = 1.0 / (v + tau2)
    z_re = float(np.sum(w_re * z) / np.sum(w_re))
    se_z_re = float(1.0 / np.sqrt(np.sum(w_re)))
    pooled_r = float(np.tanh(z_re))
    ci_lo_z = z_re - 1.96 * se_z_re
    ci_hi_z = z_re + 1.96 * se_z_re
    I2 = 0.0 if Q <= df else float(max(0.0, (Q - df) / max(Q, 1e-12)) * 100.0)
    # Q-test p-value (chi-squared with df degrees of freedom)
    Q_p = float(1.0 - scipy_stats.chi2.cdf(Q, df)) if df > 0 else np.nan
    out.update({
        "pooled_r": pooled_r,
        "pooled_ci_lo": float(np.tanh(ci_lo_z)),
        "pooled_ci_hi": float(np.tanh(ci_hi_z)),
        "I2": I2,
        "tau2": float(tau2),
        "Q": Q,
        "Q_p": Q_p,
        "k": int(len(z)),
        "pooled_z": z_re,
        "se_z": se_z_re,
    })
    return out


def compute_mes_corr_table(cohorts):
    """Per-dataset correlations with p-values, CIs, partial correlations."""
    rows = []
    pr = Progress("PartA per-dataset corr", total=len(cohorts))
    hb = Heartbeat("PartA correlations", every_s=20)
    hb.start("start")
    for ds, a in cohorts.items():
        donor_col = detect_donor_col(a)
        obs = a.obs.copy()
        gr_col   = detect_gr_col(obs)
        mes_cols = detect_mes_cols(obs)
        if gr_col is None or len(mes_cols) == 0:
            rows.append({"dataset": ds, "level": "NA", "N": np.nan, "note": "missing GR/MES"})
            pr.tick(1); continue

        if donor_col is not None:
            # Donor-level aggregation
            grp   = obs.groupby(obs[donor_col].astype(str))
            dmean = grp[[gr_col] + mes_cols].mean(numeric_only=True)
            N     = int(dmean.shape[0])
            level = "donor"
            # Build covariates
            covars = build_covariates(dmean, obs, donor_col)
            for mes in mes_cols:
                raw = corr_with_pvalue(dmean[gr_col], dmean[mes], method="spearman")
                pcr = partial_corr(dmean, gr_col, mes, covars, method="spearman")
                rows.append({
                    "dataset": ds, "level": level, "N": N, "donor_col": donor_col,
                    "MES": canonical_mes_label(mes),
                    "r": raw["r"], "p": raw["p"],
                    "ci_lo": raw["ci_lo"], "ci_hi": raw["ci_hi"],
                    "r_partial": pcr["r"], "p_partial": pcr["p"],
                    "covariates": pcr.get("covariates", ""),
                    "GR_col_used": gr_col, "MES_col_used": mes,
                })
        else:
            N     = int(obs.shape[0])
            level = "cell"
            for mes in mes_cols:
                raw = corr_with_pvalue(obs[gr_col], obs[mes], method="spearman")
                rows.append({
                    "dataset": ds, "level": level, "N": N, "donor_col": None,
                    "MES": canonical_mes_label(mes),
                    "r": raw["r"], "p": raw["p"],
                    "ci_lo": raw["ci_lo"], "ci_hi": raw["ci_hi"],
                    "r_partial": np.nan, "p_partial": np.nan,
                    "covariates": "",
                    "GR_col_used": gr_col, "MES_col_used": mes,
                })
        pr.tick(1)
        log(pr.line(extra=f"done={ds} | level={level} | N={N}"))
    hb.stop("done")
    df = pd.DataFrame(rows)
    if len(df) and "p" in df.columns:
        df["q_BH"] = bh_fdr(df["p"].values)
        if "p_partial" in df.columns:
            df["q_partial_BH"] = bh_fdr(df["p_partial"].values)
    return df


def make_forest_plot(meta_df, corr_df, out_png):
    """Proper forest plot: per-study CIs + pooled diamond."""
    mes_list = sorted(meta_df["MES"].dropna().unique())
    fig, axes = plt.subplots(1, len(mes_list),
                             figsize=(3.5 * len(mes_list),
                                      max(3.0, 0.5 * len(corr_df["dataset"].unique()))),
                             sharey=True, squeeze=False)
    axes = axes.ravel()
    datasets = sorted(corr_df["dataset"].dropna().unique())
    y_pos = {ds: i for i, ds in enumerate(reversed(datasets))}

    for ax_i, mes in enumerate(mes_list):
        ax = axes[ax_i]
        sub_meta = meta_df[meta_df["MES"] == mes].iloc[0] if len(meta_df[meta_df["MES"]==mes]) else None
        sub_corr = corr_df[(corr_df["MES"] == mes) & corr_df["r"].notna()]
        for _, row in sub_corr.iterrows():
            ds = row["dataset"]
            if ds not in y_pos: continue
            yi = y_pos[ds]
            ci_lo = row.get("ci_lo", np.nan)
            ci_hi = row.get("ci_hi", np.nan)
            xerr_lo = row["r"] - ci_lo if np.isfinite(ci_lo) else 0
            xerr_hi = ci_hi - row["r"] if np.isfinite(ci_hi) else 0
            color = PAL_DATASETS[ax_i % len(PAL_DATASETS)]
            ax.errorbar(row["r"], yi, xerr=[[xerr_lo], [xerr_hi]],
                        fmt="o", color=color, ecolor="black",
                        elinewidth=0.7, capsize=2, markersize=5)
            star = sig_stars(row.get("q_BH", row.get("p")))
            if star != "ns":
                ax.text(ci_hi + 0.02 if np.isfinite(ci_hi) else row["r"] + 0.05,
                        yi, star, va="center", fontsize=6, fontweight="bold")
        # Pooled diamond
        if sub_meta is not None and np.isfinite(sub_meta.get("pooled_r", np.nan)):
            pr = sub_meta["pooled_r"]
            plo = sub_meta.get("pooled_ci_lo", pr)
            phi = sub_meta.get("pooled_ci_hi", pr)
            dy = -1.0
            diamond_x = [plo, pr, phi, pr, plo]
            diamond_y = [dy, dy-0.3, dy, dy+0.3, dy]
            ax.fill(diamond_x, diamond_y, color="red", alpha=0.7)
            ax.text(phi + 0.02, dy, f"r={pr:.2f}", va="center", fontsize=6, color="red")
        ax.axvline(0, color="grey", linewidth=0.6, linestyle="--")
        ax.set_xlabel("Spearman ρ")
        ax.set_title(mes, fontsize=8, fontweight="bold")
        if ax_i == 0:
            ax.set_yticks(list(y_pos.values()))
            ax.set_yticklabels(list(y_pos.keys()))
    fig.suptitle("Meta-analysis: corr(GR, MES) — per-study CIs + pooled (◆)",
                 fontsize=9, y=1.02)
    fig.tight_layout()
    save_fig(out_png, fig)


def make_heatmap(corr_df, out_png):
    """Annotated heatmap with significance markers."""
    df = corr_df[corr_df["MES"].notna()].copy()
    piv_r = df.pivot_table(index="dataset", columns="MES", values="r", aggfunc="mean")
    piv_q = df.pivot_table(index="dataset", columns="MES", values="q_BH", aggfunc="mean") if "q_BH" in df.columns else None
    ordered = [f"MES0{i}" for i in range(1, 9)]
    cols = [c for c in ordered if c in piv_r.columns] + [c for c in piv_r.columns if c not in ordered]
    piv_r = piv_r[cols]
    if piv_q is not None:
        piv_q = piv_q.reindex(columns=cols, index=piv_r.index)
    fig, ax = plt.subplots(figsize=(10, max(2.5, 0.5 * piv_r.shape[0])))
    if HAS_SNS:
        sns.heatmap(piv_r, ax=ax, annot=True, fmt=".2f", cmap="RdBu_r", center=0,
                    linewidths=0.5, linecolor="white",
                    cbar_kws={"label": "Spearman ρ", "shrink": 0.8})
        # Significance markers
        if piv_q is not None:
            for i, ds in enumerate(piv_r.index):
                for j, mes in enumerate(piv_r.columns):
                    q = piv_q.loc[ds, mes] if ds in piv_q.index and mes in piv_q.columns else np.nan
                    s = sig_stars(q)
                    if s != "ns":
                        ax.text(j + 0.5, i + 0.82, s, ha="center", va="center", fontsize=6)
    else:
        im = ax.imshow(piv_r.values, aspect="auto", cmap="RdBu_r")
        ax.set_xticks(np.arange(piv_r.shape[1])); ax.set_xticklabels(piv_r.columns, rotation=45, ha="right")
        ax.set_yticks(np.arange(piv_r.shape[0])); ax.set_yticklabels(piv_r.index)
        plt.colorbar(im, ax=ax, label="r", shrink=0.8)
    ax.set_title("Per-dataset corr(GR, MES) — donor-level where available")
    save_fig(out_png, fig)


def run_partA(cohorts):
    log("=" * 50)
    log("PART A: Correlations + Meta-analysis")
    corr_df = compute_mes_corr_table(cohorts)
    save_xlsx(corr_df, TAB_DIR / "Supplementary_Table5_PerDatasetCorr.xlsx")

    if corr_df.shape[0] == 0 or "MES" not in corr_df.columns:
        log("[WARN] Part A: 0 correlation rows.")
        meta_df = pd.DataFrame()
        i2_df   = pd.DataFrame()
        lodo_df = pd.DataFrame()
        return corr_df, meta_df, i2_df, lodo_df

    mes_list = sorted(corr_df["MES"].dropna().unique().tolist())

    # Meta-analysis per MES
    meta_rows = []
    for mes in mes_list:
        sub = corr_df[corr_df["MES"] == mes].copy()
        out = random_effects_meta(sub["r"].tolist(), sub["N"].tolist())
        meta_rows.append({"MES": mes, **out})
    meta_df = pd.DataFrame(meta_rows)

    # Figures
    if meta_df.shape[0] > 0 and corr_df.shape[0] > 0:
        make_forest_plot(meta_df, corr_df, FIG_DIR / "Main_Fig5A_Forest_Meta.png")
    if corr_df.shape[0] > 0:
        make_heatmap(corr_df, FIG_DIR / "Main_Fig5B_PooledHeatmap.png")

    # I2 table + figure
    i2_df = meta_df[["MES", "I2", "tau2", "Q", "Q_p", "k"]].copy() if len(meta_df) else pd.DataFrame()
    save_xlsx(meta_df, TAB_DIR / "Main_Table5.xlsx")
    save_xlsx(i2_df,   TAB_DIR / "Supplementary_Table5_I2.xlsx")

    if len(i2_df) > 0:
        fig, ax = plt.subplots(figsize=(6, 3))
        colors = ["#D6604D" if i2 > 50 else "#FDAE6B" if i2 > 25 else "#66C2A5"
                  for i2 in i2_df["I2"].values]
        ax.bar(i2_df["MES"], i2_df["I2"], color=colors, edgecolor="black", linewidth=0.5)
        for i, (_, row) in enumerate(i2_df.iterrows()):
            s = sig_stars(row.get("Q_p"))
            if s != "ns":
                ax.text(i, row["I2"] + 1, s, ha="center", fontsize=7, fontweight="bold")
        ax.axhline(25, color="grey", linewidth=0.5, linestyle=":")
        ax.axhline(50, color="grey", linewidth=0.5, linestyle=":")
        ax.axhline(75, color="grey", linewidth=0.5, linestyle=":")
        ax.set_ylabel("I² (%)")
        ax.set_title("Heterogeneity by MES (Q-test significance annotated)")
        ax.set_xticklabels(i2_df["MES"], rotation=45, ha="right")
        if HAS_SNS: sns.despine(ax=ax)
        save_fig(FIG_DIR / "Supplementary_Fig5B_I2.png", fig)

    # LODO
    lodo = []
    for mes in mes_list:
        sub = corr_df[corr_df["MES"] == mes].copy()
        datasets = sorted(sub["dataset"].dropna().unique().tolist())
        for leave in datasets:
            sub2 = sub[sub["dataset"] != leave]
            out = random_effects_meta(sub2["r"].tolist(), sub2["N"].tolist())
            lodo.append({"MES": mes, "leave_out": leave,
                         "pooled_r": out["pooled_r"],
                         "pooled_ci_lo": out["pooled_ci_lo"],
                         "pooled_ci_hi": out["pooled_ci_hi"],
                         "I2": out["I2"], "k": out["k"]})
    lodo_df = pd.DataFrame(lodo)
    save_xlsx(lodo_df, TAB_DIR / "Supplementary_Table5_LODO.xlsx")

    if len(lodo_df) > 0:
        fig, ax = plt.subplots(figsize=(7, 3.5))
        for mes in mes_list:
            sub = lodo_df[lodo_df["MES"] == mes]
            if sub.empty: continue
            ax.errorbar(sub["leave_out"], sub["pooled_r"],
                        yerr=[sub["pooled_r"] - sub["pooled_ci_lo"],
                              sub["pooled_ci_hi"] - sub["pooled_r"]],
                        fmt="o-", label=mes, markersize=4, linewidth=0.8,
                        capsize=2, alpha=0.8)
        ax.axhline(0, color="grey", linewidth=0.5, linestyle="--")
        ax.set_ylabel("LODO pooled ρ (95% CI)")
        ax.set_title("Leave-one-dataset-out sensitivity")
        ax.legend(fontsize=6, ncol=4, loc="upper right")
        plt.xticks(rotation=45, ha="right")
        if HAS_SNS: sns.despine(ax=ax)
        save_fig(FIG_DIR / "Supplementary_Fig5A_LODO.png", fig)

    return corr_df, meta_df, i2_df, lodo_df


# PART B: Robustness (raw vs adjusted; donor vs cell)

def run_partB(corr_df, cohorts):
    log("=" * 50)
    log("PART B: Robustness")

    # Step 2: recompute cell-level for donor datasets
    hb2 = Heartbeat("PartB cell recompute", every_s=30)
    hb2.start("start")
    donor_datasets = [(ds, detect_donor_col(a)) for ds, a in cohorts.items() if detect_donor_col(a)]
    pr = Progress("PartB cell recompute", total=len(donor_datasets))
    out_rows = []
    for ds, donor_col in donor_datasets:
        a = cohorts[ds]; obs = a.obs
        gr_col   = detect_gr_col(obs);  mes_cols = detect_mes_cols(obs)
        if gr_col is None or not mes_cols: pr.tick(1); continue
        pos = stratified_subsample_by_donor_positions(obs, donor_col, MAX_CELLS_FOR_CELL_RECOMPUTE, 42)
        aa = a[pos, :].copy()
        for mes in mes_cols:
            raw = corr_with_pvalue(aa.obs[gr_col], aa.obs[mes], method="spearman")
            out_rows.append({
                "dataset": ds, "donor_col": donor_col,
                "N_used": int(aa.n_obs),
                "donors": int(pd.Series(aa.obs[donor_col].astype(str)).nunique()),
                "MES": canonical_mes_label(mes),
                "r_cell_recompute": raw["r"], "p_cell": raw["p"],
                "ci_lo_cell": raw["ci_lo"], "ci_hi_cell": raw["ci_hi"],
            })
        pr.tick(1)
        log(pr.line(extra=f"done={ds}"))
        del aa; gc.collect()
    hb2.stop("done")
    df_cell_recompute = pd.DataFrame(out_rows)
    save_xlsx(df_cell_recompute, TAB_DIR / "Supplementary_Table5_DonorDatasets_CellRecompute.xlsx")

    # Step 1: raw vs adjusted figure
    hb = Heartbeat("PartB figs", every_s=30)
    hb.start("start")
    df = corr_df.copy()
    if df.shape[0] == 0 or "MES" not in df.columns:
        save_xlsx(pd.DataFrame(), TAB_DIR / "Supplementary_Table5_RawVsAdjusted.xlsx")
        fig, ax = plt.subplots(figsize=(4, 2))
        ax.text(0.05, 0.5, "No data", fontsize=10); ax.axis("off")
        save_fig(FIG_DIR / "Main_Fig5C_RawVsAdjusted.png", fig)
    else:
        df = df[df["MES"].notna()]
        piv = df.pivot_table(index=["dataset", "MES"], columns="level", values="r", aggfunc="mean").reset_index()
        if "cell"  not in piv.columns: piv["cell"]  = np.nan
        if "donor" not in piv.columns: piv["donor"] = np.nan
        piv["adjusted"] = piv["donor"].where(piv["donor"].notna(), piv["cell"])
        # Also include partial r where available
        piv_partial = df.pivot_table(index=["dataset","MES"], values="r_partial", aggfunc="mean").reset_index()
        if len(piv_partial):
            piv = piv.merge(piv_partial, on=["dataset","MES"], how="left")
        summ2 = piv.groupby("dataset")[["cell","adjusted"]].mean(numeric_only=True).reset_index()
        if "r_partial" in piv.columns:
            summ_partial = piv.groupby("dataset")["r_partial"].mean().reset_index()
            summ2 = summ2.merge(summ_partial, on="dataset", how="left")

        fig, ax = plt.subplots(figsize=(6.5, max(2.5, 0.45 * summ2.shape[0])))
        y = np.arange(summ2.shape[0])[::-1]
        ax.scatter(summ2["cell"],     y, label="raw (cell)",     marker="s", s=40, zorder=3)
        ax.scatter(summ2["adjusted"], y, label="adjusted (donor)", marker="o", s=50, zorder=3)
        if "r_partial" in summ2.columns:
            ax.scatter(summ2["r_partial"], y + 0.15, label="partial (covar-adj)",
                       marker="D", s=35, color="#66C2A5", zorder=3)
        ax.set_yticks(y); ax.set_yticklabels(summ2["dataset"])
        ax.axvline(0, linewidth=0.8, color="grey", linestyle="--")
        ax.set_xlabel("Mean ρ across MES")
        ax.set_title("Robustness: raw vs donor-adjusted vs covariate-adjusted")
        ax.legend(fontsize=7)
        if HAS_SNS: sns.despine(ax=ax)
        save_fig(FIG_DIR / "Main_Fig5C_RawVsAdjusted.png", fig)
        save_xlsx(summ2, TAB_DIR / "Supplementary_Table5_RawVsAdjusted.xlsx")

    # Donor vs cell overlap
    overlap = pd.DataFrame()
    if df.shape[0] > 0 and df_cell_recompute.shape[0] > 0:
        donor_only = df[df["level"]=="donor"][["dataset","MES","r","ci_lo","ci_hi"]].rename(
            columns={"r": "r_donor", "ci_lo": "ci_lo_donor", "ci_hi": "ci_hi_donor"})
        cell_only  = df_cell_recompute[["dataset","MES","r_cell_recompute","ci_lo_cell","ci_hi_cell"]].rename(
            columns={"r_cell_recompute": "r_cell"})
        overlap = donor_only.merge(cell_only, on=["dataset","MES"], how="inner")
    save_xlsx(overlap, TAB_DIR / "Supplementary_Table5_DonorVsCell_OverlapUsedForFig.xlsx")

    if overlap.shape[0] > 0:
        fig, ax = plt.subplots(figsize=(4.5, 4.5))
        ax.scatter(overlap["r_cell"], overlap["r_donor"], s=30, alpha=0.7,
                   color="#4393C3", edgecolors="black", linewidth=0.3)
        lo = float(np.nanmin([overlap["r_cell"].min(), overlap["r_donor"].min()])) - 0.1
        hi = float(np.nanmax([overlap["r_cell"].max(), overlap["r_donor"].max()])) + 0.1
        ax.plot([lo, hi], [lo, hi], color="grey", linewidth=0.8, linestyle="--")
        ax.axhline(0, linewidth=0.5, color="grey"); ax.axvline(0, linewidth=0.5, color="grey")
        # Correlation between donor and cell estimates
        cc = corr_with_pvalue(overlap["r_cell"], overlap["r_donor"], method="pearson")
        ax.set_title(f"Donor vs Cell ρ | concordance r={cc['r']:.2f}, p={cc['p']:.2e}")
        ax.set_xlabel("Cell-level ρ (recomputed)")
        ax.set_ylabel("Donor-level ρ (Part A)")
        if HAS_SNS: sns.despine(ax=ax)
        save_fig(FIG_DIR / "Supplementary_Fig5C_DonorVsCell.png", fig)
    else:
        fig, ax = plt.subplots(figsize=(4, 2))
        ax.text(0.05, 0.5, "No donor+cell overlap", fontsize=10); ax.axis("off")
        save_fig(FIG_DIR / "Supplementary_Fig5C_DonorVsCell.png", fig)
    hb.stop("done")
    return df_cell_recompute, overlap


# PART C: GR stratified DE + volcano + cross-dataset overlap

def pseudobulk_by_donor(adata, donor_col, layer_counts="counts",
                        use_log_norm=True, progress_label=None):
    if donor_col not in adata.obs.columns: return None
    donors = adata.obs[donor_col].astype(str)
    uniq   = donors.unique().tolist()
    if len(uniq) < 4: return None
    Xsrc = adata.layers[layer_counts] if layer_counts in adata.layers else adata.X
    rows = []
    for i, d in enumerate(uniq, 1):
        pos = np.flatnonzero(donors.values == d)
        if pos.size == 0: continue
        v = np.asarray(Xsrc[pos, :].sum(axis=0)).ravel()
        rows.append(v)
        if i == 1 or i == len(uniq) or i % max(1, len(uniq)//4) == 0:
            elapsed = time.time()
            log(f"PROG PartC pseudobulk: {i}/{len(uniq)} | {progress_label or ''}")
    if not rows: return None
    Xpb = np.vstack(rows)
    pb = ad.AnnData(X=Xpb, obs=pd.DataFrame({donor_col: uniq}), var=adata.var.copy())
    pb.obs_names = pd.Index(uniq)
    gr_col = detect_gr_col(adata.obs)
    if gr_col:
        pb.obs[gr_col] = adata.obs.groupby(donors)[gr_col].mean(numeric_only=True).reindex(uniq).values
    if use_log_norm:
        sc.pp.normalize_total(pb, target_sum=1e6)
        sc.pp.log1p(pb)
    return pb


def maybe_run_enrichment(df_de, label):
    if not DO_ENRICHMENT or df_de is None or df_de.shape[0] == 0:
        return pd.DataFrame()
    try:
        import gseapy as gp
    except:
        return pd.DataFrame()
    sub = df_de.sort_values("padj").head(300)
    genes = sub["gene"].astype(str).tolist()
    if len(genes) < 20: return pd.DataFrame()
    try:
        enr = gp.enrichr(gene_list=genes, gene_sets=["GO_Biological_Process_2021"],
                         organism="Human", outdir=None, no_plot=True)
        res = enr.results.copy()
        res.insert(0, "label", label)
        return res
    except:
        return pd.DataFrame()


def make_volcano(df_de, ds_name, out_png):
    """Volcano plot with significance thresholds and non-overlapping labels.

    FIX: removed bbox from gene labels so adjustText measures true text
    extent and arrows anchor from edges not midpoints. Arrows are now
    light grey with a simple thin connector style.
    """
    df = df_de.copy()
    df["neg_log10_padj"] = -np.log10(df["padj"].clip(lower=1e-300))
    df["significant"] = (df["padj"] < 0.05) & (df["logFC"].abs() > 0.25)

    fig, ax = plt.subplots(figsize=(5.5, 4.5))

    # Non-significant
    ns = df[~df["significant"]]
    ax.scatter(ns["logFC"], ns["neg_log10_padj"], s=5, alpha=0.3, color="#CCCCCC", label="ns")

    # Up
    up = df[df["significant"] & (df["logFC"] > 0)]
    ax.scatter(up["logFC"], up["neg_log10_padj"], s=8, alpha=0.6, color="#D6604D", label=f"Up ({len(up)})")

    # Down
    dn = df[df["significant"] & (df["logFC"] < 0)]
    ax.scatter(dn["logFC"], dn["neg_log10_padj"], s=8, alpha=0.6, color="#4393C3", label=f"Down ({len(dn)})")

    ax.axhline(-np.log10(0.05), color="grey", linewidth=0.5, linestyle="--")
    ax.axvline(0.25,  color="grey", linewidth=0.5, linestyle=":")
    ax.axvline(-0.25, color="grey", linewidth=0.5, linestyle=":")

    # Label top genes — NO bbox, light grey arrows, simple connector
    top = df.nlargest(10, "neg_log10_padj")
    texts = []
    for _, row in top.iterrows():
        texts.append(ax.text(row["logFC"], row["neg_log10_padj"],
                             row["gene"], fontsize=5.5, fontweight="bold"))
    adjust_text(texts, ax=ax,
                force_points=(0.6, 0.9),
                force_text=(0.6, 0.9),
                expand_points=(1.5, 1.5),
                expand_text=(1.2, 1.2),
                arrowprops=dict(arrowstyle="-",
                                color="#999999",
                                lw=0.5))

    ax.set_xlabel("log₂FC (GR_high vs GR_low)")
    ax.set_ylabel("-log₁₀(padj)")
    ax.set_title(f"{ds_name}: GR-stratified DE")
    ax.legend(fontsize=6, loc="upper left")
    if HAS_SNS: sns.despine(ax=ax)
    save_fig(out_png, fig)


def compute_de_overlap(de_tables):
    """Cross-dataset gene overlap analysis."""
    if len(de_tables) < 2: return pd.DataFrame()
    sig_genes = {}
    for ds, df in de_tables.items():
        sig = df[(df["padj"] < 0.05) & (df["logFC"].abs() > 0.25)]["gene"].tolist()
        if sig: sig_genes[ds] = set(sig)
    if len(sig_genes) < 2: return pd.DataFrame()
    rows = []
    # Pairwise overlaps
    for (ds1, g1), (ds2, g2) in combinations(sig_genes.items(), 2):
        overlap = g1 & g2
        union   = g1 | g2
        jaccard = len(overlap) / len(union) if union else 0
        rows.append({
            "ds1": ds1, "ds2": ds2,
            "n_sig_ds1": len(g1), "n_sig_ds2": len(g2),
            "n_overlap": len(overlap), "n_union": len(union),
            "jaccard": jaccard,
            "overlap_genes": ",".join(sorted(overlap)[:50])
        })
    # Genes significant in >=2 datasets
    from collections import Counter
    gene_counts = Counter()
    for gs in sig_genes.values():
        gene_counts.update(gs)
    replicated = {g: c for g, c in gene_counts.items() if c >= 2}
    for g, c in sorted(replicated.items(), key=lambda x: -x[1])[:50]:
        rows.append({
            "ds1": "REPLICATED", "ds2": f"in_{c}_datasets",
            "n_overlap": c, "overlap_genes": g
        })
    return pd.DataFrame(rows)


def run_partC(cohorts):
    log("=" * 50)
    log("PART C: GR-stratified DE + Volcano + Overlap")
    hb = Heartbeat("PartC DE", every_s=30)
    hb.start("start")
    pr_ds = Progress("PartC datasets", total=len(cohorts))
    de_tables  = {}
    eff_rows   = []
    enrich_rows = []

    for ds, a in cohorts.items():
        hb.update(f"starting {ds}")
        log(f"DE start: {ds} (cells={a.n_obs})")
        donor_col = detect_donor_col(a)
        gr_col    = detect_gr_col(a.obs)
        if gr_col is None:
            log(f"[WARN] {ds}: no GR column; skipping.")
            pr_ds.tick(1); continue
        use_pseudobulk = donor_col is not None
        chosen_level   = None

        # Pseudobulk path
        if use_pseudobulk:
            pb = pseudobulk_by_donor(a, donor_col=donor_col, progress_label=ds)
            if pb is None:
                use_pseudobulk = False
            else:
                add_gr_group(pb, gr_col=gr_col, out_col="GR_group")
                df_de = de_rank_genes_groups(pb, "GR_group", "GR_high", "GR_low",
                                             method="wilcoxon", min_n_per_group=3)
                if df_de is not None and df_de.shape[0] > 0:
                    de_tables[ds] = df_de
                    chosen_level = "donor_pseudobulk"
                    # Volcano plot
                    make_volcano(df_de, ds, FIG_DIR / f"Supplementary_Fig5_Volcano_{ds}.png")
                    top = df_de.sort_values("padj").head(50).copy()
                    top["dataset"] = ds; top["level"] = chosen_level
                    eff_rows.append(top)
                    enr = maybe_run_enrichment(df_de, f"{ds}|{chosen_level}")
                    if enr is not None and enr.shape[0] > 0: enrich_rows.append(enr)
                    log(f"DE end: {ds} (level={chosen_level}, DE_rows={df_de.shape[0]})")
                else:
                    use_pseudobulk = False

        # Cell fallback
        if not use_pseudobulk:
            if donor_col and donor_col in a.obs.columns:
                pos = stratified_subsample_by_donor_positions(a.obs, donor_col, MAX_CELLS_FOR_CELL_DE, 42)
            else:
                if a.n_obs > MAX_CELLS_FOR_CELL_DE:
                    rng = np.random.default_rng(42)
                    pos = rng.choice(np.arange(a.n_obs, dtype=int), size=MAX_CELLS_FOR_CELL_DE, replace=False)
                else:
                    pos = np.arange(a.n_obs, dtype=int)
            aa = a[pos, :].copy()
            add_gr_group(aa, gr_col=gr_col, out_col="GR_group")
            df_de = de_rank_genes_groups(aa, "GR_group", "GR_high", "GR_low",
                                          method="wilcoxon", min_n_per_group=10)
            if df_de is not None and df_de.shape[0] > 0:
                de_tables[ds] = df_de
                chosen_level = "cell_level_subsampled"
                make_volcano(df_de, ds, FIG_DIR / f"Supplementary_Fig5_Volcano_{ds}.png")
                top = df_de.sort_values("padj").head(50).copy()
                top["dataset"] = ds; top["level"] = chosen_level
                eff_rows.append(top)
                enr = maybe_run_enrichment(df_de, f"{ds}|{chosen_level}")
                if enr is not None and enr.shape[0] > 0: enrich_rows.append(enr)
            else:
                chosen_level = "empty"
            del aa; gc.collect()
        pr_ds.tick(1)
        log(pr_ds.line(extra=f"done {ds} | level={chosen_level}"))
    hb.stop("done")

    df_gr_eff  = pd.concat(eff_rows, ignore_index=True) if eff_rows else pd.DataFrame()
    df_enrich  = pd.concat(enrich_rows, ignore_index=True) if enrich_rows else pd.DataFrame()

    # Cross-dataset overlap
    df_de_overlap = compute_de_overlap(de_tables)
    if len(df_de_overlap):
        log(f"DE overlap: {len(df_de_overlap)} rows")
    return de_tables, df_enrich, pd.DataFrame(), df_gr_eff, df_de_overlap


# Supplementary Table 5 index

def write_table5_index():
    rows = [
        {"artifact": "Main_Table5.xlsx",
         "role": "Pooled random-effects meta: pooled_r, 95% CI, I², tau², Q, Q_p per MES.",
         "produced_by": "Part A"},
        {"artifact": "Supplementary_Table5_PerDatasetCorr.xlsx",
         "role": "Per-dataset ρ(GR, MES) with p-values, bootstrap CIs, partial ρ, BH-FDR q.",
         "produced_by": "Part A"},
        {"artifact": "Supplementary_Table5_I2.xlsx",
         "role": "Heterogeneity: I², tau², Q, Q_p per MES.",
         "produced_by": "Part A"},
        {"artifact": "Supplementary_Table5_LODO.xlsx",
         "role": "LODO sensitivity with pooled CIs.",
         "produced_by": "Part A"},
        {"artifact": "Supplementary_Table5_RawVsAdjusted.xlsx",
         "role": "Robustness: raw cell vs donor vs covariate-adjusted mean ρ.",
         "produced_by": "Part B"},
        {"artifact": "Supplementary_Table5_DonorDatasets_CellRecompute.xlsx",
         "role": "Recomputed cell-level ρ for donor datasets with p-values + CIs.",
         "produced_by": "Part B"},
        {"artifact": "Supplementary_Table5_DonorVsCell_OverlapUsedForFig.xlsx",
         "role": "Donor vs cell overlap with CIs.",
         "produced_by": "Part B"},
        {"artifact": "Supplementary_Table5_GR_DE_AllDatasets.xlsx",
         "role": "Full DE results (GR_high vs GR_low) per dataset.",
         "produced_by": "Part C"},
        {"artifact": "Supplementary_Table5_GR_DE_Top50_PerDataset.xlsx",
         "role": "Top 50 DE genes per dataset.",
         "produced_by": "Part C"},
        {"artifact": "Supplementary_Table5_GR_DE_Overlap.xlsx",
         "role": "Cross-dataset DE gene overlap + Jaccard + replicated genes.",
         "produced_by": "Part C"},
        {"artifact": "Supplementary_Table5_GR_Enrichment.xlsx",
         "role": "Enrichment results (if DO_ENRICHMENT=True).",
         "produced_by": "Part C"},
        {"artifact": "NB5_DIAG.xlsx",
         "role": "Diagnostics.",
         "produced_by": "Diagnostics"},
    ]
    save_xlsx(pd.DataFrame(rows), TAB_DIR / "Supplementary_Table5_Index.xlsx")


# MAIN RUNNER

log("=" * 60)
log("NB5 START")
log("=" * 60)

try:
    diag_df = diagnostics()
    for _, r in diag_df.iterrows():
        if r["check"] in ("scored matches", "detected GR col", "detected MES cols",
                          "detected donor col", "detected sex col", "detected PMI col"):
            log(f"[DIAG] {r['check']}: {r['value']}")
except Exception as e:
    log(f"[WARN] diagnostics failed: {repr(e)}")

if RUN_DIAGNOSTICS_ONLY:
    log("RUN_DIAGNOSTICS_ONLY=True -> stopping.")
else:
    tables_nb4 = load_nb4_tables()
    cohorts    = load_scored_cohorts()

    # --- Part A ---
    log(f"START Part A | RSS={_rss_gb():.2f} GB")
    corr_df, meta_df, i2_df, lodo_df = run_partA(cohorts)
    log(f"END Part A | RSS={_rss_gb():.2f} GB")

    # --- Part B ---
    log(f"START Part B | RSS={_rss_gb():.2f} GB")
    df_cell_recompute, df_overlap_B = run_partB(corr_df, cohorts)
    log(f"END Part B | RSS={_rss_gb():.2f} GB")

    # --- Part C ---
    log(f"START Part C | RSS={_rss_gb():.2f} GB")
    de_tables, df_enrich, df_path_meta, df_gr_eff, df_de_overlap = run_partC(cohorts)
    log(f"END Part C | RSS={_rss_gb():.2f} GB")

    # Save Part C outputs (SAME names)
    if len(de_tables) > 0:
        all_de = []
        for ds, df in de_tables.items():
            tmp = df.copy(); tmp.insert(0, "dataset", ds); all_de.append(tmp)
        save_xlsx(pd.concat(all_de, ignore_index=True),
                  TAB_DIR / "Supplementary_Table5_GR_DE_AllDatasets.xlsx")
    else:
        save_xlsx(pd.DataFrame(), TAB_DIR / "Supplementary_Table5_GR_DE_AllDatasets.xlsx")

    if df_gr_eff is not None and df_gr_eff.shape[0] > 0:
        save_xlsx(df_gr_eff, TAB_DIR / "Supplementary_Table5_GR_DE_Top50_PerDataset.xlsx")
    else:
        save_xlsx(pd.DataFrame(), TAB_DIR / "Supplementary_Table5_GR_DE_Top50_PerDataset.xlsx")

    if df_de_overlap is not None and df_de_overlap.shape[0] > 0:
        save_xlsx(df_de_overlap, TAB_DIR / "Supplementary_Table5_GR_DE_Overlap.xlsx")

    if df_enrich is not None and df_enrich.shape[0] > 0:
        save_xlsx(df_enrich, TAB_DIR / "Supplementary_Table5_GR_Enrichment.xlsx")

    write_table5_index()

    log("=" * 60)
    log("NB5 COMPLETE")
    log("=" * 60)