# NB2 — Aim 1: Thymus Profiling, Negative Control & Innate Memory QC

Single-cell QC and preprocessing of three thymus cohorts (Tabula Sapiens, Lavaert GSE144870, Le GSE139042), peripheral blood negative control, and innate immune memory datasets.

**Paper:** Zafar SA, Qin W. *Thymus-Derived Myeloid Education Signatures Predict Microglial Tolerance Positioning and Are Modulated by Glucocorticoid Stress-Axis Activity.* Neuroimmunomodulation (2026).

> **Note:** Update the path variables in section 0 to match your local directory structure before running. Raw data can be obtained from the public repositories listed in Supplementary Table S1.


In [None]:
import warnings, re, json, gzip
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import scipy.sparse as sp
import scipy.stats as stats
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns

warnings.filterwarnings("ignore")
sc.settings.verbosity = 2
np.random.seed(42)


# 0) Paths

BASE_DIR = Path(".")  # <-- SET TO YOUR PROJECT ROOT
RAW_DIR  = BASE_DIR / "Raw Data"
PROC_DIR = BASE_DIR / "Process Data"

MANUSCRIPT_DIR = BASE_DIR / "outputs" / "manuscript"
FIG_DIR = MANUSCRIPT_DIR / "Figures"
TAB_DIR = MANUSCRIPT_DIR / "Tables"

OUT_AIM1     = PROC_DIR / "aim1_thymus"
OUT_NEGCTRL  = PROC_DIR / "negctrl"
OUT_IMM      = PROC_DIR / "innate_memory"

for d in [OUT_AIM1, OUT_NEGCTRL, OUT_IMM, FIG_DIR, TAB_DIR]:
    d.mkdir(parents=True, exist_ok=True)

assert RAW_DIR.exists(), f"RAW_DIR not found: {RAW_DIR}"

print("BASE_DIR:", BASE_DIR)
print("RAW_DIR :", RAW_DIR)
print("PROC_DIR:", PROC_DIR)
print("MANUSCRIPT_DIR:", MANUSCRIPT_DIR)
print("FIG_DIR :", FIG_DIR)
print("TAB_DIR :", TAB_DIR)


# 1) Figure style

plt.rcParams.update({
    "font.family": "Arial",
    "font.size": 8,
    "axes.labelsize": 8,
    "axes.titlesize": 9,
    "xtick.labelsize": 7,
    "ytick.labelsize": 7,
    "legend.fontsize": 7,
    "figure.titlesize": 9,
    "axes.linewidth": 0.8,
})
FIG_DPI = 1200

PALETTE_STAGE = ["#3182bd", "#e6550d"]  # pre / post
PALETTE_PASS  = ["#2ca25f", "#de2d26"]  # pass / fail

def save_fig(fig, fname: str, kind: str = "Supplementary"):
    assert kind in ("Main", "Supplementary")
    out = FIG_DIR / f"{kind}_{fname}.png"
    fig.savefig(out, dpi=FIG_DPI, bbox_inches="tight")
    plt.close(fig)
    print(f"[SAVED FIG] {out}")

def save_table_xlsx(sheets: dict, fname: str, kind: str = "Supplementary"):
    assert kind in ("Main", "Supplementary")
    out = TAB_DIR / f"{kind}_{fname}.xlsx"
    with pd.ExcelWriter(out, engine="openpyxl") as w:
        for sn, df in sheets.items():
            df.to_excel(w, index=False, sheet_name=str(sn)[:31])
    print(f"[SAVED TABLE] {out}")


# 2) Registry

REG = {
    "TS_Thymus_filtered":  RAW_DIR / "Thymus" / "TS_Thymus_filtered.h5ad",
    "GSE144870_Lavaert":   RAW_DIR / "Thymus" / "GSE144870",
    "GSE139042_Le":        RAW_DIR / "Thymus" / "GSE139042",
    "GSE133341_Zeng":      RAW_DIR / "Thymus" / "GSE133341",
    "TS_Blood_NegCtrl":    RAW_DIR / "Peripheral_Myeloid_NegCtrl" / "Tabula_Sapiens_Blood.h5ad",
    "GSE229940_dir":       RAW_DIR / "Innate_Immune_Memory" / "GSE229940",
}
for k, p in REG.items():
    assert Path(p).exists(), f"[MISSING] {k}: {p}"
print("[OK] Core NB2 inputs exist.")


# 3) Helpers

def normalize_gene_symbols(adata: ad.AnnData) -> ad.AnnData:
    adata.var_names = adata.var_names.astype(str).str.strip().str.upper()
    adata.var_names_make_unique()
    return adata

def add_mt_ribo_flags(adata: ad.AnnData) -> ad.AnnData:
    """Add mitochondrial and ribosomal gene flags."""
    adata.var_names = adata.var_names.astype(str)
    adata.var["mt"] = adata.var_names.str.startswith("MT-")
    adata.var["ribo"] = adata.var_names.str.match(r"^RP[SL]\d+")
    return adata

def sanitize_for_write(adata: ad.AnnData) -> ad.AnnData:
    for df_name in ["obs", "var"]:
        df = getattr(adata, df_name)
        if df.index.name is not None:
            idx_name = df.index.name
            if idx_name in df.columns:
                if not pd.Series(df.index, index=df.index).equals(df[idx_name]):
                    df.rename(columns={idx_name: f"{idx_name}_col"}, inplace=True)
            df.index.name = None
        df.columns = df.columns.astype(str)
    adata.obs_names = adata.obs_names.astype(str)
    adata.var_names = adata.var_names.astype(str)
    return adata

def standardize_obs(adata: ad.AnnData, dataset_id: str, cohort: str) -> ad.AnnData:
    adata.obs = adata.obs.copy()
    adata.obs["dataset_id"] = dataset_id
    adata.obs["cohort"] = cohort
    adata.obs_names = adata.obs_names.astype(str)
    adata.var_names_make_unique()
    return adata


# 4) MAD-based adaptive QC

def median_abs_deviation(x: np.ndarray) -> float:
    return float(np.median(np.abs(x - np.median(x))))

def mad_outlier_mask(values: np.ndarray, n_mads: float = 5.0,
                     direction: str = "both") -> np.ndarray:
    """
    Return boolean mask of outliers based on median ± n_mads × MAD.
    direction: 'both', 'upper', 'lower'
    """
    med = np.median(values)
    mad = median_abs_deviation(values)
    if mad < 1e-6:
        return np.zeros(len(values), dtype=bool)
    if direction == "upper":
        return values > med + n_mads * mad
    elif direction == "lower":
        return values < med - n_mads * mad
    else:
        return (values > med + n_mads * mad) | (values < med - n_mads * mad)

def qc_filter_mad(
    adata: ad.AnnData,
    min_genes: int = 200,
    min_cells_per_gene: int = 3,
    min_counts: int = 500,
    max_mt_pct: float = 20.0,
    n_mads_counts: float = 5.0,
    n_mads_genes: float = 5.0,
    max_counts_hard: Optional[int] = None,
) -> tuple:
    """
    MAD-based adaptive QC.
    Uses median ± N×MAD for UMI counts and gene counts,
    with hard floor thresholds as safety nets.
    """
    n0, g0 = adata.n_obs, adata.n_vars
    if "counts" not in adata.layers:
        adata.layers["counts"] = adata.X.copy()

    if "mt" not in adata.var.columns:
        adata = add_mt_ribo_flags(adata)

    sc.pp.calculate_qc_metrics(
        adata, qc_vars=["mt", "ribo"], percent_top=None, log1p=False, inplace=True
    )

    counts = adata.obs["total_counts"].values.astype(float)
    genes  = adata.obs["n_genes_by_counts"].values.astype(float)
    mt_pct = adata.obs["pct_counts_mt"].values.astype(float)

    # MAD-based upper outlier detection for counts
    outlier_counts_hi = mad_outlier_mask(np.log1p(counts), n_mads=n_mads_counts, direction="upper")
    outlier_genes_hi  = mad_outlier_mask(np.log1p(genes),  n_mads=n_mads_genes,  direction="upper")
    outlier_genes_lo  = genes < min_genes
    outlier_counts_lo = counts < min_counts
    outlier_mt        = mt_pct > max_mt_pct

    if max_counts_hard is not None:
        outlier_counts_hard = counts > max_counts_hard
    else:
        outlier_counts_hard = np.zeros(len(counts), dtype=bool)

    # Compute effective thresholds for reporting
    med_c = np.median(np.log1p(counts))
    mad_c = median_abs_deviation(np.log1p(counts))
    upper_count_thresh = float(np.expm1(med_c + n_mads_counts * mad_c))

    med_g = np.median(np.log1p(genes))
    mad_g = median_abs_deviation(np.log1p(genes))
    upper_gene_thresh = float(np.expm1(med_g + n_mads_genes * mad_g))

    # Store per-cell QC flags
    adata.obs["outlier_counts_hi"] = outlier_counts_hi
    adata.obs["outlier_counts_lo"] = outlier_counts_lo
    adata.obs["outlier_genes_hi"]  = outlier_genes_hi
    adata.obs["outlier_genes_lo"]  = outlier_genes_lo
    adata.obs["outlier_mt"]        = outlier_mt

    keep = ~(outlier_counts_hi | outlier_counts_lo | outlier_counts_hard |
             outlier_genes_hi | outlier_genes_lo | outlier_mt)

    adata.obs["qc_pass"] = keep

    adata_f = adata[keep].copy()
    sc.pp.filter_genes(adata_f, min_cells=min_cells_per_gene)

    n1, g1 = adata_f.n_obs, adata_f.n_vars

    qc_stats = {
        "n_cells_before": int(n0),
        "n_cells_after": int(n1),
        "pct_cells_kept": round(float(n1 / max(n0, 1) * 100.0), 2),
        "n_genes_before": int(g0),
        "n_genes_after": int(g1),
        "median_genes": round(float(np.median(adata_f.obs["n_genes_by_counts"])), 1) if n1 else np.nan,
        "median_umi": round(float(np.median(adata_f.obs["total_counts"])), 1) if n1 else np.nan,
        "median_mt_pct": round(float(np.median(adata_f.obs["pct_counts_mt"])), 2) if n1 else np.nan,
        "median_ribo_pct": round(float(np.median(adata_f.obs["pct_counts_ribo"])), 2) if n1 and "pct_counts_ribo" in adata_f.obs.columns else np.nan,
        "n_outlier_counts_hi": int(outlier_counts_hi.sum()),
        "n_outlier_counts_lo": int(outlier_counts_lo.sum()),
        "n_outlier_genes_hi": int(outlier_genes_hi.sum()),
        "n_outlier_genes_lo": int(outlier_genes_lo.sum()),
        "n_outlier_mt": int(outlier_mt.sum()),
        "effective_max_counts": round(upper_count_thresh, 0),
        "effective_max_genes": round(upper_gene_thresh, 0),
        "n_mads_counts": float(n_mads_counts),
        "n_mads_genes": float(n_mads_genes),
        "min_genes": int(min_genes),
        "min_counts": int(min_counts),
        "max_mt_pct": float(max_mt_pct),
        "qc_method": "MAD-based adaptive",
    }
    return adata_f, qc_stats


# 5) Doublet detection (Scrublet)

def run_scrublet(adata: ad.AnnData, expected_doublet_rate: float = 0.06,
                 seed: int = 42) -> ad.AnnData:
    """
    Run Scrublet per sample. Stores doublet_score and predicted_doublet in obs.
    Falls back gracefully if too few cells or scrublet unavailable.
    """
    try:
        import scrublet as scr
    except ImportError:
        print("  [WARN] scrublet not installed — skipping doublet detection")
        adata.obs["doublet_score"] = np.nan
        adata.obs["predicted_doublet"] = False
        return adata

    if adata.n_obs < 50:
        adata.obs["doublet_score"] = np.nan
        adata.obs["predicted_doublet"] = False
        return adata

    X = adata.layers["counts"] if "counts" in adata.layers else adata.X
    if sp.issparse(X):
        X = X.tocsr()
    else:
        X = sp.csr_matrix(X)

    try:
        scrub = scr.Scrublet(X, expected_doublet_rate=expected_doublet_rate,
                              random_state=seed)
        scores, preds = scrub.scrub_doublets(min_counts=2, min_cells=3,
                                              min_gene_variability_pctl=85,
                                              n_prin_comps=min(30, adata.n_vars - 1),
                                              verbose=False)
        adata.obs["doublet_score"] = scores.astype(np.float32)
        adata.obs["predicted_doublet"] = preds.astype(bool)
    except Exception as e:
        print(f"  [WARN] Scrublet failed: {e}")
        adata.obs["doublet_score"] = np.nan
        adata.obs["predicted_doublet"] = False

    return adata

def run_scrublet_per_sample(adata: ad.AnnData, sample_col: str = "sample_id",
                            expected_doublet_rate: float = 0.06) -> ad.AnnData:
    """Run Scrublet independently per sample to avoid cross-sample artifacts."""
    adata.obs["doublet_score"] = np.nan
    adata.obs["predicted_doublet"] = False

    if sample_col not in adata.obs.columns:
        return run_scrublet(adata, expected_doublet_rate=expected_doublet_rate)

    samples = adata.obs[sample_col].unique()
    for samp in samples:
        mask = adata.obs[sample_col] == samp
        sub = adata[mask].copy()
        sub = run_scrublet(sub, expected_doublet_rate=expected_doublet_rate)
        adata.obs.loc[mask, "doublet_score"] = sub.obs["doublet_score"].values
        adata.obs.loc[mask, "predicted_doublet"] = sub.obs["predicted_doublet"].values

    return adata


# 6) QC snapshot

def qc_snapshot(adata: ad.AnnData, dataset_id: str, stage: str,
                max_cells: int = 20000) -> pd.DataFrame:
    if "mt" not in adata.var.columns:
        adata = add_mt_ribo_flags(adata)
    if "total_counts" not in adata.obs.columns:
        sc.pp.calculate_qc_metrics(adata, qc_vars=["mt", "ribo"],
                                   percent_top=None, log1p=False, inplace=True)

    n = adata.n_obs
    if n > max_cells:
        idx = np.random.choice(np.arange(n), size=max_cells, replace=False)
        obs = adata.obs.iloc[idx]
    else:
        obs = adata.obs

    cols = ["total_counts", "n_genes_by_counts", "pct_counts_mt"]
    if "pct_counts_ribo" in obs.columns:
        cols.append("pct_counts_ribo")
    if "doublet_score" in obs.columns:
        cols.append("doublet_score")

    df = obs[cols].copy()
    df["dataset_id"] = dataset_id
    df["stage"] = stage
    return df

def per_sample_qc_summary(adata: ad.AnnData, dataset_id: str,
                          sample_col: str = "sample_id") -> pd.DataFrame:
    """Per-sample QC summary: median metrics, cell counts, doublet rates."""
    rows = []
    if sample_col not in adata.obs.columns:
        samples = [dataset_id]
        masks = [np.ones(adata.n_obs, dtype=bool)]
    else:
        samples = sorted(adata.obs[sample_col].unique())
        masks = [adata.obs[sample_col] == s for s in samples]

    for samp, mask in zip(samples, masks):
        sub = adata.obs.loc[mask]
        row = {
            "dataset_id": dataset_id,
            "sample_id": samp,
            "n_cells": int(mask.sum()),
            "median_genes": round(float(sub["n_genes_by_counts"].median()), 1),
            "median_umi": round(float(sub["total_counts"].median()), 1),
            "median_mt_pct": round(float(sub["pct_counts_mt"].median()), 2),
        }
        if "pct_counts_ribo" in sub.columns:
            row["median_ribo_pct"] = round(float(sub["pct_counts_ribo"].median()), 2)
        if "doublet_score" in sub.columns:
            ds = sub["doublet_score"].dropna()
            row["median_doublet_score"] = round(float(ds.median()), 4) if len(ds) else np.nan
        if "predicted_doublet" in sub.columns:
            pd_col = sub["predicted_doublet"]
            n_doub = int(pd_col.sum())
            row["n_doublets"] = n_doub
            row["doublet_rate_pct"] = round(float(n_doub / max(mask.sum(), 1) * 100), 2)
        rows.append(row)

    return pd.DataFrame(rows)


# 7) Lavaert robust discovery + loaders

def read_tsv_any(path: Path) -> pd.Series:
    if path.suffix.lower() == ".gz":
        df = pd.read_csv(path, sep="\t", header=None, compression="gzip")
    else:
        df = pd.read_csv(path, sep="\t", header=None)
    return df.iloc[:, 0].astype(str)

def read_features_any(path: Path) -> pd.DataFrame:
    if path.suffix.lower() == ".gz":
        df = pd.read_csv(path, sep="\t", header=None, compression="gzip")
    else:
        df = pd.read_csv(path, sep="\t", header=None)
    return df

def _sniff_encoding(path: Path, nbytes: int = 4096) -> str:
    b = path.read_bytes()[:nbytes]
    if b.count(b"\x00") > 100:
        return "utf-16"
    return "utf-8-sig"

def read_mtx_any(path: Path) -> sp.csr_matrix:
    enc = _sniff_encoding(path) if path.suffix.lower() != ".gz" else "utf-8"
    if path.suffix.lower() == ".gz":
        import io
        with gzip.open(path, "rt", encoding="utf-8", errors="ignore") as f:
            txt = f.read()
        df = pd.read_csv(io.StringIO(txt), sep=r"\s+", header=None,
                         comment="%", engine="python").dropna(how="all")
    else:
        df = pd.read_csv(path, sep=r"\s+", header=None, comment="%",
                         engine="python", encoding=enc).dropna(how="all")

    nrows, ncols, nnz = df.iloc[0, :3].astype(int).values
    trip = df.iloc[1:, :3]
    i = trip.iloc[:, 0].astype(np.int64).values - 1
    j = trip.iloc[:, 1].astype(np.int64).values - 1
    x = trip.iloc[:, 2].astype(np.float32).values
    return sp.coo_matrix((x, (i, j)), shape=(nrows, ncols)).tocsr()

def find_10x_matrix_sets(root: Path):
    sets = []
    # standard 10x folders
    for m in root.rglob("matrix.mtx*"):
        folder = m.parent
        b1 = folder / "barcodes.tsv"
        b2 = folder / "barcodes.tsv.gz"
        f1 = folder / "features.tsv"
        f2 = folder / "features.tsv.gz"
        g1 = folder / "genes.tsv"
        g2 = folder / "genes.tsv.gz"
        bar = b1 if b1.exists() else (b2 if b2.exists() else None)
        feat = f1 if f1.exists() else (f2 if f2.exists() else (g1 if g1.exists() else (g2 if g2.exists() else None)))
        if bar and feat:
            sets.append({"matrix": m, "barcodes": bar, "features": feat,
                         "folder": folder, "prefix": folder.name})
    # flat triples
    for m in root.rglob("*_matrix.mtx*"):
        folder = m.parent
        pref = re.sub(r"_matrix\.mtx(\.gz)?$", "", m.name)
        bar = next((p for p in [
            folder / f"{pref}_barcodes.tsv", folder / f"{pref}_barcodes.tsv.gz"
        ] if p.exists()), None)
        feat = next((p for p in [
            folder / f"{pref}_features.tsv", folder / f"{pref}_features.tsv.gz",
            folder / f"{pref}_genes.tsv", folder / f"{pref}_genes.tsv.gz"
        ] if p.exists()), None)
        if bar and feat:
            sets.append({"matrix": m, "barcodes": bar, "features": feat,
                         "folder": folder, "prefix": pref})
    uniq = {}
    for s in sets:
        uniq[str(Path(s["matrix"]).resolve())] = s
    return list(uniq.values())

def load_10x_set(s):
    obs_names = read_tsv_any(s["barcodes"]).values
    gf = read_features_any(s["features"])
    var_names = gf.iloc[:, 1].astype(str).values if gf.shape[1] >= 2 else gf.iloc[:, 0].astype(str).values

    n_cells = len(obs_names)
    n_genes = len(var_names)

    M = read_mtx_any(s["matrix"])
    sh = M.shape

    if sh == (n_genes, n_cells):
        X = M.T.tocsr()
        orient = "genes_x_cells__transposed"
    elif sh == (n_cells, n_genes):
        X = M.tocsr()
        orient = "cells_x_genes__kept"
    else:
        X = M.T.tocsr()
        orient = "mismatch__assumed_genes_x_cells_transposed"

    adata = ad.AnnData(X=X, obs=pd.DataFrame(index=obs_names),
                       var=pd.DataFrame(index=var_names))
    adata.var_names_make_unique()
    dbg = {
        "prefix": s["prefix"],
        "folder": str(s["folder"]),
        "matrix": str(s["matrix"]),
        "barcodes": str(s["barcodes"]),
        "features": str(s["features"]),
        "mtx_shape": str(sh),
        "n_cells_from_barcodes": int(n_cells),
        "n_genes_from_features": int(n_genes),
        "orientation": orient,
    }
    return adata, dbg


# 8) NEGCTRL config auto-lock

CONFIG_PATH = OUT_NEGCTRL / "config_negctrl.json"
CELLTYPE_COL_CANDIDATES = [
    "cell_type", "cell_type_label", "celltype", "annotation",
    "annotation_fine", "annotation_coarse",
]

def pick_best_celltype_col(obs: pd.DataFrame) -> Optional[str]:
    best, best_score = None, -1
    for c in CELLTYPE_COL_CANDIDATES:
        if c not in obs.columns:
            continue
        s = obs[c].astype(str).str.lower()
        score = int(s.str.contains(r"monocyte|macroph|myeloid|dc\b|dendritic", regex=True).sum())
        if score > best_score:
            best_score, best = score, c
    return best

def load_or_set_negctrl_config(adata_obs: pd.DataFrame) -> dict:
    if CONFIG_PATH.exists():
        return json.loads(CONFIG_PATH.read_text(encoding="utf-8"))
    col = pick_best_celltype_col(adata_obs)
    cfg = {"BLOOD_CELLTYPE_COL": col}
    CONFIG_PATH.write_text(json.dumps(cfg, indent=2), encoding="utf-8")
    print(f"[NEGCTRL CONFIG SAVED] {CONFIG_PATH} -> {cfg}")
    return cfg

def subset_blood_myeloid(adata: ad.AnnData, col: Optional[str]) -> tuple:
    if (col is None) or (col not in adata.obs.columns):
        return adata, "UNFILTERED_no_celltype_col_found"
    s = adata.obs[col].astype(str).str.lower()
    keep = s.str.contains(r"monocyte|mono\b|macroph", regex=True)
    if keep.sum() < 50:
        keep = s.str.contains(r"myeloid|monocyte|mono\b|macroph|dendritic|dc\b", regex=True)
    if keep.sum() == 0:
        return adata, f"UNFILTERED_fixed_col={col}_no_match"
    return adata[keep].copy(), f"FILTERED_fixed_col={col}"


# 9) Batch diagnostic PCA

def batch_diagnostic_pca(adata: ad.AnnData, dataset_id: str,
                         sample_col: str = "sample_id",
                         n_hvg: int = 2000, n_pcs: int = 30):
    """Quick PCA + UMAP colored by sample_id to flag batch effects."""
    if adata.n_obs < 100:
        return None

    tmp = adata.copy()
    if "counts" in tmp.layers:
        tmp.X = tmp.layers["counts"].copy()
    sc.pp.normalize_total(tmp, target_sum=1e4)
    sc.pp.log1p(tmp)

    n_hvg_eff = min(n_hvg, tmp.n_vars - 1)
    if n_hvg_eff < 200:
        return None

    try:
        sc.pp.highly_variable_genes(tmp, n_top_genes=n_hvg_eff, flavor="seurat_v3")
        tmp = tmp[:, tmp.var["highly_variable"]].copy()
        sc.tl.pca(tmp, n_comps=min(n_pcs, tmp.n_vars - 1), svd_solver="arpack")
        sc.pp.neighbors(tmp, n_neighbors=15, n_pcs=min(n_pcs, tmp.n_vars - 1))
        sc.tl.umap(tmp)
    except Exception as e:
        print(f"  [WARN] Batch PCA failed for {dataset_id}: {e}")
        return None

    return tmp


# 10) Run loaders

TS_H5AD     = Path(REG["TS_Thymus_filtered"])
LAVAERT_DIR = Path(REG["GSE144870_Lavaert"])
LE_DIR      = Path(REG["GSE139042_Le"])
ZENG_DIR    = Path(REG["GSE133341_Zeng"])
TS_BLOOD    = Path(REG["TS_Blood_NegCtrl"])

qc_rows, deferred = [], []
adata_dict, adata_dict_neg = {}, {}
qc_snaps = []
per_sample_summaries = []
lava_dbg_rows = []

print("\nAIM 1 / THYMUS: LOADING + QC")

# TS thymus
print("\nTS Thymus")
adata_ts = sc.read_h5ad(TS_H5AD)
adata_ts = normalize_gene_symbols(adata_ts)
adata_ts = add_mt_ribo_flags(adata_ts)
adata_ts = standardize_obs(adata_ts, "TS_Thymus_filtered", "thymus")
qc_snaps.append(qc_snapshot(adata_ts, "TS_Thymus_filtered", "pre"))

# Doublet detection
print("  Running Scrublet ...")
adata_ts = run_scrublet_per_sample(adata_ts, expected_doublet_rate=0.06)
n_doub_ts = int(adata_ts.obs["predicted_doublet"].sum())
print(f"  Doublets detected: {n_doub_ts} / {adata_ts.n_obs} "
      f"({n_doub_ts/max(adata_ts.n_obs,1)*100:.1f}%)")

# Remove doublets before QC
adata_ts = adata_ts[~adata_ts.obs["predicted_doublet"]].copy()

# MAD-based QC
adata_ts, st = qc_filter_mad(adata_ts, max_mt_pct=20.0, max_counts_hard=60000,
                              n_mads_counts=5.0, n_mads_genes=5.0)
qc_snaps.append(qc_snapshot(adata_ts, "TS_Thymus_filtered", "post"))
st.update({"dataset_id": "TS_Thymus_filtered", "source": "TabulaSapiens_filtered",
           "loader": "read_h5ad", "sample_id": "TS",
           "n_doublets_removed": n_doub_ts})
qc_rows.append(st)
per_sample_summaries.append(per_sample_qc_summary(adata_ts, "TS_Thymus_filtered"))
adata_dict["TS_Thymus_filtered"] = adata_ts
print(f"  TS_Thymus_filtered -> {adata_ts.n_obs:,} cells, {adata_ts.n_vars:,} genes")

# Lavaert
print("\nLavaert discovery (recursive):", LAVAERT_DIR)
lava_sets = find_10x_matrix_sets(LAVAERT_DIR)
print(f"  Found {len(lava_sets)} candidate 10x matrix sets under GSE144870.")
for s in lava_sets[:8]:
    print("   -", s["matrix"])

lava_adatas = []
for s in lava_sets:
    try:
        adata, dbg = load_10x_set(s)
        lava_dbg_rows.append(dbg)

        adata = normalize_gene_symbols(adata)
        adata.obs["sample_id"] = dbg["prefix"]
        adata = add_mt_ribo_flags(adata)
        adata = standardize_obs(adata, "GSE144870_Lavaert", "thymus")

        qc_snaps.append(qc_snapshot(adata, "GSE144870_Lavaert", "pre"))

        # Doublet detection per sample
        adata = run_scrublet(adata, expected_doublet_rate=0.06)
        n_d = int(adata.obs["predicted_doublet"].sum())
        adata = adata[~adata.obs["predicted_doublet"]].copy()

        adata, st = qc_filter_mad(adata, min_genes=150, min_counts=300,
                                   max_mt_pct=20.0, n_mads_counts=5.0,
                                   n_mads_genes=5.0)
        qc_snaps.append(qc_snapshot(adata, "GSE144870_Lavaert", "post"))

        st.update({"dataset_id": "GSE144870_Lavaert", "source": "GEO",
                   "loader": "find_10x_matrix_sets", "sample_id": dbg["prefix"],
                   "n_doublets_removed": n_d})
        qc_rows.append(st)

        if adata.n_obs > 0:
            lava_adatas.append(adata)
        else:
            deferred.append({"file": dbg["matrix"], "reason": "0 cells after QC"})
    except Exception as e:
        deferred.append({"file": str(s["matrix"]), "reason": str(e)})
        qc_rows.append({"dataset_id": "GSE144870_Lavaert", "source": "GEO",
                        "loader": "find_10x_matrix_sets",
                        "sample_id": s.get("prefix", ""), "error": str(e)})

df_lava_dbg = pd.DataFrame(lava_dbg_rows)

if lava_adatas:
    adata_lava = ad.concat(lava_adatas, join="outer", merge="same")
    adata_lava = standardize_obs(adata_lava, "GSE144870_Lavaert", "thymus")
    per_sample_summaries.append(per_sample_qc_summary(adata_lava, "GSE144870_Lavaert"))
    adata_dict["GSE144870_Lavaert"] = adata_lava
    print(f"  Lavaert merged -> {adata_lava.n_obs:,} cells, {adata_lava.n_vars:,} genes")
else:
    print("  [WARN] Lavaert still not loaded. See Lavaert_Debug.")

# Le
print("\nLe GSE139042")
le_files = sorted(list(LE_DIR.glob("GSM*Thymus*counts*.txt")) +
                  sorted(list(LE_DIR.glob("GSM*Thymus*counts*.csv"))))
le_loaded = []
for f in le_files:
    try:
        sep = "\t" if f.suffix.lower() == ".txt" else ","
        df = pd.read_csv(f, sep=sep, index_col=0)
        df = df.apply(pd.to_numeric, errors="coerce").fillna(0.0)

        X = sp.csr_matrix(df.T.values.astype(np.float32))
        adata = ad.AnnData(X=X, obs=pd.DataFrame(index=df.columns.astype(str)),
                           var=pd.DataFrame(index=df.index.astype(str)))
        adata = normalize_gene_symbols(adata)
        adata.obs["sample_id"] = f.stem
        adata = add_mt_ribo_flags(adata)
        adata = standardize_obs(adata, "GSE139042_Le", "thymus")

        qc_snaps.append(qc_snapshot(adata, "GSE139042_Le", "pre"))

        # Doublet detection
        adata = run_scrublet(adata, expected_doublet_rate=0.06)
        n_d = int(adata.obs["predicted_doublet"].sum())
        adata = adata[~adata.obs["predicted_doublet"]].copy()

        adata, st = qc_filter_mad(adata, max_mt_pct=20.0, n_mads_counts=5.0,
                                   n_mads_genes=5.0)
        qc_snaps.append(qc_snapshot(adata, "GSE139042_Le", "post"))

        st.update({"dataset_id": "GSE139042_Le", "source": "GEO",
                   "loader": "txt/csv_dense_to_sparse_numeric", "sample_id": f.stem,
                   "n_doublets_removed": n_d})
        qc_rows.append(st)

        if adata.n_obs > 0:
            le_loaded.append(adata)
        else:
            deferred.append({"file": str(f), "reason": "0 cells after QC"})
    except Exception as e:
        deferred.append({"file": str(f), "reason": str(e)})
        qc_rows.append({"dataset_id": "GSE139042_Le", "source": "GEO",
                        "loader": "txt/csv_dense_to_sparse_numeric",
                        "sample_id": f.stem, "error": str(e)})

if le_loaded:
    adata_le = ad.concat(le_loaded, join="outer", merge="same")
    adata_le = standardize_obs(adata_le, "GSE139042_Le", "thymus")
    per_sample_summaries.append(per_sample_qc_summary(adata_le, "GSE139042_Le"))
    adata_dict["GSE139042_Le"] = adata_le
    print(f"  Le merged -> {adata_le.n_obs:,} cells, {adata_le.n_vars:,} genes")

# Zeng — documented skip with justification
print("\nZeng GSE133341")
zeng_candidates = sorted([p for p in ZENG_DIR.glob("GSM*.txt")
                          if "thymus" in p.name.lower()])
for f in zeng_candidates:
    deferred.append({"file": str(f),
                     "reason": "Zeng GSE133341: excluded — bulk RNA-seq or "
                               "non-standard gene×cell format incompatible "
                               "with single-cell QC pipeline. Not a data loss; "
                               "3 independent thymus cohorts remain (TS, Lavaert, Le)."})
print("  [INFO] Zeng excluded (documented justification in deferred table).")

# NEGCTRL
print("\nNEGCTRL")
adata_blood = sc.read_h5ad(TS_BLOOD)
adata_blood = normalize_gene_symbols(adata_blood)
adata_blood = add_mt_ribo_flags(adata_blood)
adata_blood = standardize_obs(adata_blood, "TS_Blood_NegCtrl", "negctrl_peripheral")

cfg = load_or_set_negctrl_config(adata_blood.obs)
col = cfg.get("BLOOD_CELLTYPE_COL", None)
adata_blood_my, flag = subset_blood_myeloid(adata_blood, col)
adata_blood_my.obs["negctrl_filter_flag"] = flag

qc_snaps.append(qc_snapshot(adata_blood_my, "TS_Blood_NegCtrl", "pre"))

# Doublet detection
adata_blood_my = run_scrublet_per_sample(adata_blood_my, expected_doublet_rate=0.06)
n_doub_blood = int(adata_blood_my.obs["predicted_doublet"].sum())
adata_blood_my = adata_blood_my[~adata_blood_my.obs["predicted_doublet"]].copy()

adata_blood_my, st_b = qc_filter_mad(adata_blood_my, min_genes=150, min_counts=300,
                                      max_mt_pct=25.0, n_mads_counts=5.0,
                                      n_mads_genes=5.0)
qc_snaps.append(qc_snapshot(adata_blood_my, "TS_Blood_NegCtrl", "post"))
st_b.update({"dataset_id": "TS_Blood_NegCtrl", "source": "TabulaSapiens_Blood",
             "loader": "read_h5ad", "sample_id": flag,
             "n_doublets_removed": n_doub_blood})
qc_rows.append(st_b)
per_sample_summaries.append(per_sample_qc_summary(adata_blood_my, "TS_Blood_NegCtrl"))
adata_dict_neg["TS_Blood_NegCtrl_Myeloid"] = adata_blood_my
print(f"  Blood negctrl -> {adata_blood_my.n_obs:,} cells, {adata_blood_my.n_vars:,} genes, "
      f"flag={flag}, col={col}, doublets_removed={n_doub_blood}")

# Innate memory stats
gse229940 = next((p for p in [
    Path(REG["GSE229940_dir"]) / "GSE229940_countmtx.csv",
    Path(REG["GSE229940_dir"]) / "GSE229940_countmtx.csv.gz",
] if p.exists()), None)

imm_rows = []
if gse229940:
    head = pd.read_csv(gse229940, nrows=5) if gse229940.suffix.lower() != ".gz" else pd.read_csv(gse229940, compression="gzip", nrows=5)
    imm_rows.append({"file": str(gse229940), "cols": int(head.shape[1]),
                     "note": "innate memory anchor"})
df_imm = pd.DataFrame(imm_rows)
df_imm.to_csv(OUT_IMM / "GSE229940_manifest.csv", index=False)


# 11) Save processed h5ad

saved_rows = []
for ds_id, adata_ in adata_dict.items():
    if adata_.n_obs == 0:
        continue
    adata_ = sanitize_for_write(adata_)
    out = OUT_AIM1 / f"{ds_id}__qc.h5ad"
    adata_.write_h5ad(out)
    saved_rows.append({"dataset_id": ds_id, "path": str(out),
                       "n_cells": int(adata_.n_obs), "n_genes": int(adata_.n_vars)})

for ds_id, adata_ in adata_dict_neg.items():
    if adata_.n_obs == 0:
        continue
    adata_ = sanitize_for_write(adata_)
    out = OUT_NEGCTRL / f"{ds_id}__qc.h5ad"
    adata_.write_h5ad(out)
    saved_rows.append({"dataset_id": ds_id, "path": str(out),
                       "n_cells": int(adata_.n_obs), "n_genes": int(adata_.n_vars)})

df_saved = pd.DataFrame(saved_rows)
df_qc = pd.DataFrame(qc_rows)
df_def = pd.DataFrame(deferred)
df_snaps = pd.concat(qc_snaps, ignore_index=True) if qc_snaps else pd.DataFrame()
df_per_sample = pd.concat(per_sample_summaries, ignore_index=True) if per_sample_summaries else pd.DataFrame()


# 12) Outlier sample flagging

def flag_outlier_samples(df_ps: pd.DataFrame) -> pd.DataFrame:
    """Flag samples that are statistical outliers across QC metrics."""
    if df_ps.empty:
        return df_ps
    df_ps = df_ps.copy()
    for metric in ["median_genes", "median_umi", "median_mt_pct"]:
        if metric not in df_ps.columns:
            continue
        vals = df_ps[metric].dropna().values
        if len(vals) < 3:
            df_ps[f"outlier_{metric}"] = False
            continue
        med = np.median(vals)
        mad = median_abs_deviation(vals)
        if mad < 1e-6:
            df_ps[f"outlier_{metric}"] = False
        else:
            df_ps[f"outlier_{metric}"] = np.abs(df_ps[metric].values - med) > 3 * mad

    outlier_cols = [c for c in df_ps.columns if c.startswith("outlier_")]
    if outlier_cols:
        df_ps["any_outlier"] = df_ps[outlier_cols].any(axis=1)
    return df_ps

df_per_sample = flag_outlier_samples(df_per_sample)


# 13) QC figures

print("\nGENERATING QC FIGURES")

plot_order = ["TS_Thymus_filtered", "GSE144870_Lavaert", "GSE139042_Le", "TS_Blood_NegCtrl"]

# Fig S2a-c: QC distributions pre/post
if not df_snaps.empty:
    df_snaps_plot = df_snaps[df_snaps["dataset_id"].isin(plot_order)].copy()

    for metric, ylabel in [("total_counts", "UMI"), ("n_genes_by_counts", "Genes"),
                           ("pct_counts_mt", "MT%")]:
        fig = plt.figure(figsize=(8.6, 3.9))
        ax = fig.add_subplot(111)
        groups, xticks = [], []
        for ds in plot_order:
            for stage in ["pre", "post"]:
                sub = df_snaps_plot[(df_snaps_plot["dataset_id"] == ds) &
                                   (df_snaps_plot["stage"] == stage)][metric].dropna().values
                if len(sub) == 0:
                    continue
                groups.append(sub)
                xticks.append(f"{ds}\n{stage}")
        if groups:
            parts = ax.violinplot(groups, showmedians=True, showextrema=False)
            for i, pc in enumerate(parts["bodies"]):
                color = PALETTE_STAGE[0] if "pre" in xticks[i] else PALETTE_STAGE[1]
                pc.set_facecolor(color)
                pc.set_alpha(0.6)
            parts["cmedians"].set_color("black")
            ax.set_xticks(range(1, len(xticks) + 1))
            ax.set_xticklabels(xticks, rotation=45, ha="right")
            ax.set_ylabel(ylabel)
            ax.set_title(f"QC distribution pre vs post: {ylabel}")
            ax.grid(axis="y", alpha=0.2)
        plt.tight_layout()
        save_fig(fig, f"Fig2_QC_{metric}_PrePost", kind="Supplementary")

# Fig S2d: Cells per sample barplot (colored by dataset, legend)
if not df_per_sample.empty:
    fig = plt.figure(figsize=(9.0, 3.8))
    ax = fig.add_subplot(111)
    df_bar = df_per_sample.sort_values(["dataset_id", "sample_id"])
    x = np.arange(len(df_bar))

    # Short dataset aliases for legend + color mapping
    _ds_alias = {"TS_Thymus_filtered": "TS Thymus", "GSE144870_Lavaert": "Lavaert",
                 "GSE139042_Le": "Le", "TS_Blood_NegCtrl": "Blood NegCtrl"}
    _ds_palette = {"TS_Thymus_filtered": "#1b9e77", "GSE144870_Lavaert": "#d95f02",
                   "GSE139042_Le": "#7570b3", "TS_Blood_NegCtrl": "#e7298a"}

    # Clean sample labels: strip GSM prefix junk, drop _matrix/_barcodes etc.
    def _clean_sample(sid, dsid):
        s = str(sid)
        s = re.sub(r"^GSM\d+_?", "", s)           # strip GSM ID prefix
        s = re.sub(r"_?(matrix|barcodes|features|filtered).*", "", s, flags=re.I)
        s = s.replace("_", " ").strip()
        if not s:
            s = str(sid)[:15]                       # fallback: truncated original
        return s

    bar_colors = [_ds_palette.get(r["dataset_id"], "#999999") for _, r in df_bar.iterrows()]
    # Dim outlier bars
    bar_edge = ["#de2d26" if r.get("any_outlier", False) else "white"
                for _, r in df_bar.iterrows()]
    bar_lw = [1.5 if r.get("any_outlier", False) else 0.3
              for _, r in df_bar.iterrows()]

    ax.bar(x, df_bar["n_cells"].values, color=bar_colors, edgecolor=bar_edge,
           linewidth=bar_lw)
    ax.set_xticks(x)
    ax.set_xticklabels([_clean_sample(r["sample_id"], r["dataset_id"])
                        for _, r in df_bar.iterrows()],
                       rotation=90, fontsize=3.5)
    ax.set_ylabel("Cells after QC")
    ax.set_title("Cells per sample (red border = outlier)")
    ax.grid(axis="y", alpha=0.2)

    # Legend by dataset
    seen = {}
    for _, r in df_bar.iterrows():
        ds = r["dataset_id"]
        if ds not in seen:
            seen[ds] = _ds_palette.get(ds, "#999999")
    handles = [mpatches.Patch(color=c, label=_ds_alias.get(ds, ds))
               for ds, c in seen.items()]
    ax.legend(handles=handles, fontsize=5.5, frameon=False, loc="upper right")

    plt.tight_layout()
    save_fig(fig, "Fig2_QC_CellsPerSample", kind="Supplementary")

# Fig S2e: Scatter genes vs UMI colored by MT%
for ds_id, adata_ in adata_dict.items():
    if adata_.n_obs < 50 or "total_counts" not in adata_.obs.columns:
        continue
    fig = plt.figure(figsize=(5.5, 4.2))
    ax = fig.add_subplot(111)
    n_plot = min(adata_.n_obs, 15000)
    idx = np.random.choice(adata_.n_obs, size=n_plot, replace=False)
    obs_sub = adata_.obs.iloc[idx]
    sc_plot = ax.scatter(
        obs_sub["total_counts"], obs_sub["n_genes_by_counts"],
        c=obs_sub["pct_counts_mt"], cmap="RdYlGn_r", s=2, alpha=0.5,
        vmin=0, vmax=min(20, obs_sub["pct_counts_mt"].quantile(0.99)),
        rasterized=True
    )
    ax.set_xlabel("Total UMI counts")
    ax.set_ylabel("Genes detected")
    ax.set_title(f"{ds_id}: genes vs UMI (color = MT%)")
    cbar = plt.colorbar(sc_plot, ax=ax, fraction=0.03, pad=0.02)
    cbar.set_label("MT%")
    plt.tight_layout()
    save_fig(fig, f"Fig2_QC_scatter_{ds_id}", kind="Supplementary")

# Fig S2f: Doublet score distribution per dataset
doublet_data = []
for ds_id, adata_ in {**adata_dict, **adata_dict_neg}.items():
    if "doublet_score" in adata_.obs.columns:
        scores = adata_.obs["doublet_score"].dropna().values
        if len(scores) > 0:
            for s in scores[:10000]:
                doublet_data.append({"dataset": ds_id, "doublet_score": s})

if doublet_data:
    df_doub_plot = pd.DataFrame(doublet_data)
    datasets = df_doub_plot["dataset"].unique()
    fig, axes = plt.subplots(1, len(datasets), figsize=(3.5 * len(datasets), 3.2),
                              constrained_layout=True, squeeze=False)
    for i, ds in enumerate(datasets):
        ax = axes[0, i]
        vals = df_doub_plot[df_doub_plot["dataset"] == ds]["doublet_score"].values
        ax.hist(vals, bins=50, color="#636363", edgecolor="white", linewidth=0.3,
                density=True)
        ax.set_xlabel("Doublet score")
        ax.set_ylabel("Density")
        ax.set_title(ds, fontsize=7)
        ax.grid(alpha=0.15)
    save_fig(fig, "Fig2_QC_DoubletScores", kind="Supplementary")

# Fig S2g: Batch diagnostic PCA/UMAP per dataset
print("\nBATCH DIAGNOSTIC PCA")
for ds_id, adata_ in adata_dict.items():
    if adata_.n_obs < 200:
        continue
    sample_col = "sample_id"
    if sample_col not in adata_.obs.columns:
        continue
    n_samples = adata_.obs[sample_col].nunique()
    if n_samples < 2:
        continue

    print(f"  Batch PCA for {ds_id} ({n_samples} samples) ...")
    tmp = batch_diagnostic_pca(adata_, ds_id, sample_col=sample_col)
    if tmp is None:
        continue

    fig, axes = plt.subplots(1, 2, figsize=(9.0, 3.8), constrained_layout=True)

    # PCA
    ax = axes[0]
    pca_xy = tmp.obsm["X_pca"][:, :2]
    samples = tmp.obs[sample_col].astype("category")
    palette = sns.color_palette("Set2", n_colors=samples.cat.categories.size)
    for idx_s, samp in enumerate(samples.cat.categories):
        mask = (samples == samp).values
        ax.scatter(pca_xy[mask, 0], pca_xy[mask, 1], s=2, alpha=0.4,
                   color=palette[idx_s % len(palette)], label=samp,
                   rasterized=True)
    ax.set_xlabel("PC1"); ax.set_ylabel("PC2")
    ax.set_title(f"{ds_id}: PCA by sample")
    if n_samples <= 10:
        ax.legend(markerscale=4, fontsize=5, frameon=False,
                  bbox_to_anchor=(1.01, 1.0), loc="upper left")

    # UMAP
    ax = axes[1]
    umap_xy = tmp.obsm["X_umap"]
    for idx_s, samp in enumerate(samples.cat.categories):
        mask = (samples == samp).values
        ax.scatter(umap_xy[mask, 0], umap_xy[mask, 1], s=2, alpha=0.4,
                   color=palette[idx_s % len(palette)], label=samp,
                   rasterized=True)
    ax.set_xlabel("UMAP1"); ax.set_ylabel("UMAP2")
    ax.set_title(f"{ds_id}: UMAP by sample")

    save_fig(fig, f"Fig2_QC_BatchPCA_{ds_id}", kind="Supplementary")


# 14) Table 2

save_table_xlsx(
    sheets={
        "Aim1_Thymus_QC": df_qc,
        "Per_Sample_QC_Summary": df_per_sample,
        "Saved_Outputs": df_saved,
        "Deferred_or_Skipped": df_def,
        "Lavaert_Debug": df_lava_dbg,
        "InnateMemory_Stats": df_imm,
    },
    fname="Table2",
    kind="Supplementary"
)


# 15) Final report

print("\nNB2 DONE")
print("NEGCTRL config:", CONFIG_PATH)
print(f"""
OUTPUTS:
  h5ad:
    - {OUT_AIM1}/TS_Thymus_filtered__qc.h5ad
    - {OUT_AIM1}/GSE144870_Lavaert__qc.h5ad
    - {OUT_AIM1}/GSE139042_Le__qc.h5ad
    - {OUT_NEGCTRL}/TS_Blood_NegCtrl_Myeloid__qc.h5ad
  Tables:
    - Supplementary_Table2.xlsx (expanded: Per_Sample_QC_Summary sheet added)
  Figures:
    - Supplementary_Fig2_QC_total_counts_PrePost.png
    - Supplementary_Fig2_QC_n_genes_by_counts_PrePost.png
    - Supplementary_Fig2_QC_pct_counts_mt_PrePost.png
    - Supplementary_Fig2_QC_CellsPerSample.png (new)
    - Supplementary_Fig2_QC_scatter_*.png (new, per dataset)
    - Supplementary_Fig2_QC_DoubletScores.png (new)
    - Supplementary_Fig2_QC_BatchPCA_*.png (new, per dataset)
""")