Elastic-net Logistic Regression validation on an external cohort (NO fixed weights).

Goal:
- Evaluate whether a *gene set* (e.g., your 37 genes) can classify AD vs Control
  in an external dataset, using nested cross-validation (hyperparameter tuning inside CV).

What it does:
1) Loads expression gene×sample CSV + metadata CSV
2) Builds samples×genes matrix (HGNC symbol rows collapsed by mean)
3) Keeps only genes in GENE_LIST that exist in the dataset
4) Runs nested CV:
   - Outer CV: performance estimation (RepeatedStratifiedKFold)
   - Inner CV: tune elastic-net (C, l1_ratio) using GridSearchCV
5) Produces pooled out-of-fold probabilities (one per sample)
6) Reports:
   - Fold metrics (mean ± SD) at THRESHOLD
   - Pooled OOF metrics at THRESHOLD
   - Optional: pooled OOF metrics at "best threshold" (max balanced accuracy)
   - Bootstrap 95% CI on pooled OOF predictions
7) Saves results/oof_predictions.csv

In [1]:
from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from sklearn.calibration import calibration_curve
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    average_precision_score,
    balanced_accuracy_score,
    brier_score_loss,
    confusion_matrix,
    matthews_corrcoef,
    precision_recall_curve,
    roc_auc_score,
    roc_curve,
)
from sklearn.model_selection import (
    GridSearchCV,
    RepeatedStratifiedKFold,
    StratifiedKFold,
)
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

In [7]:
EXPR_CSV = Path(r"data/GSE125583/DE_data/GSE125583_log2cpm_SELECTED_GENES.csv")
META_CSV = Path(r"data/GSE125583/DE_data/metadata_200samples.csv")

SAMPLE_COL = "geo_accession"
LABEL_COL = "diagnosis:ch1"
POSITIVE_LABEL = "Alzheimer's disease"

# If known, set it; else set None to use last column
GENE_COL_NAME: Optional[str] = "Gene"

# Paste your 37 genes here (HGNC symbols)
GENE_LIST: List[str] = [
    "ADAM33", "AEBP1", "CCDC102A","CLDN9", "GFAP","HSPB1","HSPB7","KANK2", "KLF15", "MRGPRF", "NUPR1", "PIK3R5", "PRELP", "PRX", "TCEA3", "TMPRSS5", "CHML", "ELOVL4",
    "GAD1", "GAD2", "HPRT1", "ITFG1", "MAS1", "NAP1L5", "NCALD", "NEUROD6", "NRN1", "OPN3", "RAB3B", "RAB3C", "RGS4", "RPH3A", "SCG2", "SERPINI1", "STAT4", "TRIM36"
]

# CV / evaluation
N_SPLITS = 5
N_REPEATS = 20
INNER_SPLITS = 5
SEED = 42
THRESHOLD = 0.50
BOOTSTRAP = 1000

# Elastic-net tuning grid
C_GRID = np.logspace(-3, 3, 13)
L1_RATIO_GRID = np.linspace(0.0, 1.0, 11)

# Outputs
OUTDIR = Path("results")
PREDICTIONS_OUT = OUTDIR / "oof_predictions.csv"
BESTPARAMS_OUT = OUTDIR / "best_params_per_fold.json"


In [9]:
@dataclass
class Metrics:
    auc: float
    auprc: float
    balanced_accuracy: float
    sensitivity: float
    specificity: float
    mcc: float
    brier: float


def expected_calibration_error(y_true: np.ndarray, y_prob: np.ndarray, bins: int = 10) -> float:
    frac_pos, mean_pred = calibration_curve(y_true, y_prob, n_bins=bins, strategy="quantile")
    if len(frac_pos) == 0:
        return float("nan")
    return float(np.mean(np.abs(frac_pos - mean_pred)))


def metrics_at_threshold(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> Metrics:
    y_pred = (y_prob >= threshold).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()

    sens = tp / (tp + fn) if (tp + fn) > 0 else float("nan")
    spec = tn / (tn + fp) if (tn + fp) > 0 else float("nan")

    return Metrics(
        auc=float(roc_auc_score(y_true, y_prob)),
        auprc=float(average_precision_score(y_true, y_prob)),
        balanced_accuracy=float(balanced_accuracy_score(y_true, y_pred)),
        sensitivity=float(sens),
        specificity=float(spec),
        mcc=float(matthews_corrcoef(y_true, y_pred)),
        brier=float(brier_score_loss(y_true, y_prob)),
    )


def summarize(name: str, values: List[float]) -> str:
    arr = np.array(values, dtype=float)
    return f"{name}: {np.nanmean(arr):.3f} ± {np.nanstd(arr):.3f}"


def bootstrap_ci(
    y_true: np.ndarray,
    y_prob: np.ndarray,
    threshold: float,
    n_bootstrap: int,
    seed: int,
) -> Dict[str, Tuple[float, float]]:
    rng = np.random.default_rng(seed)
    n = len(y_true)

    keys = ["auc", "auprc", "balanced_accuracy", "sensitivity", "specificity", "mcc", "brier"]
    stats = {k: [] for k in keys}

    for _ in range(n_bootstrap):
        idx = rng.integers(0, n, size=n)
        yt = y_true[idx]
        yp = y_prob[idx]
        if len(np.unique(yt)) < 2:
            continue

        m = metrics_at_threshold(yt, yp, threshold)
        stats["auc"].append(m.auc)
        stats["auprc"].append(m.auprc)
        stats["balanced_accuracy"].append(m.balanced_accuracy)
        stats["sensitivity"].append(m.sensitivity)
        stats["specificity"].append(m.specificity)
        stats["mcc"].append(m.mcc)
        stats["brier"].append(m.brier)

    ci = {}
    for k in keys:
        vals = stats[k]
        if len(vals) == 0:
            ci[k] = (float("nan"), float("nan"))
        else:
            lo, hi = np.percentile(vals, [2.5, 97.5])
            ci[k] = (float(lo), float(hi))
    return ci


def choose_threshold_max_balanced_accuracy(y_true: np.ndarray, y_prob: np.ndarray) -> float:
    fpr, tpr, thr = roc_curve(y_true, y_prob)

    # remove non-finite
    finite = np.isfinite(thr)
    fpr, tpr, thr = fpr[finite], tpr[finite], thr[finite]

    # avoid degenerate ~0 thresholds (common when scores are extremely close)
    keep = thr > 1e-6
    if keep.sum() == 0:
        keep = thr > 0
    if keep.sum() == 0:
        keep = np.ones_like(thr, dtype=bool)

    fpr2, tpr2, thr2 = fpr[keep], tpr[keep], thr[keep]
    bal_acc = (tpr2 + (1 - fpr2)) / 2.0
    best_idx = int(np.argmax(bal_acc))
    return float(thr2[best_idx])


def build_samples_x_genes(expr_csv: Path, meta_csv: Path) -> pd.DataFrame:
    expr = pd.read_csv(expr_csv)
    meta = pd.read_csv(meta_csv)

    gene_col = GENE_COL_NAME if GENE_COL_NAME is not None else expr.columns[-1]
    if gene_col not in expr.columns:
        raise ValueError(f"GENE_COL_NAME='{GENE_COL_NAME}' not found in expression CSV columns.")
    if SAMPLE_COL not in meta.columns:
        raise ValueError(f"SAMPLE_COL='{SAMPLE_COL}' not found in metadata CSV columns.")
    if LABEL_COL not in meta.columns:
        raise ValueError(f"LABEL_COL='{LABEL_COL}' not found in metadata CSV columns.")

    # sample columns = everything except gene_col and (likely) an ID column
    exclude = {gene_col}
    first_col = expr.columns[0]
    if first_col != gene_col and (
        first_col.lower() in {"id", "index", "ensembl", "ensg"} or expr[first_col].dtype == object
    ):
        exclude.add(first_col)

    sample_cols = [c for c in expr.columns if c not in exclude]
    gx = expr[[gene_col] + sample_cols].dropna(subset=[gene_col]).copy()
    gx[gene_col] = gx[gene_col].astype(str)

    # collapse duplicated symbols
    gx = gx.groupby(gene_col, as_index=True)[sample_cols].mean()

    # transpose -> samples × genes
    sxg = gx.T
    sxg.index.name = SAMPLE_COL
    sxg = sxg.reset_index()

    df = meta[[SAMPLE_COL, LABEL_COL]].merge(sxg, on=SAMPLE_COL, how="inner")
    if df.empty:
        raise ValueError(
            "After merging metadata and expression, got 0 rows. "
            "Check that metadata sample IDs match the expression sample column names."
        )
    return df

In [10]:
def main() -> None:
    if len(GENE_LIST) == 0:
        raise ValueError("GENE_LIST is empty. Paste your 37 genes into GENE_LIST first.")

    OUTDIR.mkdir(parents=True, exist_ok=True)

    df = build_samples_x_genes(EXPR_CSV, META_CSV)

    y = (df[LABEL_COL].astype(str) == str(POSITIVE_LABEL)).astype(int).to_numpy()
    if y.sum() == 0 or y.sum() == len(y):
        raise ValueError("Both classes are required (need positives and negatives).")

    # Use only the provided gene set
    genes = [g for g in GENE_LIST if g in df.columns]
    missing = sorted(set(GENE_LIST) - set(genes))

    print(f"Samples after merge: {len(df)} | AD: {int(y.sum())} | Control: {int(len(y) - y.sum())}")
    print(f"Genes requested: {len(GENE_LIST)} | present: {len(genes)} | missing: {len(missing)}")
    if missing:
        print("Missing genes (first 15):", missing[:15])

    if len(genes) < 5:
        raise ValueError("Too few genes present from GENE_LIST after intersection. Check gene symbols/case.")

    X = df.loc[:, genes].copy()
    X = X.apply(pd.to_numeric, errors="coerce")
    if X.isna().any().any():
        raise ValueError("Expression matrix contains non-numeric values after coercion. Fix input CSV.")

    # Model + tuning
    pipe = Pipeline([
        ("scaler", StandardScaler()),
        ("clf", LogisticRegression(
            solver="saga",
            penalty="elasticnet",
            class_weight="balanced",
            max_iter=20000,
            random_state=SEED,
        )),
    ])

    param_grid = {
        "clf__C": C_GRID,
        "clf__l1_ratio": L1_RATIO_GRID,
    }

    outer_cv = RepeatedStratifiedKFold(n_splits=N_SPLITS, n_repeats=N_REPEATS, random_state=SEED)
    inner_cv = StratifiedKFold(n_splits=INNER_SPLITS, shuffle=True, random_state=SEED)

    pooled_prob = np.full(len(y), np.nan, dtype=float)
    per_fold: List[Metrics] = []
    best_params_per_fold: List[Dict[str, float]] = []

    total_folds = N_SPLITS * N_REPEATS
    for fold_idx, (train_idx, test_idx) in enumerate(outer_cv.split(X, y), start=1):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        gs = GridSearchCV(
            estimator=pipe,
            param_grid=param_grid,
            scoring="roc_auc",   # change to "average_precision" if you prefer PR-focused tuning
            cv=inner_cv,
            n_jobs=-1,
            refit=True,
        )
        gs.fit(X_train, y_train)
        y_prob = gs.predict_proba(X_test)[:, 1]

        pooled_prob[test_idx] = y_prob
        per_fold.append(metrics_at_threshold(y_test, y_prob, THRESHOLD))

        best_params_per_fold.append({
            "C": float(gs.best_params_["clf__C"]),
            "l1_ratio": float(gs.best_params_["clf__l1_ratio"]),
        })

        if fold_idx % N_SPLITS == 0:
            print(f"Processed fold {fold_idx}/{total_folds}")

    valid = np.isfinite(pooled_prob)
    y_true_oof = y[valid]
    y_prob_oof = pooled_prob[valid]

    # Threshold summaries
    best_threshold = choose_threshold_max_balanced_accuracy(y_true_oof, y_prob_oof)
    pred_pos_rate = float((y_prob_oof >= best_threshold).mean())

    print(f"\nOptimal threshold (max balanced accuracy on pooled OOF): {best_threshold:.6f}")
    m_opt = metrics_at_threshold(y_true_oof, y_prob_oof, best_threshold)
    print(f"Predicted positive rate @ optimal threshold: {pred_pos_rate:.3f}")
    print("\n=== Pooled OOF metrics @ optimal threshold ===")
    print(f"AUROC:             {m_opt.auc:.3f}")
    print(f"AUPRC:             {m_opt.auprc:.3f}")
    print(f"Balanced Accuracy: {m_opt.balanced_accuracy:.3f}")
    print(f"Sensitivity:       {m_opt.sensitivity:.3f}")
    print(f"Specificity:       {m_opt.specificity:.3f}")
    print(f"MCC:               {m_opt.mcc:.3f}")
    print(f"Brier:             {m_opt.brier:.3f}")

    # Fold metrics @ fixed threshold
    print(f"\n=== Fold metrics @ threshold={THRESHOLD:.2f} (mean ± SD) ===")
    print(summarize("AUROC", [m.auc for m in per_fold]))
    print(summarize("AUPRC", [m.auprc for m in per_fold]))
    print(summarize("Balanced Accuracy", [m.balanced_accuracy for m in per_fold]))
    print(summarize("Sensitivity", [m.sensitivity for m in per_fold]))
    print(summarize("Specificity", [m.specificity for m in per_fold]))
    print(summarize("MCC", [m.mcc for m in per_fold]))
    print(summarize("Brier", [m.brier for m in per_fold]))

    # Bootstrap CI on pooled OOF predictions @ optimal threshold
    ci = bootstrap_ci(y_true_oof, y_prob_oof, best_threshold, BOOTSTRAP, SEED)
    print("\n=== Bootstrapped 95% CI (pooled OOF, using optimal threshold) ===")
    for k, (lo, hi) in ci.items():
        print(f"{k}: [{lo:.3f}, {hi:.3f}]")

    # PR + calibration summaries (threshold-free)
    prevalence = float(y_true_oof.mean())
    auprc_oof = float(average_precision_score(y_true_oof, y_prob_oof))
    precision, recall, _ = precision_recall_curve(y_true_oof, y_prob_oof)
    ece = expected_calibration_error(y_true_oof, y_prob_oof, bins=10)

    print(f"\nClass prevalence (AD): {prevalence:.3f}")
    print(f"No-skill AUPRC baseline: {prevalence:.3f}")
    print(f"AUPRC lift over baseline: {auprc_oof - prevalence:.3f}")
    print(f"PR curve points: {len(precision)}")
    print(f"Expected calibration error (10-bin): {ece:.3f}")

    # Save OOF predictions
    out = df[[SAMPLE_COL, LABEL_COL]].copy()
    out["y_true"] = y
    out["y_prob_oof"] = pooled_prob
    out["y_pred_oof_thresh0p5"] = (pooled_prob >= THRESHOLD).astype(int)
    out["y_pred_oof_optthr"] = (pooled_prob >= best_threshold).astype(int)

    out.to_csv(PREDICTIONS_OUT, index=False)
    Path(BESTPARAMS_OUT).write_text(json.dumps(best_params_per_fold, indent=2))

    print(f"\nSaved out-of-fold predictions to: {PREDICTIONS_OUT.resolve()}")
    print(f"Saved best params per fold to:       {BESTPARAMS_OUT.resolve()}")


In [11]:
if __name__ == "__main__":
    main()

Samples after merge: 200 | AD: 158 | Control: 42
Genes requested: 36 | present: 36 | missing: 0




Processed fold 5/100




Processed fold 10/100




Processed fold 15/100




Processed fold 20/100




Processed fold 25/100




Processed fold 30/100




Processed fold 35/100




Processed fold 40/100




Processed fold 45/100




Processed fold 50/100




Processed fold 55/100




Processed fold 60/100




Processed fold 65/100




Processed fold 70/100




Processed fold 75/100




Processed fold 80/100




Processed fold 85/100




Processed fold 90/100




Processed fold 95/100




Processed fold 100/100

Optimal threshold (max balanced accuracy on pooled OOF): 0.729248
Predicted positive rate @ optimal threshold: 0.595

=== Pooled OOF metrics @ optimal threshold ===
AUROC:             0.866
AUPRC:             0.963
Balanced Accuracy: 0.816
Sensitivity:       0.728
Specificity:       0.905
MCC:               0.525
Brier:             0.141

=== Fold metrics @ threshold=0.50 (mean ± SD) ===
AUROC: 0.860 ± 0.056
AUPRC: 0.959 ± 0.019
Balanced Accuracy: 0.763 ± 0.081
Sensitivity: 0.842 ± 0.068
Specificity: 0.685 ± 0.163
MCC: 0.489 ± 0.143
Brier: 0.147 ± 0.038

=== Bootstrapped 95% CI (pooled OOF, using optimal threshold) ===
auc: [0.806, 0.914]
auprc: [0.943, 0.979]
balanced_accuracy: [0.751, 0.869]
sensitivity: [0.652, 0.792]
specificity: [0.800, 0.978]
mcc: [0.401, 0.626]
brier: [0.108, 0.176]

Class prevalence (AD): 0.790
No-skill AUPRC baseline: 0.790
AUPRC lift over baseline: 0.173
PR curve points: 201
Expected calibration error (10-bin): 0.100

Saved out-of-fold