# NB3 — Consensus NMF: Myeloid Education Signature (MES) Derivation (K=8)

Consensus NMF training on Tabula Sapiens thymus (K=8), rank-sweep optimization, cross-cohort module matching (Hungarian algorithm), and functional axis enrichment for eight MES modules.

**Paper:** Zafar SA, Qin W. *Thymus-Derived Myeloid Education Signatures Predict Microglial Tolerance Positioning and Are Modulated by Glucocorticoid Stress-Axis Activity.* Neuroimmunomodulation (2026).

> **Note:** Update the path variables in section 0 to match your local directory structure before running. Raw data can be obtained from the public repositories listed in Supplementary Table S1.


In [None]:
# Imports
import warnings, itertools
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import scipy.sparse as sp
import scipy.stats as stats
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns

from sklearn.decomposition import NMF
from sklearn.metrics import silhouette_score
from scipy.optimize import linear_sum_assignment
from scipy.cluster.hierarchy import linkage, cophenet
from scipy.spatial.distance import squareform, pdist

warnings.filterwarnings("ignore")
sc.settings.verbosity = 2
np.random.seed(42)

# 0) Paths 
BASE_DIR  = Path(".")  # <-- SET TO YOUR PROJECT ROOT
PROC_DIR  = BASE_DIR / "Process Data"
AIM1_DIR  = PROC_DIR / "aim1_thymus"
NEG_DIR   = PROC_DIR / "negctrl"
MANUSCRIPT_DIR = BASE_DIR / "outputs" / "manuscript".exists()
    else (BASE_DIR / "Manuscript Data")
)
FIG_DIR = MANUSCRIPT_DIR / "Figures"
TAB_DIR = MANUSCRIPT_DIR / "Tables"
for d in [AIM1_DIR, NEG_DIR, FIG_DIR, TAB_DIR]:
    d.mkdir(parents=True, exist_ok=True)

TS_QC    = AIM1_DIR / "TS_Thymus_filtered__qc.h5ad"
LAV_QC   = AIM1_DIR / "GSE144870_Lavaert__qc.h5ad"
LE_QC    = AIM1_DIR / "GSE139042_Le__qc.h5ad"
BLOOD_QC = NEG_DIR  / "TS_Blood_NegCtrl_Myeloid__qc.h5ad"

assert TS_QC.exists()  and LAV_QC.exists() and LE_QC.exists(), \
    "NB2 thymus outputs missing. Re-run NB2 first."
assert BLOOD_QC.exists(), "NB2 negctrl output missing. Re-run NB2 first."

# 1) Figure style
plt.rcParams.update({
    "font.family":       "Arial",
    "font.size":         8,
    "axes.labelsize":    8,
    "axes.titlesize":    9,
    "xtick.labelsize":   7,
    "ytick.labelsize":   7,
    "legend.fontsize":   7,
    "figure.titlesize":  9,
    "axes.linewidth":    0.8,
})
FIG_DPI = 1200

# Colour palettes
PALETTE_MES  = sns.color_palette("Set2", 12)
PALETTE_DUAL = ["#3182bd", "#e6550d"]    # thymus vs blood
CMAP_RDBU    = plt.cm.RdBu_r
CMAP_VIRIDIS = plt.cm.viridis
CMAP_HEAT    = LinearSegmentedColormap.from_list(
    "custom_heat", ["#f7f7f7", "#fee08b", "#fc8d59", "#d73027", "#67001f"]
)

def save_fig(fig, fname: str, kind: str = "Main"):
    assert kind in ("Main", "Supplementary")
    out = FIG_DIR / f"{kind}_{fname}.png"
    fig.savefig(out, dpi=FIG_DPI, bbox_inches="tight")
    plt.close(fig)
    print(f"[SAVED FIG] {out}")

def save_excel(sheets: dict, fname: str, kind: str = "Main"):
    assert kind in ("Main", "Supplementary")
    out = TAB_DIR / f"{kind}_{fname}.xlsx"
    with pd.ExcelWriter(out, engine="openpyxl") as w:
        for sn, df in sheets.items():
            df.to_excel(w, index=False, sheet_name=str(sn)[:31])
    print(f"[SAVED TABLE] {out}")


def sanitize_for_write(adata: ad.AnnData) -> ad.AnnData:
    """Clean obs/var so h5ad write never fails on duplicate column / index."""
    for df_name in ["obs", "var"]:
        df = getattr(adata, df_name)
        if df.index.name is not None:
            idx_name = df.index.name
            if idx_name in df.columns:
                if not pd.Series(df.index, index=df.index).equals(df[idx_name]):
                    df.rename(columns={idx_name: f"{idx_name}_col"}, inplace=True)
            df.index.name = None
        df.columns = df.columns.astype(str)
    adata.obs_names = adata.obs_names.astype(str)
    adata.var_names = adata.var_names.astype(str)
    return adata


# 2) Load QC datasets
print("")
print("NB3: Loading QC datasets")
print("")
ts    = sc.read_h5ad(TS_QC)
lav   = sc.read_h5ad(LAV_QC)
le    = sc.read_h5ad(LE_QC)
blood = sc.read_h5ad(BLOOD_QC)

def ensure_counts(a: ad.AnnData):
    if "counts" not in a.layers:
        a.layers["counts"] = a.X.copy()
    if not sp.issparse(a.layers["counts"]):
        a.layers["counts"] = sp.csr_matrix(a.layers["counts"])
    return a

ts    = ensure_counts(ts)
lav   = ensure_counts(lav)
le    = ensure_counts(le)
blood = ensure_counts(blood)

print(f"  TS:          {ts.n_obs:>7,} cells × {ts.n_vars:>6,} genes")
print(f"  Lavaert:     {lav.n_obs:>7,} cells × {lav.n_vars:>6,} genes")
print(f"  Le:          {le.n_obs:>7,} cells × {le.n_vars:>6,} genes")
print(f"  Blood NEGCTRL: {blood.n_obs:>7,} cells × {blood.n_vars:>6,} genes")

# 3) Axis priors
AXIS = {
    "chemokine": [
        "CXCL12","CXCR4","CCL19","CCR7","CCL21","CXCL8","CXCL10",
        "CXCR3","CCL5","CCR5","CXCL9","CXCL11","CCR2","CCL2",
    ],
    "ecm_adhesion": [
        "FN1","ITGA5","ITGB1","ITGA4","ITGB7","LAMC1","LAMA4","LAMB1",
        "COL1A1","COL1A2","COL3A1","VCAN","ICAM1","VCAM1","LGALS3",
    ],
    "guidance": [
        "SEMA3A","SEMA3C","SEMA4D","SEMA7A","NRP1","NRP2","EPHA2",
        "EPHA4","EPHB2","EFNA1","EFNA5","EFNB2","SLIT2","ROBO1","ROBO2",
    ],
    "tolerance_assoc": [
        "FOXP3","IL2RA","CTLA4","IKZF2","TNFRSF18","LAG3","TIGIT",
        "ICOS","TNFRSF4","IL10","TGFB1","BATF","IRF4",
    ],
}
AXIS = {k: [g.upper() for g in v] for k, v in AXIS.items()}

# 4) TS training gene universe (NO LEAKAGE)
print("\n=== Preprocess TS (training only — no leakage) ===")
ts_train = ts.copy()
ts_train.X = ts_train.layers["counts"].copy()
sc.pp.normalize_total(ts_train, target_sum=1e4)
sc.pp.log1p(ts_train)
sc.pp.highly_variable_genes(ts_train, n_top_genes=2500, flavor="seurat_v3")

hvg_ts       = ts_train.var_names[ts_train.var["highly_variable"]].tolist()
ts_var_upper = pd.Index([str(v).upper() for v in ts_train.var_names])
axis_all     = {g for lst in AXIS.values() for g in lst}
axis_present = [ts_train.var_names[i]
                for i, u in enumerate(ts_var_upper) if u in axis_all]
GENE_UNIVERSE = list(dict.fromkeys(hvg_ts + axis_present))
GENE_UNI_UP   = [str(g).strip().upper() for g in GENE_UNIVERSE]

print(f"  TS HVG: {len(hvg_ts)} | Axis present: {len(axis_present)} "
      f"| Training gene universe: {len(GENE_UNIVERSE)}")

# 5) Gene ID harmonisation + diagnostic
def norm_sym(s: str) -> str:
    return str(s).strip().upper()

def build_symbol_index(adata: ad.AnnData) -> dict:
    idx = {}
    for v in adata.var_names:
        idx[norm_sym(v)] = v
    candidates = []
    for c in adata.var.columns:
        cn = c.lower()
        if any(k in cn for k in ["symbol", "gene", "feature", "name"]):
            candidates.append(c)
    for c in candidates:
        vals = adata.var[c].astype(str).values
        for vname, sym in zip(adata.var_names, vals):
            k = norm_sym(sym)
            if k and (k not in idx):
                idx[k] = vname
    return idx

def gene_overlap_report(genes_ref: list, adata: ad.AnnData, label: str) -> dict:
    idx = build_symbol_index(adata)
    present = sum(1 for g in genes_ref if norm_sym(g) in idx)
    sample_var = list(adata.var_names[:5])
    ensembl_like = sum(1 for x in sample_var if str(x).upper().startswith("ENSG"))
    return {
        "dataset": label,
        "n_overlap": int(present),
        "pct_overlap": round(float(present / max(len(genes_ref), 1) * 100.0), 2),
        "varnames_look_ensembl": int(ensembl_like),
        "var_columns_with_gene": ", ".join(
            [c for c in adata.var.columns
             if any(k in c.lower() for k in ["symbol","gene","feature","name"])][:8]
        ),
        "varname_head": "; ".join([str(x) for x in sample_var]),
    }

df_diag = pd.DataFrame([
    gene_overlap_report(GENE_UNIVERSE, lav,   "Lavaert"),
    gene_overlap_report(GENE_UNIVERSE, le,    "Le"),
    gene_overlap_report(GENE_UNIVERSE, blood, "Blood_NEGCTRL"),
])
print("\n[DIAGNOSTIC] Gene overlap with TS training universe:")
print(df_diag.to_string(index=False))

# 6) CORE NMF UTILITIES — consensus, cophenetic, projection

def _X(a: ad.AnnData, genes: list):
    X = a[:, genes].X
    return X.tocsr() if sp.issparse(X) else sp.csr_matrix(X)


def _subsample(n_obs: int, fit_n: int, rng: np.random.RandomState):
    if n_obs > fit_n:
        return rng.choice(n_obs, size=fit_n, replace=False)
    return np.arange(n_obs)


# single NMF fit
def fit_nmf(X: sp.csr_matrix, k: int, seed: int = 42, max_iter: int = 800):
    model = NMF(n_components=k, init="nndsvda", random_state=seed, max_iter=max_iter)
    W = model.fit_transform(X)
    H = model.components_
    return model, W, H, model.reconstruction_err_


# Consensus NMF (multiple seeds → connectivity → cophenetic)
def consensus_nmf(X: sp.csr_matrix, k: int, n_runs: int = 20,
                  max_iter: int = 800, base_seed: int = 42):
    """
    Run NMF `n_runs` times with different seeds.
    Returns: best (W, H, model), consensus_matrix, cophenetic_corr, all_errors
    """
    n = X.shape[0]
    connectivity_sum = np.zeros((n, n), dtype=np.float32)
    best_err  = np.inf
    best_run  = None
    all_errors = []

    for r in range(n_runs):
        seed = base_seed + r
        model, W, H, err = fit_nmf(X, k, seed=seed, max_iter=max_iter)
        all_errors.append(err)
        if err < best_err:
            best_err = err
            best_run = (model, W, H)
        # Connectivity: cells assigned to same component ⇒ 1
        labels = np.argmax(W, axis=1)
        conn   = (labels[:, None] == labels[None, :]).astype(np.float32)
        connectivity_sum += conn

    C = connectivity_sum / n_runs            # consensus matrix
    # Cophenetic correlation
    dist = 1.0 - C
    np.fill_diagonal(dist, 0.0)
    dist = np.clip(dist, 0, None)
    dist = (dist + dist.T) / 2.0            # ensure symmetry
    try:
        condensed = squareform(dist, checks=False)
        Z = linkage(condensed, method="average")
        coph_corr, _ = cophenet(Z, condensed)
    except Exception:
        coph_corr = np.nan

    return best_run, C, float(coph_corr), all_errors


# Projection helper
def project_with_ts_nmf(nmf_model, adata_counts: ad.AnnData,
                        genes_ref: list) -> tuple:
    idx_map = build_symbol_index(adata_counts)
    present_mask = np.zeros(len(genes_ref), dtype=bool)
    actual_vars  = []
    for i, g in enumerate(genes_ref):
        key = norm_sym(g)
        if key in idx_map:
            present_mask[i] = True
            actual_vars.append(idx_map[key])
        else:
            actual_vars.append(None)

    n_present = int(present_mask.sum())
    n_missing = int((~present_mask).sum())
    if n_present == 0:
        return None, {"n_present": 0, "n_missing": n_missing, "note": "ZERO OVERLAP"}

    present_vars = [v for v in actual_vars if v is not None]
    tmp = adata_counts[:, present_vars].copy()
    tmp.X = tmp.layers["counts"].copy()
    sc.pp.normalize_total(tmp, target_sum=1e4)
    sc.pp.log1p(tmp)
    Xp = tmp.X.tocsr() if sp.issparse(tmp.X) else sp.csr_matrix(tmp.X)

    ref_positions = np.where(present_mask)[0].astype(np.int32)
    Xcoo     = Xp.tocoo()
    new_cols = ref_positions[Xcoo.col]
    Xref     = sp.coo_matrix(
        (Xcoo.data, (Xcoo.row, new_cols)),
        shape=(tmp.n_obs, len(genes_ref))
    ).tocsr()

    # Ensure dtype matches NMF components (sklearn requires identical dtypes)
    target_dtype = nmf_model.components_.dtype
    if Xref.dtype != target_dtype:
        Xref = Xref.astype(target_dtype)
    W = nmf_model.transform(Xref)
    return W, {"n_present": n_present, "n_missing": n_missing}


# Component-level correlation + Hungarian matching
def corr_components(H_a, H_b):
    """Row-wise Pearson between two H matrices (k × genes)."""
    A = (H_a - H_a.mean(axis=1, keepdims=True)) / (H_a.std(axis=1, keepdims=True) + 1e-12)
    B = (H_b - H_b.mean(axis=1, keepdims=True)) / (H_b.std(axis=1, keepdims=True) + 1e-12)
    return (A @ B.T) / A.shape[1]

def hungarian_match(C):
    r, c = linear_sum_assignment(-C)
    return r, c, C[r, c]


# Refit NMF on external dataset
def fit_nmf_on_dataset(ad_log: ad.AnnData, genes_actual: list,
                       k: int, seed: int = 42, fit_n: int = 30000,
                       max_iter: int = 800):
    rng = np.random.RandomState(seed)
    idx = _subsample(ad_log.n_obs, fit_n, rng)
    ad_fit = ad_log[idx, genes_actual].copy()
    model = NMF(n_components=k, init="nndsvda", random_state=seed,
                max_iter=max_iter)
    model.fit(_X(ad_fit, genes_actual))
    return model.components_


# 7) STATISTICAL UTILITIES

def benjamini_hochberg(pvals: np.ndarray) -> np.ndarray:
    """Benjamini-Hochberg FDR correction."""
    n = len(pvals)
    order = np.argsort(pvals)
    ranked = np.empty_like(pvals)
    ranked[order] = np.arange(1, n + 1)
    adj = np.minimum(1.0, pvals * n / ranked)
    # enforce monotonicity
    for i in reversed(range(n - 1)):
        adj[order[i]] = min(adj[order[i]], adj[order[i + 1]])
    return adj


def cohens_d(a: np.ndarray, b: np.ndarray) -> float:
    """Cohen's d effect size."""
    na, nb = len(a), len(b)
    if na < 2 or nb < 2:
        return 0.0
    pooled = np.sqrt(((na - 1) * a.var(ddof=1) + (nb - 1) * b.var(ddof=1)) /
                     (na + nb - 2))
    return float((a.mean() - b.mean()) / (pooled + 1e-12))


def bootstrap_ci(arr: np.ndarray, stat_fn=np.mean, n_boot: int = 2000,
                 alpha: float = 0.05, seed: int = 42) -> tuple:
    """Bootstrap confidence interval for `stat_fn`."""
    rng = np.random.RandomState(seed)
    boot_stats = np.array([
        stat_fn(rng.choice(arr, size=len(arr), replace=True))
        for _ in range(n_boot)
    ])
    lo = np.percentile(boot_stats, 100 * alpha / 2)
    hi = np.percentile(boot_stats, 100 * (1 - alpha / 2))
    return float(stat_fn(arr)), float(lo), float(hi)


def permutation_test_corr(H_ref, H_tgt, n_perm: int = 500, seed: int = 42):
    """
    Permutation null for mean matched Hungarian correlation.
    Permutes gene columns of H_tgt to break gene-level correspondence.
    Returns observed mean, null distribution, p-value.
    """
    C_obs = corr_components(H_ref, H_tgt)
    _, _, vals_obs = hungarian_match(C_obs)
    obs_mean = float(np.mean(vals_obs))

    rng = np.random.RandomState(seed)
    null_means = np.empty(n_perm)
    n_genes = H_tgt.shape[1]
    for p in range(n_perm):
        perm = rng.permutation(n_genes)
        C_p  = corr_components(H_ref, H_tgt[:, perm])
        _, _, vals_p = hungarian_match(C_p)
        null_means[p] = float(np.mean(vals_p))

    p_val = float((np.sum(null_means >= obs_mean) + 1) / (n_perm + 1))
    return obs_mean, null_means, p_val


def hypergeom_enrichment(module_genes: list, gene_set: list,
                         universe_size: int) -> dict:
    """One-sided hypergeometric test for overlap enrichment."""
    module_set = set(module_genes)
    gs_set     = set(gene_set)
    overlap    = module_set & gs_set
    k = len(overlap)
    M = universe_size
    n = len(gs_set)
    N = len(module_set)
    pval = float(stats.hypergeom.sf(k - 1, M, n, N)) if k > 0 else 1.0
    fold = (k / max(N, 1)) / (n / max(M, 1)) if n > 0 and M > 0 else 0.0
    return {
        "overlap": k,
        "module_size": N,
        "gene_set_size": n,
        "universe": M,
        "fold_enrichment": round(fold, 3),
        "pval_hypergeom": pval,
        "overlap_genes": ", ".join(sorted(overlap)),
    }


# 8) K-SWEEP WITH MULTI-CRITERIA MODEL SELECTION
print("")
print("K-SWEEP: consensus NMF, cophenetic, reconstruction error, silhouette")
print("")

K_LIST       = [5, 8, 10, 12]
N_CONSENSUS  = 20          # consensus runs per K
FIT_N_TS     = 60000
FIT_N_EXT    = 30000

# Prepare TS matrix once
rng_ks = np.random.RandomState(42)
idx_ks = _subsample(ts_train.n_obs, FIT_N_TS, rng_ks)
ts_fit = ts_train[idx_ks, GENE_UNIVERSE].copy()
X_ts_fit = _X(ts_fit, GENE_UNIVERSE)

# Pre-log external datasets (reused later)
lav_log = lav.copy(); lav_log.X = lav_log.layers["counts"].copy()
sc.pp.normalize_total(lav_log, target_sum=1e4); sc.pp.log1p(lav_log)
le_log = le.copy();  le_log.X = le_log.layers["counts"].copy()
sc.pp.normalize_total(le_log, target_sum=1e4);  sc.pp.log1p(le_log)

lav_idx  = build_symbol_index(lav)
le_idx   = build_symbol_index(le)
genes_lav        = [g for g in GENE_UNIVERSE if norm_sym(g) in lav_idx]
genes_le         = [g for g in GENE_UNIVERSE if norm_sym(g) in le_idx]
genes_lav_actual = [lav_idx[norm_sym(g)] for g in genes_lav]
genes_le_actual  = [le_idx[norm_sym(g)]  for g in genes_le]
pos_ref = {norm_sym(g): i for i, g in enumerate(GENE_UNIVERSE)}

sweep_rows = []
consensus_results = {}   # store for K=8 reuse

for K in K_LIST:
    print(f"\n  K={K} — consensus NMF ({N_CONSENSUS} runs) …")
    (model_k, W_k, H_k), C_mat, coph, errors = consensus_nmf(
        X_ts_fit, K, n_runs=N_CONSENSUS, base_seed=42
    )

    # Reconstruction error (mean across runs)
    mean_err = float(np.mean(errors))
    std_err  = float(np.std(errors))

    # Silhouette on dominant module assignment
    labels = np.argmax(W_k, axis=1)
    n_labels = len(np.unique(labels))
    if n_labels >= 2:
        sil = float(silhouette_score(
            X_ts_fit.toarray() if sp.issparse(X_ts_fit) else X_ts_fit,
            labels, metric="cosine", sample_size=min(10000, X_ts_fit.shape[0]),
            random_state=42
        ))
    else:
        sil = np.nan

    # External stability vs Lavaert + Le (refit + Hungarian)
    H_lavK = fit_nmf_on_dataset(lav_log, genes_lav_actual, K, seed=42)
    H_leK  = fit_nmf_on_dataset(le_log,  genes_le_actual,  K, seed=42)
    H_k_lav = H_k[:, [pos_ref[norm_sym(g)] for g in genes_lav]]
    H_k_le  = H_k[:, [pos_ref[norm_sym(g)] for g in genes_le]]
    C_lav_k = corr_components(H_k_lav, H_lavK)
    C_le_k  = corr_components(H_k_le,  H_leK)
    _, _, vals_lav_k = hungarian_match(C_lav_k)
    _, _, vals_le_k  = hungarian_match(C_le_k)

    sweep_rows.append({
        "K": K,
        "cophenetic": round(coph, 4),
        "silhouette_cosine": round(sil, 4) if not np.isnan(sil) else np.nan,
        "recon_error_mean": round(mean_err, 2),
        "recon_error_std": round(std_err, 2),
        "ext_corr_Lavaert_mean": round(float(np.mean(vals_lav_k)), 4),
        "ext_corr_Le_mean": round(float(np.mean(vals_le_k)), 4),
        "ext_corr_Lavaert_min": round(float(np.min(vals_lav_k)), 4),
        "ext_corr_Le_min": round(float(np.min(vals_le_k)), 4),
    })
    consensus_results[K] = {
        "model": model_k, "W": W_k, "H": H_k,
        "C_mat": C_mat, "coph": coph,
        "C_lav": C_lav_k, "C_le": C_le_k,
        "vals_lav": vals_lav_k, "vals_le": vals_le_k,
        "H_lav_refit": H_lavK, "H_le_refit": H_leK,
    }
    print(f"    cophenetic={coph:.4f}  silhouette={sil:.4f}  "
          f"ext_Lav={np.mean(vals_lav_k):.4f}  ext_Le={np.mean(vals_le_k):.4f}")

df_sweep = pd.DataFrame(sweep_rows)
print("\n[K-SWEEP SUMMARY]")
print(df_sweep.to_string(index=False))

# 9) MAIN FIT — CONSENSUS NMF K=8 + WITHIN-TS BOOTSTRAP STABILITY
K_MAIN   = 8
N_SEEDS  = 10       # within-TS multi-seed stability
mes_cols = [f"MES{k+1:02d}" for k in range(K_MAIN)]

print(f"\n{'='*72}")
print(f"MAIN FIT: Consensus NMF K={K_MAIN} on full TS training set")
print("")

# Full TS (not subsampled for final scores)
X_ts_full = _X(ts_train, GENE_UNIVERSE)
nmf_ts_best = consensus_results[K_MAIN]["model"]
H_ts        = consensus_results[K_MAIN]["H"]

# Re-transform full TS with the best model
if X_ts_full.dtype != nmf_ts_best.components_.dtype:
    X_ts_full = X_ts_full.astype(nmf_ts_best.components_.dtype)
W_ts_full = nmf_ts_best.transform(X_ts_full)
ts_scores = pd.DataFrame(W_ts_full, columns=mes_cols, index=ts_train.obs_names)

# Within-TS multi-seed stability with bootstrap CI
print(f"\n  Within-TS stability ({N_SEEDS} seeds) …")
stab_within_rows = []
all_within_corrs = []

for seed in range(N_SEEDS):
    model_s, W_s, H_s, _ = fit_nmf(X_ts_fit, K_MAIN, seed=seed, max_iter=800)
    C_s = corr_components(H_ts, H_s)
    _, _, vals = hungarian_match(C_s)
    all_within_corrs.extend(vals.tolist())
    stab_within_rows.append({
        "seed": seed,
        "mean_corr": round(float(np.mean(vals)), 4),
        "median_corr": round(float(np.median(vals)), 4),
        "min_corr": round(float(np.min(vals)), 4),
    })

df_stab_within = pd.DataFrame(stab_within_rows)
within_arr = np.array(all_within_corrs)
within_mean, within_ci_lo, within_ci_hi = bootstrap_ci(within_arr)
print(f"    Within-TS mean matched corr = {within_mean:.4f} "
      f"[{within_ci_lo:.4f}, {within_ci_hi:.4f}] 95% CI")

# 10) EXTERNAL PROJECTION + STABILITY WITH PERMUTATION TESTS
print(f"\n{'='*72}")
print("EXTERNAL VALIDATION: projection + permutation null")
print("")

W_lav, info_lav   = project_with_ts_nmf(nmf_ts_best, lav, GENE_UNIVERSE)
W_le,  info_le    = project_with_ts_nmf(nmf_ts_best, le,  GENE_UNIVERSE)

# Retrieve stored refit results from K-sweep
H_lav_refit = consensus_results[K_MAIN]["H_lav_refit"]
H_le_refit  = consensus_results[K_MAIN]["H_le_refit"]
C_lav       = consensus_results[K_MAIN]["C_lav"]
C_le        = consensus_results[K_MAIN]["C_le"]
vals_lav    = consensus_results[K_MAIN]["vals_lav"]
vals_le     = consensus_results[K_MAIN]["vals_le"]

H_ts_lav = H_ts[:, [pos_ref[norm_sym(g)] for g in genes_lav]]
H_ts_le  = H_ts[:, [pos_ref[norm_sym(g)] for g in genes_le]]

# Permutation tests (500 permutations)
print("  Permutation test — TS vs Lavaert …")
obs_lav, null_lav, pval_lav = permutation_test_corr(H_ts_lav, H_lav_refit,
                                                     n_perm=500, seed=42)
print(f"    Observed={obs_lav:.4f}  p={pval_lav:.4f}")

print("  Permutation test — TS vs Le …")
obs_le, null_le, pval_le = permutation_test_corr(H_ts_le, H_le_refit,
                                                  n_perm=500, seed=42)
print(f"    Observed={obs_le:.4f}  p={pval_le:.4f}")

# Bootstrap CI on matched correlations
lav_mean, lav_ci_lo, lav_ci_hi = bootstrap_ci(vals_lav)
le_mean,  le_ci_lo,  le_ci_hi  = bootstrap_ci(vals_le)

r_lav, c_lav, _ = hungarian_match(C_lav)
r_le,  c_le,  _ = hungarian_match(C_le)

df_stab8 = pd.DataFrame({
    "target_dataset": (["Lavaert"] * K_MAIN) + (["Le"] * K_MAIN),
    "ref_module_TS": [f"MES{i+1:02d}" for i in np.concatenate([r_lav, r_le])],
    "target_module_refit": [f"MES{i+1:02d}" for i in np.concatenate([c_lav, c_le])],
    "corr": np.concatenate([vals_lav, vals_le]).astype(float).round(4),
})
df_stab8_summary = pd.DataFrame([
    {"target": "Lavaert", "mean_corr": round(lav_mean, 4),
     "CI95_lo": round(lav_ci_lo, 4), "CI95_hi": round(lav_ci_hi, 4),
     "perm_pval": pval_lav},
    {"target": "Le", "mean_corr": round(le_mean, 4),
     "CI95_lo": round(le_ci_lo, 4), "CI95_hi": round(le_ci_hi, 4),
     "perm_pval": pval_le},
])

# 11) NEGCTRL — FORMAL STATISTICAL COMPARISON
print(f"\n{'='*72}")
print("NEGATIVE CONTROL: formal statistical test per module")
print("")

W_blood, info_blood = project_with_ts_nmf(nmf_ts_best, blood, GENE_UNIVERSE)
blood_scores = pd.DataFrame(W_blood, columns=mes_cols, index=blood.obs_names)

negctrl_stat_rows = []
for col in mes_cols:
    t_vals = ts_scores[col].values
    b_vals = blood_scores[col].values
    U_stat, mwu_pval = stats.mannwhitneyu(t_vals, b_vals, alternative="greater")
    d = cohens_d(t_vals, b_vals)
    negctrl_stat_rows.append({
        "module": col,
        "thymus_mean": round(float(t_vals.mean()), 5),
        "thymus_median": round(float(np.median(t_vals)), 5),
        "blood_mean": round(float(b_vals.mean()), 5),
        "blood_median": round(float(np.median(b_vals)), 5),
        "mannwhitney_U": float(U_stat),
        "pval_one_sided": mwu_pval,
        "cohens_d": round(d, 3),
    })

df_negctrl_stat = pd.DataFrame(negctrl_stat_rows)
df_negctrl_stat["pval_BH"] = benjamini_hochberg(
    df_negctrl_stat["pval_one_sided"].values
)
df_negctrl_stat["significant_BH005"] = df_negctrl_stat["pval_BH"] < 0.05
print(df_negctrl_stat[["module","thymus_mean","blood_mean","cohens_d",
                        "pval_BH","significant_BH005"]].to_string(index=False))

# Summary table for backward compat
df_negctrl_summary = pd.concat([
    pd.DataFrame({
        "dataset": "TS_thymus", "module": mes_cols,
        "mean": ts_scores.mean().values,
        "median": ts_scores.median().values,
        "p95": ts_scores.quantile(0.95).values,
    }),
    pd.DataFrame({
        "dataset": "Blood_negctrl", "module": mes_cols,
        "mean": blood_scores.mean().values,
        "median": blood_scores.median().values,
        "p95": blood_scores.quantile(0.95).values,
    }),
], ignore_index=True)

# 12) THYMIC COMPARTMENT ENRICHMENT — WILCOXON + EFFECT SIZES
print(f"\n{'='*72}")
print("THYMIC COMPARTMENT: differential enrichment per module")
print("")

def pick_best_compartment_col(obs: pd.DataFrame):
    preferred = [
        "cell_type","cell_type_label","celltype","annotation",
        "annotation_fine","annotation_coarse","compartment",
        "lineage","subset","cluster","leiden","louvain",
    ]
    for c in preferred:
        if c in obs.columns and 2 <= obs[c].nunique(dropna=True) <= 80:
            return c
    for c in obs.columns:
        if (obs[c].dtype == "object" or str(obs[c].dtype).startswith("category")):
            if 2 <= obs[c].nunique(dropna=True) <= 60:
                return c
    return None

comp_col = pick_best_compartment_col(ts.obs)
if comp_col is None:
    raise RuntimeError("No suitable thymus compartment column in TS obs.")

ts_comp = ts_scores.copy()
ts_comp[comp_col] = ts.obs.loc[ts_comp.index, comp_col].astype(str).values

# Per-module per-compartment Wilcoxon rank-sum (one-vs-rest)
enrichment_rows = []
compartments = sorted(ts_comp[comp_col].unique())
for mod in mes_cols:
    for ct in compartments:
        in_group  = ts_comp.loc[ts_comp[comp_col] == ct, mod].values
        out_group = ts_comp.loc[ts_comp[comp_col] != ct, mod].values
        if len(in_group) < 5 or len(out_group) < 5:
            continue
        U, pval = stats.mannwhitneyu(in_group, out_group, alternative="greater")
        d = cohens_d(in_group, out_group)
        enrichment_rows.append({
            "module": mod, "compartment": ct,
            "n_cells_in": len(in_group),
            "mean_in": round(float(in_group.mean()), 5),
            "mean_out": round(float(out_group.mean()), 5),
            "cohens_d": round(d, 3),
            "pval_wilcox": pval,
        })

df_enrichment = pd.DataFrame(enrichment_rows)
if len(df_enrichment):
    df_enrichment["pval_BH"] = benjamini_hochberg(
        df_enrichment["pval_wilcox"].values
    )
    df_enrichment["significant_BH005"] = df_enrichment["pval_BH"] < 0.05

df_comp_ts = ts_comp.groupby(comp_col)[mes_cols].mean()

# Compartment-level transfer
def spearman_corr(a, b):
    r, p = stats.spearmanr(a, b)
    return float(r), float(p)

def compartment_means_projected(W, adata_ref, col):
    df = pd.DataFrame(W, columns=mes_cols, index=adata_ref.obs_names)
    df[col] = adata_ref.obs[col].astype(str).values
    return df.groupby(col)[mes_cols].mean()

transfer_rows = []
if comp_col in lav.obs.columns:
    lav_means = compartment_means_projected(W_lav, lav, comp_col)
    shared = sorted(set(df_comp_ts.index) & set(lav_means.index))
    for c in shared:
        r, p = spearman_corr(df_comp_ts.loc[c, mes_cols], lav_means.loc[c, mes_cols])
        transfer_rows.append({"target": "Lavaert", "compartment": c,
                              "spearman_r": round(r, 4), "spearman_p": p})

if comp_col in le.obs.columns:
    le_means = compartment_means_projected(W_le, le, comp_col)
    shared = sorted(set(df_comp_ts.index) & set(le_means.index))
    for c in shared:
        r, p = spearman_corr(df_comp_ts.loc[c, mes_cols], le_means.loc[c, mes_cols])
        transfer_rows.append({"target": "Le", "compartment": c,
                              "spearman_r": round(r, 4), "spearman_p": p})

df_transfer = pd.DataFrame(transfer_rows)
if len(df_transfer):
    df_transfer["pval_BH"] = benjamini_hochberg(df_transfer["spearman_p"].values)

# 13) MODULE ANNOTATION — GENE TABLES + HYPERGEOMETRIC ENRICHMENT
print(f"\n{'='*72}")
print("MODULE ANNOTATION: top genes + axis enrichment (hypergeometric)")
print("")

TOP_K = 20
module_rows, top_rows = [], []
for comp in range(K_MAIN):
    w = pd.Series(H_ts[comp, :], index=GENE_UNI_UP)
    # Axis enrichment via hypergeometric
    top_genes_set = w.sort_values(ascending=False).head(100).index.tolist()
    enr = {}
    for ax_name, genes_ax in AXIS.items():
        res = hypergeom_enrichment(top_genes_set, genes_ax, len(GENE_UNI_UP))
        enr[ax_name] = res
    best_axis = max(enr, key=lambda x: enr[x]["fold_enrichment"])
    module_rows.append({
        "module": f"MES{comp+1:02d}",
        "best_axis": best_axis,
        "best_fold_enrichment": enr[best_axis]["fold_enrichment"],
        "best_pval_hypergeom": enr[best_axis]["pval_hypergeom"],
        "best_overlap_genes": enr[best_axis]["overlap_genes"],
        **{f"fold_{k}": v["fold_enrichment"] for k, v in enr.items()},
        **{f"pval_{k}": v["pval_hypergeom"] for k, v in enr.items()},
    })
    # Top K genes
    wtop = w.sort_values(ascending=False).head(TOP_K)
    for r, (g, val) in enumerate(wtop.items(), start=1):
        top_rows.append({"module": f"MES{comp+1:02d}", "rank": r,
                         "gene": g, "weight": round(float(val), 6)})

df_module = pd.DataFrame(module_rows)
# BH correct the best p-values
df_module["best_pval_BH"] = benjamini_hochberg(
    df_module["best_pval_hypergeom"].values
)

df_top = pd.DataFrame(top_rows)
df_weights = pd.DataFrame(
    H_ts.T, index=GENE_UNI_UP, columns=mes_cols
).reset_index().rename(columns={"index": "gene"})

df_module["top_genes_10"] = (
    df_top.groupby("module")["gene"]
    .apply(lambda x: ", ".join(list(x)[:10]))
    .reindex(df_module["module"]).values
)

# Full enrichment table (all axes × all modules)
enrich_full_rows = []
for comp in range(K_MAIN):
    w = pd.Series(H_ts[comp, :], index=GENE_UNI_UP)
    top100 = w.sort_values(ascending=False).head(100).index.tolist()
    for ax_name, genes_ax in AXIS.items():
        res = hypergeom_enrichment(top100, genes_ax, len(GENE_UNI_UP))
        res["module"] = f"MES{comp+1:02d}"
        res["axis"] = ax_name
        enrich_full_rows.append(res)

df_enrich_full = pd.DataFrame(enrich_full_rows)
df_enrich_full["pval_BH"] = benjamini_hochberg(
    df_enrich_full["pval_hypergeom"].values
)

# 14) PUBLICATION FIGURES
print(f"\n{'='*72}")
print("FIGURES: publication-grade, multi-panel")
print("")

# UMAP (TS training only)
print("  Computing UMAP (TS training) …")
N_EMB = 80000
if ts_train.n_obs > N_EMB:
    idx_emb = np.random.choice(ts_train.n_obs, size=N_EMB, replace=False)
    ts_emb = ts_train[idx_emb].copy()
    ts_emb_scores = ts_scores.loc[ts_emb.obs_names]
else:
    ts_emb = ts_train.copy()
    ts_emb_scores = ts_scores.copy()

ts_emb = ts_emb[:, hvg_ts].copy()
sc.tl.pca(ts_emb, n_comps=50, svd_solver="arpack")
sc.pp.neighbors(ts_emb, n_neighbors=15, n_pcs=30)
sc.tl.umap(ts_emb)
ts_emb.obs[mes_cols] = ts_emb_scores.loc[ts_emb.obs_names, mes_cols].values
ts_emb.obs["MES_dom"] = np.array(mes_cols)[
    np.argmax(ts_emb.obs[mes_cols].values, axis=1)
]

# MAIN FIG 1A — UMAP dominant module
fig = plt.figure(figsize=(6.8, 5.2))
ax = plt.gca()
xy = ts_emb.obsm["X_umap"]
cats = ts_emb.obs["MES_dom"].astype("category")
for idx_cat, cat in enumerate(sorted(cats.cat.categories)):
    m = (cats == cat).values
    ax.scatter(xy[m, 0], xy[m, 1], s=1.5, alpha=0.55, linewidths=0,
               color=PALETTE_MES[idx_cat], label=cat, rasterized=True)
ax.set_xlabel("UMAP1"); ax.set_ylabel("UMAP2")
ax.set_title("Main Fig 1A Thymus MES dominant module per cell")
ax.legend(markerscale=5, bbox_to_anchor=(1.02, 1.0), loc="upper left",
          frameon=False, fontsize=6)
save_fig(fig, "Fig1A", kind="Main")

# MAIN FIG 1B — Gene-weight heatmap (clustered)
gene_union = sorted(df_top["gene"].unique())
wmap = {g: i for i, g in enumerate(GENE_UNI_UP)}
mat = np.zeros((K_MAIN, len(gene_union)), dtype=float)
for i in range(K_MAIN):
    for j, g in enumerate(gene_union):
        mat[i, j] = float(H_ts[i, wmap[g]]) if g in wmap else 0.0

# Hierarchical clustering of genes
if len(gene_union) > 2:
    gene_link = linkage(mat.T, method="ward", metric="euclidean")
    from scipy.cluster.hierarchy import leaves_list
    gene_order = leaves_list(gene_link)
else:
    gene_order = np.arange(len(gene_union))

mat_ordered = mat[:, gene_order]
gene_labels_ordered = [gene_union[i] for i in gene_order]

fig = plt.figure(figsize=(10.5, 3.8))
ax = plt.gca()
im = ax.imshow(mat_ordered, aspect="auto", cmap=CMAP_HEAT)
ax.set_yticks(range(K_MAIN))
ax.set_yticklabels(mes_cols)
ax.set_xticks(range(len(gene_labels_ordered)))
ax.set_xticklabels(gene_labels_ordered, rotation=90, fontsize=2.5)
ax.set_title("Main Fig 1B MES gene weight heatmap union of top genes")
cbar = plt.colorbar(im, ax=ax, fraction=0.02, pad=0.02)
cbar.set_label("NMF weight")
plt.tight_layout()
save_fig(fig, "Fig1B", kind="Main")

# MAIN FIG 1C — External stability heatmaps
fig, axes = plt.subplots(1, 2, figsize=(10.2, 3.4), constrained_layout=True)

for ax, C_ext, tgt_label, perm_p in [
    (axes[0], C_lav, "Lavaert", pval_lav),
    (axes[1], C_le,  "Le",      pval_le),
]:
    im = ax.imshow(C_ext, aspect="equal", cmap=CMAP_RDBU, vmin=-1, vmax=1)
    ax.set_xticks(range(K_MAIN))
    ax.set_xticklabels(mes_cols, rotation=45, ha="right", fontsize=6)
    ax.set_yticks(range(K_MAIN))
    ax.set_yticklabels(mes_cols, fontsize=6)
    ax.set_xlabel(f"{tgt_label} refit modules")
    ax.set_ylabel("TS reference modules")
    for ri in range(K_MAIN):
        for ci in range(K_MAIN):
            v = C_ext[ri, ci]
            col = "white" if abs(v) > 0.5 else "black"
            ax.text(ci, ri, f"{v:.2f}", ha="center", va="center",
                    fontsize=5, color=col)
    ax.set_title(f"Main Fig 1C TS vs {tgt_label} (perm p={perm_p:.3g})")

cbar = fig.colorbar(im, ax=axes.ravel().tolist(), fraction=0.02, pad=0.02)
cbar.set_label("Pearson r")
save_fig(fig, "Fig1C", kind="Main")

# MAIN FIG 1D — Permutation null distributions
fig, axes = plt.subplots(1, 2, figsize=(9.5, 3.4), constrained_layout=True)

for ax, null_dist, obs_val, perm_p, label in [
    (axes[0], null_lav, obs_lav, pval_lav, "Lavaert"),
    (axes[1], null_le,  obs_le,  pval_le,  "Le"),
]:
    ax.hist(null_dist, bins=40, color="#bdbdbd", edgecolor="white",
            linewidth=0.5, density=True, label="Permutation null")
    ax.axvline(obs_val, color="#d7191c", lw=2, ls="--",
               label=f"Observed = {obs_val:.4f}")
    ax.set_xlabel("Mean matched correlation")
    ax.set_ylabel("Density")
    ax.set_title(f"Main Fig 1D TS vs {label} (p = {perm_p:.3g})")
    ax.legend(frameon=False, fontsize=6)
    ax.grid(alpha=0.15)

save_fig(fig, "Fig1D", kind="Main")

# SUPPLEMENTARY FIG S3A — K-sweep multi-criteria
fig_s3, axes_s3 = plt.subplots(2, 2, figsize=(10.0, 7.2), constrained_layout=True)

# S3A-1: Cophenetic
ax = axes_s3[0, 0]
ax.plot(df_sweep["K"], df_sweep["cophenetic"], "o-", color="#2c7fb8", lw=1.5)
ax.set_xlabel("K"); ax.set_ylabel("Cophenetic correlation")
ax.set_title("a  Cophenetic coefficient", loc="left")
ax.set_xticks(K_LIST); ax.grid(alpha=0.2)

# S3A-2: Silhouette
ax = axes_s3[0, 1]
ax.plot(df_sweep["K"], df_sweep["silhouette_cosine"], "s-", color="#d95f0e", lw=1.5)
ax.set_xlabel("K"); ax.set_ylabel("Silhouette (cosine)")
ax.set_title("b  Silhouette score", loc="left")
ax.set_xticks(K_LIST); ax.grid(alpha=0.2)

# S3A-3: Reconstruction error
ax = axes_s3[1, 0]
ax.errorbar(df_sweep["K"], df_sweep["recon_error_mean"],
            yerr=df_sweep["recon_error_std"], fmt="D-", color="#7570b3",
            lw=1.5, capsize=4)
ax.set_xlabel("K"); ax.set_ylabel("Reconstruction error (Frobenius)")
ax.set_title("c  Reconstruction error", loc="left")
ax.set_xticks(K_LIST); ax.grid(alpha=0.2)

# S3A-4: External stability
ax = axes_s3[1, 1]
ax.plot(df_sweep["K"], df_sweep["ext_corr_Lavaert_mean"], "o-",
        label="Lavaert", color=PALETTE_DUAL[0], lw=1.5)
ax.plot(df_sweep["K"], df_sweep["ext_corr_Le_mean"], "s-",
        label="Le", color=PALETTE_DUAL[1], lw=1.5)
ax.set_xlabel("K"); ax.set_ylabel("Mean matched correlation")
ax.set_title("d  External stability vs K", loc="left")
ax.set_xticks(K_LIST); ax.legend(frameon=False); ax.grid(alpha=0.2)

save_fig(fig_s3, "Fig3A_MES_K_sensitivity", kind="Supplementary")

# SUPPLEMENTARY FIG S4 — NEGCTRL with significance
fig_s4, axes_s4 = plt.subplots(1, 2, figsize=(11.5, 4.0),
                                gridspec_kw={"width_ratios": [1.3, 1]},
                                constrained_layout=True)

# S4 left: violin plot thymus vs blood per module
ax = axes_s4[0]
plot_data = []
for col in mes_cols:
    for val in ts_scores[col].values[:5000]:   # subsample for speed
        plot_data.append({"module": col, "score": val, "group": "Thymus"})
    for val in blood_scores[col].values[:5000]:
        plot_data.append({"module": col, "score": val, "group": "Blood negctrl"})

df_plot = pd.DataFrame(plot_data)
sns.violinplot(data=df_plot, x="module", y="score", hue="group",
               split=True, inner="quart", palette=PALETTE_DUAL,
               linewidth=0.6, ax=ax, density_norm="width", cut=0)
ax.set_xlabel(""); ax.set_ylabel("MES score")
ax.set_title("a  MES specificity: thymus vs peripheral myeloid", loc="left",
             )
ax.legend(frameon=False, fontsize=6, loc="upper right")
ax.tick_params(axis="x", rotation=45)
# Significance stars
for i, row in df_negctrl_stat.iterrows():
    star = ""
    if row["pval_BH"] < 0.001:
        star = "***"
    elif row["pval_BH"] < 0.01:
        star = "**"
    elif row["pval_BH"] < 0.05:
        star = "*"
    if star:
        ymax = df_plot[df_plot["module"] == row["module"]]["score"].quantile(0.98)
        ax.text(i, ymax * 1.02, star, ha="center", va="bottom", fontsize=7,
                )

# S4 right: effect size bar
ax = axes_s4[1]
colors = ["#2ca25f" if d > 0.5 else "#bdbdbd" for d in df_negctrl_stat["cohens_d"]]
ax.barh(range(K_MAIN), df_negctrl_stat["cohens_d"].values, color=colors,
        edgecolor="white", linewidth=0.5)
ax.set_yticks(range(K_MAIN))
ax.set_yticklabels(mes_cols)
ax.set_xlabel("Cohen's d (thymus > blood)")
ax.set_title("b  Effect sizes", loc="left")
ax.axvline(0.5, ls="--", color="grey", lw=0.8, label="d=0.5 (medium)")
ax.axvline(0.8, ls=":", color="grey", lw=0.8, label="d=0.8 (large)")
ax.legend(frameon=False, fontsize=6)
ax.grid(axis="x", alpha=0.2)

save_fig(fig_s4, "FigS4_NEGCTRL_Specificity", kind="Supplementary")

# SUPPLEMENTARY FIG S5 — Compartment enrichment
fig_s5 = plt.figure(figsize=(10.0, max(3.5, 0.22 * len(df_comp_ts.index) + 2.0)))
ax = plt.gca()

# Cluster rows and columns
if len(df_comp_ts.index) > 2:
    row_link = linkage(df_comp_ts.values, method="ward")
    row_order = leaves_list(row_link)
else:
    row_order = np.arange(len(df_comp_ts.index))

if K_MAIN > 2:
    col_link = linkage(df_comp_ts.values.T, method="ward")
    col_order = leaves_list(col_link)
else:
    col_order = np.arange(K_MAIN)

mat_comp = df_comp_ts.values[np.ix_(row_order, col_order)]
row_labels = [df_comp_ts.index[i] for i in row_order]
col_labels = [mes_cols[i] for i in col_order]

im = ax.imshow(mat_comp, aspect="auto", cmap=CMAP_HEAT)
ax.set_yticks(range(len(row_labels)))
ax.set_yticklabels(row_labels)
ax.set_xticks(range(len(col_labels)))
ax.set_xticklabels(col_labels, rotation=45, ha="right")
ax.set_title(f"Thymus compartment enrichment (column: {comp_col}; "
             f"hierarchically clustered)")
cbar = plt.colorbar(im, ax=ax, fraction=0.02, pad=0.02, shrink=0.85)
cbar.set_label("Mean MES score")
plt.tight_layout()
save_fig(fig_s5, "FigS5_ThymusCompartment_Enrichment", kind="Supplementary")

# SUPPLEMENTARY FIG S6 — Compartment transfer
if not df_transfer.empty:
    fig_s6, axes_s6 = plt.subplots(1, 2, figsize=(9.5, 3.8),
                                    constrained_layout=True)
    # Left: box + strip plot of Spearman r
    ax = axes_s6[0]
    data_box, labels_box = [], []
    for tgt in ["Lavaert", "Le"]:
        vals = df_transfer[df_transfer["target"] == tgt]["spearman_r"].dropna().values
        if len(vals):
            data_box.append(vals)
            labels_box.append(tgt)
    bp = ax.boxplot(data_box, labels=labels_box, showfliers=False,
                    patch_artist=True, widths=0.45)
    for patch, col in zip(bp["boxes"], PALETTE_DUAL[:len(labels_box)]):
        patch.set_facecolor(col); patch.set_alpha(0.4)
    for i, vals in enumerate(data_box):
        ax.scatter(np.full_like(vals, i + 1) + np.random.normal(0, 0.04, len(vals)),
                   vals, s=18, alpha=0.7, color=PALETTE_DUAL[i], zorder=3,
                   edgecolors="white", linewidths=0.3)
    ax.set_ylabel("Spearman ρ of MES profile")
    ax.set_title("a  Compartment-level transfer", loc="left")
    ax.grid(axis="y", alpha=0.2)
    ax.axhline(0, ls="--", color="grey", lw=0.6)

    # Right: per-compartment bars
    ax = axes_s6[1]
    for idx_t, tgt in enumerate(["Lavaert", "Le"]):
        sub = df_transfer[df_transfer["target"] == tgt].sort_values("spearman_r",
                                                                      ascending=True)
        if len(sub):
            y_pos = np.arange(len(sub)) + idx_t * 0.35
            ax.barh(y_pos, sub["spearman_r"].values, height=0.3,
                    color=PALETTE_DUAL[idx_t], alpha=0.75, label=tgt,
                    edgecolor="white", linewidth=0.3)
            ax.set_yticks(np.arange(len(sub)) + 0.175)
            ax.set_yticklabels(sub["compartment"].values, fontsize=5.5)
    ax.set_xlabel("Spearman ρ")
    ax.set_title("b  Per-compartment Spearman ρ", loc="left")
    ax.legend(frameon=False, fontsize=6); ax.grid(axis="x", alpha=0.2)

    save_fig(fig_s6, "FigS6_Compartment_level_transfer_profile_correlation",
             kind="Supplementary")

# SUPPLEMENTARY FIG S8 — Enrichment dot plot
fig_s7 = plt.figure(figsize=(6.5, 4.5))
ax = plt.gca()

df_dot = df_enrich_full.copy()
df_dot["-log10_pBH"] = -np.log10(df_dot["pval_BH"].clip(lower=1e-20))
df_dot["module_idx"] = df_dot["module"].str.extract(r"(\d+)").astype(int) - 1
ax_names = sorted(AXIS.keys())
df_dot["axis_idx"] = df_dot["axis"].map({a: i for i, a in enumerate(ax_names)})

for _, row in df_dot.iterrows():
    ax.scatter(row["axis_idx"], row["module_idx"],
               s=max(5, row["fold_enrichment"] * 35),
               c=row["-log10_pBH"],
               cmap=CMAP_VIRIDIS, vmin=0,
               vmax=max(3, df_dot["-log10_pBH"].quantile(0.95)),
               edgecolors="black", linewidths=0.3)

ax.set_xticks(range(len(ax_names)))
ax.set_xticklabels([a.replace("_", "\n") for a in ax_names], fontsize=7)
ax.set_yticks(range(K_MAIN))
ax.set_yticklabels(mes_cols)
ax.set_title("Module–axis enrichment (size = fold enrichment; colour = −log₁₀ FDR)",
             fontsize=8)
sm = plt.cm.ScalarMappable(cmap=CMAP_VIRIDIS,
                            norm=plt.Normalize(0, max(3, df_dot["-log10_pBH"].quantile(0.95))))
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, fraction=0.025, pad=0.03)
cbar.set_label("−log₁₀ FDR", fontsize=7)
plt.tight_layout()
save_fig(fig_s7, "FigS7_Module_Axis_Enrichment_DotPlot", kind="Supplementary")

# SUPPLEMENTARY FIG S9 — Consensus matrix for K=8
fig_s8 = plt.figure(figsize=(5.5, 5.0))
ax = plt.gca()
C_consensus = consensus_results[K_MAIN]["C_mat"]
# reorder by dominant module for visual clarity
dom_labels = np.argmax(consensus_results[K_MAIN]["W"], axis=1)
reorder = np.argsort(dom_labels)
C_sorted = C_consensus[np.ix_(reorder, reorder)]
im = ax.imshow(C_sorted, cmap="YlOrRd", vmin=0, vmax=1, aspect="equal")
ax.set_title(f"Consensus matrix K={K_MAIN} ({N_CONSENSUS} runs)\n"
             f"Cophenetic = {consensus_results[K_MAIN]['coph']:.4f}",
             fontsize=8)
ax.set_xlabel("Cells (ordered by dominant module)")
ax.set_ylabel("Cells")
# Remove tick labels for large matrices
ax.set_xticks([]); ax.set_yticks([])
cbar = plt.colorbar(im, ax=ax, fraction=0.03, pad=0.02)
cbar.set_label("Consensus", fontsize=7)
plt.tight_layout()
save_fig(fig_s8, "FigS8_Consensus_Matrix_K8", kind="Supplementary")

# 15) TABLES
print(f"\n{'='*72}")
print("SAVING TABLES")
print("")

# Main Table 1 (same filename)
save_excel(
    sheets={
        "Module_Summary": df_module.sort_values("module"),
        "TopGenes": df_top.sort_values(["module", "rank"]),
        "GeneWeights": df_weights,
        "Enrichment_HyperGeom": df_enrich_full.sort_values(["module", "axis"]),
        "Stability_within_TS": df_stab_within,
        "Stability_external_K8": df_stab8,
        "Stability_ext_summary": df_stab8_summary,
        "NEGCTRL_stats": df_negctrl_stat,
        "NEGCTRL_summary": df_negctrl_summary.sort_values(["module", "dataset"]),
        "Gene_coverage": pd.DataFrame([
            {"dataset": "Lavaert",      **info_lav},
            {"dataset": "Le",           **info_le},
            {"dataset": "Blood_NEGCTRL", **info_blood},
        ]),
        "Compartment_means": df_comp_ts.reset_index().rename(
            columns={comp_col: "compartment"}),
        "Compartment_enrichment": df_enrichment if len(df_enrichment) else pd.DataFrame(),
        "Transfer_compartment": df_transfer if len(df_transfer) else pd.DataFrame(),
        "DIAGNOSTIC_gene_overlap": df_diag,
    },
    fname="Table1", kind="Main"
)

# Supplementary Table 3 (same filename)
save_excel(
    sheets={
        "K_sweep_multi_criteria": df_sweep.sort_values("K"),
        "K8_matches_detail": df_stab8.sort_values(
            ["target_dataset", "corr"], ascending=[True, False]),
        "K8_summary_with_CI": df_stab8_summary,
    },
    fname="Table3", kind="Supplementary"
)

# 16) SAVE SCORED ARTIFACTS
print(f"\n{'='*72}")
print("SAVING SCORED ARTIFACTS")
print("")

# TS
ts_scored = ts.copy()
ts_scored.obs[mes_cols] = ts_scores.loc[ts_scored.obs_names, mes_cols].values
ts_scored.obs["MES_dom"] = np.array(mes_cols)[
    np.argmax(ts_scored.obs[mes_cols].values, axis=1)
]
ts_scored = sanitize_for_write(ts_scored)
out_ts = AIM1_DIR / "TS_thymus_scored_K8.h5ad"
ts_scored.write_h5ad(out_ts)
print(f"  [SAVED] {out_ts}")

# Lavaert
lav_scored = lav.copy()
lav_scored.obs[mes_cols] = W_lav
lav_scored.obs["MES_dom"] = np.array(mes_cols)[np.argmax(W_lav, axis=1)]
lav_scored = sanitize_for_write(lav_scored)
out_lav = AIM1_DIR / "Lavaert_projected_scored_K8.h5ad"
lav_scored.write_h5ad(out_lav)
print(f"  [SAVED] {out_lav}")

# Le
le_scored = le.copy()
le_scored.obs[mes_cols] = W_le
le_scored.obs["MES_dom"] = np.array(mes_cols)[np.argmax(W_le, axis=1)]
le_scored = sanitize_for_write(le_scored)
out_le = AIM1_DIR / "Le_projected_scored_K8.h5ad"
le_scored.write_h5ad(out_le)
print(f"  [SAVED] {out_le}")

# Blood NEGCTRL
blood_scored = blood.copy()
blood_scored.obs[mes_cols] = W_blood
blood_scored.obs["MES_dom"] = np.array(mes_cols)[np.argmax(W_blood, axis=1)]
blood_scored = sanitize_for_write(blood_scored)
out_blood = NEG_DIR / "Blood_negctrl_projected_scored_K8.h5ad"
blood_scored.write_h5ad(out_blood)
print(f"  [SAVED] {out_blood}")

# FINAL REPORT
print("")
print("NB3 COMPLETED")
print("")
print(f"""
STATISTICAL UPGRADES APPLIED:
  ✓ Consensus NMF ({N_CONSENSUS} runs) with cophenetic correlation
  ✓ Multi-criteria K selection (cophenetic, silhouette, recon error, ext stability)
  ✓ Permutation null for external stability (500 perms, p-values)
  ✓ Bootstrap 95% CI on matched correlations
  ✓ Mann-Whitney U + BH correction for NEGCTRL (one-sided)
  ✓ Cohen's d effect sizes for NEGCTRL
  ✓ Wilcoxon rank-sum + BH for compartment enrichment (one-vs-rest)
  ✓ Hypergeometric enrichment for axis gene sets + BH correction
  ✓ Spearman ρ with p-values for compartment transfer

FIGURE:
  ✓ Separate Main Fig 1A/1B/1C/1D panels
  ✓ Hierarchically clustered heatmaps
  ✓ Violin + strip plots for NEGCTRL with significance stars
  ✓ Effect size bars
  ✓ Permutation null histograms (new: Main Fig 1D)
  ✓ Enrichment dot plot (new: Fig S7)
  ✓ Consensus matrix visualisation (new: Fig S8)
  ✓ K-sweep multi-criteria panels (upgraded: Fig S3A)

OUTPUTS:
  Figures:
    - Main_Fig1A.png / Main_Fig1B.png / Main_Fig1C.png
    - Main_Fig1D.png  (NEW — permutation null)
    - Supplementary_Fig3A_MES_K_sensitivity
    - Supplementary_FigS4_NEGCTRL_Specificity
    - Supplementary_FigS5_ThymusCompartment_Enrichment
    - Supplementary_FigS6_Compartment_level_transfer_profile_correlation
    - Supplementary_FigS7_Module_Axis_Enrichment_DotPlot  (NEW)
    - Supplementary_FigS8_Consensus_Matrix_K8  (NEW)
  Tables:
    - Main_Table1.xlsx 
    - Supplementary_Table3.xlsx  (multi-criteria K-sweep)
  Scored artifacts:
    - {out_ts}
    - {out_lav}
    - {out_le}
    - {out_blood}
""")