In [None]:
%pip install -q GEOparse scanpy anndata scikit-learn matplotlib seaborn pandas numpy joblib

# Breast cancer ML pipeline

This notebook splits the original large cell into modular cells and adds caching: if merged outputs exist in `project_output`, the notebook will load them instead of re-downloading GSE datasets.

In [None]:
# Imports and global settings
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import warnings
import joblib
warnings.filterwarnings("ignore")

# GEO parsing and batch correction libraries
import GEOparse
import scanpy as sc

# sklearn: selection, models, metrics
from sklearn.model_selection import train_test_split
from sklearn.feature_selection import VarianceThreshold
from sklearn.linear_model import LogisticRegressionCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, classification_report
from sklearn.preprocessing import StandardScaler, LabelEncoder
import seaborn as sns

RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)

def ensure_dir(p):
    Path(p).mkdir(parents=True, exist_ok=True)

def infer_label_from_gsm(gsm: GEOparse.GSM) -> str:
    """Try to infer 'cancer' or 'normal' from GSM metadata fields heuristically."""
    text_fields = []
    for key in ["characteristics_ch1", "source_name_ch1", "description", "title"]:
        val = gsm.metadata.get(key)
        if val:
            if isinstance(val, list):
                text_fields.extend(val)
            else:
                text_fields.append(str(val))
    combined = " ".join([str(x).lower() for x in text_fields])
    if re.search(r"normal|healthy|control|adjacent|non[-\s]?tumor|non[-\s]?tumour", combined):
        return "normal"
    if re.search(r"tumor|tumour|cancer|carcinoma|brca|malignant|diseased", combined):
        return "cancer"
    return None

In [None]:
# Download & extract single GSE (keeps GEOparse caching behavior)
def download_and_extract_gse(gse_id, outdir="geo_raw"):
    ensure_dir(outdir)
    print(f"Fetching {gse_id} (GEOparse will reuse cached files in {outdir} if present)...")
    gse = GEOparse.get_GEO(geo=gse_id, destdir=outdir, silent=True)
    # Try series matrix table first
    if hasattr(gse, "table") and isinstance(gse.table, pd.DataFrame) and not gse.table.empty:
        df = gse.table.copy()
        sample_cols = [c for c in df.columns if c.startswith("GSM")]
        if sample_cols:
            proc = df[sample_cols].transpose()
            proc.index.name = "sample"
            proc.columns = df.iloc[:, 0].values
            proc = proc.apply(pd.to_numeric, errors='coerce')
            sample_meta = []
            for gsm_id in proc.index:
                gsm = gse.gsms.get(gsm_id)
                sample_meta.append({"sample": gsm_id, "label": infer_label_from_gsm(gsm), **{k: (v if not isinstance(v, list) else ";".join(v)) for k,v in gsm.metadata.items()}})
            meta = pd.DataFrame(sample_meta).set_index("sample")
            return proc, meta
    # Fallback: iterate GSM tables
    exprs = {}
    meta_rows = {}
    for gsm_id, gsm in gse.gsms.items():
        if hasattr(gsm, "table") and isinstance(gsm.table, pd.DataFrame) and not gsm.table.empty:
            tbl = gsm.table.copy()
            val_col = None
            for c in tbl.columns:
                if c.lower() in ["value", "exprs", "intensity", "signal", "log2"]:
                    val_col = c
                    break
            if val_col is None:
                val_col = tbl.columns[-1]
            idx = None
            for c in tbl.columns:
                if c.lower() in ["id", "id_ref", "probe_id", "gene_id"]:
                    idx = c
                    break
            if idx is None:
                idx = tbl.columns[0]
            s = tbl.set_index(idx)[val_col]
            exprs[gsm_id] = s
            meta_rows[gsm_id] = {"label": infer_label_from_gsm(gsm), **{k: (v if not isinstance(v, list) else ";".join(v)) for k,v in gsm.metadata.items()}}
    if not exprs:
        raise RuntimeError(f"No expression tables found for {gse_id} using GEOparse.")
    df_all = pd.concat(exprs, axis=1)
    if isinstance(df_all.columns, pd.MultiIndex):
        df_all.columns = [c[0] for c in df_all.columns]
    proc = df_all.transpose().astype(float)
    meta = pd.DataFrame.from_dict(meta_rows, orient='index')
    proc.index.name = "sample"
    return proc, meta

In [None]:
# Build merged dataset from multiple GSEs
def build_merged_dataset(gse_list, outdir="geo_raw", manual_label_csv=None):
    exprs = []
    metas = []
    for gse in gse_list:
        X, meta = download_and_extract_gse(gse, outdir=outdir)
        batch_name = gse
        meta = meta.copy()
        meta["batch"] = batch_name
        meta.index = meta.index.astype(str)
        X.index = X.index.astype(str)
        exprs.append(X)
        metas.append(meta)
    # Align columns - intersection preferred
    common_genes = set.intersection(*[set(df.columns) for df in exprs])
    if len(common_genes) < 50:
        print("Intersection small; using union and filling missing values.")
        merged_X = pd.concat(exprs, axis=0, sort=True).fillna(np.nan)
    else:
        print(f"Using intersection of genes: {len(common_genes)} genes.")
        merged_X = pd.concat([df.loc[:, sorted(common_genes)] for df in exprs], axis=0)
    merged_meta = pd.concat(metas, axis=0)
    # Try manual labels if provided
    if manual_label_csv and os.path.exists(manual_label_csv):
        manual = pd.read_csv(manual_label_csv)
        manual.index = manual['sample'].astype(str)
        merged_meta.loc[manual.index, "label"] = manual['GROUP'].values
    # Drop unlabeled
    if "label" not in merged_meta.columns:
        merged_meta["label"] = None
    n_unlabeled = merged_meta["label"].isnull().sum()
    if n_unlabeled > 0:
        print(f"Warning: {n_unlabeled} samples remain unlabeled. Dropping them.")
        keep = merged_meta[~merged_meta["label"].isnull()].index
        merged_meta = merged_meta.loc[keep]
        merged_X = merged_X.loc[keep]
    merged_meta["label"] = merged_meta["label"].astype(str).str.lower().replace({"tumor": "cancer", "tumour": "cancer", "brca": "cancer", "control": "normal", "adjacent normal": "normal"})
    merged_meta = merged_meta[merged_meta["label"].isin(["cancer", "normal"]) ]
    merged_X = merged_X.loc[merged_meta.index]
    return merged_X, merged_meta

In [None]:
# Preprocessing and batch correction
def preprocess_and_batch_correct(X_df: pd.DataFrame, meta_df: pd.DataFrame, log_transform=True):
    X = X_df.copy()
    X.columns = X.columns.astype(str)
    X = X.fillna(X.median(axis=0))
    if log_transform:
        X = np.log1p(X)
    
    scaler = StandardScaler(with_mean=True, with_std=True)
    X_scaled = pd.DataFrame(scaler.fit_transform(X), index=X.index, columns=X.columns)
    adata = sc.AnnData(X_scaled)
    adata.obs["batch"] = meta_df.loc[adata.obs_names, "batch"].values
    sc.pp.combat(adata, key="batch")
    X_corrected = pd.DataFrame(adata.X, index=X.index, columns=X.columns)
    return X_corrected, scaler

In [None]:
def run_full_pipeline(
    gse_list,
    geo_raw_dir="geo_raw",
    manual_label_csv=None,
    save_processed=True,
    processed_dir="processed_data"
    # results_dir="results_dir"
                                ):
    """
    Loads processed data if available. Otherwise runs:
    GEO download/merge → preprocess → batch correct → save
    """

    # ---------------------------------------------------------------------
    # STEP 0: CHECK IF PROCESSED FILES ALREADY EXIST
    # ---------------------------------------------------------------------
    Xc_path = os.path.join(processed_dir, "X_corrected.csv")
    meta_path = os.path.join(processed_dir, "metadata.csv")

    if os.path.exists(Xc_path) and os.path.exists(meta_path):
        print("\n==============================")
        print("LOADING CACHED PROCESSED DATA")
        print("==============================")

        X_corrected = pd.read_csv(Xc_path, index_col=0)
        meta = pd.read_csv(meta_path, index_col=0)

        # Ensure label and batch exist
        meta["label"] = meta["label"].astype(str)
        X_corrected.columns = X_corrected.columns.astype(str)

        # Dummy scaler (not needed later)
        scaler = None
        
        return X_corrected, meta, scaler

    # ---------------------------------------------------------------------
    # OTHERWISE → RUN FULL PREPROCESSING PIPELINE
    # ---------------------------------------------------------------------

    print("\n==============================")
    print("STEP 1: BUILD MERGED DATASET")
    print("==============================")

    X_raw, meta = build_merged_dataset(
        gse_list=gse_list,
        outdir=geo_raw_dir,
        manual_label_csv=manual_label_csv
    )

    print(f"\nMerged shape: {X_raw.shape}")
    print(meta['label'].value_counts())

    print("\n==============================")
    print("STEP 2: PREPROCESS + BATCH CORRECT")
    print("==============================")

    X_corrected, scaler = preprocess_and_batch_correct(X_raw, meta)
    
    ensure_dir(processed_dir)
    joblib.dump(scaler, os.path.join(processed_dir, "scaler.joblib"))
    print("Scaler saved to:", os.path.join(processed_dir, "scaler.joblib"))
    if save_processed:
        ensure_dir(processed_dir)
        X_corrected.to_csv(Xc_path)
        meta.to_csv(meta_path)
        X_raw.to_csv(os.path.join(processed_dir, "X_raw.csv"))
        print(f"\nProcessed data saved in: {processed_dir}/")

    return X_corrected, meta, scaler

In [None]:
def feature_selection(X, y, var_thresh=0.01, max_features=500):
    """
    SAFE FEATURE SELECTION PIPELINE
    - Attempts variance threshold
    - If removed all features → skip variance threshold
    - Applies L1 logistic selection
    - If still empty → fall back to top variance genes
    """

    print("\n==============================")
    print("STEP 3: FEATURE SELECTION")
    print("==============================")

    # ----------------------------
    # STEP 1: Variance Threshold
    # ----------------------------
    try:
        vt = VarianceThreshold(threshold=var_thresh)
        X_v = pd.DataFrame(
            vt.fit_transform(X),
            index=X.index,
            columns=X.columns[vt.get_support()]
        )
        print(f"Features after variance threshold: {X_v.shape[1]}")

    except ValueError:
        print("⚠️ Variance threshold removed ALL features. Skipping variance threshold.")
        X_v = X.copy()

    # If still zero features
    if X_v.shape[1] == 0:
        print("⚠️ No features remain after threshold. Falling back to top variance genes.")
        var_series = X.var(axis=0).sort_values(ascending=False)
        top_genes = var_series.head(200).index  # take 200 highest variance genes
        X_v = X[top_genes]
        print(f"Selected top 200 variance genes: {len(top_genes)}")

    # Safety: keep column names strings
    X_v.columns = X_v.columns.astype(str)

    # ----------------------------
    # STEP 2: L1 Logistic Selection
    # ----------------------------

    try:
        lr = LogisticRegressionCV(
            Cs=10,
            cv=5,
            penalty="l1",
            solver="saga",
            scoring="roc_auc",
            max_iter=2000,
            random_state=RANDOM_STATE,
            n_jobs=-1
        )
        lr.fit(X_v, y)

        coef = np.abs(lr.coef_).ravel()
        sel_mask = coef > 1e-6
        selected = X_v.columns[sel_mask]
        print(f"L1 selected: {len(selected)}")

        # If L1 selects nothing → fallback
        if len(selected) == 0:
            print("⚠️ L1 selected ZERO features. Falling back to top variance genes.")
            var_series = X.var(axis=0).sort_values(ascending=False)
            selected = var_series.head(max_features).index
            X_sel = X[selected]
            return X_sel, list(selected)

        # Truncate to max_features if needed
        if len(selected) > max_features:
            ranks = np.argsort(-coef[sel_mask])
            selected = selected[ranks[:max_features]]
            print(f"Truncated to top {max_features} L1 features.")

        X_sel = X_v[selected]
        return X_sel, list(selected)

    except Exception as e:
        print("⚠️ L1 selection failed due to:", e)
        print("Fallback: Selecting top variance genes.")

        var_series = X.var(axis=0).sort_values(ascending=False)
        selected = var_series.head(max_features).index
        X_sel = X[selected]
        return X_sel, list(selected)

In [None]:


def train_and_evaluate(X, y, outdir="results"):
    ensure_dir(outdir)

    print("\n==============================")
    print("STEP 4: TRAIN & EVALUATE MODELS")
    print("==============================")

    # Encode labels
    le = LabelEncoder()
    y_enc = le.fit_transform(y)
    classes = le.classes_

    # Split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y_enc, stratify=y_enc, test_size=0.2, random_state=RANDOM_STATE
    )
    print(f"Train samples: {X_train.shape[0]}, Test: {X_test.shape[0]}")

    # -------------------- Random Forest --------------------
    rf = RandomForestClassifier(
        n_estimators=300,
        random_state=RANDOM_STATE,
        n_jobs=-1
    )
    rf.fit(X_train, y_train)
    y_proba_rf = rf.predict_proba(X_test)[:, 1]
    y_pred_rf = rf.predict(X_test)
    auc_rf = roc_auc_score(y_test, y_proba_rf)
    print(f"RF ROC-AUC = {auc_rf:.4f}")

    # ------------------------- MLP --------------------------
    mlp = MLPClassifier(hidden_layer_sizes=(512, 128),
                        max_iter=500,
                        random_state=RANDOM_STATE,
                        early_stopping=True)
    mlp.fit(X_train, y_train)
    y_proba_mlp = mlp.predict_proba(X_test)[:, 1]
    y_pred_mlp = mlp.predict(X_test)
    auc_mlp = roc_auc_score(y_test, y_proba_mlp)
    print(f"MLP ROC-AUC = {auc_mlp:.4f}")

    # Save models

    joblib.dump(rf, os.path.join(outdir, "rf_model.joblib"))
    joblib.dump(mlp, os.path.join(outdir, "mlp_model.joblib"))

    # Reports
    print("\nRandom Forest classification report:")
    print(classification_report(y_test, y_pred_rf, target_names=classes))

    print("\nMLP classification report:")
    print(classification_report(y_test, y_pred_mlp, target_names=classes))

    # ROC plot
    plt.figure(figsize=(6, 6))
    fpr_rf, tpr_rf, _ = roc_curve(y_test, y_proba_rf)
    fpr_mlp, tpr_mlp, _ = roc_curve(y_test, y_proba_mlp)
    plt.plot(fpr_rf, tpr_rf, label=f"RF (AUC={auc_rf:.3f})")
    plt.plot(fpr_mlp, tpr_mlp, label=f"MLP (AUC={auc_mlp:.3f})")
    plt.plot([0, 1], [0, 1], "k--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curves")
    plt.legend()
    plt.savefig(os.path.join(outdir, "roc_curves.png"), dpi=200)

    # Confusion Matrix (RF)
    cm = confusion_matrix(y_test, y_pred_rf)
    plt.figure(figsize=(4, 4))
    sns.heatmap(cm, annot=True, fmt="d", xticklabels=classes, yticklabels=classes)
    plt.title("RF Confusion Matrix")
    plt.savefig(os.path.join(outdir, "rf_confusion_matrix.png"), dpi=200)

    # RF top features heatmap
    importances = rf.feature_importances_
    top_idx = np.argsort(importances)[-30:][::-1]
    top_genes = X.columns[top_idx]

    sns.clustermap(X[top_genes].T, z_score=0, cmap="vlag", figsize=(8, 10))
    plt.suptitle("Heatmap (Top 30 Features)")
    plt.savefig(os.path.join(outdir, "top_features_heatmap.png"), dpi=200)

    return {
        "rf_auc": auc_rf,
        "mlp_auc": auc_mlp,
        "rf_model": rf,
        "mlp_model": mlp,
        "label_encoder": le,
        "top_features": list(top_genes)
    }



In [None]:
def run_full_training_pipeline(
    gse_list,
    geo_raw_dir="geo_raw",
    manual_label_csv=None,
    processed_dir="processed_data",
    results_dir="results",
    var_thresh=0.01,
    max_features=500
):

    print("\n======================================")
    print(" FULL TRAINING PIPELINE ")
    print("======================================")

    # Step 1+2 are cached
    X_corrected, meta, scaler = run_full_pipeline(
        gse_list=gse_list,
        geo_raw_dir=geo_raw_dir,
        manual_label_csv=manual_label_csv,
        save_processed=True,
        processed_dir=processed_dir
        # results_dir=results_dir
    )

    # Step 3: Feature selection
    y = meta["label"]
    X_selected, selected_genes = feature_selection(
        X_corrected, y,
        var_thresh=var_thresh,
        max_features=max_features
    )
    ensure_dir(results_dir)
    with open(os.path.join(results_dir, "selected_genes.txt"), "w") as f:
        for gene in selected_genes:
            f.write(f"{gene}\n")
    print("Selected genes saved to:", os.path.join(results_dir, "selected_genes.txt"))

    # Step 4: Train models
    results = train_and_evaluate(
        X_selected,
        y,
        outdir=results_dir
    )

    results["selected_genes"] = selected_genes
    return results

In [None]:
gse_ids = ["GSE2034", "GSE15852"]
# "GSE70947", "GSE42568"
results = run_full_training_pipeline(
    gse_list=gse_ids,
    geo_raw_dir="geo_raw",
    processed_dir="processed_data",
    results_dir="results"
)