# Nerve Injury Signature Derivation, Validation & Benchmarking

**PeriNeuroImmuneMap Analysis Pipeline**

---

**Correspondence:** Wei Qin (<wqin@sjtu.edu.cn>)


In [None]:
from pathlib import Path

BASE_DIR = Path('.')
DATA_DIR = BASE_DIR / 'data'
OUTPUT_DIR = BASE_DIR / 'outputs'
FIGURE_DIR = OUTPUT_DIR / 'figures'
TABLE_DIR  = OUTPUT_DIR / 'tables'

for d in [DATA_DIR, OUTPUT_DIR, FIGURE_DIR, TABLE_DIR]:
    d.mkdir(parents=True, exist_ok=True)

print('✓ Paths configured')


In [None]:
# NOTEBOOK 3: Perineural Immune Landscape, Response Validation & External Validation

# IMPORTS
import os, sys, platform, time, json, pickle, warnings, gc
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Union
from collections import defaultdict

import numpy as np
import pandas as pd
import scipy.sparse as sp
from scipy.stats import mannwhitneyu, spearmanr, pearsonr, chi2_contingency, fisher_exact
from scipy.spatial.distance import cdist, pdist, squareform
from sklearn.metrics import roc_curve, auc, precision_recall_curve, roc_auc_score
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import StratifiedKFold
from sklearn.linear_model import LogisticRegression
import statsmodels.api as sm
from statsmodels.stats.multitest import multipletests

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
from matplotlib.colors import ListedColormap

import anndata as ad
import scanpy as sc
import psutil

warnings.filterwarnings("ignore")
sc.settings.verbosity = 2
sc.settings.n_jobs = -1

# HEST package (optional, for official HEST-1k loading)
try:
    import hest
    HAS_HEST_PACKAGE = True
    print("HEST package available")
except:
    HAS_HEST_PACKAGE = False
    print("HEST package not available (will use direct h5ad loading)")

# Survival analysis
try:
    from lifelines import KaplanMeierFitter, CoxPHFitter
    from lifelines.statistics import multivariate_logrank_test
    HAS_LIFELINES = True
except:
    HAS_LIFELINES = False
    print("Warning: lifelines not available, survival analysis will be skipped")

# SETUP - Reproducibility + Paths
RUN_TS = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
import random
random.seed(RANDOM_SEED)

SAVE_DPI = 1200
VERSION = "v1.0"

BASE_DIR = Path(r"D:\个人文件夹\Sanwal\Neuro")
MANUSCRIPT_DIR = BASE_DIR / "Manuscript Data"
FIGURES_DIR = MANUSCRIPT_DIR / "Figures"
TABLES_DIR = MANUSCRIPT_DIR / "Tables"
NB2_DIR = BASE_DIR / "processed" / "notebook2"
PROCESSED_DIR = BASE_DIR / "processed" / "notebook3"
RAW_DATA_DIR = BASE_DIR / "Raw data"
SUPP_DIR = RAW_DATA_DIR / "GSE289745_Nature_Supplements"

for d in [FIGURES_DIR, TABLES_DIR, PROCESSED_DIR]:
    d.mkdir(parents=True, exist_ok=True)

LOG_PATH = PROCESSED_DIR / "run_log.txt"

# SETUP - Plotting style
def set_n_style():
    mpl.rcParams.update({
        "font.family": "sans-serif",
        "font.sans-serif": ["Arial", "Helvetica", "DejaVu Sans"],
        "font.size": 7,
        "axes.titlesize": 7,
        "axes.labelsize": 7,
        "xtick.labelsize": 6,
        "ytick.labelsize": 6,
        "legend.fontsize": 6,
        "axes.linewidth": 0.6,
        "lines.linewidth": 0.8,
        "xtick.major.width": 0.6,
        "ytick.major.width": 0.6,
        "xtick.direction": "in",
        "ytick.direction": "in",
        "savefig.dpi": SAVE_DPI,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
    })
set_n_style()

def log(msg: str):
    print(msg)
    with open(LOG_PATH, "a", encoding="utf-8") as f:
        f.write(msg.rstrip() + "\n")

with open(LOG_PATH, "w", encoding="utf-8") as f:
    f.write("PeriNeuroImmuneMap Notebook 3 log\n")
    f.write(f"Start: {RUN_TS}\n")

log(f"Python: {sys.version.split()[0]}")
log(f"Platform: {platform.platform()}")
log(f"Seed: {RANDOM_SEED}")
log(f"CPU logical: {psutil.cpu_count(True)} | CPU physical: {psutil.cpu_count(False)}")
log(f"RAM total GB: {psutil.virtual_memory().total/(1024**3):.2f}")
log(f"scanpy: {sc.__version__} | anndata: {ad.__version__}")
log(f"lifelines: {'yes' if HAS_LIFELINES else 'no'}")

# HELPER FUNCTIONS
def safe_mean(X, axis=0):
    if sp.issparse(X):
        return np.asarray(X.mean(axis=axis)).ravel()
    return np.mean(X, axis=axis)

def safe_sum(X, axis=0):
    if sp.issparse(X):
        return np.asarray(X.sum(axis=axis)).ravel()
    return np.sum(X, axis=axis)

def get_layer(adata_obj: ad.AnnData, layer_name: str):
    if layer_name is None or layer_name == "X":
        return adata_obj.X
    if layer_name in adata_obj.layers:
        return adata_obj.layers[layer_name]
    return adata_obj.X

def find_genes_in_adata(genes: List[str], adata_var: pd.DataFrame) -> List[str]:
    """Match gene symbols (case-insensitive) or var_names"""
    found = []
    if "gene_symbol" in adata_var.columns:
        sym = adata_var["gene_symbol"].astype(str).str.upper()
    else:
        sym = pd.Series([""] * adata_var.shape[0], index=adata_var.index, dtype=str)
    
    var_upper = pd.Index(adata_var.index.astype(str).str.upper())
    
    for g in genes:
        gu = str(g).upper()
        # symbol match
        hit = sym[sym == gu]
        if len(hit) > 0:
            found.append(hit.index[0])
            continue
        # varname match
        if gu in var_upper:
            idx = var_upper.get_loc(gu)
            found.append(adata_var.index[idx])
    return list(pd.unique(found))

def score_mean_z(adata_obj: ad.AnnData, genes_varnames: List[str], layer="log1p_norm") -> np.ndarray:
    """Z-score normalized mean score"""
    X = get_layer(adata_obj, layer)
    genes = [g for g in genes_varnames if g in adata_obj.var_names]
    if len(genes) == 0:
        return np.zeros(adata_obj.n_obs, dtype=np.float32)
    idx = [adata_obj.var_names.get_loc(g) for g in genes]
    if sp.issparse(X):
        Xsub = X[:, idx].toarray()
    else:
        Xsub = X[:, idx]
    mu = Xsub.mean(axis=0)
    sd = Xsub.std(axis=0) + 1e-9
    Z = (Xsub - mu) / sd
    return Z.mean(axis=1).astype(np.float32)

def cohens_d(x1, x2):
    """Cohen's d effect size"""
    n1, n2 = len(x1), len(x2)
    var1, var2 = np.var(x1, ddof=1), np.var(x2, ddof=1)
    pooled_std = np.sqrt(((n1-1)*var1 + (n2-1)*var2) / (n1+n2-2))
    return (np.mean(x1) - np.mean(x2)) / (pooled_std + 1e-9)

def morans_I_sparse(score: np.ndarray, W: sp.csr_matrix) -> float:
    """Moran's I using unstandardized W"""
    x = score.astype(np.float64)
    x = x - x.mean()
    denom = float(np.dot(x, x))
    if denom <= 0:
        return np.nan
    S0 = float(W.sum())
    if S0 <= 0:
        return np.nan
    Wx = W.dot(x)
    num = float(np.dot(x, Wx))
    n = x.shape[0]
    return float((n / S0) * (num / denom))

# SECTION 0: OVERVIEW & DATA SOURCES

log("\n" + "="*80)
log("SECTION 0: NOTEBOOK OVERVIEW & DATA SOURCES")
log("="*80)

PARAMS = {
    # Perineural regions
    "perineural_percentile": 90,  # top 10% nerve score
    "distance_thresholds_um": [50, 100, 200, 500],
    
    # Statistical
    "fdr_threshold": 0.05,
    "min_spots_per_region": 30,
    
    # Permutations
    "n_permutations": 1000,
    "perm_log_every": 200,
    
    # GeoMx
    "geomx_high_percentile": 75,  # top 25% for binary classification
    "geomx_min_rois_per_patient": 5,
}

log("PARAMS: " + json.dumps(PARAMS, indent=2))

log("\nData inventory:")
log("  GSE289745 Visium: 11 samples, ~20K genes (treatment-naive)")
log("  GeoMx DSP: 457 ROIs, 36 patients, 85 proteins (CR/MPR vs PD)")
log("  HEST-1k: 169 samples (external validation)")

# PART A: VISIUM SPATIAL ANALYSIS
# SECTION 1: Setup & Data Loading

log("\n" + "="*80)
log("PART A: VISIUM SPATIAL ANALYSIS (GSE289745)")
log("="*80)

log("\nSECTION 1: Setup & Data Loading (Visium)")

adata_path = NB2_DIR / "adata_with_signature_scores.h5ad"
sig_path = NB2_DIR / "nerve_injury_signature_v1.0_FINAL.csv"

if not adata_path.exists():
    raise FileNotFoundError(f"Run Notebook 2 first: {adata_path}")
if not sig_path.exists():
    raise FileNotFoundError(f"Signature file missing: {sig_path}")

log(f"Loading: {adata_path}")
adata = ad.read_h5ad(adata_path)
log(f"Loaded: spots={adata.n_obs:,} genes={adata.n_vars:,} samples={adata.obs['sample_id'].nunique()}")

sig_df = pd.read_csv(sig_path)
log(f"Signature: {sig_df.shape[0]} genes")

# Check required columns
required_cols = ["nerve_injury_score", "is_nerve_enriched", "canonical_nerve_score"]
for col in required_cols:
    if col not in adata.obs.columns:
        raise ValueError(f"Missing column from Notebook 2: {col}")

# Spatial graph
if "spatial_connectivities" in adata.obsp:
    W = adata.obsp["spatial_connectivities"].tocsr()
elif "spatial_k6_connectivities" in adata.obsp:
    W = adata.obsp["spatial_k6_connectivities"].tocsr()
else:
    raise ValueError("No spatial connectivity matrix found")

W.sum_duplicates()
W.eliminate_zeros()

# Define immune gene sets
IMMUNE_GENE_SETS = {
    "Exhaustion": ["PDCD1", "LAG3", "HAVCR2", "TOX", "TIGIT", "CTLA4", "CD244", "CD160"],
    "Treg_suppression": ["FOXP3", "IL10", "TGFB1", "IL2RA", "ICOS", "GITR"],
    "Myeloid_suppression": ["ARG1", "IDO1", "CD274", "VEGFA", "IL10", "TGFB1", "CD163", "MRC1"],
    "Cytotoxicity": ["GZMB", "PRF1", "IFNG", "TNF", "GNLY", "NKG7", "GZMA"],
    "Activation": ["CD69", "CD25", "CD44", "ICOS", "OX40", "CD27", "CD28"],
    "IL6_axis": ["IL6", "IL6R", "IL6ST", "STAT3", "SOCS3", "JAK1"],
    "Type1_IFN": ["IFNA1", "IFNB1", "ISG15", "MX1", "OAS1", "IFIT1", "IRF7"],
    "Type2_IFN": ["IFNG", "CXCL9", "CXCL10", "IDO1", "GBP1", "STAT1"],
    "Proliferation": ["MKI67", "TOP2A", "PCNA", "CDK1", "CCNB1"],
}

# Find available genes
immune_genes = {}
for prog, genes in IMMUNE_GENE_SETS.items():
    found = find_genes_in_adata(genes, adata.var)
    immune_genes[prog] = found
    log(f"  {prog}: {len(found)}/{len(genes)} genes found")

# SECTION 2: Immune Cell-Type Deconvolution

# %%
log("\nSECTION 2: Immune Cell-Type Deconvolution (simplified marker-based)")

# Simplified marker-based approach
CELL_TYPE_MARKERS = {
    "CD8_T": ["CD8A", "CD8B"],
    "CD4_T": ["CD4", "IL7R"],
    "Treg": ["FOXP3", "IL2RA"],
    "B_cells": ["CD19", "MS4A1", "CD79A"],
    "NK": ["NKG7", "GNLY", "NCAM1"],
    "Macrophages": ["CD68", "CD163", "MRC1"],
    "Dendritic": ["CD1C", "CLEC9A", "BATF3"],
    "Neutrophils": ["S100A8", "S100A9", "FCGR3B"],
}

cell_type_scores = {}
for ct, markers in CELL_TYPE_MARKERS.items():
    found = find_genes_in_adata(markers, adata.var)
    if len(found) > 0:
        score = score_mean_z(adata, found, layer="log1p_norm")
        adata.obs[f"celltype_{ct}"] = score
        cell_type_scores[ct] = score
        log(f"  {ct}: {len(found)}/{len(markers)} markers, mean={score.mean():.3f}")
    else:
        log(f"  {ct}: No markers found, skipping")

log(f"Computed {len(cell_type_scores)} cell type scores")

# Main Figure 3A: Cell type spatial maps
log("\nMaking Main Figure 3A: Immune cell type maps")

rep_sample = sorted(adata.obs["sample_id"].unique())[0]
rep_mask = (adata.obs["sample_id"].values == rep_sample)
rep_data = adata[rep_mask].copy()

n_types = min(6, len(cell_type_scores))
types_to_plot = list(cell_type_scores.keys())[:n_types]

fig = plt.figure(figsize=(7.2, 4.8), dpi=SAVE_DPI)
gs = fig.add_gridspec(2, 3, hspace=0.30, wspace=0.25)

for i, ct in enumerate(types_to_plot):
    row = i // 3
    col = i % 3
    ax = fig.add_subplot(gs[row, col])
    
    coords = rep_data.obsm["spatial"]
    vals = rep_data.obs[f"celltype_{ct}"].values
    
    scat = ax.scatter(coords[:,1], coords[:,0], c=vals, s=2, cmap="Reds", 
                     alpha=0.9, rasterized=True, vmin=np.percentile(vals, 5),
                     vmax=np.percentile(vals, 95))
    ax.invert_yaxis()
    ax.set_title(ct.replace("_", " "))
    ax.set_xticks([]); ax.set_yticks([])
    cbar = plt.colorbar(scat, ax=ax, fraction=0.046, pad=0.04)
    cbar.ax.tick_params(labelsize=5)

fig.suptitle("Main Figure 3A: Immune cell type scores (marker-based)", y=0.98)
fig3a_path = FIGURES_DIR / "Main_Figure_3A.png"
fig.savefig(fig3a_path, dpi=SAVE_DPI, bbox_inches="tight")
plt.close(fig)
log(f"Saved {fig3a_path}")

# SECTION 3: Perineural Region Definition

log("\nSECTION 3: Perineural Region Definition")

# Define perineural regions (top percentile of nerve injury score)
thr = np.percentile(adata.obs["nerve_injury_score"].values, PARAMS["perineural_percentile"])
adata.obs["is_perineural"] = (adata.obs["nerve_injury_score"].values >= thr).astype(int)

n_peri = int(adata.obs["is_perineural"].sum())
n_rest = int((~adata.obs["is_perineural"].astype(bool)).sum())
log(f"Perineural spots: {n_peri:,} ({n_peri/adata.n_obs*100:.1f}%)")
log(f"Control spots: {n_rest:,}")

# Export Supplementary Table 8
supp8_df = pd.DataFrame({
    "sample_id": adata.obs["sample_id"].values,
    "barcode": adata.obs.index.values,
    "nerve_injury_score": adata.obs["nerve_injury_score"].values,
    "is_perineural": adata.obs["is_perineural"].values,
    "canonical_nerve_score": adata.obs["canonical_nerve_score"].values,
})

supp8_path = TABLES_DIR / "Supplementary_Table_8.xlsx"
with pd.ExcelWriter(supp8_path, engine="openpyxl") as w:
    supp8_df.to_excel(w, index=False, sheet_name="Perineural_Definitions")
log(f"Saved {supp8_path}")

# SECTION 4: Immune Functional Program Scoring

log("\nSECTION 4: Immune Functional Program Scoring")

# Compute immune program scores
for prog, genes in immune_genes.items():
    if len(genes) >= 2:
        score = score_mean_z(adata, genes, layer="log1p_norm")
        adata.obs[f"immune_{prog}"] = score
        log(f"  {prog}: mean={score.mean():.3f}, std={score.std():.3f}")

# Compare perineural vs control
peri_mask = adata.obs["is_perineural"].astype(bool).values
results = []

for prog in immune_genes.keys():
    col = f"immune_{prog}"
    if col not in adata.obs.columns:
        continue
    
    scores = adata.obs[col].values
    s_peri = scores[peri_mask]
    s_ctrl = scores[~peri_mask]
    
    u, p = mannwhitneyu(s_peri, s_ctrl, alternative="two-sided")
    d = cohens_d(s_peri, s_ctrl)
    
    results.append({
        "program": prog,
        "mean_perineural": float(s_peri.mean()),
        "mean_control": float(s_ctrl.mean()),
        "diff": float(s_peri.mean() - s_ctrl.mean()),
        "cohens_d": float(d),
        "mann_whitney_u": float(u),
        "pval": float(p),
    })

results_df = pd.DataFrame(results)
results_df["pval_adj"] = multipletests(results_df["pval"].values, method="fdr_bh")[1]
results_df = results_df.sort_values("pval_adj")

log("\nImmune programs enriched in perineural regions:")
for _, r in results_df.iterrows():
    if r["pval_adj"] < PARAMS["fdr_threshold"]:
        log(f"  {r['program']}: d={r['cohens_d']:.3f}, p_adj={r['pval_adj']:.2e}")

# Main Figure 3B: Immune program comparison
log("\nMaking Main Figure 3B: Immune programs (perineural vs control)")

sig_progs = results_df.loc[results_df["pval_adj"] < PARAMS["fdr_threshold"], "program"].tolist()
if len(sig_progs) == 0:
    sig_progs = results_df["program"].head(8).tolist()

fig, axes = plt.subplots(2, min(4, (len(sig_progs)+1)//2), figsize=(7.2, 4.5), dpi=SAVE_DPI)
axes = axes.flatten() if len(sig_progs) > 1 else [axes]

for i, prog in enumerate(sig_progs[:len(axes)]):
    col = f"immune_{prog}"
    if col not in adata.obs.columns:
        continue
    
    s = adata.obs[col].values
    data = [s[~peri_mask], s[peri_mask]]
    
    bp = axes[i].boxplot(data, labels=["Control", "Perineural"], patch_artist=True,
                         widths=0.6, showfliers=False)
    bp['boxes'][0].set_facecolor('lightgray')
    bp['boxes'][1].set_facecolor('salmon')
    
    r = results_df.loc[results_df["program"] == prog].iloc[0]
    axes[i].set_title(f"{prog}\nd={r['cohens_d']:.2f}, p={r['pval_adj']:.2e}", fontsize=6)
    axes[i].set_ylabel("Score")
    axes[i].tick_params(labelsize=6)

# Hide extra axes
for j in range(len(sig_progs), len(axes)):
    axes[j].set_visible(False)

fig.suptitle("Main Figure 3B: Immune programs (perineural vs control)", y=0.98)
fig3b_path = FIGURES_DIR / "Main_Figure_3B.png"
fig.tight_layout()
fig.savefig(fig3b_path, dpi=SAVE_DPI, bbox_inches="tight")
plt.close(fig)
log(f"Saved {fig3b_path}")

# SECTION 5: Spatial Statistics - Nerve-Immune Coupling

log("\nSECTION 5: Spatial Statistics - Nerve-Immune Coupling")

# Distance-based correlation
coords_all = adata.obsm["spatial"]
nerve_scores = adata.obs["nerve_injury_score"].values

distance_bins = [(0, 50), (50, 100), (100, 200), (200, 500), (500, 1000)]
spatial_corr_results = []

for prog in immune_genes.keys():
    col = f"immune_{prog}"
    if col not in adata.obs.columns:
        continue
    
    immune_scores = adata.obs[col].values
    
    # Per-sample analysis
    for sample_id in adata.obs["sample_id"].unique():
        mask = (adata.obs["sample_id"].values == sample_id)
        coords = coords_all[mask]
        n_s = nerve_scores[mask]
        i_s = immune_scores[mask]
        
        # Pairwise distances
        dists = cdist(coords, coords, metric="euclidean")
        
        for d_min, d_max in distance_bins:
            # Select pairs in distance range
            pair_mask = (dists >= d_min) & (dists < d_max)
            
            if pair_mask.sum() < 100:
                continue
            
            # Extract values for pairs
            idx_i, idx_j = np.where(pair_mask)
            
            if len(idx_i) == 0:
                continue
            
            # Correlation for this distance bin
            r, p = spearmanr(n_s[idx_i], i_s[idx_j])
            
            spatial_corr_results.append({
                "sample_id": sample_id,
                "program": prog,
                "distance_min": d_min,
                "distance_max": d_max,
                "spearman_r": float(r) if np.isfinite(r) else np.nan,
                "pval": float(p) if np.isfinite(p) else np.nan,
                "n_pairs": int(pair_mask.sum()),
            })

spatial_corr_df = pd.DataFrame(spatial_corr_results)
log(f"Spatial correlations computed: {spatial_corr_df.shape[0]} combinations")

# Main Figure 3C: Distance decay curves
log("\nMaking Main Figure 3C: Spatial coupling (distance decay)")

fig, axes = plt.subplots(1, 2, figsize=(7.0, 3.2), dpi=SAVE_DPI)

# Panel A: Distance decay
top_progs = results_df["program"].head(4).tolist()
for prog in top_progs:
    subset = spatial_corr_df[spatial_corr_df["program"] == prog].copy()
    if subset.shape[0] == 0:
        continue
    
    # Average across samples per distance bin
    agg = subset.groupby("distance_min")["spearman_r"].mean().reset_index()
    agg = agg.sort_values("distance_min")
    
    mid_dist = agg["distance_min"].values + 25
    axes[0].plot(mid_dist, agg["spearman_r"].values, marker="o", label=prog)

axes[0].axhline(y=0, linestyle="--", color="gray", linewidth=0.8, alpha=0.6)
axes[0].set_xlabel("Distance (um)")
axes[0].set_ylabel("Spearman r (nerve x immune)")
axes[0].set_title("A  Distance decay")
axes[0].legend(frameon=False, fontsize=6)

# Panel B: Heatmap
pivot = spatial_corr_df.groupby(["program", "distance_min"])["spearman_r"].mean().reset_index()
pivot_mat = pivot.pivot(index="program", columns="distance_min", values="spearman_r")

sns.heatmap(pivot_mat, cmap="coolwarm", center=0, ax=axes[1], cbar_kws={"shrink": 0.8},
            vmin=-0.3, vmax=0.3, annot=False, fmt=".2f", linewidths=0.5)
axes[1].set_title("B  Program x Distance")
axes[1].set_xlabel("Distance (um)")
axes[1].set_ylabel("")

fig.suptitle("Main Figure 3C: Spatial nerve-immune coupling", y=0.98)
fig3c_path = FIGURES_DIR / "Main_Figure_3C.png"
fig.tight_layout()
fig.savefig(fig3c_path, dpi=SAVE_DPI, bbox_inches="tight")
plt.close(fig)
log(f"Saved {fig3c_path}")

# SECTION 6: Ligand-Receptor Interactions

log("\nSECTION 6: Ligand-Receptor Interactions (marker-based approximation)")

# Simplified L-R pairs
LR_PAIRS = {
    "IL6-IL6R": ("IL6", "IL6R"),
    "IFNG-IFNGR1": ("IFNG", "IFNGR1"),
    "TNF-TNFRSF1A": ("TNF", "TNFRSF1A"),
    "TGFB1-TGFBR1": ("TGFB1", "TGFBR1"),
    "CSF1-CSF1R": ("CSF1", "CSF1R"),
}

lr_results = []
for pair_name, (lig, rec) in LR_PAIRS.items():
    lig_genes = find_genes_in_adata([lig], adata.var)
    rec_genes = find_genes_in_adata([rec], adata.var)
    
    if len(lig_genes) == 0 or len(rec_genes) == 0:
        log(f"  {pair_name}: missing genes")
        continue
    
    X_log = adata.layers["log1p_norm"]
    lig_expr = safe_mean(X_log[:, [adata.var_names.get_loc(g) for g in lig_genes]], axis=1)
    rec_expr = safe_mean(X_log[:, [adata.var_names.get_loc(g) for g in rec_genes]], axis=1)
    
    # Compare perineural vs control
    lig_peri = lig_expr[peri_mask].mean()
    lig_ctrl = lig_expr[~peri_mask].mean()
    rec_peri = rec_expr[peri_mask].mean()
    rec_ctrl = rec_expr[~peri_mask].mean()
    
    lr_results.append({
        "pair": pair_name,
        "ligand": lig,
        "receptor": rec,
        "lig_peri": float(lig_peri),
        "lig_ctrl": float(lig_ctrl),
        "rec_peri": float(rec_peri),
        "rec_ctrl": float(rec_ctrl),
        "lr_score_peri": float(np.sqrt(lig_peri * rec_peri)),
        "lr_score_ctrl": float(np.sqrt(lig_ctrl * rec_ctrl)),
    })

lr_df = pd.DataFrame(lr_results)
lr_df["fold_enrichment"] = lr_df["lr_score_peri"] / (lr_df["lr_score_ctrl"] + 1e-9)

# Supplementary Table 9
supp9_path = TABLES_DIR / "Supplementary_Table_9.xlsx"
with pd.ExcelWriter(supp9_path, engine="openpyxl") as w:
    lr_df.to_excel(w, index=False, sheet_name="LR_Interactions")
log(f"Saved {supp9_path}")

log(f"L-R interactions computed: {lr_df.shape[0]} pairs")
for _, r in lr_df.iterrows():
    log(f"  {r['pair']}: fold={r['fold_enrichment']:.2f}")

# PART B: GeoMx RESPONSE VALIDATION
# SECTION 7: GeoMx Data Loading

log("\n" + "="*80)
log("PART B: GeoMx RESPONSE VALIDATION (CLINICAL TRIAL)")
log("="*80)

log("\nSECTION 7: GeoMx Data Loading & Protein Scoring")

geomx_file = SUPP_DIR / "41586_2025_9370_MOESM9_ESM.xlsx"
if not geomx_file.exists():
    log(f"WARNING: GeoMx file not found: {geomx_file}")
    log("Skipping GeoMx analysis")
    geomx_available = False
else:
    geomx_available = True
    
    # Load GeoMx data
    log(f"Loading GeoMx data: {geomx_file}")
    try:
        # First check available sheets
        xl = pd.ExcelFile(geomx_file)
        log(f"  Available sheets: {xl.sheet_names}")
        
        # Try to find expression and clinical sheets
        expr_sheet = None
        clinical_sheet = None
        
        for sheet in xl.sheet_names:
            sheet_lower = sheet.lower()
            if 'expr' in sheet_lower or 'geomx' in sheet_lower or 'data' in sheet_lower:
                expr_sheet = sheet
            if 'clinic' in sheet_lower or 'patient' in sheet_lower or 'response' in sheet_lower:
                clinical_sheet = sheet
        
        if expr_sheet is None:
            # Use first sheet as expression
            expr_sheet = xl.sheet_names[0]
            log(f"  Using first sheet as expression: {expr_sheet}")
        
        geomx_expr = pd.read_excel(xl, sheet_name=expr_sheet)
        log(f"  Expression data: {geomx_expr.shape}")
        log(f"  Columns: {list(geomx_expr.columns[:10])}")
        
        if clinical_sheet is not None:
            geomx_clinical = pd.read_excel(xl, sheet_name=clinical_sheet)
            log(f"  Clinical data: {geomx_clinical.shape}")
        else:
            log(f"  WARNING: No clinical sheet found, will check if clinical data is in expression sheet")
            geomx_clinical = pd.DataFrame()  # empty
        
    except Exception as e:
        log(f"ERROR loading GeoMx: {type(e).__name__}: {e}")
        geomx_available = False


if geomx_available:
    # GeoMx data is in WIDE format: proteins (rows) x ROIs (columns)
    log("\nRestructuring GeoMx data from wide to long format...")
    
    # Get protein info columns
    protein_info_cols = ["Panel", "ID", "Class", "Probe_ID", "SPOT_ID"]
    available_info_cols = [c for c in protein_info_cols if c in geomx_expr.columns]
    
    # Get ROI columns (all columns that look like sample IDs)
    roi_cols = [c for c in geomx_expr.columns if c not in available_info_cols + ["UNIPROT_LIST", "GI_LIST", "PANEL_LIST"]]
    log(f"  Detected {len(roi_cols)} ROIs")
    log(f"  Detected {geomx_expr.shape[0]} proteins")
    
    # Use "ID" or "SPOT_ID" as protein names
    if "ID" in geomx_expr.columns:
        protein_names = geomx_expr["ID"].tolist()
    elif "SPOT_ID" in geomx_expr.columns:
        protein_names = geomx_expr["SPOT_ID"].tolist()
    else:
        protein_names = [f"protein_{i}" for i in range(geomx_expr.shape[0])]
    
    # Transpose: ROIs as rows, proteins as columns
    geomx_long = geomx_expr[roi_cols].T
    geomx_long.columns = protein_names
    geomx_long.index.name = "ROI_ID"
    geomx_long = geomx_long.reset_index()
    
    # Extract patient_id from ROI_ID (format: P7669_DSP96225_001)
    geomx_long["patient_id"] = geomx_long["ROI_ID"].str.split("_").str[0]
    
    log(f"  Reshaped to: {geomx_long.shape[0]} ROIs x {geomx_long.shape[1]-2} proteins")
    log(f"  Unique patients: {geomx_long['patient_id'].nunique()}")
    
    # Define protein-based immune programs (use actual protein names from data)
    GEOMX_IMMUNE_PROGRAMS = {
        "Exhaustion": ["PD-1", "LAG-3", "TIM-3", "CTLA-4", "VISTA"],
        "T_suppression": ["FOXP3", "CD25", "ICOS", "GITR"],
        "Myeloid_suppression": ["ARG1", "IDO1", "PD-L1", "CD163"],
        "Cytotoxicity": ["CD8", "Granzyme B"],
        "Activation": ["CD27", "CD44", "OX40", "4-1BB", "GITR"],
        "Pan_immune": ["CD45", "CD3", "CD4", "CD8"],
    }
    
    # Check availability
    log("\nGeoMx immune protein availability:")
    geomx_programs = {}
    for prog, prot_list in GEOMX_IMMUNE_PROGRAMS.items():
        # Try exact match and partial match
        available = []
        for p in prot_list:
            # Exact match
            if p in protein_names:
                available.append(p)
            else:
                # Partial match (case-insensitive)
                p_upper = p.upper().replace("-", "").replace(" ", "")
                for prot in protein_names:
                    prot_clean = str(prot).upper().replace("-", "").replace(" ", "")
                    if p_upper in prot_clean or prot_clean in p_upper:
                        available.append(prot)
                        break
        
        geomx_programs[prog] = list(set(available))  # remove duplicates
        log(f"  {prog}: {len(geomx_programs[prog])}/{len(prot_list)} available")
        if len(geomx_programs[prog]) > 0:
            log(f"    {geomx_programs[prog]}")
    
    # Compute program scores
    for prog, prot_list in geomx_programs.items():
        if len(prot_list) == 0:
            continue
        
        prot_data = geomx_long[prot_list].values.astype(float)
        
        # Z-score per protein, then mean
        z_data = (prot_data - prot_data.mean(axis=0)) / (prot_data.std(axis=0) + 1e-9)
        score = z_data.mean(axis=1)
        
        geomx_long[f"score_{prog}"] = score
        log(f"  Computed {prog} score: mean={score.mean():.3f}")
    
    # Try to load clinical data if available in other sheets
    geomx_merged = geomx_long.copy()
    
    if not geomx_clinical.empty and "patient_id" in geomx_clinical.columns:
        log("\nMerging with clinical data...")
        geomx_merged = geomx_long.merge(geomx_clinical, on="patient_id", how="left")
        log(f"Merged GeoMx data: {geomx_merged.shape}")
    else:
        log("\nWARNING: No clinical data available")
        log("  Will skip response and PNI analysis")
        log("  Only protein-level analysis possible")
    
    # Check response labels
    if "response" in geomx_merged.columns:
        resp_counts = geomx_merged["response"].value_counts()
        log(f"Response distribution:\n{resp_counts}")
    else:
        log("  'response' column not found")
    
    if "PNI_status" in geomx_merged.columns:
        pni_counts = geomx_merged["PNI_status"].value_counts()
        log(f"PNI distribution:\n{pni_counts}")
    else:
        log("  'PNI_status' column not found")

# ## SECTION 8: Therapy Response Stratification

if geomx_available and "response" in geomx_merged.columns:
    log("\nSECTION 8: Therapy Response Stratification")
    
    # 8.1 PNI status x Response
    if "PNI_status" in geomx_merged.columns:
        log("\n8.1 PNI status x Response")
        
        ct = pd.crosstab(geomx_merged["PNI_status"], geomx_merged["response"])
        log(f"Contingency table:\n{ct}")
        
        chi2, p_chi, dof, expected = chi2_contingency(ct)
        log(f"Chi-square: chi2={chi2:.3f}, p={p_chi:.3e}")
        
        # Odds ratio (if 2x2)
        if ct.shape == (2, 2):
            or_val = (ct.iloc[0,0] * ct.iloc[1,1]) / (ct.iloc[0,1] * ct.iloc[1,0] + 1e-9)
            log(f"Odds ratio: {or_val:.3f}")
    
    # 8.2 Immune programs x Response
    log("\n8.2 Immune programs x Response")
    
    response_comp = []
    for prog in geomx_programs.keys():
        col = f"score_{prog}"
        if col not in geomx_merged.columns:
            continue
        
        # Binary response
        responder_mask = geomx_merged["response"].isin(["CR", "MPR", "PR"])
        
        s_resp = geomx_merged.loc[responder_mask, col].values
        s_nonresp = geomx_merged.loc[~responder_mask, col].values
        
        if len(s_resp) < 5 or len(s_nonresp) < 5:
            continue
        
        u, p = mannwhitneyu(s_resp, s_nonresp, alternative="two-sided")
        d = cohens_d(s_resp, s_nonresp)
        
        response_comp.append({
            "program": prog,
            "mean_responder": float(s_resp.mean()),
            "mean_nonresponder": float(s_nonresp.mean()),
            "diff": float(s_resp.mean() - s_nonresp.mean()),
            "cohens_d": float(d),
            "mann_whitney_p": float(p),
        })
    
    response_df = pd.DataFrame(response_comp)
    if response_df.shape[0] > 0:
        response_df["p_adj"] = multipletests(response_df["mann_whitney_p"].values, method="fdr_bh")[1]
        response_df = response_df.sort_values("p_adj")
        
        log("Programs associated with response:")
        for _, r in response_df.iterrows():
            if r["p_adj"] < 0.1:
                log(f"  {r['program']}: d={r['cohens_d']:.3f}, p={r['p_adj']:.3e}")

    # Main Figure 4: GeoMx response validation
    log("\nMaking Main Figure 4: GeoMx response validation")
    
    fig = plt.figure(figsize=(7.2, 6.0), dpi=SAVE_DPI)
    gs = fig.add_gridspec(2, 2, hspace=0.35, wspace=0.30)
    
    # Panel A: PNI x Response
    if "PNI_status" in geomx_merged.columns:
        ax1 = fig.add_subplot(gs[0, 0])
        ct_plot = pd.crosstab(geomx_merged["PNI_status"], geomx_merged["response"], normalize="index") * 100
        ct_plot.plot(kind="bar", stacked=True, ax=ax1, color=["lightgreen", "salmon"])
        ax1.set_title("A  PNI x Response")
        ax1.set_xlabel("PNI status")
        ax1.set_ylabel("% ROIs")
        ax1.legend(title="Response", frameon=False, fontsize=6)
        ax1.tick_params(labelsize=6, rotation=0)
    
    # Panel B: Immune program heatmap
    ax2 = fig.add_subplot(gs[0, 1])
    if response_df.shape[0] > 0:
        plot_progs = response_df["program"].head(6).tolist()
        heatmap_data = []
        for prog in plot_progs:
            col = f"score_{prog}"
            if col in geomx_merged.columns:
                heatmap_data.append(geomx_merged[col].values)
        
        if len(heatmap_data) > 0:
            heatmap_mat = np.array(heatmap_data)
            sns.heatmap(heatmap_mat, cmap="coolwarm", center=0, ax=ax2, 
                       yticklabels=plot_progs, xticklabels=False, cbar_kws={"shrink": 0.6})
            ax2.set_title("B  Program scores")
    
    # Panel C: ROC curves
    ax3 = fig.add_subplot(gs[1, 0])
    if response_df.shape[0] > 0:
        y_true = geomx_merged["response"].isin(["CR", "MPR", "PR"]).astype(int).values
        
        for prog in response_df["program"].head(4).tolist():
            col = f"score_{prog}"
            if col not in geomx_merged.columns:
                continue
            
            y_score = geomx_merged[col].values
            fpr, tpr, _ = roc_curve(y_true, y_score)
            roc_auc = auc(fpr, tpr)
            
            ax3.plot(fpr, tpr, label=f"{prog} (AUC={roc_auc:.2f})", linewidth=1.0)
        
        ax3.plot([0,1],[0,1], "k--", linewidth=0.8, alpha=0.5)
        ax3.set_xlabel("FPR")
        ax3.set_ylabel("TPR")
        ax3.set_title("C  ROC (response prediction)")
        ax3.legend(frameon=False, fontsize=6)
    
    # Panel D: Effect sizes
    ax4 = fig.add_subplot(gs[1, 1])
    if response_df.shape[0] > 0:
        top_progs = response_df["program"].head(8).tolist()
        y_pos = np.arange(len(top_progs))
        d_vals = [response_df.loc[response_df["program"]==p, "cohens_d"].values[0] for p in top_progs]
        
        colors = ["salmon" if d < 0 else "lightgreen" for d in d_vals]
        ax4.barh(y_pos, d_vals, color=colors)
        ax4.set_yticks(y_pos)
        ax4.set_yticklabels(top_progs, fontsize=6)
        ax4.axvline(x=0, linestyle="--", color="gray", linewidth=0.8)
        ax4.set_xlabel("Cohen's d")
        ax4.set_title("D  Effect sizes")
        ax4.invert_yaxis()
    
    fig.suptitle("Main Figure 4: GeoMx response validation", y=0.98)
    fig4_path = FIGURES_DIR / "Main_Figure_4.png"
    fig.savefig(fig4_path, dpi=SAVE_DPI, bbox_inches="tight")
    plt.close(fig)
    log(f"Saved {fig4_path}")
    
    # Supplementary Table 10
    supp10_path = TABLES_DIR / "Supplementary_Table_10.xlsx"
    with pd.ExcelWriter(supp10_path, engine="openpyxl") as w:
        if response_df.shape[0] > 0:
            response_df.to_excel(w, index=False, sheet_name="Response_Programs")
        if "PNI_status" in geomx_merged.columns:
            ct.to_excel(w, sheet_name="PNI_Response_Contingency")
    log(f"Saved {supp10_path}")

else:
    log("\nSkipping GeoMx analysis (data not available or no response labels)")

# SECTION 9: Survival Analysis

if geomx_available and HAS_LIFELINES and "OS_months" in geomx_merged.columns:
    log("\nSECTION 9: Survival Analysis (exploratory, n=36)")
    
    # Aggregate to patient level
    patient_data = geomx_merged.groupby("patient_id").agg({
        "OS_months": "first",
        "OS_event": "first",
        "PNI_status": "first",
        **{f"score_{p}": "mean" for p in geomx_programs.keys() if f"score_{p}" in geomx_merged.columns}
    }).reset_index()
    
    log(f"Patient-level data: n={patient_data.shape[0]}")
    
    # KM curves by PNI
    if "PNI_status" in patient_data.columns:
        kmf = KaplanMeierFitter()
        
        fig, ax = plt.subplots(1, 1, figsize=(4.5, 3.5), dpi=SAVE_DPI)
        
        for pni in patient_data["PNI_status"].unique():
            mask = (patient_data["PNI_status"] == pni)
            kmf.fit(patient_data.loc[mask, "OS_months"], 
                   patient_data.loc[mask, "OS_event"], 
                   label=f"PNI {pni}")
            kmf.plot_survival_function(ax=ax)
        
        ax.set_xlabel("Months")
        ax.set_ylabel("Overall Survival")
        ax.set_title("Survival by PNI status")
        ax.legend(frameon=False, fontsize=6)
        
        km_path = FIGURES_DIR / "Supplementary_Figure_Survival_PNI.png"
        fig.savefig(km_path, dpi=SAVE_DPI, bbox_inches="tight")
        plt.close(fig)
        log(f"Saved {km_path}")
    
    log("CAVEAT: n=36 patients, exploratory analysis only")

else:
    log("\nSkipping survival analysis (lifelines not available or no survival data)")

# PART C: NEGATIVE CONTROLS & SENSITIVITY

log("\n" + "="*80)
log("PART C: NEGATIVE CONTROLS & SENSITIVITY")
log("="*80)

# SECTION 11: Sensitivity Analyses

log("\nSECTION 11: Sensitivity Analyses")

# Threshold sensitivity
thresholds = [80, 85, 90, 95]
threshold_results = []

for thr_pct in thresholds:
    thr_val = np.percentile(adata.obs["nerve_injury_score"].values, thr_pct)
    mask = (adata.obs["nerve_injury_score"].values >= thr_val)
    
    # Re-test one immune program
    test_prog = "Exhaustion"
    if f"immune_{test_prog}" in adata.obs.columns:
        s = adata.obs[f"immune_{test_prog}"].values
        s_peri = s[mask]
        s_ctrl = s[~mask]
        
        if len(s_peri) >= 30 and len(s_ctrl) >= 30:
            u, p = mannwhitneyu(s_peri, s_ctrl, alternative="two-sided")
            d = cohens_d(s_peri, s_ctrl)
            
            threshold_results.append({
                "threshold_pct": thr_pct,
                "n_perineural": int(mask.sum()),
                "cohens_d": float(d),
                "pval": float(p),
            })

threshold_df = pd.DataFrame(threshold_results)
log(f"Threshold sensitivity: {threshold_df.shape[0]} thresholds tested")
for _, r in threshold_df.iterrows():
    log(f"  {r['threshold_pct']}%: n={r['n_perineural']}, d={r['cohens_d']:.3f}, p={r['pval']:.2e}")

# SECTION 12: Negative Controls

log("\nSECTION 12: Negative Controls")
log("\n12.1 Spatial permutation test")

test_prog = "Exhaustion"
if f"immune_{test_prog}" in adata.obs.columns:
    real_score = adata.obs[f"immune_{test_prog}"].values
    nerve_score = adata.obs["nerve_injury_score"].values
    
    real_corr, _ = spearmanr(nerve_score, real_score)
    
    perm_corrs = []
    for i in range(PARAMS["n_permutations"]):
        perm_idx = np.random.permutation(len(real_score))
        perm_score = real_score[perm_idx]
        r, _ = spearmanr(nerve_score, perm_score)
        perm_corrs.append(r)
        
        if (i + 1) % PARAMS["perm_log_every"] == 0:
            log(f"  perm {i+1}/{PARAMS['n_permutations']}")
    
    perm_corrs = np.array(perm_corrs)
    perm_p = (np.sum(np.abs(perm_corrs) >= np.abs(real_corr)) + 1) / (len(perm_corrs) + 1)
    
    log(f"Real correlation: r={real_corr:.3f}")
    log(f"Permutation p-value: {perm_p:.4f}")


    # Supplementary Figure 6: Negative controls
    fig, axes = plt.subplots(1, 2, figsize=(7.0, 3.0), dpi=SAVE_DPI)
    
    axes[0].hist(perm_corrs, bins=50, alpha=0.7, edgecolor="black", linewidth=0.4)
    axes[0].axvline(x=real_corr, linestyle="--", color="red", linewidth=1.2, 
                   label=f"Real r={real_corr:.3f}")
    axes[0].set_xlabel("Spearman r")
    axes[0].set_ylabel("Frequency")
    axes[0].set_title(f"A  Permutation test\np={perm_p:.4f}")
    axes[0].legend(frameon=False, fontsize=6)
    
    # Panel B: Threshold sensitivity
    if threshold_df.shape[0] > 0:
        axes[1].plot(threshold_df["threshold_pct"], threshold_df["cohens_d"], 
                    marker="o", linewidth=1.0)
        axes[1].set_xlabel("Threshold (%)")
        axes[1].set_ylabel("Cohen's d")
        axes[1].set_title("B  Threshold sensitivity")
        axes[1].axhline(y=0, linestyle="--", color="gray", linewidth=0.8)
    
    fig.suptitle("Supplementary Figure 6: Negative controls & sensitivity", y=0.98)
    supp_fig6_path = FIGURES_DIR / "Supplementary_Figure_6.png"
    fig.savefig(supp_fig6_path, dpi=SAVE_DPI, bbox_inches="tight")
    plt.close(fig)
    log(f"Saved {supp_fig6_path}")

# PART D: EXTERNAL VALIDATION (HEST-1k)
# SECTION 13: HEST-1k Validation

log("\n" + "="*80)
log("PART D: EXTERNAL VALIDATION (HEST-1k)")
log("="*80)

log("\nSECTION 13: HEST-1k External Validation")

# 13.1 Load HEST-1k metadata and select samples
HEST_DIR = RAW_DATA_DIR / "HEST_1k"
HEST_META_FILE = HEST_DIR / "HEST_v1_2_1.csv"
HEST_DATA_DIR = HEST_DIR / "hest_subset" / "st"

log(f"\nHEST-1k directory: {HEST_DIR}")
log(f"  Exists: {HEST_DIR.exists()}")

log(f"Metadata file: {HEST_META_FILE}")
log(f"  Exists: {HEST_META_FILE.exists()}")

log(f"Data directory: {HEST_DATA_DIR}")
log(f"  Exists: {HEST_DATA_DIR.exists()}")

if not HEST_META_FILE.exists():
    # Try alternative formats
    log(f"\nERROR: HEST metadata file not found at: {HEST_META_FILE}")
    
    alternatives = [
        HEST_DIR / "HEST_v1_2_1.csv",
        HEST_DIR / "HEST_v1_2_1.xlsx",
        HEST_DIR / "hest_subset" / "HEST_v1_2_1.csv",
        HEST_DIR / "hest_subset" / "HEST_v1_2_1.xlsx",
    ]
    
    log("\nChecking alternative locations:")
    for alt in alternatives:
        log(f"  {alt}: exists={alt.exists()}")
        if alt.exists():
            HEST_META_FILE = alt
            log(f"  -> Using: {alt}")
            break
    
    if not HEST_META_FILE.exists():
        log("Skipping HEST-1k validation")
        hest_available = False
else:
    hest_available = True
    
    # Load metadata (CSV or XLSX)
    if HEST_META_FILE.suffix == '.csv':
        hest_meta = pd.read_csv(HEST_META_FILE)
    else:
        hest_meta = pd.read_excel(HEST_META_FILE)
    
    log(f"HEST metadata loaded: {hest_meta.shape[0]} samples")
    log(f"Columns: {list(hest_meta.columns[:15])}")

if hest_available:
    # Filter for cancer samples only
    hest_cancer = hest_meta[hest_meta["disease_state"] == "Cancer"].copy()
    log(f"\nCancer samples: {hest_cancer.shape[0]}")
    
    # Check cancer type distribution
    if "oncotree_code" in hest_cancer.columns:
        cancer_types = hest_cancer["oncotree_code"].value_counts()
        log(f"\nCancer type distribution:")
        for ct, count in cancer_types.items():
            organ = hest_cancer[hest_cancer["oncotree_code"]==ct]["organ"].iloc[0]
            log(f"  {ct} ({organ}): {count} samples")
        
        # Select top cancer types (at least 3 samples each)
        selected_cancers = cancer_types[cancer_types >= 3].index.tolist()[:5]
        hest_selected = hest_cancer[hest_cancer["oncotree_code"].isin(selected_cancers)].copy()
    else:
        log("WARNING: 'oncotree_code' column not found, using all cancer samples")
        hest_selected = hest_cancer.copy()
        selected_cancers = []
    
    log(f"\nSelected {len(selected_cancers)} cancer types: {selected_cancers}")
    log(f"Total samples for validation: {hest_selected.shape[0]}")

if hest_available:
    # 13.2 Load HEST samples and apply nerve injury signature
    log("\n13.2 Loading HEST samples and applying signature")
    
    # Load signature from Notebook 2 - USE GENE SYMBOLS not Ensembl IDs
    if "gene_symbol" in sig_df.columns:
        sig_genes = sig_df["gene_symbol"].dropna().tolist()
        log(f"Nerve injury signature: {len(sig_genes)} genes (using gene symbols)")
    else:
        sig_genes = sig_df["gene"].tolist()
        log(f"Nerve injury signature: {len(sig_genes)} genes (using gene IDs)")
    
    # ===== FIX: Define immune gene sets for HEST validation =====
    IMMUNE_GENE_SETS_HEST = {
        "Exhaustion": ["PDCD1", "LAG3", "HAVCR2", "TOX", "TIGIT", "CTLA4", "CD244", "CD160"],
        "Treg_suppression": ["FOXP3", "IL10", "TGFB1", "IL2RA", "ICOS", "GITR"],
        "Myeloid_suppression": ["ARG1", "IDO1", "CD274", "VEGFA", "IL10", "TGFB1", "CD163", "MRC1"],
        "Cytotoxicity": ["GZMB", "PRF1", "IFNG", "TNF", "GNLY", "NKG7", "GZMA"],
        "Activation": ["CD69", "CD25", "CD44", "ICOS", "OX40", "CD27", "CD28"],
        "IL6_axis": ["IL6", "IL6R", "IL6ST", "STAT3", "SOCS3", "JAK1"],
        "Type1_IFN": ["IFNA1", "IFNB1", "ISG15", "MX1", "OAS1", "IFIT1", "IRF7"],
        "Type2_IFN": ["IFNG", "CXCL9", "CXCL10", "IDO1", "GBP1", "STAT1"],
        "Proliferation": ["MKI67", "TOP2A", "PCNA", "CDK1", "CCNB1"],
    }
    log(f"Defined {len(IMMUNE_GENE_SETS_HEST)} immune programs for HEST validation")

    
    hest_results = []
    failed_samples = []
    
    for idx, row in hest_selected.iterrows():
        sample_id = row["id"]
        h5ad_file = HEST_DATA_DIR / f"{sample_id}.h5ad"
        
        if not h5ad_file.exists():
            log(f"  SKIP {sample_id}: file not found")
            failed_samples.append(sample_id)
            continue
        
        try:
            # Load HEST sample
            hest_adata = ad.read_h5ad(h5ad_file)
            log(f"  {sample_id}: {hest_adata.n_obs} spots, {hest_adata.n_vars} genes")
            
            # Match signature genes
            matched_genes = [g for g in sig_genes if g in hest_adata.var_names]
            overlap_pct = 100 * len(matched_genes) / len(sig_genes)
            log(f"    Signature overlap: {len(matched_genes)}/{len(sig_genes)} ({overlap_pct:.1f}%)")
            
            if len(matched_genes) < 10:
                log(f"    SKIP: insufficient gene overlap (<10)")
                failed_samples.append(sample_id)
                continue
            
            # Normalize if needed
            if "log1p_norm" not in hest_adata.layers:
                if sp.issparse(hest_adata.X):
                    hest_adata.layers["counts"] = hest_adata.X.copy()
                else:
                    hest_adata.layers["counts"] = hest_adata.X.copy()
                
                sc.pp.normalize_total(hest_adata, target_sum=1e4)
                sc.pp.log1p(hest_adata)
                hest_adata.layers["log1p_norm"] = hest_adata.X.copy()
            
            # Apply nerve injury signature
            nerve_score = score_mean_z(hest_adata, matched_genes, layer="log1p_norm")
            hest_adata.obs["nerve_injury_score"] = nerve_score
            
            # Define perineural regions (top 10%, same as training)
            thr = np.percentile(nerve_score, PARAMS["perineural_percentile"])
            is_peri = (nerve_score >= thr).astype(int)
            hest_adata.obs["is_perineural"] = is_peri
            
            n_peri = int(is_peri.sum())
            log(f"    Perineural spots: {n_peri} ({n_peri/len(is_peri)*100:.1f}%)")
            
    
            prog_scores = {}
            for prog, gene_list in IMMUNE_GENE_SETS_HEST.items():  # ← FIXED!
                matched_prog = [g for g in gene_list if g in hest_adata.var_names]
                if len(matched_prog) >= 2:
                    score = score_mean_z(hest_adata, matched_prog, layer="log1p_norm")
                    hest_adata.obs[f"immune_{prog}"] = score
                    prog_scores[prog] = score
            
            log(f"    Computed {len(prog_scores)} immune programs")  # Should now be >0!
            
            
            # Compare perineural vs control
            peri_mask = (is_peri == 1)
            
            for prog, score in prog_scores.items():
                s_peri = score[peri_mask]
                s_ctrl = score[~peri_mask]
                
                if len(s_peri) < 30 or len(s_ctrl) < 30:
                    continue
                
                u, p = mannwhitneyu(s_peri, s_ctrl, alternative="two-sided")
                d = cohens_d(s_peri, s_ctrl)
                
                hest_results.append({
                    "sample_id": sample_id,
                    "cancer_type": row["oncotree_code"],
                    "organ": row["organ"],
                    "program": prog,
                    "mean_perineural": float(s_peri.mean()),
                    "mean_control": float(s_ctrl.mean()),
                    "cohens_d": float(d),
                    "mann_whitney_p": float(p),
                    "n_spots": int(hest_adata.n_obs),
                    "n_perineural": int(n_peri),
                    "sig_overlap_pct": float(overlap_pct),
                })
            
            # Save annotated HEST sample
            hest_out = PROCESSED_DIR / f"hest_{sample_id}_annotated.h5ad"
            hest_adata.write_h5ad(hest_out)
            
        except Exception as e:
            log(f"  ERROR {sample_id}: {type(e).__name__}: {e}")
            failed_samples.append(sample_id)
    
    log(f"\nProcessed: {len(hest_selected) - len(failed_samples)}/{hest_selected.shape[0]} samples")
    if len(failed_samples) > 0:
        log(f"Failed (first 20): {failed_samples[:20]}")

if hest_available and len(hest_results) > 0:
    # 13.3 Validation analysis
    log("\n13.3 Validation Analysis")
    
    hest_df = pd.DataFrame(hest_results)
    log(f"Total comparisons: {hest_df.shape[0]}")
    
    # FDR correction
    hest_df["p_adj"] = multipletests(hest_df["mann_whitney_p"].values, method="fdr_bh")[1]
    
    # Count replications
    sig_mask = (hest_df["p_adj"] < PARAMS["fdr_threshold"]) & (hest_df["cohens_d"] > 0)
    n_sig = sig_mask.sum()
    n_total = hest_df.shape[0]
    replication_rate = 100 * n_sig / n_total
    
    log(f"\nReplication: {n_sig}/{n_total} ({replication_rate:.1f}%) at FDR<0.05")
    
    # Per-program replication
    prog_replication = hest_df.groupby("program").agg({
        "cohens_d": ["mean", "std"],
        "p_adj": lambda x: (x < PARAMS["fdr_threshold"]).sum(),
    })
    prog_replication.columns = ["mean_d", "std_d", "n_sig"]
    prog_replication["n_total"] = hest_df.groupby("program").size()
    prog_replication["replication_pct"] = 100 * prog_replication["n_sig"] / prog_replication["n_total"]
    
    log("\nPer-program replication:")
    for prog, row in prog_replication.iterrows():
        log(f"  {prog}: {row['n_sig']}/{row['n_total']} ({row['replication_pct']:.1f}%) | d={row['mean_d']:.3f}±{row['std_d']:.3f}")
    
    # Per-cancer replication
    cancer_replication = hest_df.groupby("cancer_type").agg({
        "cohens_d": ["mean", "std"],
        "p_adj": lambda x: (x < PARAMS["fdr_threshold"]).sum(),
    })
    cancer_replication.columns = ["mean_d", "std_d", "n_sig"]
    cancer_replication["n_total"] = hest_df.groupby("cancer_type").size()
    cancer_replication["replication_pct"] = 100 * cancer_replication["n_sig"] / cancer_replication["n_total"]
    
    log("\nPer-cancer replication:")
    for ct, row in cancer_replication.iterrows():
        log(f"  {ct}: {row['n_sig']}/{row['n_total']} ({row['replication_pct']:.1f}%) | d={row['mean_d']:.3f}±{row['std_d']:.3f}")

if hest_available and len(hest_results) > 0:
    # 13.4 Meta-analysis across cancer types
    log("\n13.4 Meta-analysis")
    
    from scipy.stats import norm
    
    # For each program, compute pooled effect size
    meta_results = []
    
    for prog in hest_df["program"].unique():
        prog_data = hest_df[hest_df["program"] == prog].copy()
        
        # Random effects meta-analysis (inverse variance weighting)
        prog_data["se"] = np.sqrt(2 / prog_data["n_perineural"] + 2 / (prog_data["n_spots"] - prog_data["n_perineural"]))
        prog_data["weight"] = 1 / (prog_data["se"]**2)
        
        pooled_d = np.average(prog_data["cohens_d"], weights=prog_data["weight"])
        pooled_se = np.sqrt(1 / prog_data["weight"].sum())
        
        # 95% CI
        ci_lower = pooled_d - 1.96 * pooled_se
        ci_upper = pooled_d + 1.96 * pooled_se
        
        # Z-test
        z = pooled_d / pooled_se
        p_meta = 2 * (1 - norm.cdf(np.abs(z)))
        
        # Heterogeneity (I^2)
        q = np.sum(prog_data["weight"] * (prog_data["cohens_d"] - pooled_d)**2)
        df = len(prog_data) - 1
        i2 = max(0, 100 * (q - df) / q) if q > 0 else 0
        
        meta_results.append({
            "program": prog,
            "pooled_d": float(pooled_d),
            "pooled_se": float(pooled_se),
            "ci_lower": float(ci_lower),
            "ci_upper": float(ci_upper),
            "p_meta": float(p_meta),
            "i2": float(i2),
            "n_studies": int(len(prog_data)),
        })
    
    meta_df = pd.DataFrame(meta_results)
    meta_df["p_adj"] = multipletests(meta_df["p_meta"].values, method="fdr_bh")[1]
    meta_df = meta_df.sort_values("pooled_d", ascending=False)
    
    log("\nMeta-analysis results (pooled effect sizes):")
    for _, r in meta_df.iterrows():
        sig = "***" if r["p_adj"] < 0.001 else "**" if r["p_adj"] < 0.01 else "*" if r["p_adj"] < 0.05 else ""
        log(f"  {r['program']}: d={r['pooled_d']:.3f} [{r['ci_lower']:.3f}, {r['ci_upper']:.3f}] | I²={r['i2']:.1f}% {sig}")

if hest_available and len(hest_results) > 0:
    # Main Figure 5: HEST-1k validation results
    log("\nMaking Main Figure 5: HEST-1k validation")
    
    # Increased figure height to accommodate all programs
    fig = plt.figure(figsize=(7.2, 7.5), dpi=SAVE_DPI)
    gs = fig.add_gridspec(2, 2, hspace=0.45, wspace=0.35, 
                          height_ratios=[1, 1])
    
    # Panel A: Cancer type overview (vertical bars, better colors)
    ax1 = fig.add_subplot(gs[0, 0])
    cancer_counts = hest_df.groupby("cancer_type")["sample_id"].nunique().sort_values(ascending=False)
    
    # Color gradient from dark blue to light blue
    colors_a = plt.cm.Blues(np.linspace(0.6, 0.9, len(cancer_counts)))
    ax1.bar(range(len(cancer_counts)), cancer_counts.values, color=colors_a, 
            edgecolor='black', linewidth=0.5, width=0.7)
    ax1.set_xticks(range(len(cancer_counts)))
    ax1.set_xticklabels(cancer_counts.index, fontsize=7, rotation=0)
    ax1.set_ylabel("N samples", fontsize=7)
    ax1.set_title("A  HEST-1k cancer types", fontsize=7, pad=10)
    ax1.grid(axis='y', alpha=0.3, linewidth=0.5)
    ax1.set_axisbelow(True)
    
    # Panel B: Forest plot (meta-analysis) - show all 9 with proper spacing
    ax2 = fig.add_subplot(gs[0, 1])
    top_progs = meta_df.head(9)  # All 9 programs
    y_pos = np.arange(len(top_progs)) * 1.2  # Increase spacing between programs
    
    for i, (_, r) in enumerate(top_progs.iterrows()):
        y = y_pos[i]
        ax2.plot([r["ci_lower"], r["ci_upper"]], [y, y], 'k-', linewidth=1.0)
        color = "#d62728" if r["p_adj"] < 0.05 else "#7f7f7f"
        ax2.scatter(r["pooled_d"], y, s=60, c=color, zorder=3, 
                   edgecolor='black', linewidth=0.5)
    
    ax2.axvline(x=0, linestyle="--", color="gray", linewidth=0.8, alpha=0.6)
    ax2.set_yticks(y_pos)
    ax2.set_yticklabels(top_progs["program"].values, fontsize=7)
    ax2.invert_yaxis()
    ax2.set_xlabel("Pooled Cohen's d", fontsize=7)
    ax2.set_title("B  Meta-analysis", fontsize=7, pad=10)
    ax2.grid(axis='x', alpha=0.3, linewidth=0.5)
    ax2.set_axisbelow(True)
    # Set y-axis limits to prevent crowding
    ax2.set_ylim(y_pos[-1] - 0.6, y_pos[0] + 0.6)
    
    # Panel C: Example validation sample
    ax3 = fig.add_subplot(gs[1, 0])
    example_id = hest_df["sample_id"].iloc[0]
    example_file = PROCESSED_DIR / f"hest_{example_id}_annotated.h5ad"
    
    if example_file.exists():
        ex_adata = ad.read_h5ad(example_file)
        coords = ex_adata.obsm["spatial"]
        scores = ex_adata.obs["nerve_injury_score"].values
        
        scat = ax3.scatter(coords[:,0], coords[:,1], c=scores, s=1.5, 
                          cmap="RdYlBu_r", alpha=0.9, rasterized=True,
                          vmin=np.percentile(scores, 2), vmax=np.percentile(scores, 98))
        ax3.set_title(f"C  Example: {example_id}", fontsize=7, pad=10)
        ax3.set_xticks([]); ax3.set_yticks([])
        ax3.set_xlabel("", fontsize=7)
        ax3.set_ylabel("", fontsize=7)
        cbar = plt.colorbar(scat, ax=ax3, fraction=0.046, pad=0.04)
        cbar.ax.tick_params(labelsize=5.5)
        cbar.set_label("Nerve injury score", fontsize=6.5)
    else:
        ax3.text(0.5, 0.5, "Example not available", ha="center", va="center", 
                transform=ax3.transAxes, fontsize=7)
        ax3.set_xticks([]); ax3.set_yticks([])
    
    # Panel D: Replication summary - with proper spacing
    ax4 = fig.add_subplot(gs[1, 1])
    
    prog_rep = prog_replication.sort_values("replication_pct", ascending=True)
    y_pos_d = np.arange(len(prog_rep)) * 1.2  # Increase spacing between bars
    
    # Enhanced color scheme: gradient based on replication rate
    colors_d = []
    for x in prog_rep["replication_pct"].values:
        if x >= 70:
            colors_d.append("#2ca02c")  # Strong green
        elif x >= 60:
            colors_d.append("#78c878")  # Medium green
        elif x >= 50:
            colors_d.append("#bcdebc")  # Light green
        elif x >= 40:
            colors_d.append("#ffb3b3")  # Light red
        else:
            colors_d.append("#ff6b6b")  # Strong red
    
    ax4.barh(y_pos_d, prog_rep["replication_pct"].values, color=colors_d, 
             edgecolor='black', linewidth=0.5, height=0.8)
    ax4.set_yticks(y_pos_d)
    ax4.set_yticklabels(prog_rep.index, fontsize=7)
    ax4.set_xlabel("Replication rate (%)", fontsize=7)
    ax4.set_title("D  Replication summary", fontsize=7, pad=10)
    ax4.axvline(x=50, linestyle="--", color="gray", linewidth=0.8, alpha=0.6)
    ax4.grid(axis='x', alpha=0.3, linewidth=0.5)
    ax4.set_axisbelow(True)
    ax4.set_xlim(0, 85)
    # Set y-axis limits to prevent crowding
    ax4.set_ylim(y_pos_d[0] - 0.6, y_pos_d[-1] + 0.6)
    
    fig.suptitle("Main Figure 5: HEST-1k external validation", y=0.98, fontsize=8, fontweight='bold')
    fig5_path = FIGURES_DIR / "Main_Figure_5.png"
    fig.savefig(fig5_path, dpi=SAVE_DPI, bbox_inches="tight")
    plt.close(fig)
    log(f"Saved {fig5_path}")
    
    # Supplementary Table 11
    supp11_path = TABLES_DIR / "Supplementary_Table_11.xlsx"
    with pd.ExcelWriter(supp11_path, engine="openpyxl") as w:
        hest_df.to_excel(w, index=False, sheet_name="HEST_Validation_All")
        meta_df.to_excel(w, index=False, sheet_name="Meta_Analysis")
        prog_replication.to_excel(w, sheet_name="Program_Replication")
        cancer_replication.to_excel(w, sheet_name="Cancer_Replication")
    log(f"Saved {supp11_path}")
    
    # Supplementary Figure 7
    log("\nMaking Supplementary Figure 7: Extended HEST-1k validation")
    
    n_cancers = len(hest_df["cancer_type"].unique())
    n_cols = min(3, n_cancers)
    n_rows = (n_cancers + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(7.2, 2.5*n_rows), dpi=SAVE_DPI)
    if n_cancers == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    for i, ct in enumerate(hest_df["cancer_type"].unique()):
        if i >= len(axes):
            break
        
        ct_data = hest_df[hest_df["cancer_type"] == ct]
        prog_d = ct_data.groupby("program")["cohens_d"].mean().sort_values(ascending=True)
        
        y_pos = np.arange(len(prog_d))
        colors = ["salmon" if d < 0 else "lightgreen" for d in prog_d.values]
        
        axes[i].barh(y_pos, prog_d.values, color=colors)
        axes[i].set_yticks(y_pos)
        axes[i].set_yticklabels(prog_d.index, fontsize=6)
        axes[i].axvline(x=0, linestyle="--", color="gray", linewidth=0.8)
        axes[i].set_xlabel("Cohen's d", fontsize=6)
        axes[i].set_title(ct, fontsize=7)
        axes[i].tick_params(labelsize=6)
    
    # Hide extra axes
    for j in range(i+1, len(axes)):
        axes[j].set_visible(False)
    
    fig.suptitle("Supplementary Figure 7: HEST-1k validation per cancer type", y=0.98)
    supp_fig7_path = FIGURES_DIR / "Supplementary_Figure_7.png"
    fig.tight_layout()
    fig.savefig(supp_fig7_path, dpi=SAVE_DPI, bbox_inches="tight")
    plt.close(fig)
    log(f"Saved {supp_fig7_path}")

else:
    log("\nHEST-1k validation skipped (no data or processing failed)")
    
    # Create placeholder
    fig, axes = plt.subplots(2, 2, figsize=(7.2, 6.0), dpi=SAVE_DPI)
    
    axes[0,0].text(0.5, 0.5, "HEST-1k cancer type map\n(no data)", 
                  ha="center", va="center", transform=axes[0,0].transAxes)
    axes[0,0].set_title("A  HEST-1k overview")
    axes[0,0].set_xticks([]); axes[0,0].set_yticks([])
    
    axes[0,1].text(0.5, 0.5, "Forest plot\neffect sizes\n(no data)", 
                  ha="center", va="center", transform=axes[0,1].transAxes)
    axes[0,1].set_title("B  Meta-analysis")
    axes[0,1].set_xticks([]); axes[0,1].set_yticks([])
    
    axes[1,0].text(0.5, 0.5, "Example validation\n(no data)", 
                  ha="center", va="center", transform=axes[1,0].transAxes)
    axes[1,0].set_title("C  Example validation")
    axes[1,0].set_xticks([]); axes[1,0].set_yticks([])
    
    axes[1,1].text(0.5, 0.5, "Validation summary\n(no data)", 
                  ha="center", va="center", transform=axes[1,1].transAxes)
    axes[1,1].set_title("D  Summary")
    axes[1,1].set_xticks([]); axes[1,1].set_yticks([])
    
    fig.suptitle("Main Figure 5: HEST-1k external validation (placeholder)", y=0.98)
    fig5_path = FIGURES_DIR / "Main_Figure_5_placeholder.png"
    fig.tight_layout()
    fig.savefig(fig5_path, dpi=SAVE_DPI, bbox_inches="tight")
    plt.close(fig)
    log(f"Saved {fig5_path}")

# SECTION 14: Final Exports

log("\nSECTION 14: Final Exports")

# Export comprehensive results
results_export = {
    "perineural_definitions": supp8_df,
    "immune_programs_visium": results_df,
    "spatial_correlations": spatial_corr_df,
    "lr_interactions": lr_df,
}

if geomx_available and 'response_df' in locals() and response_df.shape[0] > 0:
    results_export["geomx_response"] = response_df

for name, df in results_export.items():
    out_path = PROCESSED_DIR / f"{name}.csv"
    df.to_csv(out_path, index=False)
    log(f"Saved {out_path}")

# Save updated AnnData
adata_out = PROCESSED_DIR / "adata_with_immune_scores.h5ad"
adata.write_h5ad(adata_out)
log(f"Saved {adata_out}")

# Session info
session_info = {
    "run_timestamp": RUN_TS,
    "python_version": sys.version.split()[0],
    "scanpy_version": sc.__version__,
    "random_seed": RANDOM_SEED,
    "parameters": PARAMS,
    "n_spots": int(adata.n_obs),
    "n_genes": int(adata.n_vars),
    "n_samples": int(adata.obs["sample_id"].nunique()),
    "geomx_available": geomx_available,
}

session_path = PROCESSED_DIR / "session_info.json"
with open(session_path, "w") as f:
    json.dump(session_info, f, indent=2)
log(f"Saved {session_path}")

log("\n" + "="*80)
log("NOTEBOOK 3 COMPLETE")
log("="*80)
log(f"End: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
log(f"\nOutputs:")
log(f"  Figures: {FIGURES_DIR}")
log(f"  Tables: {TABLES_DIR}")
log(f"  Processed: {PROCESSED_DIR}")

print("\nDone.")
print(f"Main figures: {FIGURES_DIR}")
print(f"Tables: {TABLES_DIR}")
print(f"Processed data: {PROCESSED_DIR}")