# Data Processing, Quality Control & Spatial Feature Engineering

**PeriNeuroImmuneMap Analysis Pipeline**

---

**Correspondence:** Wei Qin (<wqin@sjtu.edu.cn>)


In [None]:
from pathlib import Path

BASE_DIR = Path('.')
DATA_DIR = BASE_DIR / 'data'
OUTPUT_DIR = BASE_DIR / 'outputs'
FIGURE_DIR = OUTPUT_DIR / 'figures'
TABLE_DIR  = OUTPUT_DIR / 'tables'

for d in [DATA_DIR, OUTPUT_DIR, FIGURE_DIR, TABLE_DIR]:
    d.mkdir(parents=True, exist_ok=True)

print('✓ Paths configured')


In [None]:
# Notebook 1: Data Processing, QC & Spatial Feature Engineering

import os, sys, platform, time, json, gzip, pickle, warnings, hashlib, gc
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import scipy.sparse as sp
from scipy.io import mmread
from scipy.stats import median_abs_deviation, spearmanr
from scipy.spatial import Delaunay

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import patheffects as pe

from sklearn.neighbors import NearestNeighbors

warnings.filterwarnings("ignore")

import anndata as ad
import scanpy as sc
import psutil

sc.settings.verbosity = 2
sc.settings.n_jobs = -1

RUN_TS = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
import random
random.seed(RANDOM_SEED)

SAVE_DPI = 1200
RUN_BENCHMARK = True

BASE_DIR = Path(r"D:\个人文件夹\Sanwal\Neuro")
RAW_DATA_DIR = BASE_DIR / "Raw data"

MANUSCRIPT_DIR = BASE_DIR / "Manuscript Data"
FIGURES_DIR = MANUSCRIPT_DIR / "Figures"
TABLES_DIR = MANUSCRIPT_DIR / "Tables"
PROCESSED_DIR = BASE_DIR / "processed" / "notebook1"

for d in [FIGURES_DIR, TABLES_DIR, PROCESSED_DIR]:
    d.mkdir(parents=True, exist_ok=True)

GSE_DIR = RAW_DATA_DIR / "GSE289745"
if not GSE_DIR.exists():
    raise FileNotFoundError(f"GSE289745 folder not found: {GSE_DIR}")

LOG_PATH = PROCESSED_DIR / "run_log.txt"
REQ_PATH = PROCESSED_DIR / "requirements.txt"
CHK_PATH = PROCESSED_DIR / "checksums.md5"
CHECKLIST_PATH = PROCESSED_DIR / "reproducibility_checklist.txt"


def set_n_style():
    mpl.rcParams.update({
        "font.family": "sans-serif",
        "font.sans-serif": ["Arial", "Helvetica", "DejaVu Sans"],
        "font.size": 7,
        "axes.titlesize": 7,
        "axes.labelsize": 7,
        "xtick.labelsize": 6,
        "ytick.labelsize": 6,
        "legend.fontsize": 6,
        "axes.linewidth": 0.6,
        "lines.linewidth": 0.8,
        "xtick.major.width": 0.6,
        "ytick.major.width": 0.6,
        "xtick.direction": "in",
        "ytick.direction": "in",
        "savefig.dpi": SAVE_DPI,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
    })


set_n_style()


def log(msg: str):
    print(msg)
    with open(LOG_PATH, "a", encoding="utf-8") as f:
        f.write(msg.rstrip() + "\n")


with open(LOG_PATH, "w", encoding="utf-8") as f:
    f.write("PeriNeuroImmuneMap Notebook 1 log\n")
    f.write(f"Start: {RUN_TS}\n")

log(f"Python: {sys.version}")
log(f"Platform: {platform.platform()}")
log(f"Seed: {RANDOM_SEED}")
log(f"CPU logical: {psutil.cpu_count(True)} | CPU physical: {psutil.cpu_count(False)}")
log(f"RAM total GB: {psutil.virtual_memory().total/(1024**3):.2f}")
log(f"scanpy: {sc.__version__} | anndata: {ad.__version__}")


REQ_PATH.write_text(
    "\n".join([
        "# PeriNeuroImmuneMap Notebook 1 requirements",
        f"# Generated: {RUN_TS}",
        f"python=={sys.version.split()[0]}",
        "scikit-misc>=0.3.1",
        "scipy>=1.9.0",
        "matplotlib>=3.6.0",
        "scikit-learn>=1.1.0",
        "openpyxl>=3.0.0",
        "psutil>=5.9.0",
    ]),
    encoding="utf-8"
)
log(f"Saved requirements: {REQ_PATH}")


def find_existing(base: Path, candidates):
    for name in candidates:
        p = base / name
        if p.exists():
            return p
    return None


def open_maybe_gz(path: Path, mode="rt"):
    if path.suffix == ".gz":
        return gzip.open(path, mode)
    return open(path, mode)


def md5_file(path: Path, chunk_size=1024 * 1024):
    h = hashlib.md5()
    with open(path, "rb") as f:
        while True:
            b = f.read(chunk_size)
            if not b:
                break
            h.update(b)
    return h.hexdigest()


def safe_sparse_sum(x, axis=None):
    if sp.issparse(x):
        return np.asarray(x.sum(axis=axis)).ravel()
    return np.asarray(np.sum(x, axis=axis)).ravel()


def load_visium_components(sample_id: str, base_dir: Path) -> ad.AnnData:
    mtx = find_existing(base_dir, [f"{sample_id}_matrix.mtx.gz", f"{sample_id}_matrix.mtx"])
    feat = find_existing(base_dir, [f"{sample_id}_features.tsv.gz", f"{sample_id}_features.tsv"])
    bc  = find_existing(base_dir, [f"{sample_id}_barcodes.tsv.gz", f"{sample_id}_barcodes.tsv"])
    pos = find_existing(base_dir, [f"{sample_id}_tissue_positions_list.csv.gz", f"{sample_id}_tissue_positions_list.csv"])
    sf  = find_existing(base_dir, [f"{sample_id}_scalefactors_json.json.gz", f"{sample_id}_scalefactors_json.json"])

    if any(x is None for x in [mtx, feat, bc, pos]):
        missing = [n for n,x in [("mtx",mtx),("features",feat),("barcodes",bc),("positions",pos)] if x is None]
        raise FileNotFoundError(f"{sample_id}: missing {missing}")

    with open_maybe_gz(mtx, "rb") as f:
        X = mmread(f).T.tocsr()

    feat_df = pd.read_csv(feat, sep="\t", header=None)
    if feat_df.shape[1] >= 3:
        feat_df = feat_df.iloc[:, :3].copy()
        feat_df.columns = ["gene_id", "gene_symbol", "feature_type"]
    elif feat_df.shape[1] == 2:
        feat_df.columns = ["gene_id", "gene_symbol"]
        feat_df["feature_type"] = "Gene Expression"
    else:
        raise ValueError(f"{sample_id}: unexpected features.tsv cols={feat_df.shape[1]}")

    feat_df["gene_id"] = feat_df["gene_id"].astype(str)
    feat_df["gene_symbol"] = feat_df["gene_symbol"].astype(str)
    feat_df["feature_type"] = feat_df["feature_type"].astype(str)

    barcodes = pd.read_csv(bc, sep="\t", header=None, names=["barcode"])
    positions = pd.read_csv(
        pos, header=None,
        names=["barcode","in_tissue","array_row","array_col","pxl_row_in_fullres","pxl_col_in_fullres"]
    )

    a = ad.AnnData(X=X, obs=barcodes, var=feat_df)

    a.var_names = a.var["gene_id"].astype(str)
    a.var_names_make_unique()

    a.obs = a.obs.merge(positions, on="barcode", how="left")
    a.obs.index = a.obs["barcode"].astype(str)

    a = a[a.obs["in_tissue"] == 1, :].copy()
    a.obsm["spatial"] = a.obs[["pxl_row_in_fullres","pxl_col_in_fullres"]].to_numpy()

    a.obs["sample_id"] = sample_id
    a.obs["gsm_id"] = sample_id.split("_")[0]

    a.uns["scalefactors"] = {}
    if sf is not None:
        with open_maybe_gz(sf, "rt") as f:
            try:
                a.uns["scalefactors"] = json.load(f)
            except Exception:
                a.uns["scalefactors"] = {}

    return a


# Detect sample IDs
all_files = sorted([p.name for p in GSE_DIR.glob("*") if p.is_file()])
matrix_files = [f for f in all_files if f.endswith("_matrix.mtx.gz") or f.endswith("_matrix.mtx")]
sample_ids = sorted(list({f.split("_matrix.mtx")[0] for f in matrix_files}))
log(f"Detected Visium samples: {len(sample_ids)}")
if len(sample_ids) == 0:
    raise RuntimeError("No _matrix.mtx(.gz) found in GSE folder.")


# Checksums (subset)
to_hash = []
for sid in sample_ids[:3]:
    for suf in ["_matrix.mtx.gz","_matrix.mtx","_features.tsv.gz","_features.tsv","_barcodes.tsv.gz","_barcodes.tsv",
                "_tissue_positions_list.csv.gz","_tissue_positions_list.csv","_scalefactors_json.json.gz","_scalefactors_json.json"]:
        p = GSE_DIR / f"{sid}{suf}"
        if p.exists():
            to_hash.append(p)

with open(CHK_PATH, "w", encoding="utf-8") as f:
    f.write(f"# MD5 checksums (subset)\n# Generated: {RUN_TS}\n\n")
    for p in to_hash:
        try:
            f.write(f"{md5_file(p)}  {p.name}\n")
        except Exception:
            pass
log(f"Saved checksums: {CHK_PATH}")


# Load all samples
adata_list = []
for sid in sample_ids:
    a = load_visium_components(sid, GSE_DIR)
    adata_list.append(a)
    log(f"Loaded {sid}: spots={a.n_obs:,} genes={a.n_vars:,}")

adata = ad.concat(
    adata_list,
    label="sample_id_cat",
    keys=[a.obs["sample_id"].iloc[0] for a in adata_list],
    index_unique="_",
    join="inner",
    merge="same"
)

log(f"Combined spots={adata.n_obs:,} genes={adata.n_vars:,} samples={adata.obs['sample_id'].nunique()}")
log(f"adata.var columns: {list(adata.var.columns)}")

tc0 = safe_sparse_sum(adata.X, axis=1)
ng0 = safe_sparse_sum(adata.X > 0, axis=1)
log(f"Mean UMI/spot={tc0.mean():.1f} median genes/spot={np.median(ng0):.1f}")


# Supplementary Table 1
rows = []
for sid in sorted(adata.obs["sample_id"].unique()):
    a = adata[adata.obs["sample_id"] == sid]
    tc = safe_sparse_sum(a.X, axis=1)
    ng = safe_sparse_sum(a.X > 0, axis=1)
    rows.append({
        "Analysis_Date": RUN_TS,
        "Sample_ID": sid,
        "GSM_ID": sid.split("_")[0],
        "N_Spots": int(a.n_obs),
        "Total_UMI": int(tc.sum()),
        "Mean_UMI_per_Spot": float(tc.mean()),
        "Median_Genes_per_Spot": float(np.median(ng)),
        "Cancer_Type": "Cutaneous squamous cell carcinoma",
        "Technology": "10x Visium",
        "Treatment_Status": "Treatment-naïve",
        "Dataset": "GSE289745",
    })
supp1 = pd.DataFrame(rows)
supp1_path = TABLES_DIR / "Supplementary_Table_1.xlsx"
with pd.ExcelWriter(supp1_path, engine="openpyxl") as w:
    supp1.to_excel(w, index=False, sheet_name="Sample_Metadata")
log(f"Saved {supp1_path}")


# QC metrics
sc.pp.calculate_qc_metrics(adata, percent_top=None, log1p=False, inplace=True)
sym = adata.var["gene_symbol"].astype(str).str.upper()
adata.var["mt"] = sym.str.startswith("MT-")
adata.var["ribo"] = sym.str.match(r"^RP[SL]")
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt","ribo"], percent_top=None, log1p=False, inplace=True)

def mad_bounds(x, n_mads=5):
    med = np.median(x)
    mad = median_abs_deviation(x)
    return float(med - n_mads*mad), float(med + n_mads*mad)

tc = adata.obs["total_counts"].to_numpy()
ng = adata.obs["n_genes_by_counts"].to_numpy()
mt = adata.obs["pct_counts_mt"].to_numpy()

tc_lo, tc_hi = mad_bounds(tc, 5)
ng_lo, ng_hi = mad_bounds(ng, 5)
tc_lo = max(tc_lo, 500.0)
ng_lo = max(ng_lo, 200.0)

mt_hi = 100.0 if int(adata.var["mt"].sum()) == 0 else min(float(np.median(mt) + 5 * median_abs_deviation(mt)), 20.0)

adata.uns["qc_thresholds"] = {
    "total_counts_min": tc_lo,
    "total_counts_max": tc_hi,
    "n_genes_min": ng_lo,
    "n_genes_max": ng_hi,
    "pct_counts_mt_max": mt_hi,
    "method": "MAD(5)+caps"
}
log(f"QC thresholds: counts [{tc_lo:.0f},{tc_hi:.0f}] genes [{ng_lo:.0f},{ng_hi:.0f}] mito<{mt_hi:.2f}%")


pass_mask = (
    (adata.obs["total_counts"] >= tc_lo) & (adata.obs["total_counts"] <= tc_hi) &
    (adata.obs["n_genes_by_counts"] >= ng_lo) & (adata.obs["n_genes_by_counts"] <= ng_hi) &
    (adata.obs["pct_counts_mt"] <= mt_hi)
)

n_spots_before = adata.n_obs
adata = adata[pass_mask].copy()
log(f"Spot filter: {n_spots_before:,}->{adata.n_obs:,} ({adata.n_obs/n_spots_before*100:.1f}%)")


# Gene filter (>=1% spots)
gene_n = safe_sparse_sum(adata.X > 0, axis=0)
min_spots = max(1, int(0.01 * adata.n_obs))
keep_gene = gene_n >= min_spots
n_genes_before = adata.n_vars
adata = adata[:, keep_gene].copy()

# MT/ribo logging then removal
sym = adata.var["gene_symbol"].astype(str).str.upper()
adata.var["mt"] = sym.str.startswith("MT-")
adata.var["ribo"] = sym.str.match(r"^RP[SL]")
mt_before = int(adata.var["mt"].sum())
ribo_before = int(adata.var["ribo"].sum())
log(f"MT genes before removal: {mt_before}")
log(f"Ribo genes before removal: {ribo_before}")

adata = adata[:, ~(adata.var["mt"] | adata.var["ribo"])].copy()
log(f"Gene filter: {n_genes_before:,}->{adata.n_vars:,} (min_spots={min_spots}; removed MT/ribo={mt_before+ribo_before})")


# Batch check (fast)
tmp = adata.copy()
sc.pp.normalize_total(tmp, target_sum=1e4)
sc.pp.log1p(tmp)
sc.pp.highly_variable_genes(tmp, n_top_genes=2000, batch_key="sample_id", flavor="seurat_v3", layer=None)
tmp = tmp[:, tmp.var["highly_variable"]].copy()
sc.pp.scale(tmp, max_value=10, zero_center=False)
sc.tl.pca(tmp, n_comps=30, svd_solver="arpack")
sample_codes = pd.Categorical(tmp.obs["sample_id"]).codes
pc1_corr = spearmanr(tmp.obsm["X_pca"][:, 0], sample_codes).correlation
pc2_corr = spearmanr(tmp.obsm["X_pca"][:, 1], sample_codes).correlation
batch_flag = (abs(pc1_corr) > 0.5) or (abs(pc2_corr) > 0.5)
adata.uns["batch_assessment"] = {"pc1_spearman": float(pc1_corr), "pc2_spearman": float(pc2_corr), "flag_strong": bool(batch_flag)}
log(f"Batch check: PC1={pc1_corr:.3f} PC2={pc2_corr:.3f} strong={batch_flag}")
del tmp
gc.collect()


# Supplementary Figure 1
sc.pp.calculate_qc_metrics(adata, percent_top=None, log1p=False, inplace=True)

fig = plt.figure(figsize=(7.2, 4.6), dpi=SAVE_DPI)
gs = fig.add_gridspec(2, 3, hspace=0.45, wspace=0.40)

ax1 = fig.add_subplot(gs[0,0])
ax1.boxplot(adata.obs["total_counts"].values, showfliers=False)
ax1.set_title("Total UMI")
ax1.set_ylabel("UMI")
ax1.set_xticks([])

ax2 = fig.add_subplot(gs[0,1])
ax2.boxplot(adata.obs["n_genes_by_counts"].values, showfliers=False)
ax2.set_title("Genes")
ax2.set_ylabel("Genes")
ax2.set_xticks([])

ax3 = fig.add_subplot(gs[0,2])
ax3.boxplot(adata.obs["pct_counts_mt"].values, showfliers=False)
ax3.axhline(y=mt_hi, linestyle="--", linewidth=0.8)
ax3.set_title("Mito %")
ax3.set_ylabel("%")
ax3.set_xticks([])

ax4 = fig.add_subplot(gs[1,0])
ax4.scatter(adata.obs["total_counts"], adata.obs["n_genes_by_counts"], s=1, alpha=0.20, rasterized=True)
ax4.set_xlabel("UMI")
ax4.set_ylabel("Genes")
ax4.set_title("UMI vs Genes")

ax5 = fig.add_subplot(gs[1,1])
ax5.scatter(adata.obs["total_counts"], adata.obs["pct_counts_mt"], s=1, alpha=0.20, rasterized=True)
ax5.axhline(y=mt_hi, linestyle="--", linewidth=0.8)
ax5.set_xlabel("UMI")
ax5.set_ylabel("Mito %")
ax5.set_title("UMI vs Mito %")

ax6 = fig.add_subplot(gs[1,2])
groups = [adata[adata.obs["sample_id"] == s].obs["total_counts"].values for s in sorted(adata.obs["sample_id"].unique())]
ax6.boxplot(groups, labels=[f"S{i+1}" for i in range(len(groups))], showfliers=False)
ax6.set_title("UMI by sample")
ax6.set_ylabel("UMI")
ax6.tick_params(axis="x", rotation=45)

fig.suptitle("Supplementary Figure 1: QC", y=0.99)
supp_fig1_path = FIGURES_DIR / "Supplementary_Figure_1.png"
fig.savefig(supp_fig1_path, dpi=SAVE_DPI, bbox_inches="tight")
plt.close(fig)
log(f"Saved {supp_fig1_path}")


# Preserve counts layer
adata.layers["counts"] = adata.X.copy()
adata.raw = adata.copy()


# Normalize + log1p
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata.layers["log1p_norm"] = adata.X.copy()
log("Normalization done")


# HVGs from counts layer
sc.pp.highly_variable_genes(
    adata,
    n_top_genes=3000,
    batch_key="sample_id",
    flavor="seurat_v3",
    layer="counts"
)
log(f"HVGs: {int(adata.var['highly_variable'].sum())}")


# Spatial graphs
def build_knn_graph_blockdiag(adata_in, k: int):
    n = adata_in.n_obs
    rows, cols, dists = [], [], []
    for sid in sorted(adata_in.obs["sample_id"].unique()):
        idx = np.where(adata_in.obs["sample_id"].values == sid)[0]
        coords = adata_in.obsm["spatial"][idx]
        if len(idx) <= 2:
            continue
        nn = NearestNeighbors(n_neighbors=min(k+1, len(idx)), algorithm="auto").fit(coords)
        dist, nbr = nn.kneighbors(coords)
        for i_local in range(len(idx)):
            i = idx[i_local]
            for j_pos in range(1, nbr.shape[1]):
                j = idx[nbr[i_local, j_pos]]
                rows.append(i); cols.append(j); dists.append(float(dist[i_local, j_pos]))
    A = sp.csr_matrix((np.ones(len(rows), dtype=np.float32), (rows, cols)), shape=(n, n))
    D = sp.csr_matrix((np.array(dists, dtype=np.float32), (rows, cols)), shape=(n, n))
    A = A.maximum(A.T)
    D = D.maximum(D.T)
    return A, D

def build_delaunay_graph_blockdiag(adata_in):
    n = adata_in.n_obs
    rows, cols = [], []
    skipped = []
    for sid in sorted(adata_in.obs["sample_id"].unique()):
        idx = np.where(adata_in.obs["sample_id"].values == sid)[0]
        coords = adata_in.obsm["spatial"][idx]
        if len(idx) < 4:
            skipped.append((sid, "too_few_points"))
            continue
        try:
            tri = Delaunay(coords, qhull_options="QJ")
            simplices = tri.simplices
            for s in simplices:
                pairs = [(s[0],s[1]),(s[1],s[2]),(s[0],s[2])]
                for a,b in pairs:
                    rows.append(idx[a]); cols.append(idx[b])
                    rows.append(idx[b]); cols.append(idx[a])
        except Exception as e:
            skipped.append((sid, f"qhull_fail:{type(e).__name__}"))
            continue

    A = sp.csr_matrix((np.ones(len(rows), dtype=np.float32), (rows, cols)), shape=(n, n))
    A = A.maximum(A.T)
    coords_all = adata_in.obsm["spatial"]
    rr, cc = A.nonzero()
    dist = np.sqrt(((coords_all[rr] - coords_all[cc])**2).sum(axis=1)).astype(np.float32)
    D = sp.csr_matrix((dist, (rr, cc)), shape=(n, n))
    return A, D, skipped

log("Building spatial graphs...")
A6, D6 = build_knn_graph_blockdiag(adata, 6)
A10, D10 = build_knn_graph_blockdiag(adata, 10)
Adel, Ddel, del_skipped = build_delaunay_graph_blockdiag(adata)

adata.obsp["spatial_k6_connectivities"] = A6
adata.obsp["spatial_k6_distances"] = D6
adata.obsp["spatial_k10_connectivities"] = A10
adata.obsp["spatial_k10_distances"] = D10
adata.obsp["spatial_delaunay_connectivities"] = Adel
adata.obsp["spatial_delaunay_distances"] = Ddel
adata.obsp["spatial_connectivities"] = A6.copy()
adata.obsp["spatial_distances"] = D6.copy()


# µm↔px calibration
VISIUM_SPOT_DIAMETER_UM = 55.0
scale_by_sample = {}
missing_sf = 0
for sid in sorted(adata.obs["sample_id"].unique()):
    sf_path = find_existing(GSE_DIR, [f"{sid}_scalefactors_json.json.gz", f"{sid}_scalefactors_json.json"])
    if sf_path is None:
        missing_sf += 1
        continue
    with open_maybe_gz(sf_path, "rt") as f:
        sf = json.load(f)
    spot_diam_px = sf.get("spot_diameter_fullres")
    um_per_px = VISIUM_SPOT_DIAMETER_UM / float(spot_diam_px)
    px_per_um = 1.0 / um_per_px
    scale_by_sample[sid] = {
        "spot_diameter_fullres_px": float(spot_diam_px),
        "um_per_px": float(um_per_px),
        "px_per_um": float(px_per_um)
    }

adata.uns["spatial_scale"] = {"visium_spot_diameter_um_assumed": VISIUM_SPOT_DIAMETER_UM, "per_sample": scale_by_sample}
log(f"Spatial scale missing scalefactors: {missing_sf}")

perineural_thresholds_um = [50, 100, 200, 500]
adata.uns["perineural_thresholds_um"] = perineural_thresholds_um
adata.uns["perineural_thresholds_px_by_sample"] = {
    sid: {f"{u}um": float(u * scale_by_sample[sid]["px_per_um"]) for u in perineural_thresholds_um}
    for sid in scale_by_sample.keys()
}


# Moran's I 
def morans_i_sparse(X, W_csr, n_perms=100, seed=42):
    # X: (n, g) dense float32; W_csr: row-normalized sparse
    rng = np.random.default_rng(seed)
    n, g = X.shape
    # center
    Xc = X - X.mean(axis=0, keepdims=True)
    denom = (Xc * Xc).sum(axis=0)
    denom = np.where(denom == 0, np.nan, denom)

    WX = W_csr @ Xc  # (n, g)
    num = (Xc * WX).sum(axis=0)
    I = (n * num) / denom  # since sum(W)=n for row-normalized

    p = np.full(g, np.nan, dtype=float)
    if n_perms > 0:
        ge = np.zeros(g, dtype=int)
        for _ in range(n_perms):
            perm = rng.permutation(n)
            Xp = Xc[perm, :]
            WXp = W_csr @ Xp
            nump = (Xp * WXp).sum(axis=0)
            Ip = (n * nump) / denom
            ge += (np.abs(Ip) >= np.abs(I)).astype(int)
        p = (ge + 1.0) / (n_perms + 1.0)

    return I.astype(float), p.astype(float)

def bh_fdr(pvals):
    p = np.asarray(pvals, dtype=float)
    m = np.sum(np.isfinite(p))
    q = np.full_like(p, np.nan, dtype=float)
    if m == 0:
        return q
    idx = np.where(np.isfinite(p))[0]
    pv = p[idx]
    order = np.argsort(pv)
    pv_sorted = pv[order]
    ranks = np.arange(1, m+1)
    q_sorted = pv_sorted * m / ranks
    q_sorted = np.minimum.accumulate(q_sorted[::-1])[::-1]
    q[idx[order]] = np.clip(q_sorted, 0, 1)
    return q

# Build row-normalized W from k6 adjacency
W = A6.tocsr().astype(np.float32)
row_sums = np.asarray(W.sum(axis=1)).ravel().astype(np.float32)
row_sums[row_sums == 0] = 1.0
W = sp.diags(1.0 / row_sums) @ W  # row-normalize

# Pick top 500 HVGs by dispersions_norm
hv_mask = adata.var["highly_variable"].values
hvg_idx = np.where(hv_mask)[0]
if "dispersions_norm" in adata.var.columns:
    disp = adata.var["dispersions_norm"].to_numpy()
    hvg_idx = hvg_idx[np.argsort(disp[hvg_idx])[::-1]]
top_n = 500
top_hvg_idx = hvg_idx[:min(top_n, len(hvg_idx))]
top_hvg_names = adata.var_names[top_hvg_idx]

# Use log1p normalized expression for Moran
X_moran = adata[:, top_hvg_names].X
if sp.issparse(X_moran):
    X_moran = X_moran.toarray()
X_moran = X_moran.astype(np.float32)

log(f"Moran's I: genes={X_moran.shape[1]} perms=100")
I, pvals = morans_i_sparse(X_moran, W, n_perms=100, seed=RANDOM_SEED)
qvals = bh_fdr(pvals)

# Store Moran into adata.var for those genes
adata.var["morans_i"] = np.nan
adata.var["morans_p"] = np.nan
adata.var["morans_q"] = np.nan
adata.var.loc[top_hvg_names, "morans_i"] = I
adata.var.loc[top_hvg_names, "morans_p"] = pvals
adata.var.loc[top_hvg_names, "morans_q"] = qvals


# Supplementary Table 2 (HVGs + Moran info where available)
hvg_df = adata.var.loc[adata.var["highly_variable"]].copy()
hvg_df["var_id"] = hvg_df.index.astype(str)
cols = [c for c in ["var_id","gene_id","gene_symbol","feature_type","means","dispersions","dispersions_norm","morans_i","morans_p","morans_q"] if c in hvg_df.columns]
supp2 = hvg_df[cols].copy()
if "dispersions_norm" in supp2.columns:
    supp2 = supp2.sort_values("dispersions_norm", ascending=False)
supp2_path = TABLES_DIR / "Supplementary_Table_2.xlsx"
with pd.ExcelWriter(supp2_path, engine="openpyxl") as w:
    supp2.to_excel(w, index=False, sheet_name="HVG_Moran")
log(f"Saved {supp2_path}")


# Embeddings for Main Figure 1
sc.pp.scale(adata, max_value=10)
sc.tl.pca(adata, n_comps=50, use_highly_variable=True, svd_solver="arpack")

var_ratio = adata.uns["pca"]["variance_ratio"]
cum = np.cumsum(var_ratio)
n_pcs = int(np.where(cum > 0.8)[0][0] + 1) if np.any(cum > 0.8) else 30
n_pcs_use = min(n_pcs, 30)
adata.uns["n_pcs_use"] = n_pcs_use

sc.pp.neighbors(adata, n_neighbors=15, n_pcs=n_pcs_use)
sc.tl.umap(adata)
sc.tl.leiden(adata, resolution=0.5, key_added="leiden_r0.5")

rep_sample = sorted(adata.obs["sample_id"].unique())[0]
rep = adata[adata.obs["sample_id"] == rep_sample].copy()


# Main Figure
# Panel A: legend below; Panel B: cluster labels at centroids (no legend)
fig = plt.figure(figsize=(7.2, 2.4), dpi=SAVE_DPI)
gs = fig.add_gridspec(1, 3, wspace=0.35)

ax1 = fig.add_subplot(gs[0,0])
handles_A = []
labels_A = []
for s in sorted(adata.obs["sample_id"].unique()):
    m = (adata.obs["sample_id"] == s).to_numpy()
    h = ax1.scatter(adata.obsm["X_umap"][m,0], adata.obsm["X_umap"][m,1], s=2, alpha=0.7, label=s, rasterized=True)
    handles_A.append(h); labels_A.append(s)
ax1.set_title("A  UMAP by sample")
ax1.set_xlabel("UMAP1")
ax1.set_ylabel("UMAP2")

# Put legend below the axis
legA = ax1.legend(
    loc="upper center",
    bbox_to_anchor=(0.5, -0.20),
    ncol=2,
    frameon=False,
    handletextpad=0.3,
    columnspacing=0.8,
    borderaxespad=0.0,
    markerscale=2.5
)

ax2 = fig.add_subplot(gs[0,1])
# Plot points colored by cluster
clusters = adata.obs["leiden_r0.5"].astype(str).values
uniq = sorted(np.unique(clusters), key=lambda x: int(x) if x.isdigit() else x)
# stable color map
cmap = plt.cm.tab20
colors = {c: cmap(i % 20) for i, c in enumerate(uniq)}
for c in uniq:
    m = (clusters == c)
    ax2.scatter(adata.obsm["X_umap"][m,0], adata.obsm["X_umap"][m,1], s=2, alpha=0.7, color=colors[c], rasterized=True)

# Annotate cluster IDs at centroids
for c in uniq:
    m = (clusters == c)
    if m.sum() < 50:
        continue
    x = np.median(adata.obsm["X_umap"][m,0])
    y = np.median(adata.obsm["X_umap"][m,1])
    t = ax2.text(x, y, c, ha="center", va="center", fontsize=6, color="black")
    t.set_path_effects([pe.withStroke(linewidth=2.0, foreground="white")])

ax2.set_title("B  UMAP by Leiden")
ax2.set_xlabel("UMAP1")
ax2.set_ylabel("UMAP2")

ax3 = fig.add_subplot(gs[0,2])
coords = rep.obsm["spatial"]
tc_rep = rep.obs["total_counts"].to_numpy()
scat = ax3.scatter(coords[:,1], coords[:,0], c=tc_rep, s=5, alpha=0.9, rasterized=True)
ax3.invert_yaxis()
ax3.set_title("C  Spatial Total UMI\n" + rep_sample)
ax3.set_xlabel("X")
ax3.set_ylabel("Y")
cbar = fig.colorbar(scat, ax=ax3, fraction=0.05, pad=0.02)
cbar.set_label("Total UMI")

# Give bottom space for Panel A legend
fig.subplots_adjust(bottom=0.28)

main_fig1_path = FIGURES_DIR / "Main_Figure_1.png"
fig.savefig(main_fig1_path, dpi=SAVE_DPI, bbox_inches="tight")
plt.close(fig)
log(f"Saved {main_fig1_path}")


# Benchmark + Supplementary Figure 2
def bench_step(fn, label):
    mem0 = psutil.Process(os.getpid()).memory_info().rss
    t0 = time.time()
    fn()
    t1 = time.time()
    mem1 = psutil.Process(os.getpid()).memory_info().rss
    return {"step": label, "time_sec": float(t1-t0), "mem_delta_mb": float((mem1-mem0)/(1024**2))}

bench_rows = []
n_total = adata.n_obs
for frac in [0.10, 0.25, 0.50, 1.00]:
    if frac < 1.0:
        n_sub = int(n_total * frac)
        idx = np.random.choice(n_total, n_sub, replace=False)
        a_sub = adata[idx].copy()
        label = f"{int(frac*100)}%"
    else:
        a_sub = adata.copy()
        label = "100%"

    def step_spatial_k6():
        _A, _D = build_knn_graph_blockdiag(a_sub, 6)
        _ = _A.nnz + _D.nnz

    def step_pca30():
        sc.pp.scale(a_sub, max_value=10)
        sc.tl.pca(a_sub, n_comps=30, svd_solver="arpack")

    gc.collect()
    r1 = bench_step(step_spatial_k6, f"{label}_spatial_k6")
    gc.collect()
    r2 = bench_step(step_pca30, f"{label}_pca")
    r1["spots"] = int(a_sub.n_obs); r2["spots"] = int(a_sub.n_obs)
    bench_rows.extend([r1, r2])

bench_df = pd.DataFrame(bench_rows)
supp3_path = TABLES_DIR / "Supplementary_Table_3.xlsx"
with pd.ExcelWriter(supp3_path, engine="openpyxl") as w:
    bench_df.to_excel(w, index=False, sheet_name="Benchmark")
log(f"Saved {supp3_path}")

fig = plt.figure(figsize=(7.2, 2.4), dpi=SAVE_DPI)
gs = fig.add_gridspec(1, 2, wspace=0.35)
ax1 = fig.add_subplot(gs[0,0])
ax2 = fig.add_subplot(gs[0,1])

t1 = bench_df[bench_df["step"].str.contains("spatial_k6")].copy()
t2 = bench_df[bench_df["step"].str.contains("_pca")].copy()

ax1.plot(t1["spots"], t1["time_sec"], marker="o")
ax1.set_title("Spatial k6 runtime")
ax1.set_xlabel("Spots")
ax1.set_ylabel("Sec")

ax2.plot(t2["spots"], t2["time_sec"], marker="o")
ax2.set_title("PCA runtime")
ax2.set_xlabel("Spots")
ax2.set_ylabel("Sec")

supp_fig2_path = FIGURES_DIR / "Supplementary_Figure_2.png"
fig.savefig(supp_fig2_path, dpi=SAVE_DPI, bbox_inches="tight")
plt.close(fig)
log(f"Saved {supp_fig2_path}")


# Save processed artifacts
adata.obs.index.name = "obs_id"  

adata_path = PROCESSED_DIR / "processed_spatial_adata.h5ad"
adata.write_h5ad(adata_path)

graphs = {
    "spatial_k6_connectivities": adata.obsp["spatial_k6_connectivities"],
    "spatial_k6_distances": adata.obsp["spatial_k6_distances"],
    "spatial_k10_connectivities": adata.obsp["spatial_k10_connectivities"],
    "spatial_k10_distances": adata.obsp["spatial_k10_distances"],
    "spatial_delaunay_connectivities": adata.obsp["spatial_delaunay_connectivities"],
    "spatial_delaunay_distances": adata.obsp["spatial_delaunay_distances"],
    "default_connectivities": adata.obsp["spatial_connectivities"],
    "default_distances": adata.obsp["spatial_distances"],
    "qc_thresholds": adata.uns.get("qc_thresholds"),
    "batch_assessment": adata.uns.get("batch_assessment"),
    "spatial_scale": adata.uns.get("spatial_scale"),
    "perineural_thresholds_um": adata.uns.get("perineural_thresholds_um"),
    "perineural_thresholds_px_by_sample": adata.uns.get("perineural_thresholds_px_by_sample"),
    "delaunay_skipped": del_skipped,
}
graphs_path = PROCESSED_DIR / "spatial_graphs.pkl"
with open(graphs_path, "wb") as f:
    pickle.dump(graphs, f)

checklist = [
    "Notebook 1 reproducibility checklist",
    f"Run timestamp: {RUN_TS}",
    f"Seed: {RANDOM_SEED}",
    "Manuscript outputs:",
    str(FIGURES_DIR / "Main_Figure_1.png"),
    str(FIGURES_DIR / "Supplementary_Figure_1.png"),
    str(FIGURES_DIR / "Supplementary_Figure_2.png"),
    str(TABLES_DIR / "Supplementary_Table_1.xlsx"),
    str(TABLES_DIR / "Supplementary_Table_2.xlsx"),
    str(TABLES_DIR / "Supplementary_Table_3.xlsx"),
    "Intermediate outputs:",
    str(adata_path),
    str(graphs_path),
    str(REQ_PATH),
    str(LOG_PATH),
    str(CHK_PATH),
]
CHECKLIST_PATH.write_text("\n".join(checklist), encoding="utf-8")

log(f"Saved {adata_path}")
log(f"Saved {graphs_path}")
log(f"Saved {CHECKLIST_PATH}")
log(f"End: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

print("\nDone.")
print("Figures:", FIGURES_DIR)
print("Tables :", TABLES_DIR)


In [None]:
# Notebook 2: Nerve Injury Signature Derivation, Validation & Benchmarking

import os, sys, platform, time, json, pickle, warnings, gc
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Tuple, Optional

import numpy as np
import pandas as pd
import scipy.sparse as sp

from scipy.stats import mannwhitneyu
from sklearn.metrics import roc_curve, auc, precision_recall_curve, f1_score, roc_auc_score

import matplotlib as mpl
import matplotlib.pyplot as plt

import anndata as ad
import scanpy as sc
import psutil

warnings.filterwarnings("ignore")
sc.settings.verbosity = 2
sc.settings.n_jobs = -1

# Optional enrichment
try:
    import gseapy as gp
    HAS_GSEAPY = True
except Exception:
    HAS_GSEAPY = False

# Repro + paths
RUN_TS = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
import random
random.seed(RANDOM_SEED)

SAVE_DPI = 1200
VERSION = "v1.0"

BASE_DIR = Path(r"D:\个人文件夹\Sanwal\Neuro")
MANUSCRIPT_DIR = BASE_DIR / "Manuscript Data"
FIGURES_DIR = MANUSCRIPT_DIR / "Figures"
TABLES_DIR = MANUSCRIPT_DIR / "Tables"
NB1_DIR = BASE_DIR / "processed" / "notebook1"
PROCESSED_DIR = BASE_DIR / "processed" / "notebook2"
for d in [FIGURES_DIR, TABLES_DIR, PROCESSED_DIR]:
    d.mkdir(parents=True, exist_ok=True)

LOG_PATH = PROCESSED_DIR / "run_log.txt"

def set_n_style():
    mpl.rcParams.update({
        "font.family": "sans-serif",
        "font.sans-serif": ["Arial", "Helvetica", "DejaVu Sans"],
        "font.size": 7,
        "axes.titlesize": 7,
        "axes.labelsize": 7,
        "xtick.labelsize": 6,
        "ytick.labelsize": 6,
        "legend.fontsize": 6,
        "axes.linewidth": 0.6,
        "lines.linewidth": 0.8,
        "xtick.major.width": 0.6,
        "ytick.major.width": 0.6,
        "xtick.direction": "in",
        "ytick.direction": "in",
        "savefig.dpi": SAVE_DPI,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
    })
set_n_style()

def log(msg: str):
    print(msg)
    with open(LOG_PATH, "a", encoding="utf-8") as f:
        f.write(msg.rstrip() + "\n")

with open(LOG_PATH, "w", encoding="utf-8") as f:
    f.write("PeriNeuroImmuneMap Notebook 2 log\n")
    f.write(f"Start: {RUN_TS}\n")

log(f"Python: {sys.version.split()[0]}")
log(f"Platform: {platform.platform()}")
log(f"Seed: {RANDOM_SEED}")
log(f"CPU logical: {psutil.cpu_count(True)} | CPU physical: {psutil.cpu_count(False)}")
log(f"RAM total GB: {psutil.virtual_memory().total/(1024**3):.2f}")
log(f"scanpy: {sc.__version__} | anndata: {ad.__version__}")
log(f"gseapy: {'yes' if HAS_GSEAPY else 'no'}")

# Helpers
def safe_mean(X, axis=0):
    if sp.issparse(X):
        return np.asarray(X.mean(axis=axis)).ravel()
    return np.mean(X, axis=axis)

def safe_sum(X, axis=0):
    if sp.issparse(X):
        return np.asarray(X.sum(axis=axis)).ravel()
    return np.sum(X, axis=axis)

def get_layer(adata_obj: ad.AnnData, layer_name: str):
    if layer_name is None or layer_name == "X":
        return adata_obj.X
    if layer_name in adata_obj.layers:
        return adata_obj.layers[layer_name]
    return adata_obj.X

def find_genes_in_adata(genes: List[str], adata_var: pd.DataFrame) -> List[str]:
    """
    Return var_names (gene_id) for requested gene SYMBOLS (case-insensitive),
    falling back to matching var_names directly.
    """
    found = []
    if "gene_symbol" in adata_var.columns:
        sym = adata_var["gene_symbol"].astype(str).str.upper()
    else:
        sym = pd.Series([""] * adata_var.shape[0], index=adata_var.index, dtype=str)

    var_upper = pd.Index(adata_var.index.astype(str).str.upper())

    for g in genes:
        gu = str(g).upper()
        # symbol match
        hit = sym[sym == gu]
        if len(hit) > 0:
            found.append(hit.index[0])
            continue
        # varname match
        if gu in var_upper:
            idx = var_upper.get_loc(gu)
            found.append(adata_var.index[idx])
    return list(pd.unique(found))

def morans_I_sparse(score: np.ndarray, W: sp.csr_matrix) -> float:
    """
    Classic Moran's I using unstandardized W:
      I = (n / S0) * (x^T W x) / (x^T x), where x is mean-centered.
    """
    x = score.astype(np.float64)
    x = x - x.mean()
    denom = float(np.dot(x, x))
    if denom <= 0:
        return np.nan
    S0 = float(W.sum())
    if S0 <= 0:
        return np.nan
    Wx = W.dot(x)
    num = float(np.dot(x, Wx))
    n = x.shape[0]
    return float((n / S0) * (num / denom))

def score_mean_z(adata_obj: ad.AnnData, genes_varnames: List[str], layer="log1p_norm") -> np.ndarray:
    X = get_layer(adata_obj, layer)
    genes = [g for g in genes_varnames if g in adata_obj.var_names]
    if len(genes) == 0:
        return np.zeros(adata_obj.n_obs, dtype=np.float32)
    idx = [adata_obj.var_names.get_loc(g) for g in genes]
    if sp.issparse(X):
        Xsub = X[:, idx].toarray()
    else:
        Xsub = X[:, idx]
    mu = Xsub.mean(axis=0)
    sd = Xsub.std(axis=0) + 1e-9
    Z = (Xsub - mu) / sd
    return Z.mean(axis=1).astype(np.float32)

def score_weighted_mean(adata_obj: ad.AnnData, genes_varnames: List[str], weights: Dict[str, float], layer="log1p_norm") -> np.ndarray:
    X = get_layer(adata_obj, layer)
    genes = [g for g in genes_varnames if g in adata_obj.var_names and g in weights]
    if len(genes) == 0:
        return np.zeros(adata_obj.n_obs, dtype=np.float32)
    idx = [adata_obj.var_names.get_loc(g) for g in genes]
    w = np.array([weights[g] for g in genes], dtype=np.float32)
    wsum = float(np.sum(np.abs(w))) if float(np.sum(np.abs(w))) > 0 else 1.0
    if sp.issparse(X):
        Xsub = X[:, idx].toarray().astype(np.float32)
    else:
        Xsub = X[:, idx].astype(np.float32)
    # weights can be signed; normalize by sum(|w|) to keep scale stable
    return (Xsub @ w / wsum).astype(np.float32)

# Load NB1 output
adata_path = NB1_DIR / "processed_spatial_adata.h5ad"
if not adata_path.exists():
    raise FileNotFoundError(f"Run Notebook 1 first: {adata_path}")

log(f"Loading: {adata_path}")
adata = ad.read_h5ad(adata_path)
log(f"Loaded: spots={adata.n_obs:,} genes={adata.n_vars:,} samples={adata.obs['sample_id'].nunique()}")
log(f"Layers: {list(adata.layers.keys())}")

if "counts" not in adata.layers:
    raise ValueError("Notebook 1 must save adata.layers['counts'] (raw counts).")
if "log1p_norm" not in adata.layers:
    raise ValueError("Notebook 1 must save adata.layers['log1p_norm'] (log1p normalized).")

# Graph
if "spatial_connectivities" in adata.obsp:
    W = adata.obsp["spatial_connectivities"].tocsr()
elif "spatial_k6_connectivities" in adata.obsp:
    W = adata.obsp["spatial_k6_connectivities"].tocsr()
else:
    raise ValueError("No spatial connectivity matrix found in adata.obsp (need kNN graph from Notebook 1).")

W.sum_duplicates()
W.eliminate_zeros()

# Parameters
PARAMS = {
    "nerve_region_top_percent": 10.0,
    "min_markers_for_nerve": 2,
    "de_method": "wilcoxon",
    "de_log2fc_threshold": 0.5,
    "de_padj_threshold": 0.05,
    "signature_sizes": [50, 100, 200, 500],
    "n_bootstrap": 100,
    "n_permutations": 1000,
    "perm_log_every": 200,
}

log("PARAMS: " + json.dumps(PARAMS, indent=2))

# Canonical markers + stress sets
CANONICAL_NERVE_MARKERS = ["ATF3","JUN","SOX11","GAP43","SPRR1A","NEFM","NEFL"]

GENERIC_STRESS_SETS = {
    "Heat_Shock": ["HSPA1A","HSPA1B","HSPA6","HSPA8","HSP90AA1","HSP90AB1","HSPB1","HSPH1","DNAJA1","DNAJB1"],
    "Hypoxia": ["HIF1A","VEGFA","LDHA","PGK1","ENO1","SLC2A1","PDK1","NDRG1","BNIP3","CA9"],
    "ER_Stress": ["XBP1","ATF4","ATF6","DDIT3","ERN1","EIF2AK3","HSPA5","CALR","CANX"],
    "Oxidative_Stress": ["SOD1","SOD2","CAT","GPX1","PRDX1","HMOX1","NQO1","TXNRD1","GSR","GCLC"],
}

canonical_genes = find_genes_in_adata(CANONICAL_NERVE_MARKERS, adata.var)
log(f"Canonical markers found: {len(canonical_genes)}/{len(CANONICAL_NERVE_MARKERS)}")
if len(canonical_genes) == 0:
    raise ValueError("No canonical nerve markers found. Check adata.var['gene_symbol'] and var_names mapping.")

stress_genes = {}
for k, genes in GENERIC_STRESS_SETS.items():
    stress_genes[k] = find_genes_in_adata(genes, adata.var)
    log(f"Stress {k}: found {len(stress_genes[k])}/{len(genes)}")

# Section 2: canonical nerve scoring + nerve-enriched definition
X_log = adata.layers["log1p_norm"]
X_counts = adata.layers["counts"]

# canonical score = mean log1p_norm across available canonical genes
canon_idx = [adata.var_names.get_loc(g) for g in canonical_genes]
if sp.issparse(X_log):
    canon_mat = X_log[:, canon_idx].toarray().astype(np.float32)
else:
    canon_mat = np.asarray(X_log[:, canon_idx], dtype=np.float32)

adata.obs["canonical_nerve_score"] = canon_mat.mean(axis=1)

# marker hits = number of canonical genes detected (>0 counts) per spot
if sp.issparse(X_counts):
    hits = (X_counts[:, canon_idx] > 0).sum(axis=1)
    hits = np.asarray(hits).ravel().astype(int)
else:
    hits = (X_counts[:, canon_idx] > 0).sum(axis=1).astype(int)

adata.obs["canonical_marker_hits"] = hits

thr = np.percentile(adata.obs["canonical_nerve_score"].values, 100.0 - PARAMS["nerve_region_top_percent"])
adata.obs["is_nerve_enriched"] = ((adata.obs["canonical_nerve_score"].values >= thr) &
                                  (adata.obs["canonical_marker_hits"].values >= PARAMS["min_markers_for_nerve"])).astype(int)

n_nerve = int(adata.obs["is_nerve_enriched"].sum())
log(f"Nerve-enriched spots: {n_nerve:,}/{adata.n_obs:,} ({n_nerve/adata.n_obs*100:.2f}%)")
log(f"Threshold canonical_nerve_score (top {PARAMS['nerve_region_top_percent']}%): {thr:.4f}")
log(f"Min markers hit: {PARAMS['min_markers_for_nerve']}")

# Moran's I sanity (patched)
I_canon = morans_I_sparse(adata.obs["canonical_nerve_score"].values, W)
log(f"Canonical score Moran's I (W unstandardized): {I_canon:.4f}")

# Main Figure 2A
log("Making Main Figure 2A...")

rep_sample = sorted(adata.obs["sample_id"].unique())[0]
rep = adata[adata.obs["sample_id"] == rep_sample].copy()
rep_mask_full = (adata.obs["sample_id"].values == rep_sample)

# choose up to 4 canonical genes to display
show_genes = canonical_genes[:4]
n_show = max(1, len(show_genes))

fig = plt.figure(figsize=(7.2, 4.9), dpi=SAVE_DPI)
gs = fig.add_gridspec(2, max(3, n_show), hspace=0.35, wspace=0.30)

# top row: markers
for i, gid in enumerate(show_genes):
    ax = fig.add_subplot(gs[0, i])
    j = adata.var_names.get_loc(gid)
    if sp.issparse(X_log):
        vals = X_log[rep_mask_full, j].toarray().ravel()
    else:
        vals = np.asarray(X_log[rep_mask_full, j]).ravel()
    coords = rep.obsm["spatial"]
    sca = ax.scatter(coords[:,1], coords[:,0], c=vals, s=3, cmap="viridis", alpha=0.95, rasterized=True)
    ax.invert_yaxis()
    title = adata.var.loc[gid, "gene_symbol"] if "gene_symbol" in adata.var.columns else gid
    ax.set_title(title)
    ax.set_xticks([]); ax.set_yticks([])
    cbar = plt.colorbar(sca, ax=ax, fraction=0.046, pad=0.04)
    cbar.ax.tick_params(labelsize=5)

# bottom-left: composite score
ax_comp = fig.add_subplot(gs[1, :max(1, (max(3, n_show)//2))])
coords = rep.obsm["spatial"]
comp = rep.obs["canonical_nerve_score"].values
sca = ax_comp.scatter(coords[:,1], coords[:,0], c=comp, s=3, cmap="Reds", alpha=0.95, rasterized=True)
ax_comp.invert_yaxis()
ax_comp.set_title("Composite canonical score")
ax_comp.set_xticks([]); ax_comp.set_yticks([])
cbar = plt.colorbar(sca, ax=ax_comp, fraction=0.046, pad=0.04)
cbar.ax.tick_params(labelsize=5)

# bottom-right: binary nerve mask
ax_bin = fig.add_subplot(gs[1, max(1, (max(3, n_show)//2)) :])
mask = rep.obs["is_nerve_enriched"].values.astype(bool)
ax_bin.scatter(coords[~mask,1], coords[~mask,0], s=2, c="lightgray", alpha=0.5, rasterized=True, label="Rest")
ax_bin.scatter(coords[mask,1], coords[mask,0], s=4, c="red", alpha=0.9, rasterized=True, label="Nerve-enriched")
ax_bin.invert_yaxis()
ax_bin.set_title(f"Nerve-enriched (n={mask.sum():,})")
ax_bin.set_xticks([]); ax_bin.set_yticks([])
ax_bin.legend(loc="upper right", frameon=False, fontsize=6, markerscale=2)

fig.suptitle("Main Figure 2A: Canonical nerve markers", y=0.98)
fig2a_path = FIGURES_DIR / "Main_Figure_2A.png"
fig.savefig(fig2a_path, dpi=SAVE_DPI, bbox_inches="tight")
plt.close(fig)
log(f"Saved {fig2a_path}")

# Section 3: DE nerve vs rest

log("Running DE: nerve vs rest (rank on log1p_norm; log2FC from normalized counts)...")

adata_de = adata.copy()
adata_de.X = adata.layers["log1p_norm"]  # ranking layer
adata_de.obs["group"] = np.where(adata_de.obs["is_nerve_enriched"].values == 1, "nerve", "rest")

sc.tl.rank_genes_groups(
    adata_de,
    groupby="group",
    groups=["nerve"],
    reference="rest",
    method="wilcoxon",
    use_raw=False,
    pts=True,
    tie_correct=True,
)

de = sc.get.rank_genes_groups_df(adata_de, group="nerve").rename(columns={"names":"gene"})
log(f"DE table rows: {de.shape[0]:,}")

# map symbol
if "gene_symbol" in adata.var.columns:
    gene_map = adata.var["gene_symbol"].to_dict()
    de["gene_symbol"] = de["gene"].map(gene_map)
else:
    de["gene_symbol"] = de["gene"]

# log2FC from counts normalized to 1e4 per spot (library size)
nerve_mask = (adata.obs["is_nerve_enriched"].values == 1)
rest_mask = ~nerve_mask

Xn = adata.layers["counts"]
lib = safe_sum(Xn, axis=1)
lib[lib == 0] = 1.0
if sp.issparse(Xn):
    Xn_norm = Xn.multiply(1e4 / lib[:, None]).tocsr()
else:
    Xn_norm = (Xn / lib[:, None]) * 1e4

mn = safe_mean(Xn_norm[nerve_mask], axis=0)
mr = safe_mean(Xn_norm[rest_mask], axis=0)
log2fc = np.log2((mn + 1e-9) / (mr + 1e-9))

de["log2FC"] = log2fc
de["mean_nerve"] = mn
de["mean_rest"] = mr
de["pval"] = de["pvals"]
de["pval_adj"] = de["pvals_adj"]

sig_mask = (de["pval_adj"] < PARAMS["de_padj_threshold"]) & (np.abs(de["log2FC"]) > PARAMS["de_log2fc_threshold"])
de_sig = de.loc[sig_mask].copy()

log(f"Significant DE: {de_sig.shape[0]:,} (padj<{PARAMS['de_padj_threshold']}, |log2FC|>{PARAMS['de_log2fc_threshold']})")

# Supplementary Table 4
supp4_path = TABLES_DIR / "Supplementary_Table_4.xlsx"
with pd.ExcelWriter(supp4_path, engine="openpyxl") as w:
    de.sort_values("scores", ascending=False).to_excel(w, index=False, sheet_name="DE_Nerve_vs_Rest")
log(f"Saved {supp4_path}")

# Main Figure 2B Volcano
log("Making Main Figure 2B...")
fig, ax = plt.subplots(1, 1, figsize=(3.6, 3.6), dpi=SAVE_DPI)

y = -np.log10(de["pval_adj"].values + 1e-300)
ax.scatter(de["log2FC"].values, y, s=1, alpha=0.25, c="gray", rasterized=True)

ax.scatter(de.loc[sig_mask, "log2FC"].values,
           -np.log10(de.loc[sig_mask, "pval_adj"].values + 1e-300),
           s=2, alpha=0.7, c="red", rasterized=True)

ax.axhline(y=-np.log10(PARAMS["de_padj_threshold"]), linestyle="--", linewidth=0.8, alpha=0.6)
ax.axvline(x=PARAMS["de_log2fc_threshold"], linestyle="--", linewidth=0.8, alpha=0.6)
ax.axvline(x=-PARAMS["de_log2fc_threshold"], linestyle="--", linewidth=0.8, alpha=0.6)

# annotate a few top by score
top = de.sort_values("scores", ascending=False).head(8)
for _, r in top.iterrows():
    ax.text(r["log2FC"], -np.log10(r["pval_adj"] + 1e-300), str(r["gene_symbol"])[:12], fontsize=5, alpha=0.85)

ax.set_xlabel("log2 fold-change")
ax.set_ylabel("-log10 adj. p-value")
ax.set_title("Main Figure 2B: DE volcano (nerve vs rest)")
fig2b_path = FIGURES_DIR / "Main_Figure_2B.png"
fig.savefig(fig2b_path, dpi=SAVE_DPI, bbox_inches="tight")
plt.close(fig)
log(f"Saved {fig2b_path}")

# Section 4: Pathway enrichment (optional, always produce Supp Fig 3 + Supp Table 5)
supp5_path = TABLES_DIR / "Supplementary_Table_5.xlsx"
supp_fig3_path = FIGURES_DIR / "Supplementary_Figure_3.png"

if HAS_GSEAPY and de_sig.shape[0] >= 10:
    log("Pathway enrichment (gseapy)...")
    # use upregulated symbols if available
    up = de_sig.loc[de_sig["log2FC"] > 0, "gene_symbol"].dropna().astype(str).unique().tolist()
    up = [g for g in up if g and g.upper() != "NAN"]
    if len(up) >= 10:
        try:
            enr = gp.enrichr(
                gene_list=up,
                gene_sets=["GO_Biological_Process_2021", "KEGG_2021_Human"],
                organism="Human",
                outdir=None
            )
            enr_df = enr.results.copy()
            enr_sig = enr_df[enr_df["Adjusted P-value"] < 0.05].copy()
            with pd.ExcelWriter(supp5_path, engine="openpyxl") as w:
                enr_sig.to_excel(w, index=False, sheet_name="Enrichment")
            log(f"Saved {supp5_path}")

            fig, ax = plt.subplots(1, 1, figsize=(5.2, 3.8), dpi=SAVE_DPI)
            if enr_sig.shape[0] > 0:
                top_terms = enr_sig.nsmallest(20, "Adjusted P-value").copy()
                ax.barh(range(len(top_terms)), -np.log10(top_terms["Adjusted P-value"].values + 1e-300))
                ax.set_yticks(range(len(top_terms)))
                ax.set_yticklabels(top_terms["Term"].values, fontsize=6)
                ax.invert_yaxis()
                ax.set_xlabel("-log10 adj. p-value")
                ax.set_title("Supplementary Figure 3: Enriched pathways (up genes)")
            else:
                ax.text(0.02, 0.5, "No enriched terms at FDR<0.05", fontsize=7, transform=ax.transAxes)
                ax.set_axis_off()

            fig.savefig(supp_fig3_path, dpi=SAVE_DPI, bbox_inches="tight")
            plt.close(fig)
            log(f"Saved {supp_fig3_path}")
        except Exception as e:
            log(f"Enrichment failed: {type(e).__name__}: {e}")
            # write placeholder
            with pd.ExcelWriter(supp5_path, engine="openpyxl") as w:
                pd.DataFrame({"note":[f"Enrichment failed: {type(e).__name__}: {e}"]}).to_excel(w, index=False, sheet_name="Enrichment")
            fig, ax = plt.subplots(1, 1, figsize=(5.2, 3.2), dpi=SAVE_DPI)
            ax.text(0.02, 0.5, "Enrichment failed (see Supplementary Table 5).", fontsize=7, transform=ax.transAxes)
            ax.set_axis_off()
            fig.savefig(supp_fig3_path, dpi=SAVE_DPI, bbox_inches="tight")
            plt.close(fig)
            log(f"Saved {supp5_path}")
            log(f"Saved {supp_fig3_path}")
    else:
        log("Skipping enrichment: too few upregulated genes after filtering.")
        with pd.ExcelWriter(supp5_path, engine="openpyxl") as w:
            pd.DataFrame({"note":["Skipped: too few upregulated genes for enrichment."]}).to_excel(w, index=False, sheet_name="Enrichment")
        fig, ax = plt.subplots(1, 1, figsize=(5.2, 3.2), dpi=SAVE_DPI)
        ax.text(0.02, 0.5, "Enrichment skipped (too few genes).", fontsize=7, transform=ax.transAxes)
        ax.set_axis_off()
        fig.savefig(supp_fig3_path, dpi=SAVE_DPI, bbox_inches="tight")
        plt.close(fig)
        log(f"Saved {supp5_path}")
        log(f"Saved {supp_fig3_path}")
else:
    log("Skipping enrichment (gseapy unavailable or insufficient DE genes).")
    with pd.ExcelWriter(supp5_path, engine="openpyxl") as w:
        pd.DataFrame({"note":[f"Skipped: gseapy={HAS_GSEAPY}, de_sig={de_sig.shape[0]}"]}).to_excel(w, index=False, sheet_name="Enrichment")
    fig, ax = plt.subplots(1, 1, figsize=(5.2, 3.2), dpi=SAVE_DPI)
    ax.text(0.02, 0.5, "Enrichment skipped.", fontsize=7, transform=ax.transAxes)
    ax.set_axis_off()
    fig.savefig(supp_fig3_path, dpi=SAVE_DPI, bbox_inches="tight")
    plt.close(fig)
    log(f"Saved {supp5_path}")
    log(f"Saved {supp_fig3_path}")

# Section 5: Signature construction + benchmarking sizes/methods
log("Building candidate signature list...")

# rank metric: scanpy score * |log2FC|
de_sig["rank_metric"] = de_sig["scores"].values * np.abs(de_sig["log2FC"].values)
de_sig = de_sig.sort_values("rank_metric", ascending=False)

# optional Moran filter (only if present and not too destructive)
use_moran = ("morans_i" in adata.var.columns)
if use_moran:
    m = adata.var["morans_i"]
    mor_map = m.to_dict()
    de_sig["morans_i"] = de_sig["gene"].map(mor_map)
    before = de_sig.shape[0]
    de_sig_m = de_sig[np.isfinite(de_sig["morans_i"].values) & (de_sig["morans_i"].values > 0)].copy()
    if de_sig_m.shape[0] >= min(PARAMS["signature_sizes"]):
        de_sig = de_sig_m
        log(f"Moran filter: {before}->{de_sig.shape[0]} genes with morans_i>0")
    else:
        log(f"Moran filter would be too aggressive ({before}->{de_sig_m.shape[0]}). Skipping Moran filter.")

# benchmark only mean_z and weighted (fast + stable)
log("Benchmarking signature sizes / scoring methods...")

bench_rows = []
y_true = adata.obs["is_nerve_enriched"].values.astype(int)

for n_genes in PARAMS["signature_sizes"]:
    cand = de_sig.head(n_genes)
    genes_n = cand["gene"].tolist()
    weights_n = cand.set_index("gene")["log2FC"].to_dict()

    s_mean = score_mean_z(adata, genes_n, layer="log1p_norm")
    s_w = score_weighted_mean(adata, genes_n, weights_n, layer="log1p_norm")

    auc_mean = roc_auc_score(y_true, s_mean) if len(np.unique(y_true)) == 2 else np.nan
    auc_w = roc_auc_score(y_true, s_w) if len(np.unique(y_true)) == 2 else np.nan

    bench_rows.append({"n_genes": n_genes, "method": "mean_z", "auc": float(auc_mean)})
    bench_rows.append({"n_genes": n_genes, "method": "weighted", "auc": float(auc_w)})

    log(f"  n={n_genes}: mean_z AUC={auc_mean:.3f} | weighted AUC={auc_w:.3f}")

bench_df = pd.DataFrame(bench_rows)
best = bench_df.loc[bench_df["auc"].idxmax()]
FINAL_N = int(best["n_genes"])
FINAL_METHOD = str(best["method"])
log(f"Selected: n={FINAL_N}, method={FINAL_METHOD}, AUC={best['auc']:.3f}")

final_cand = de_sig.head(FINAL_N).copy()
final_genes = final_cand["gene"].tolist()
final_weights = final_cand.set_index("gene")["log2FC"].to_dict()

if FINAL_METHOD == "mean_z":
    adata.obs["nerve_injury_score"] = score_mean_z(adata, final_genes, layer="log1p_norm")
else:
    adata.obs["nerve_injury_score"] = score_weighted_mean(adata, final_genes, final_weights, layer="log1p_norm")

# separation
s1 = adata.obs.loc[adata.obs["is_nerve_enriched"] == 1, "nerve_injury_score"].values
s0 = adata.obs.loc[adata.obs["is_nerve_enriched"] == 0, "nerve_injury_score"].values
u, p = mannwhitneyu(s1, s0, alternative="greater")
log(f"Score separation (Mann-Whitney): p={p:.2e} | nerve mean={s1.mean():.3f} rest mean={s0.mean():.3f}")

# Export Supplementary Table 7 (benchmarking)
supp7_path = TABLES_DIR / "Supplementary_Table_7.xlsx"
with pd.ExcelWriter(supp7_path, engine="openpyxl") as w:
    bench_df.sort_values(["auc","n_genes"], ascending=[False, True]).to_excel(w, index=False, sheet_name="SigSize_Benchmark")
log(f"Saved {supp7_path}")

# Export signature v1.0
sig = final_cand.copy()
sig["signature_version"] = VERSION
sig["derivation_date"] = RUN_TS
sig["scoring_method"] = FINAL_METHOD
sig_path = PROCESSED_DIR / f"nerve_injury_signature_{VERSION}.csv"
sig.to_csv(sig_path, index=False)
log(f"Saved signature: {sig_path}")

# Export per-spot scores
scores_df = adata.obs[["sample_id","canonical_nerve_score","canonical_marker_hits","is_nerve_enriched","nerve_injury_score"]].copy()
scores_path = PROCESSED_DIR / "nerve_injury_scores_per_spot.csv"
scores_df.to_csv(scores_path, index=True)
log(f"Saved scores: {scores_path}")

# Main Figure 2D (spatial validation + permutation histogram)
I_sig = morans_I_sparse(adata.obs["nerve_injury_score"].values, W)
log(f"Signature score Moran's I (W unstandardized): {I_sig:.4f}")

log(f"Permutation tests: n={PARAMS['n_permutations']} (shuffle score on fixed graph)...")
perm = []
sv = adata.obs["nerve_injury_score"].values.astype(np.float32)
for i in range(PARAMS["n_permutations"]):
    perm_idx = np.random.permutation(sv.shape[0])
    perm.append(morans_I_sparse(sv[perm_idx], W))
    if (i + 1) % PARAMS["perm_log_every"] == 0:
        log(f"  perm {i+1}/{PARAMS['n_permutations']}")

perm = np.array([x for x in perm if np.isfinite(x)], dtype=np.float32)
perm_p = float((np.sum(np.abs(perm) >= np.abs(I_sig)) + 1) / (len(perm) + 1))
log(f"Permutation p-value: {perm_p:.4f}")

# figure
fig, axes = plt.subplots(1, 2, figsize=(7.0, 3.2), dpi=SAVE_DPI)

# spatial map (rep sample)
rep_mask_full = (adata.obs["sample_id"].values == rep_sample)
coords = adata[rep_mask_full].obsm["spatial"]
rep_scores = adata.obs.loc[rep_mask_full, "nerve_injury_score"].values
scat = axes[0].scatter(coords[:,1], coords[:,0], c=rep_scores, s=3, cmap="coolwarm", alpha=0.95, rasterized=True)
axes[0].invert_yaxis()
axes[0].set_title(f"A  Nerve injury score\n{rep_sample}")
axes[0].set_xticks([]); axes[0].set_yticks([])
cbar = plt.colorbar(scat, ax=axes[0], fraction=0.046, pad=0.04)
cbar.ax.tick_params(labelsize=6)
cbar.set_label("Score")

# permutation hist
axes[1].hist(perm, bins=50, alpha=0.75, edgecolor="black", linewidth=0.4)
axes[1].axvline(x=I_sig, linestyle="--", linewidth=1.0, color="red", label=f"Real I={I_sig:.3f}")
axes[1].set_title(f"B  Permutation test\np={perm_p:.4f}")
axes[1].set_xlabel("Moran's I")
axes[1].set_ylabel("Frequency")
axes[1].legend(frameon=False, fontsize=6)

fig.suptitle("Main Figure 2D: Spatial validation", y=0.98)
fig2d_path = FIGURES_DIR / "Main_Figure_2D.png"
fig.savefig(fig2d_path, dpi=SAVE_DPI, bbox_inches="tight")
plt.close(fig)
log(f"Saved {fig2d_path}")

# Section 6: Bootstrap stability (PATCHED: CSR conversion)
log(f"Bootstrap stability: n_boot={PARAMS['n_bootstrap']}...")

boot_presence = {g: 0 for g in final_genes}
boot_kept = 0

counts = adata.layers["counts"]

for b in range(PARAMS["n_bootstrap"]):
    idx = np.random.choice(adata.n_obs, adata.n_obs, replace=True)

    score_b = adata.obs["canonical_nerve_score"].values[idx]
    thr_b = np.percentile(score_b, 100.0 - PARAMS["nerve_region_top_percent"])
    hits_b = adata.obs["canonical_marker_hits"].values[idx]
    yb = (score_b >= thr_b) & (hits_b >= PARAMS["min_markers_for_nerve"])

    if yb.sum() < 50 or (~yb).sum() < 50:
        continue

    Xb = counts[idx]
    libb = np.asarray(Xb.sum(axis=1)).ravel()
    libb[libb == 0] = 1.0

    if sp.issparse(Xb):
        Xb_norm = Xb.multiply(1e4 / libb[:, None]).tocsr()  # <- critical
    else:
        Xb_norm = (Xb / libb[:, None]) * 1e4

    mn_b = safe_mean(Xb_norm[yb], axis=0)
    mr_b = safe_mean(Xb_norm[~yb], axis=0)
    l2 = np.log2((mn_b + 1e-9) / (mr_b + 1e-9))

    top_idx = np.argsort(l2)[::-1][:FINAL_N]
    top_genes = set(adata.var_names[top_idx].tolist())
    boot_kept += 1

    for g in final_genes:
        if g in top_genes:
            boot_presence[g] += 1

    if (b + 1) % 20 == 0:
        log(f"  boot {b+1}/{PARAMS['n_bootstrap']} (kept={boot_kept})")

boot_df = pd.DataFrame({
    "gene": final_genes,
    "bootstrap_count": [boot_presence[g] for g in final_genes],
})
boot_df["bootstrap_frequency_pct"] = 100.0 * boot_df["bootstrap_count"] / max(1, boot_kept)
boot_df = boot_df.sort_values("bootstrap_frequency_pct", ascending=False).reset_index(drop=True)

core_genes = boot_df.loc[boot_df["bootstrap_frequency_pct"] >= 80.0, "gene"].tolist()
log(f"Bootstrap kept: {boot_kept}/{PARAMS['n_bootstrap']}")
log(f"Core genes (>=80%): {len(core_genes)}")

# Supplementary Table 6
supp6_path = TABLES_DIR / "Supplementary_Table_6.xlsx"
with pd.ExcelWriter(supp6_path, engine="openpyxl") as w:
    boot_df.to_excel(w, index=False, sheet_name="Bootstrap")
log(f"Saved {supp6_path}")

# Supplementary Figure 4
fig, axes = plt.subplots(1, 2, figsize=(7.0, 3.0), dpi=SAVE_DPI)

axes[0].hist(boot_df["bootstrap_frequency_pct"].values, bins=15, edgecolor="black", linewidth=0.4)
axes[0].axvline(x=80, linestyle="--", linewidth=0.9, color="red")
axes[0].set_title("A  Bootstrap frequency")
axes[0].set_xlabel("Frequency (%)")
axes[0].set_ylabel("Genes")

top30 = boot_df.head(30).copy()
axes[1].barh(range(len(top30)), top30["bootstrap_frequency_pct"].values)
axes[1].set_yticks(range(len(top30)))
axes[1].set_yticklabels(top30["gene"].values, fontsize=6)
axes[1].invert_yaxis()
axes[1].set_title("B  Top genes")
axes[1].set_xlabel("Frequency (%)")

fig.suptitle("Supplementary Figure 4: Bootstrap stability", y=0.98)
supp_fig4_path = FIGURES_DIR / "Supplementary_Figure_4.png"
fig.savefig(supp_fig4_path, dpi=SAVE_DPI, bbox_inches="tight")
plt.close(fig)
log(f"Saved {supp_fig4_path}")

# Section 8: Benchmark vs alternatives + Main Figure 2E
alternatives = {}

# naive: canonical mean_z (all canonical genes available)
alternatives["Naive_canonical_meanZ"] = score_mean_z(adata, canonical_genes, layer="log1p_norm")

# stress sets
for name, genes in stress_genes.items():
    if len(genes) >= 5:
        alternatives[f"Stress_{name}_meanZ"] = score_mean_z(adata, genes, layer="log1p_norm")

# kitchen sink (top 500 by Wilcoxon score)
kitchen = de.sort_values("scores", ascending=False).head(500)["gene"].tolist()
alternatives["KitchenSink_top500_meanZ"] = score_mean_z(adata, kitchen, layer="log1p_norm")

# our final
alternatives[f"OurSignature_{FINAL_N}_{FINAL_METHOD}"] = adata.obs["nerve_injury_score"].values.astype(np.float32)

bench2_rows = []
for name, s in alternatives.items():
    fpr, tpr, _ = roc_curve(y_true, s)
    roc_auc = auc(fpr, tpr)
    prec, rec, _ = precision_recall_curve(y_true, s)
    pr_auc = auc(rec, prec)
    thr90 = np.percentile(s, 90)  # top 10%
    yhat = (s >= thr90).astype(int)
    f1 = f1_score(y_true, yhat)

    # Cohen's d
    s_pos = s[y_true == 1]
    s_neg = s[y_true == 0]
    denom = np.sqrt(((s_pos.size - 1) * np.var(s_pos, ddof=1) + (s_neg.size - 1) * np.var(s_neg, ddof=1)) / (s_pos.size + s_neg.size - 2) + 1e-12)
    d = float((np.mean(s_pos) - np.mean(s_neg)) / denom)

    bench2_rows.append({
        "method": name,
        "auc_roc": float(roc_auc),
        "auc_pr": float(pr_auc),
        "f1_top10pct": float(f1),
        "cohens_d": float(d),
    })
    log(f"  {name}: AUC={roc_auc:.3f} | F1={f1:.3f} | d={d:.3f}")

bench2 = pd.DataFrame(bench2_rows).sort_values("auc_roc", ascending=False).reset_index(drop=True)

# Main Figure 2E
log("Making Main Figure 2E...")
fig, axes = plt.subplots(1, 2, figsize=(7.2, 3.3), dpi=SAVE_DPI)

# ROC curves
for name, s in alternatives.items():
    fpr, tpr, _ = roc_curve(y_true, s)
    roc_auc = auc(fpr, tpr)
    if name.startswith("OurSignature"):
        axes[0].plot(fpr, tpr, linewidth=1.5, color="red", label=f"{name} (AUC={roc_auc:.3f})")
    else:
        axes[0].plot(fpr, tpr, linewidth=0.9, alpha=0.75, label=f"{name} (AUC={roc_auc:.3f})")
axes[0].plot([0,1],[0,1], "k--", linewidth=0.8, alpha=0.5)
axes[0].set_xlabel("False positive rate")
axes[0].set_ylabel("True positive rate")
axes[0].set_title("A  ROC")
axes[0].legend(frameon=False, fontsize=5, loc="lower right")

# bar AUC
axes[1].barh(range(bench2.shape[0]), bench2["auc_roc"].values)
axes[1].set_yticks(range(bench2.shape[0]))
axes[1].set_yticklabels(bench2["method"].values, fontsize=6)
axes[1].invert_yaxis()
axes[1].set_xlabel("AUC (ROC)")
axes[1].set_title("B  Comparison")

# highlight our bar
for i, m in enumerate(bench2["method"].values):
    if m.startswith("OurSignature"):
        axes[1].patches[i].set_color("red")

fig.suptitle("Main Figure 2E: Benchmarking", y=0.98)
fig2e_path = FIGURES_DIR / "Main_Figure_2E.png"
fig.savefig(fig2e_path, dpi=SAVE_DPI, bbox_inches="tight")
plt.close(fig)
log(f"Saved {fig2e_path}")

# Supplementary Table 7 already used for sig-size benchmarking.
# Export method benchmarking as Supplementary Table 7B (same file, second sheet) to avoid extra tables.
with pd.ExcelWriter(supp7_path, engine="openpyxl", mode="a", if_sheet_exists="replace") as w:
    bench2.to_excel(w, index=False, sheet_name="Method_Benchmark")
log(f"Updated {supp7_path} with Method_Benchmark sheet")

# Section 9: Published CINI comparison + Supplementary Figure 5
PUBLISHED_CINI = [
    "TAGAP","KCNJ8","COL1A1","PECAM1","TMEM119","ATF3","JUN","KLF6","NOCT","LMO7",
    "CSF1","ENTPD1","UCHL1","PINK1","BHLHE41","ITGAM","CHL1","SNCA","SCPEP1","VEGFA"
]
pub_genes = find_genes_in_adata(PUBLISHED_CINI, adata.var)
log(f"Published CINI genes found: {len(pub_genes)}/{len(PUBLISHED_CINI)}")

supp_fig5_path = FIGURES_DIR / "Supplementary_Figure_5.png"

if len(pub_genes) >= 5:
    pub_score = score_mean_z(adata, pub_genes, layer="log1p_norm")
    our_score = adata.obs["nerve_injury_score"].values.astype(np.float32)

    fpr_p, tpr_p, _ = roc_curve(y_true, pub_score)
    fpr_o, tpr_o, _ = roc_curve(y_true, our_score)
    auc_p = auc(fpr_p, tpr_p)
    auc_o = auc(fpr_o, tpr_o)

    # overlap in SYMBOL space
    our_syms = set(adata.var.loc[final_genes, "gene_symbol"].astype(str).str.upper()) if "gene_symbol" in adata.var.columns else set([g.upper() for g in final_genes])
    pub_syms = set([g.upper() for g in PUBLISHED_CINI])
    ov = sorted(list(our_syms.intersection(pub_syms)))

    fig, axes = plt.subplots(1, 2, figsize=(7.0, 3.1), dpi=SAVE_DPI)

    axes[0].plot(fpr_o, tpr_o, color="red", linewidth=1.5, label=f"Our (AUC={auc_o:.3f})")
    axes[0].plot(fpr_p, tpr_p, color="black", linewidth=0.9, alpha=0.85, label=f"Published CINI (AUC={auc_p:.3f})")
    axes[0].plot([0,1],[0,1], "k--", linewidth=0.8, alpha=0.5)
    axes[0].set_title("A  ROC comparison")
    axes[0].set_xlabel("FPR")
    axes[0].set_ylabel("TPR")
    axes[0].legend(frameon=False, fontsize=6, loc="lower right")

    axes[1].axis("off")
    txt = "Overlap (symbol):\n" + (", ".join(ov) if len(ov) > 0 else "None")
    txt += f"\n\nn(Our)={len(our_syms)}\nn(Published)={len(pub_syms)}\n|Overlap|={len(ov)}"
    axes[1].text(0.02, 0.98, txt, va="top", ha="left", fontsize=7)

    fig.suptitle("Supplementary Figure 5: Published signature comparison", y=0.98)
    fig.savefig(supp_fig5_path, dpi=SAVE_DPI, bbox_inches="tight")
    plt.close(fig)
    log(f"Saved {supp_fig5_path}")
else:
    fig, ax = plt.subplots(1, 1, figsize=(6.0, 2.5), dpi=SAVE_DPI)
    ax.text(0.02, 0.6, "Published CINI comparison skipped\n(too few genes found in this AnnData).", fontsize=7, transform=ax.transAxes)
    ax.set_axis_off()
    fig.savefig(supp_fig5_path, dpi=SAVE_DPI, bbox_inches="tight")
    plt.close(fig)
    log(f"Saved {supp_fig5_path}")

# Final exports
# Add bootstrap freq to signature and write FINAL
boot_map = boot_df.set_index("gene")["bootstrap_frequency_pct"].to_dict()
sig_final = sig.copy()
sig_final["bootstrap_frequency_pct"] = sig_final["gene"].map(boot_map)
sig_final_path = PROCESSED_DIR / f"nerve_injury_signature_{VERSION}_FINAL.csv"
sig_final.to_csv(sig_final_path, index=False)
log(f"Saved FINAL signature: {sig_final_path}")

# Supplementary Table 6 already.
# Save annotated AnnData
adata_out = PROCESSED_DIR / "adata_with_signature_scores.h5ad"
adata.write_h5ad(adata_out)
log(f"Saved annotated adata: {adata_out}")

# README
readme_path = PROCESSED_DIR / "README_signature.md"
readme_path.write_text(
f"""# Nerve injury signature ({VERSION})

Derived from GSE289745 Visium (cSCC) using canonical nerve marker enrichment to define nerve-enriched spots,
DE ranking (Wilcoxon in log space), and candidate-gene selection with effect-size filtering.

Key outputs:
- {sig_final_path.name}
- nerve_injury_scores_per_spot.csv
- Main_Figure_2A/B/D/E.png
- Supplementary_Table_4/5/6/7.xlsx
- Supplementary_Figure_3/4/5.png

Scoring:
- FINAL_N={FINAL_N}
- FINAL_METHOD={FINAL_METHOD}
- Moran's I (signature)={I_sig:.4f} | perm p={perm_p:.4f}
- Bootstrap kept={boot_kept}

""",
    encoding="utf-8"
)
log(f"Saved README: {readme_path}")

log("Notebook 2 COMPLETE")
log(f"FINAL_N={FINAL_N} | FINAL_METHOD={FINAL_METHOD} | AUC={best['auc']:.3f}")
log(f"Signature Moran's I={I_sig:.4f} | perm p={perm_p:.4f}")
log(f"Bootstrap kept={boot_kept} | core>=80%: {len(core_genes)}")
log(f"End: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

print("\nDone.")
print("Figures:", FIGURES_DIR)
print("Tables :", TABLES_DIR)
print("Processed:", PROCESSED_DIR)
