<a href="https://colab.research.google.com/github/Jay99Sohn/GEOexosome/blob/main/GEOexosome.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Cell 0: Environment & Library Setup

# Optional: install required packages when running on Colab
# Uncomment the line below if running on Google Colab for the first time:
!pip install GEOparse imbalanced-learn shap seaborn matplotlib

# ============================================================
# Standard Library & Third-party Imports
# ============================================================

import os
import sys
import random
import json

import numpy as np
import pandas as pd
import GEOparse

from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    roc_auc_score,
    roc_curve,
    accuracy_score
)
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression

from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline as ImbPipeline

import matplotlib.pyplot as plt
import seaborn as sns
import shap

# ============================================================
# Configuration & Reproducibility
# ============================================================

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)

# Set visualization style for reproducible plots
sns.set_style("whitegrid")
plt.rcParams["figure.dpi"] = 100
plt.rcParams["savefig.dpi"] = 300

print("\n" + "=" * 60)
print("ENVIRONMENT & LIBRARY SETUP")
print("=" * 60)

# ============================================================
# Environment Detection & Path Configuration
# ============================================================

if "google.colab" in sys.modules:
    from google.colab import drive
    print("[INFO] Google Colab detected. Mounting Google Drive...")
    drive.mount("/content/drive")
    base_save_path = "/content/drive/MyDrive/geoexosome_results"
    print(f"[INFO] Drive mounted. Saving results to: {base_save_path}")
else:
    base_save_path = "./geoexosome_results"
    print(f"[INFO] Local environment detected. Saving results to: {base_save_path}")

os.makedirs(base_save_path, exist_ok=True)

print("\n" + "=" * 60)
print("✓ Setup Complete")
print(f"  - Random seed: {SEED}")
print(f"  - Output path: {base_save_path}")
print("=" * 60 + "\n")


Collecting GEOparse
  Downloading GEOparse-2.0.4-py3-none-any.whl.metadata (6.5 kB)
Downloading GEOparse-2.0.4-py3-none-any.whl (29 kB)
Installing collected packages: GEOparse
Successfully installed GEOparse-2.0.4

ENVIRONMENT & LIBRARY SETUP
[INFO] Google Colab detected. Mounting Google Drive...
Mounted at /content/drive
[INFO] Drive mounted. Saving results to: /content/drive/MyDrive/geoexosome_results

✓ Setup Complete
  - Random seed: 42
  - Output path: /content/drive/MyDrive/geoexosome_results



In [None]:
# ==============================================================================
# Cell 1: GEO Dataset Loading and Quality Control
# ==============================================================================
"""
Purpose:
    Download and parse the GSE39833 dataset from NCBI GEO, extract expression
    matrices, assign sample labels with full traceability, and perform
    comprehensive quality control checks.

Dataset:
    GSE39833 - Serum exosome miRNA microarray from colorectal cancer patients
    Platform: GPL14767 (Exiqon miRNA microarray)
    Samples: 99 total (11 healthy controls, 88 CRC patients)

Outputs:
    1. df_expression: DataFrame with probe-level expression values and labels
    2. mapping_df: Probe-to-miRNA mapping table
    3. label_assignment_log.csv: Full traceability of label assignments
    4. probe_to_miRNA_mapping.csv: Complete annotation mapping
    5. unmapped_probes.txt: List of probes without miRNA annotation
    6. data_quality_report.txt: Comprehensive QC summary for Methods section

Quality Controls Implemented:
    - Probe ID consistency validation across all samples
    - Missing value detection and quantification
    - Constant feature detection
    - miRNA mapping coverage assessment
    - Expression value range verification

Author: [Jungho Sohn]
Date: 2025-12-20
Version: 1.0
"""

# ==============================================================================
# STEP 1: LOAD GEO DATASET FROM NCBI
# ==============================================================================
print("\n" + "=" * 80)
print("STEP 1: LOAD GEO DATASET (GSE39833)")
print("=" * 80)

gse_id = "GSE39833"
print(f"[INFO] Downloading GEO dataset: {gse_id}")
print("[INFO] This may take 1-2 minutes depending on network speed...")

# Download and parse GEO dataset with platform annotation
# annotate_gpl=True ensures GPL annotation is included for probe-to-miRNA mapping
gse = GEOparse.get_GEO(
    geo=gse_id,
    destdir="./data",
    annotate_gpl=True
)

print(f"[INFO] Successfully loaded {gse_id}")
print(f"  - GSM samples: {len(gse.gsms)}")
print(f"  - GPL platforms: {len(gse.gpls)}")

# ==============================================================================
# STEP 2: EXTRACT EXPRESSION MATRIX AND ASSIGN SAMPLE LABELS
# ==============================================================================
"""
Label Assignment Strategy:
    Priority 1: Sample title parsing (most reliable for this dataset)
        - "hc_*" → Healthy control (label=0)
        - "crc*" → CRC patient (label=1)

    Priority 2: Metadata characteristics (fallback)
        - Cancer keywords: tnm, stage, cancer, adenocarcinoma, tumor
        - Healthy keywords: healthy, control, normal

    All label assignments are logged with their source for transparency and
    reproducibility. This log can be included in Supplementary Materials.
"""
print("\n" + "=" * 80)
print("STEP 2: EXTRACT EXPRESSION MATRIX AND ASSIGN LABELS")
print("=" * 80)

# Initialize containers for expression data and labels
samples = []           # Sample IDs (GSM accessions)
expression_rows = []   # Expression values for each sample
labels = []            # Binary labels (0=healthy, 1=CRC)
label_assignment_log = []  # Traceability log for label assignments

# Iterate through all samples in the GEO dataset
for gsm_name, gsm in gse.gsms.items():

    # Validate that expression data is available
    tbl = gsm.table
    if "VALUE" not in tbl.columns:
        print(f"[WARNING] {gsm_name} missing VALUE column. Skipping...")
        continue

    # Extract raw expression values and convert to float
    expr_vals = tbl["VALUE"].astype(float).values
    expression_rows.append(expr_vals)
    samples.append(gsm_name)

    # -------------------------------------------------------------------------
    # Label Assignment with Source Tracking
    # -------------------------------------------------------------------------
    title_list = gsm.metadata.get("title", [""])
    title = title_list[0].lower()
    label_value = None
    label_source = None

    # PRIMARY METHOD: Title-based labeling
    # This is the most reliable method for GSE39833 as samples follow
    # a consistent naming convention
    if title.startswith("hc_"):
        # Healthy control samples
        label_value = 0
        label_source = f"title (starts with 'hc_')"
    elif title.startswith("crc"):
        # CRC patient samples (includes CRC1, CRC2, CRC3a, CRC3b, CRC4 stages)
        label_value = 1
        label_source = f"title (starts with 'crc')"

    # FALLBACK METHOD: Metadata-based labeling
    # Used only if title-based labeling fails
    # This ensures robustness against potential metadata inconsistencies
    if label_value is None:
        characteristics = (
            gsm.metadata.get("characteristics_ch1", []) +
            gsm.metadata.get("characteristics_ch2", [])
        )
        chars_low = [c.lower() for c in characteristics]

        # Define keyword lists for pattern matching
        cancer_keywords = ["tnm", "stage", "cancer", "adenocarcinoma", "tumor"]
        healthy_keywords = ["healthy", "control", "normal"]

        # Check for cancer indicators in metadata
        if any(keyword in c for keyword in cancer_keywords for c in chars_low):
            label_value = 1
            label_source = "metadata (cancer-related keywords detected)"

        # Check for healthy control indicators in metadata
        elif any(keyword in c for keyword in healthy_keywords for c in chars_low):
            label_value = 0
            label_source = "metadata (healthy control keywords detected)"

    # -------------------------------------------------------------------------
    # Error Handling: Failed Label Assignment
    # -------------------------------------------------------------------------
    # If both primary and fallback methods fail, halt execution and display
    # detailed metadata to enable manual verification and rule updates
    if label_value is None:
        print("\n" + "=" * 80)
        print(f"[ERROR] Unable to determine label for sample: {gsm_name}")
        print("=" * 80)
        print(f"\nSample Metadata:")
        print(f"  - Title: {title_list}")
        print(f"  - Characteristics (ch1): {gsm.metadata.get('characteristics_ch1', [])}")
        print(f"  - Characteristics (ch2): {gsm.metadata.get('characteristics_ch2', [])}")
        print(f"  - Source: {gsm.metadata.get('source_name_ch1', ['N/A'])}")
        print(f"  - Description: {gsm.metadata.get('description', ['N/A'])}")
        print(f"\nPossible Causes:")
        print(f"  1. Unexpected metadata format (not matching expected patterns)")
        print(f"  2. Sample naming convention differs from other samples")
        print(f"  3. Ambiguous or missing label information in metadata")
        print(f"\nAction Required:")
        print(f"  Please verify the sample metadata above and update the label")
        print(f"  assignment logic in this cell accordingly.")
        print("=" * 80 + "\n")
        raise ValueError(f"Label assignment failed for {gsm_name}")

    # Record successful label assignment with source for transparency
    labels.append(label_value)
    label_assignment_log.append({
        "Sample_ID": gsm_name,
        "Label": label_value,
        "Label_Name": "Healthy_Control" if label_value == 0 else "CRC_Patient",
        "Assignment_Source": label_source,
        "Sample_Title": title_list[0]
    })

# Convert list of expression arrays to 2D numpy array
# Shape: (n_samples, n_probes)
expression_data = np.vstack(expression_rows)

print(f"[INFO] Successfully extracted expression data for {len(samples)} samples")
print(f"  - Healthy controls (label=0): {labels.count(0)}")
print(f"  - CRC patients (label=1): {labels.count(1)}")

# ==============================================================================
# STEP 3: BUILD EXPRESSION DATAFRAME WITH PROBE-LEVEL DATA
# ==============================================================================
"""
Data Structure:
    - Rows: Samples (GSM IDs)
    - Columns: Probe IDs + 'label' column
    - Values: Raw microarray intensity values

Quality Check:
    Validate that all samples have identical probe IDs in the same order.
    This is critical for ensuring data integrity in downstream analyses.
"""
print("\n" + "=" * 80)
print("STEP 3: BUILD EXPRESSION DATAFRAME")
print("=" * 80)

# Extract probe IDs from the first sample as reference
first_gsm = gse.gsms[samples[0]]
probe_ids = first_gsm.table["ID_REF"].tolist()

# -------------------------------------------------------------------------
# Quality Control: Validate Probe ID Consistency
# -------------------------------------------------------------------------
# Verify that all samples have identical probe IDs in identical order
# This check prevents silent errors from probe ID mismatches
print("[INFO] Validating probe consistency across all samples...")

# Check first 3 samples and last sample for efficiency
# Full validation is computationally expensive for large datasets
samples_to_check = samples[:3] + [samples[-1]] if len(samples) > 3 else samples
for gsm_name in samples_to_check:
    current_probes = gse.gsms[gsm_name].table["ID_REF"].tolist()

    # Check probe count
    if len(current_probes) != len(probe_ids):
        raise ValueError(
            f"[ERROR] Probe count mismatch detected in {gsm_name}\n"
            f"Expected {len(probe_ids)} probes matching {samples[0]}, "
            f"but found {len(current_probes)} probes."
        )

    # Check probe order
    if current_probes != probe_ids:
        raise ValueError(
            f"[ERROR] Probe order mismatch detected in {gsm_name}\n"
            f"Probe IDs do not match the reference sample {samples[0]}."
        )

print(f"[INFO] ✓ Probe consistency verified across {len(samples_to_check)} samples")
print(f"[INFO] All samples contain {len(probe_ids)} probes in identical order")

# Create DataFrame with probe IDs as columns and sample IDs as index
df_expression = pd.DataFrame(
    expression_data,
    columns=probe_ids,
    index=samples
)
df_expression["label"] = labels

print(f"\n[INFO] Expression DataFrame created")
print(f"  - Shape: {df_expression.shape}")
print(f"  - Samples: {df_expression.shape[0]}")
print(f"  - Probes (features): {df_expression.shape[1] - 1}")  # Excluding 'label' column
print(f"\n[INFO] Label distribution:")
print(df_expression["label"].value_counts().to_string())

# Save label assignment log for manuscript transparency
# This file should be included in Supplementary Materials
label_log_df = pd.DataFrame(label_assignment_log)
label_log_path = os.path.join(base_save_path, "label_assignment_log.csv")
label_log_df.to_csv(label_log_path, index=False)
print(f"\n[INFO] Label assignment log saved to: {label_log_path}")
print("[NOTE] Include this file in Supplementary Materials for full transparency")

# ==============================================================================
# STEP 4: BUILD PROBE-TO-miRNA MAPPING FROM PLATFORM ANNOTATION
# ==============================================================================
"""
Purpose:
    Map microarray probe IDs to known miRNA identifiers using the GPL
    platform annotation file. This enables biological interpretation of
    features in downstream analysis.

Coverage Assessment:
    Calculate and report the percentage of probes successfully mapped to
    miRNAs. Low coverage (<60%) may indicate platform compatibility issues.

Unmapped Probes:
    Probes without miRNA annotation will be excluded from downstream analysis
    to ensure all features have biological interpretability. The list of
    excluded probes is saved for transparency.
"""
print("\n" + "=" * 80)
print("STEP 4: PROBE-TO-miRNA MAPPING")
print("=" * 80)

print("[INFO] Loading platform (GPL) annotation...")

# Extract platform annotation table
gpl = list(gse.gpls.values())[0]
gpl_table = gpl.table

# Validate GPL table structure
if "ID" not in gpl_table.columns:
    raise KeyError("[ERROR] GPL table missing 'ID' column. Cannot build mapping.")

# Identify miRNA annotation column
# Look for columns containing 'mir' (case-insensitive)
mirna_cols = [c for c in gpl_table.columns if "mir" in c.lower()]

# -------------------------------------------------------------------------
# Handle Missing miRNA Annotation
# -------------------------------------------------------------------------
# Initialize coverage_pct to prevent NameError in QC section
if len(mirna_cols) == 0:
    print("[WARNING] No miRNA annotation column detected in GPL table.")
    print("[WARNING] Probe-to-miRNA mapping will be unavailable.")
    mapping_df = None
    n_mapped = 0
    n_total = len(probe_ids)
    coverage_pct = 0.0
else:
    # Use the first miRNA annotation column found
    mirna_col = mirna_cols[0]
    print(f"[INFO] Using miRNA annotation column: '{mirna_col}'")

    # Build probe-to-miRNA dictionary for fast lookup
    probe_to_mirna = dict(zip(gpl_table["ID"], gpl_table[mirna_col]))

    # Map all probe IDs to miRNA names (NaN if not found)
    mirna_names = [probe_to_mirna.get(pid, np.nan) for pid in probe_ids]

    # -------------------------------------------------------------------------
    # Calculate Mapping Coverage Statistics
    # -------------------------------------------------------------------------
    n_mapped = sum(pd.notna(m) for m in mirna_names)
    n_total = len(probe_ids)
    coverage_pct = 100.0 * n_mapped / n_total

    print(f"\n[INFO] Mapping Coverage:")
    print(f"  - Total probes: {n_total}")
    print(f"  - Successfully mapped: {n_mapped} ({coverage_pct:.1f}%)")
    print(f"  - Unmapped probes: {n_total - n_mapped} ({100 - coverage_pct:.1f}%)")

    # Create mapping DataFrame
    mapping_df = pd.DataFrame({
        "Probe_ID": probe_ids,
        "miRNA": mirna_names
    })

    # Save complete mapping table
    mapping_path = os.path.join(base_save_path, "probe_to_miRNA_mapping.csv")
    mapping_df.to_csv(mapping_path, index=False)
    print(f"\n[INFO] Probe-to-miRNA mapping saved to: {mapping_path}")

    # -------------------------------------------------------------------------
    # Save Unmapped Probes for Transparency
    # -------------------------------------------------------------------------
    # Document which probes were excluded and why
    # This justifies feature exclusion in the manuscript
    unmapped_probes = [
        probe for probe, mirna in zip(probe_ids, mirna_names)
        if pd.isna(mirna)
    ]

    if unmapped_probes:
        unmapped_path = os.path.join(base_save_path, "unmapped_probes.txt")
        with open(unmapped_path, 'w', encoding='utf-8') as f:
            f.write(f"Unmapped Probes (no miRNA annotation): {len(unmapped_probes)} total\n")
            f.write("=" * 80 + "\n\n")
            f.write("These probes will be excluded in downstream preprocessing (Cell 2) ")
            f.write("due to lack of miRNA annotation in the platform GPL file.\n\n")
            f.write("This exclusion ensures that all analyzed features have biological ")
            f.write("interpretability as known miRNAs.\n\n")
            f.write("List of unmapped probe IDs:\n")
            f.write("-" * 80 + "\n")
            for probe in unmapped_probes:
                f.write(f"{probe}\n")
        print(f"[INFO] Unmapped probes list saved to: {unmapped_path}")
        print("[NOTE] Include in Supplementary Materials to justify feature exclusion")

# ==============================================================================
# STEP 5: COMPREHENSIVE DATA QUALITY VALIDATION
# ==============================================================================
"""
Quality Control Metrics:
    1. Missing values: Count and percentage of NaN/null values
    2. Constant probes: Features with zero variance (uninformative)
    3. Expression range: Min/max values to detect outliers or errors
    4. Summary statistics: Mean, SD for manuscript reporting

Assessment Criteria:
    - Missing values: PASS if <0.01%, WARNING if 0.01-5%, FAIL if >5%
    - Constant probes: PASS if 0, WARNING otherwise
    - Mapping coverage: PASS if ≥75%, NOTICE if 60-75%, WARNING if <60%

Output:
    Comprehensive report file (data_quality_report.txt) formatted for
    direct use in manuscript Methods section.
"""
print("\n" + "=" * 80)
print("STEP 5: DATA QUALITY SUMMARY")
print("=" * 80)

# Extract feature columns (exclude 'label' column)
feature_cols = [col for col in df_expression.columns if col != "label"]
expr_matrix = df_expression[feature_cols]

# -------------------------------------------------------------------------
# Calculate Quality Metrics
# -------------------------------------------------------------------------

# Dataset dimensions
n_samples = df_expression.shape[0]
n_features = len(feature_cols)

# Missing value analysis
n_missing = expr_matrix.isna().sum().sum()
total_values = n_samples * n_features
missing_pct = 100.0 * n_missing / total_values

# Constant feature detection (variance = 0)
expr_var = expr_matrix.var(axis=0)
n_constant = (expr_var == 0).sum()

# Expression value statistics
# Use global statistics rather than feature-wise averages for clarity
expr_min = expr_matrix.min().min()
expr_max = expr_matrix.max().max()
expr_mean_global = expr_matrix.mean().mean()        # Mean of all values
expr_std_global = expr_matrix.to_numpy().std()      # SD of all values
expr_feature_std_median = expr_matrix.std().median()  # Median SD across features

# -------------------------------------------------------------------------
# Display Quality Control Summary
# -------------------------------------------------------------------------

print(f"\n[Dataset Dimensions]")
print(f"  - Total samples: {n_samples}")
print(f"  - Total probes (features): {n_features}")
print(f"  - Healthy controls: {(df_expression['label'] == 0).sum()}")
print(f"  - CRC patients: {(df_expression['label'] == 1).sum()}")

print(f"\n[Data Quality Metrics]")
print(f"  - Missing values: {n_missing} ({missing_pct:.4f}% of all measurements)")
print(f"  - Constant probes (variance = 0): {n_constant}")
print(f"  - Expression value range: [{expr_min:.2f}, {expr_max:.2f}]")
print(f"  - Mean expression (global): {expr_mean_global:.2f}")
print(f"  - SD (global): {expr_std_global:.2f}")
print(f"  - Median SD across features: {expr_feature_std_median:.2f}")

# -------------------------------------------------------------------------
# Quality Control Assessment with Pass/Warning/Fail Criteria
# -------------------------------------------------------------------------

print(f"\n[Quality Control Assessment]")

# Check 1: Constant probes
if n_constant > 0:
    print(f"  ⚠ WARNING: {n_constant} constant probes detected.")
    print(f"     → These should be removed before feature selection.")
else:
    print(f"  ✓ PASS: No constant probes detected.")

# Check 2: Missing values
if missing_pct > 5.0:
    print(f"  ⚠ WARNING: Missing values exceed 5% threshold ({missing_pct:.2f}%).")
    print(f"     → Consider imputation or removal of problematic probes.")
elif missing_pct > 0.01:
    print(f"  ⚠ NOTICE: Low level of missing values detected ({missing_pct:.4f}%).")
    print(f"     → Acceptable for most analyses without imputation.")
else:
    print(f"  ✓ PASS: Negligible missing values ({missing_pct:.4f}%).")

# Check 3: miRNA mapping coverage
if coverage_pct < 60:
    print(f"  ⚠ WARNING: miRNA mapping coverage is low ({coverage_pct:.1f}%).")
    print(f"     → Verify platform annotation compatibility.")
elif coverage_pct < 75:
    print(f"  ⚠ NOTICE: Moderate miRNA mapping coverage ({coverage_pct:.1f}%).")
    print(f"     → Acceptable for most downstream analyses.")
else:
    print(f"  ✓ PASS: Good miRNA mapping coverage ({coverage_pct:.1f}%).")

# -------------------------------------------------------------------------
# Save Comprehensive Quality Report for Manuscript
# -------------------------------------------------------------------------
# This report is formatted for direct use in the Methods section
# and provides all necessary QC information for reproducibility

qc_report_path = os.path.join(base_save_path, "data_quality_report.txt")
with open(qc_report_path, 'w', encoding='utf-8') as f:
    f.write("=" * 80 + "\n")
    f.write("DATA QUALITY REPORT - GSE39833\n")
    f.write("=" * 80 + "\n\n")
    f.write(f"Dataset: {gse_id}\n")
    f.write(f"Analysis Date: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")

    f.write("DATASET DIMENSIONS\n")
    f.write("-" * 80 + "\n")
    f.write(f"Total samples: {n_samples}\n")
    f.write(f"Total probes: {n_features}\n")
    f.write(f"Healthy controls: {(df_expression['label'] == 0).sum()}\n")
    f.write(f"CRC patients: {(df_expression['label'] == 1).sum()}\n\n")

    f.write("DATA QUALITY METRICS\n")
    f.write("-" * 80 + "\n")
    f.write(f"Missing values: {n_missing} ({missing_pct:.4f}%)\n")
    f.write(f"Constant probes: {n_constant}\n")
    f.write(f"Expression range: [{expr_min:.2f}, {expr_max:.2f}]\n")
    f.write(f"Mean (global): {expr_mean_global:.2f}\n")
    f.write(f"SD (global): {expr_std_global:.2f}\n")
    f.write(f"Median SD across features: {expr_feature_std_median:.2f}\n\n")

    f.write("miRNA MAPPING COVERAGE\n")
    f.write("-" * 80 + "\n")
    f.write(f"Total probes: {n_total}\n")
    f.write(f"Mapped probes: {n_mapped} ({coverage_pct:.1f}%)\n")
    f.write(f"Unmapped probes: {n_total - n_mapped} ({100 - coverage_pct:.1f}%)\n\n")

    f.write("QUALITY CONTROL ASSESSMENT\n")
    f.write("-" * 80 + "\n")
    f.write(f"Constant probes: {'PASS' if n_constant == 0 else 'WARNING'}\n")
    f.write(f"Missing values: {'PASS' if missing_pct < 0.01 else 'WARNING/NOTICE'}\n")
    f.write(f"Mapping coverage: {'PASS' if coverage_pct >= 75 else 'WARNING/NOTICE'}\n\n")

    f.write("NOTES FOR MANUSCRIPT (Methods Section)\n")
    f.write("-" * 80 + "\n")
    f.write("Use the following information when writing the Methods section:\n\n")

    f.write(f"1. Sample Composition:\n")
    f.write(f"   The GSE39833 dataset comprised {n_samples} serum exosome samples, ")
    f.write(f"including {(df_expression['label'] == 0).sum()} healthy controls and ")
    f.write(f"{(df_expression['label'] == 1).sum()} colorectal cancer (CRC) patients.\n\n")

    f.write(f"2. Data Quality:\n")
    f.write(f"   Data quality was verified prior to analysis. ")
    if n_constant == 0:
        f.write(f"No constant features were detected. ")
    else:
        f.write(f"{n_constant} constant features were identified and removed. ")
    f.write(f"Missing values accounted for {missing_pct:.4f}% of all measurements")
    if missing_pct < 0.01:
        f.write(f"; no imputation was performed due to negligible missingness.\n\n")
    else:
        f.write(f".\n\n")

    f.write(f"3. Feature Annotation:\n")
    f.write(f"   Of the {n_total} microarray probes, {n_mapped} ({coverage_pct:.1f}%) were ")
    f.write(f"successfully mapped to known miRNAs in the platform annotation file. ")
    f.write(f"Unmapped probes were excluded from downstream analysis to ensure ")
    f.write(f"biological interpretability of all features.\n\n")

    f.write(f"4. Label Assignment:\n")
    f.write(f"   Sample labels were assigned based on standardized metadata fields ")
    f.write(f"(sample titles and characteristics). All label assignments were recorded ")
    f.write(f"in a traceability log (label_assignment_log.csv) to ensure transparency ")
    f.write(f"and reproducibility.\n")

print(f"\n[INFO] Comprehensive quality report saved to: {qc_report_path}")
print("[NOTE] Use this report when writing the Methods section")

# ==============================================================================
# FINAL SUMMARY
# ==============================================================================
print("\n" + "=" * 80)
print("✓ DATASET LOADING AND QUALITY CONTROL COMPLETE")
print("=" * 80)
print(f"\nFiles saved to: {base_save_path}")
print(f"  1. label_assignment_log.csv - Sample label traceability")
print(f"  2. probe_to_miRNA_mapping.csv - Probe annotation mapping")
if n_total - n_mapped > 0:
    print(f"  3. unmapped_probes.txt - Probes without miRNA annotation")
    print(f"  4. data_quality_report.txt - Comprehensive QC summary")
else:
    print(f"  3. data_quality_report.txt - Comprehensive QC summary")

print(f"\n[Summary Statistics]")
print(f"  - Dataset: {gse_id}")
print(f"  - Samples: {n_samples} ({(df_expression['label'] == 0).sum()} controls, {(df_expression['label'] == 1).sum()} CRC)")
print(f"  - Features: {n_features} probes")
print(f"  - Mapped to miRNA: {n_mapped}/{n_total} ({coverage_pct:.1f}%)")
print(f"  - Data quality: {n_constant} constant, {missing_pct:.4f}% missing")

print(f"\nNext step: Proceed to Cell 2 for preprocessing and feature selection")
print("=" * 80 + "\n")

27-Dec-2025 06:16:32 INFO GEOparse - Downloading ftp://ftp.ncbi.nlm.nih.gov/geo/series/GSE39nnn/GSE39833/soft/GSE39833_family.soft.gz to ./data/GSE39833_family.soft.gz
INFO:GEOparse:Downloading ftp://ftp.ncbi.nlm.nih.gov/geo/series/GSE39nnn/GSE39833/soft/GSE39833_family.soft.gz to ./data/GSE39833_family.soft.gz



STEP 1: LOAD GEO DATASET (GSE39833)
[INFO] Downloading GEO dataset: GSE39833
[INFO] This may take 1-2 minutes depending on network speed...


100%|██████████| 11.4M/11.4M [00:00<00:00, 38.4MB/s]
27-Dec-2025 06:16:32 DEBUG downloader - Size validation passed
DEBUG:GEOparse:Size validation passed
27-Dec-2025 06:16:32 DEBUG downloader - Moving /tmp/tmpm3_xvvc3 to /content/data/GSE39833_family.soft.gz
DEBUG:GEOparse:Moving /tmp/tmpm3_xvvc3 to /content/data/GSE39833_family.soft.gz
27-Dec-2025 06:16:32 DEBUG downloader - Successfully downloaded ftp://ftp.ncbi.nlm.nih.gov/geo/series/GSE39nnn/GSE39833/soft/GSE39833_family.soft.gz
DEBUG:GEOparse:Successfully downloaded ftp://ftp.ncbi.nlm.nih.gov/geo/series/GSE39nnn/GSE39833/soft/GSE39833_family.soft.gz
27-Dec-2025 06:16:32 INFO GEOparse - Parsing ./data/GSE39833_family.soft.gz: 
INFO:GEOparse:Parsing ./data/GSE39833_family.soft.gz: 
27-Dec-2025 06:16:32 DEBUG GEOparse - DATABASE: GeoMiame
DEBUG:GEOparse:DATABASE: GeoMiame
27-Dec-2025 06:16:32 DEBUG GEOparse - SERIES: GSE39833
DEBUG:GEOparse:SERIES: GSE39833
27-Dec-2025 06:16:32 DEBUG GEOparse - PLATFORM: GPL14767
DEBUG:GEOparse:PLATF

[INFO] Successfully loaded GSE39833
  - GSM samples: 99
  - GPL platforms: 1

STEP 2: EXTRACT EXPRESSION MATRIX AND ASSIGN LABELS
[INFO] Successfully extracted expression data for 99 samples
  - Healthy controls (label=0): 11
  - CRC patients (label=1): 88

STEP 3: BUILD EXPRESSION DATAFRAME
[INFO] Validating probe consistency across all samples...
[INFO] ✓ Probe consistency verified across 4 samples
[INFO] All samples contain 15739 probes in identical order

[INFO] Expression DataFrame created
  - Shape: (99, 15740)
  - Samples: 99
  - Probes (features): 15739

[INFO] Label distribution:
label
1    88
0    11

[INFO] Label assignment log saved to: /content/drive/MyDrive/geoexosome_results/label_assignment_log.csv
[NOTE] Include this file in Supplementary Materials for full transparency

STEP 4: PROBE-TO-miRNA MAPPING
[INFO] Loading platform (GPL) annotation...
[INFO] Using miRNA annotation column: 'miRNA_ID'

[INFO] Mapping Coverage:
  - Total probes: 15739
  - Successfully mapped: 15

In [None]:
# =============================================================================
# Cell 2: Nested Cross-Validation with Repeated Stratified K-Fold
# =============================================================================
"""
VERSION 3.0 MODIFICATIONS (Based on Literature Review):

CRITICAL CHANGES:
1. SMOTE → Class Weights (safer for n=11 minority class)
   - Demircioğlu 2024: SMOTE before CV causes +0.34 AUC bias
   - With only 9 controls per training fold, SMOTE's k=5 neighbors is problematic

2. Stability Threshold: 80% → 70% (realistic per Liu et al. 2025)
   - Kuncheva index 0.50-0.75 is realistic for HDSS data
   - 70% provides balance between stringency and practicality

3. Dual-Path Feature Selection (Lewis 2023, Parvandeh 2020)
   - Path A: Production model (full-data retrained) for external validation
   - Path B: CV-stable features (>70% of folds) for biomarker reporting

UNCHANGED:
- Consensus voting: 2-of-3 (same as before)
- Biological filtering thresholds
- All other logic

Reference:
- Demircioğlu A (2024) Sci Rep: Oversampling before CV causes bias
- Liu et al. (2025) Comput Methods: Kuncheva stability benchmarks
- Lewis et al. (2023) nestedcv R package: Dual-path strategy
- Parvandeh et al. (2020) Bioinformatics: Consensus nested CV

Author: Jungho Sohn
Date: 2025-12-28
Version: 3.0 (Class weights + Realistic stability + Dual-path)
"""

import json
import warnings
import numpy as np
import pandas as pd
from collections import Counter
from scipy import stats
from scipy.stats import mannwhitneyu

# =============================================================================
# IMPORTS (unchanged except SMOTE removal from pipeline)
# =============================================================================
from sklearn.model_selection import (
    StratifiedKFold,
    RepeatedStratifiedKFold,
    GridSearchCV
)
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LassoCV, LogisticRegression
from sklearn.svm import SVC
from sklearn.feature_selection import RFE
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    roc_auc_score,
    accuracy_score,
    balanced_accuracy_score,
    precision_recall_fscore_support,
    confusion_matrix
)
# ┌─────────────────────────────────────────────────────────────────────────────┐
# │ CHANGE 1: Import Pipeline from sklearn instead of imblearn                 │
# │ SMOTE is no longer used - we use class_weight='balanced' instead           │
# └─────────────────────────────────────────────────────────────────────────────┘
from sklearn.pipeline import Pipeline

import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore', category=FutureWarning)

# =============================================================================
# ┌─────────────────────────────────────────────────────────────────────────────┐
# │ CHANGE 2: Adjustable Feature Selection Thresholds                          │
# │ Stability threshold lowered from 80% to 70% (per literature review)        │
# └─────────────────────────────────────────────────────────────────────────────┘
# =============================================================================

# Feature Selection Thresholds (ADJUSTABLE)
LOG2FC_THRESHOLD = 1.0      # Default: 1.0 (2-fold change), Try: 1.5 for stricter
FDR_THRESHOLD = 0.05        # Default: 0.05, Try: 0.01 for stricter

# Cross-Validation Configuration
N_OUTER_SPLITS = 5          # Number of folds
N_REPEATS = 10              # Number of repetitions
N_INNER_SPLITS = 3          # Inner CV for hyperparameter tuning

# ┌─────────────────────────────────────────────────────────────────────────────┐
# │ CHANGE 3: Stability Threshold 80% → 70% (realistic per Liu et al. 2025)    │
# └─────────────────────────────────────────────────────────────────────────────┘
STABILITY_THRESHOLD = 0.70  # Features must appear in 70%+ of folds (was 0.80)

print(f"""
=====================================================================
Cell 2: Nested Cross-Validation with REPEATED Stratified K-Fold
         [VERSION 3.0 - Class Weights + Realistic Stability]

CONFIGURATION:
  Outer CV:            {N_OUTER_SPLITS}-fold × {N_REPEATS} repeats = {N_OUTER_SPLITS * N_REPEATS} iterations
  Inner CV:            {N_INNER_SPLITS}-fold (hyperparameter tuning)

  Feature Selection:
    |log2FC| threshold:  > {LOG2FC_THRESHOLD}
    FDR q-value:         < {FDR_THRESHOLD}
    Consensus:           2-of-3 methods (unchanged)

  Stability Threshold:   {STABILITY_THRESHOLD*100:.0f}% (lowered from 80%)

  CRITICAL CHANGES (v3.0):
    ✓ SMOTE removed (class_weight='balanced' used instead)
    ✓ Stability threshold: 80% → 70% (realistic for n=99)
    ✓ Dual-path feature selection implemented

This provides more stable performance estimates with variance.
=====================================================================
""")

# ==============================================================================
# 1. Load Expression Data from Cell 1 (NO Normalization Applied!)
# ==============================================================================
print("\n" + "=" * 80)
print("STEP 1: Load Expression Data from Cell 1")
print("=" * 80)

# Extract raw expression matrix (should be log2-transformed, NO quantile norm!)
expr_raw = df_expression.drop(columns=['label'])
labels_full = df_expression['label'].values
sample_ids_full = df_expression.index.tolist()

# Apply log2 transformation if not already done
if expr_raw.values.max() > 20:
    print("[INFO] Applying log2(x + 1) transformation to raw expression values")
    expr_log_full = np.log2(expr_raw + 1.0)
else:
    print("[INFO] Data appears log2-transformed. Using as-is.")
    expr_log_full = expr_raw.copy()

# Handle infinite values from log2 transformation
expr_log_full = expr_log_full.replace([np.inf, -np.inf], np.nan)
expr_log_full = expr_log_full.fillna(expr_log_full.median(axis=0))

# Convert to DataFrame with proper indices
expr_log_full = pd.DataFrame(
    expr_log_full.values,
    index=sample_ids_full,
    columns=expr_raw.columns
)

print(f"[INFO] Expression matrix loaded: {expr_log_full.shape}")
print(f"  Samples: {len(sample_ids_full)}")
print(f"  Probes: {expr_log_full.shape[1]}")
print(f"  Healthy controls: {sum(labels_full == 0)}")
print(f"  CRC patients: {sum(labels_full == 1)}")
print(f"  Class imbalance ratio: 1:{sum(labels_full == 1)/sum(labels_full == 0):.1f}")

# CRITICAL: Ensure NO normalization has been applied globally
print("\n[VERIFICATION] Checking for data leakage risks:")
print(f"  Max value: {expr_log_full.values.max():.2f}")
print(f"  Min value: {expr_log_full.values.min():.2f}")
print(f"  Global mean: {expr_log_full.values.mean():.2f}")
print(f"  Global std: {expr_log_full.values.std():.2f}")
if abs(expr_log_full.values.mean()) < 0.1 and abs(expr_log_full.values.std() - 1.0) < 0.1:
    print("  ⚠️  WARNING: Data appears globally normalized! Risk of data leakage!")
else:
    print("  ✓ Data is NOT globally normalized. Safe for fold-wise processing.")

# ==============================================================================
# 2. Cross-Validation Configuration
# ==============================================================================
print("\n" + "=" * 80)
print("STEP 2: Configure Cross-Validation Strategy")
print("=" * 80)

outer_cv = RepeatedStratifiedKFold(
    n_splits=N_OUTER_SPLITS,
    n_repeats=N_REPEATS,
    random_state=SEED
)

# Inner CV: 3-fold stratified for hyperparameter tuning (unchanged)
inner_cv = StratifiedKFold(n_splits=N_INNER_SPLITS, shuffle=True, random_state=SEED)

total_iterations = N_OUTER_SPLITS * N_REPEATS
print(f"[INFO] Outer CV: {N_OUTER_SPLITS}-fold × {N_REPEATS} repeats = {total_iterations} iterations")
print(f"[INFO] Inner CV: {inner_cv.get_n_splits()}-fold stratified")
print(f"[INFO] Random seed: {SEED}")

# Create results directory
RESULT_DIR = base_save_path
print(f"[INFO] Results will be saved to: {RESULT_DIR}")

# ==============================================================================
# ┌─────────────────────────────────────────────────────────────────────────────┐
# │ CHANGE 4: Model Configurations - ALL use class_weight, NO SMOTE            │
# └─────────────────────────────────────────────────────────────────────────────┘
# ==============================================================================
print("\n" + "=" * 80)
print("STEP 3: Define Model Configurations (Class Weights Only)")
print("=" * 80)

model_configs = {}

# Random Forest with class weighting (PRIMARY - recommended for small samples)
model_configs["RandomForest_Weighted"] = {
    "use_smote": False,  # CRITICAL: No SMOTE
    "use_scaler": False,  # RF doesn't require scaling
    "classifier": "rf",
    "param_grid": {
        "clf__n_estimators": [200, 500],
        "clf__max_depth": [None, 5, 10],
        "clf__max_features": [0.3, 0.5, "sqrt"]
    }
}

# Support Vector Machine with class weighting (No SMOTE)
model_configs["SVM_Weighted"] = {
    "use_smote": False,  # CRITICAL: No SMOTE
    "use_scaler": True,  # SVM requires scaling
    "classifier": "svm",
    "param_grid": {
        "clf__C": [0.1, 1, 10],
        "clf__gamma": ["scale", "auto"]
    }
}

# Logistic Regression with class weighting (No SMOTE)
model_configs["LogisticRegression_Weighted"] = {
    "use_smote": False,  # CRITICAL: No SMOTE
    "use_scaler": True,  # LR benefits from scaling
    "classifier": "lr",
    "param_grid": {
        "clf__C": [0.01, 0.1, 1, 10],
        "clf__penalty": ["l2"],
        "clf__solver": ["lbfgs"]
    }
}

# ┌─────────────────────────────────────────────────────────────────────────────┐
# │ OPTIONAL: Keep one SMOTE model for sensitivity analysis                    │
# │ (Can be commented out if you want pure class_weight approach)              │
# └─────────────────────────────────────────────────────────────────────────────┘
# Note: Uncomment below if you want to compare SMOTE vs class_weight
# from imblearn.over_sampling import SMOTE
# from imblearn.pipeline import Pipeline as ImbPipeline
# model_configs["RandomForest_SMOTE"] = {
#     "use_smote": True,
#     "use_scaler": False,
#     "classifier": "rf",
#     "param_grid": {
#         "clf__n_estimators": [200, 500],
#         "clf__max_depth": [None, 5, 10],
#         "clf__max_features": [0.3, 0.5, "sqrt"]
#     }
# }

print(f"[INFO] Configured {len(model_configs)} model variants:")
for model_name, cfg in model_configs.items():
    smote_status = "SMOTE" if cfg["use_smote"] else "class_weight"
    print(f"  - {model_name} ({smote_status})")

print("\n[RATIONALE] Why class_weight over SMOTE:")
print("  - With only 11 controls, SMOTE's k=5 neighbors is problematic")
print("  - After 5-fold split: ~9 controls per training fold")
print("  - Demircioğlu (2024): SMOTE before CV causes up to +0.34 AUC bias")
print("  - class_weight='balanced' achieves similar effect without synthetic data")

# ==============================================================================
# 4. Utility Functions
# ==============================================================================

def bootstrap_auc_ci(y_true, y_proba, n_bootstrap=1000, alpha=0.05, random_state=SEED):
    """
    Compute bootstrap confidence interval for ROC-AUC.

    Parameters
    ----------
    y_true : array-like
        True binary labels
    y_proba : array-like
        Predicted probabilities for positive class
    n_bootstrap : int, default=1000
        Number of bootstrap iterations
    alpha : float, default=0.05
        Significance level (0.05 for 95% CI)
    random_state : int
        Random seed for reproducibility

    Returns
    -------
    lower, upper : float
        Lower and upper bounds of confidence interval
    """
    rng = np.random.RandomState(random_state)
    y_true = np.asarray(y_true)
    y_proba = np.asarray(y_proba)
    n = len(y_true)

    aucs = []
    for _ in range(n_bootstrap):
        indices = rng.choice(n, n, replace=True)

        # Ensure both classes are represented in the bootstrap sample
        if len(np.unique(y_true[indices])) < 2:
            continue

        aucs.append(roc_auc_score(y_true[indices], y_proba[indices]))

    if len(aucs) == 0:
        return np.nan, np.nan

    lower = np.percentile(aucs, 100 * (alpha / 2))
    upper = np.percentile(aucs, 100 * (1 - alpha / 2))

    return float(lower), float(upper)


def perform_fold_feature_selection(expr_train, y_train, fold_idx, verbose=True):
    """
    Three-stage feature selection performed ONLY on training data.

    Stage 1: Biological filtering (|log2FC|, FDR q-value)
    Stage 2: Multi-method selection (LASSO, SVM-RFE, Random Forest)
    Stage 3: Consensus voting (2-of-3 agreement) - UNCHANGED

    Parameters
    ----------
    expr_train : pd.DataFrame
        Expression matrix for training samples (samples × probes)
    y_train : np.ndarray
        Binary labels for training samples
    fold_idx : int or str
        Fold identifier for logging
    verbose : bool
        Whether to print progress information

    Returns
    -------
    selected_features : list
        List of selected probe IDs
    feature_info : dict
        Statistics about feature selection process
    """
    if verbose:
        print(f"\n    [Fold {fold_idx}] Feature selection on training data only")
        print(f"      Train samples: {len(expr_train)} ({sum(y_train==0)} HC, {sum(y_train==1)} CRC)")
        print(f"      Total probes: {expr_train.shape[1]}")

    # Separate control and cancer samples
    expr_control = expr_train[y_train == 0]
    expr_cancer = expr_train[y_train == 1]

    # =========================================================================
    # Stage 1: Biological Filtering with ADJUSTABLE Thresholds
    # =========================================================================
    fold_changes = {}
    p_values = {}
    log2_fold_changes = {}

    probe_list = list(expr_train.columns)
    raw_pvals = []

    for probe in probe_list:
        control_vals = expr_control[probe].values
        cancer_vals = expr_cancer[probe].values

        control_mean = control_vals.mean()
        cancer_mean = cancer_vals.mean()

        # log2 fold change (cancer vs control)
        log2fc = cancer_mean - control_mean
        log2_fold_changes[probe] = log2fc
        fold_changes[probe] = 2 ** log2fc

        # Mann-Whitney U test (non-parametric)
        try:
            _, pval = mannwhitneyu(control_vals, cancer_vals, alternative='two-sided')
        except:
            pval = 1.0

        raw_pvals.append(pval)
        p_values[probe] = pval

    # Benjamini-Hochberg FDR correction
    raw_pvals = np.array(raw_pvals)
    n_tests = len(raw_pvals)
    sorted_indices = np.argsort(raw_pvals)

    adjusted_pvals = np.zeros(n_tests)
    for i, idx in enumerate(sorted_indices):
        rank = i + 1
        adjusted_pvals[idx] = min(1.0, raw_pvals[idx] * n_tests / rank)

    # Ensure monotonicity (cumulative minimum from the end)
    for i in range(n_tests - 2, -1, -1):
        if adjusted_pvals[sorted_indices[i]] > adjusted_pvals[sorted_indices[i + 1]]:
            adjusted_pvals[sorted_indices[i]] = adjusted_pvals[sorted_indices[i + 1]]

    # Store adjusted p-values (q-values)
    q_values = {probe: adjusted_pvals[i] for i, probe in enumerate(probe_list)}

    # Apply thresholds
    selected_bio = []
    for probe in probe_list:
        log2fc = log2_fold_changes[probe]
        qval = q_values[probe]

        # Criteria: |log2FC| > threshold AND FDR-adjusted q-value < threshold
        if abs(log2fc) > LOG2FC_THRESHOLD and qval < FDR_THRESHOLD:
            selected_bio.append(probe)

    if verbose:
        print(f"      Stage 1: {len(selected_bio)} probes (|log2FC| > {LOG2FC_THRESHOLD}, FDR q < {FDR_THRESHOLD})")

    # =========================================================================
    # Stage 2: Multi-method Consensus (UNCHANGED)
    # =========================================================================
    if verbose:
        print(f"      Stage 2: Multi-method selection (LASSO, SVM-RFE, RF)")

    X_bio = expr_train[selected_bio].values

    # Standardize for methods that require it (LASSO, SVM-RFE)
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_bio)

    # ─────────────────────────────────────────────────────────────────────────
    # Method A: LASSO Regularization
    # ─────────────────────────────────────────────────────────────────────────
    try:
        lasso = LassoCV(
            cv=3,
            random_state=SEED,
            max_iter=20000,
            n_jobs=-1,
            tol=1e-3
        )
        lasso.fit(X_scaled, y_train)

        # Select features with non-zero coefficients
        lasso_feat = [selected_bio[i] for i, c in enumerate(lasso.coef_) if c != 0]

        # Fallback: select top features by coefficient magnitude if none selected
        if len(lasso_feat) == 0:
            lasso_coef_abs = np.abs(lasso.coef_)
            top_idx = np.argsort(lasso_coef_abs)[-200:]
            lasso_feat = [selected_bio[i] for i in top_idx]

        if verbose:
            print(f"        LASSO: {len(lasso_feat)} features")

    except Exception as e:
        if verbose:
            print(f"        LASSO failed ({str(e)}), using variance fallback")
        var_bio = expr_train[selected_bio].var(axis=0)
        lasso_feat = var_bio.nlargest(min(200, len(selected_bio))).index.tolist()

    # ─────────────────────────────────────────────────────────────────────────
    # Method B: SVM-RFE (Recursive Feature Elimination)
    # ─────────────────────────────────────────────────────────────────────────
    try:
        n_features_target = min(150, len(selected_bio) // 3)
        n_features_target = max(10, n_features_target)

        svc = SVC(kernel='linear', class_weight='balanced', random_state=SEED)
        rfe = RFE(
            estimator=svc,
            n_features_to_select=n_features_target,
            step=10,
            verbose=0
        )
        rfe.fit(X_scaled, y_train)

        svm_rfe_feat = [selected_bio[i] for i in np.where(rfe.support_)[0]]

        if verbose:
            print(f"        SVM-RFE: {len(svm_rfe_feat)} features")

    except Exception as e:
        if verbose:
            print(f"        SVM-RFE failed ({str(e)}), using variance fallback")
        var_bio = expr_train[selected_bio].var(axis=0)
        svm_rfe_feat = var_bio.nlargest(min(150, len(selected_bio))).index.tolist()

    # ─────────────────────────────────────────────────────────────────────────
    # Method C: Random Forest Feature Importance
    # ─────────────────────────────────────────────────────────────────────────
    try:
        rf = RandomForestClassifier(
            n_estimators=200,
            random_state=SEED,
            n_jobs=-1,
            class_weight='balanced'
        )
        rf.fit(X_bio, y_train)

        importances = rf.feature_importances_
        n_rf = min(150, len(selected_bio))
        top_idx = np.argsort(importances)[-n_rf:]
        rf_feat = [selected_bio[i] for i in top_idx]

        if verbose:
            print(f"        Random Forest: {len(rf_feat)} features")

    except Exception as e:
        if verbose:
            print(f"        Random Forest failed ({str(e)}), using variance fallback")
        var_bio = expr_train[selected_bio].var(axis=0)
        rf_feat = var_bio.nlargest(min(150, len(selected_bio))).index.tolist()

    # ─────────────────────────────────────────────────────────────────────────
    # Stage 3: Consensus Voting (2-of-3) - UNCHANGED
    # ─────────────────────────────────────────────────────────────────────────
    feature_votes = Counter()
    for feat_set in [set(lasso_feat), set(svm_rfe_feat), set(rf_feat)]:
        feature_votes.update(feat_set)

    # Select features appearing in at least 2 methods (UNCHANGED)
    selected_consensus = [f for f, count in feature_votes.items() if count >= 2]

    if verbose:
        print(f"      Stage 3: {len(selected_consensus)} features (≥2/3 consensus)")

    # Ensure minimum feature count for stable model training (UNCHANGED)
    if len(selected_consensus) < 20:
        sorted_features = sorted(feature_votes.items(), key=lambda x: x[1], reverse=True)
        selected_consensus = [f for f, _ in sorted_features[:max(20, len(sorted_features))]]
        if verbose:
            print(f"      Adjusted to {len(selected_consensus)} features (minimum threshold)")

    # Compile feature selection statistics
    feature_info = {
        'n_stage1': len(selected_bio),
        'n_lasso': len(lasso_feat),
        'n_svm_rfe': len(svm_rfe_feat),
        'n_rf': len(rf_feat),
        'n_consensus': len(selected_consensus),
        'fold_changes': {k: float(v) for k, v in fold_changes.items() if k in selected_consensus},
        'p_values': {k: float(v) for k, v in p_values.items() if k in selected_consensus}
    }

    return selected_consensus, feature_info


# ┌─────────────────────────────────────────────────────────────────────────────┐
# │ CHANGE 5: create_pipeline now uses sklearn.Pipeline (no SMOTE)             │
# └─────────────────────────────────────────────────────────────────────────────┘
def create_pipeline(config, n_features):
    """
    Construct sklearn pipeline based on model configuration.

    CRITICAL CHANGE (v3.0):
        - No longer uses SMOTE (imblearn)
        - All classifiers use class_weight='balanced' for imbalance handling
        - Uses sklearn.Pipeline instead of imblearn.Pipeline

    Parameters
    ----------
    config : dict
        Model configuration dictionary
    n_features : int
        Number of features (for potential future use)

    Returns
    -------
    pipeline : sklearn.Pipeline
        Fitted pipeline object
    param_grid : dict
        Hyperparameter grid for GridSearchCV
    """
    steps = []

    # Add StandardScaler if requested (always before classifier)
    if config["use_scaler"]:
        steps.append(("scaler", StandardScaler()))

    # Add classifier (always last in pipeline)
    # ALL classifiers now use class_weight='balanced' for imbalance handling
    if config["classifier"] == "rf":
        clf = RandomForestClassifier(
            random_state=SEED,
            n_jobs=-1,
            class_weight="balanced_subsample"  # Handles imbalance per bootstrap sample
        )
    elif config["classifier"] == "svm":
        clf = SVC(
            probability=True,
            random_state=SEED,
            class_weight="balanced"  # Handles imbalance via inverse class frequency
        )
    elif config["classifier"] == "lr":
        clf = LogisticRegression(
            max_iter=1000,
            random_state=SEED,
            class_weight="balanced"  # Handles imbalance via inverse class frequency
        )
    else:
        raise ValueError(f"Unknown classifier: {config['classifier']}")

    steps.append(("clf", clf))

    # Create standard sklearn Pipeline (no SMOTE)
    pipeline = Pipeline(steps)
    param_grid = config["param_grid"].copy()

    return pipeline, param_grid


# ==============================================================================
# 5. Nested Cross-Validation with Fold-wise Feature Selection
# ==============================================================================
print("\n" + "=" * 80)
print("STEP 4: Execute Nested Cross-Validation")
print("=" * 80)

# Initialize result storage
results = {}
oof_predictions = {}
fold_feature_lists = {}
fold_feature_info = {}

# Iterate through each model configuration
for model_name, model_cfg in model_configs.items():

    print("\n" + "-" * 80)
    print(f"Model: {model_name}")
    print("-" * 80)

    # Store predictions for each repeat separately
    oof_proba_repeats = np.zeros((N_REPEATS, len(labels_full)))
    oof_pred_repeats = np.zeros((N_REPEATS, len(labels_full)), dtype=int)

    # Track performance across all iterations
    train_auc_folds = []
    test_auc_folds = []

    # Track AUC per repeat for variance estimation
    repeat_aucs = []

    fold_feature_lists[model_name] = {}
    fold_feature_info[model_name] = {}

    print(f"\n  Running {total_iterations} iterations ({N_OUTER_SPLITS}-fold × {N_REPEATS} repeats)...")

    # ─────────────────────────────────────────────────────────────────────────
    # Outer CV Loop: Iterate through train/test splits
    # ─────────────────────────────────────────────────────────────────────────
    for iteration, (train_idx, test_idx) in enumerate(outer_cv.split(expr_log_full, labels_full)):

        # Determine which repeat and fold we're in
        repeat_idx = iteration // N_OUTER_SPLITS
        fold_idx = iteration % N_OUTER_SPLITS + 1

        # Extract training fold data
        train_sample_ids = [sample_ids_full[i] for i in train_idx]
        expr_train_fold = expr_log_full.loc[train_sample_ids].copy()
        y_train_fold = labels_full[train_idx]

        # Extract test fold data
        test_sample_ids = [sample_ids_full[i] for i in test_idx]
        y_test_fold = labels_full[test_idx]

        # Progress indicator (every 10 iterations)
        if iteration % 10 == 0:
            train_pos = int(y_train_fold.sum())
            test_pos = int(y_test_fold.sum())
            print(f"\n    Iteration {iteration+1}/{total_iterations} (Repeat {repeat_idx+1}, Fold {fold_idx})")
            print(f"      Train: {train_pos}/{len(y_train_fold)} CRC | Test: {test_pos}/{len(y_test_fold)} CRC")

        # ─────────────────────────────────────────────────────────────────────
        # Feature Selection (WITHIN TRAINING FOLD ONLY)
        # ─────────────────────────────────────────────────────────────────────
        selected_features, feat_info = perform_fold_feature_selection(
            expr_train_fold,
            y_train_fold,
            fold_idx=f"R{repeat_idx+1}F{fold_idx}",
            verbose=(iteration % 10 == 0)  # Verbose only every 10 iterations
        )

        fold_feature_lists[model_name][f'repeat_{repeat_idx+1}_fold_{fold_idx}'] = selected_features
        fold_feature_info[model_name][f'repeat_{repeat_idx+1}_fold_{fold_idx}'] = feat_info

        # Prepare data matrices with selected features
        X_train_fold = expr_train_fold[selected_features].values
        X_test_fold = expr_log_full.loc[test_sample_ids][selected_features].values

        # ─────────────────────────────────────────────────────────────────────
        # Hyperparameter Tuning (Inner CV on training fold)
        # ─────────────────────────────────────────────────────────────────────
        pipeline, param_grid = create_pipeline(model_cfg, len(selected_features))

        gs = GridSearchCV(
            estimator=pipeline,
            param_grid=param_grid,
            scoring="roc_auc",
            cv=inner_cv,
            n_jobs=-1,
            refit=True,
            verbose=0
        )

        gs.fit(X_train_fold, y_train_fold)

        # ─────────────────────────────────────────────────────────────────────
        # Evaluate on Training and Test Folds
        # ─────────────────────────────────────────────────────────────────────
        # Training performance
        train_proba = gs.predict_proba(X_train_fold)[:, 1]
        train_auc = roc_auc_score(y_train_fold, train_proba)
        train_auc_folds.append(train_auc)

        # Test performance (out-of-fold)
        test_proba = gs.predict_proba(X_test_fold)[:, 1]
        test_pred = (test_proba >= 0.5).astype(int)
        test_auc = roc_auc_score(y_test_fold, test_proba)
        test_auc_folds.append(test_auc)

        # Store OOF predictions for this repeat
        oof_proba_repeats[repeat_idx, test_idx] = test_proba
        oof_pred_repeats[repeat_idx, test_idx] = test_pred

        # At the end of each repeat, calculate that repeat's AUC
        if fold_idx == N_OUTER_SPLITS:
            rep_auc = roc_auc_score(labels_full, oof_proba_repeats[repeat_idx])
            repeat_aucs.append(rep_auc)
            print(f"    → Repeat {repeat_idx+1} complete: OOF AUC = {rep_auc:.4f}")

    # =========================================================================
    # Aggregate OOF Predictions Across Repeats
    # =========================================================================
    # Average predictions across all repeats
    oof_proba = oof_proba_repeats.mean(axis=0)
    oof_pred = (oof_proba >= 0.5).astype(int)

    # ─────────────────────────────────────────────────────────────────────────
    # Compute Out-of-Fold Performance Metrics
    # ─────────────────────────────────────────────────────────────────────────
    auc = roc_auc_score(labels_full, oof_proba)
    ci_lower, ci_upper = bootstrap_auc_ci(labels_full, oof_proba)

    acc = accuracy_score(labels_full, oof_pred)
    bal_acc = balanced_accuracy_score(labels_full, oof_pred)

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels_full, oof_pred, average='binary', zero_division=0
    )

    cm = confusion_matrix(labels_full, oof_pred)
    tn, fp, fn, tp = cm.ravel()
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    sensitivity = recall

    train_auc_mean = np.mean(train_auc_folds)
    test_auc_mean = np.mean(test_auc_folds)

    # Report mean ± std across repeats
    repeat_auc_mean = np.mean(repeat_aucs)
    repeat_auc_std = np.std(repeat_aucs)

    print("\n  " + "=" * 76)
    print("  OUT-OF-FOLD PERFORMANCE (Unbiased Estimate)")
    print("  " + "=" * 76)
    print(f"    ROC-AUC (aggregated):  {auc:.4f}")
    print(f"    ROC-AUC (mean±std):    {repeat_auc_mean:.4f} ± {repeat_auc_std:.4f}")
    print(f"    95% CI:                [{ci_lower:.4f}, {ci_upper:.4f}]")
    print(f"    Accuracy:              {acc:.4f}")
    print(f"    Balanced Accuracy:     {bal_acc:.4f}")
    print(f"    Precision:             {precision:.4f}")
    print(f"    Sensitivity:           {sensitivity:.4f}")
    print(f"    Specificity:           {specificity:.4f}")
    print(f"    F1-score:              {f1:.4f}")
    print(f"    Mean Train AUC:        {train_auc_mean:.4f}")
    print(f"    Mean Test AUC:         {test_auc_mean:.4f}")
    print("    Confusion Matrix:")
    print("    " + str(cm).replace('\n', '\n    '))

    # ─────────────────────────────────────────────────────────────────────────
    # ┌─────────────────────────────────────────────────────────────────────────┐
    # │ CHANGE 6: Feature Stability Analysis with 70% threshold (was 80%)      │
    # └─────────────────────────────────────────────────────────────────────────┘
    # ─────────────────────────────────────────────────────────────────────────
    feature_frequency = Counter()
    for features in fold_feature_lists[model_name].values():
        feature_frequency.update(features)

    n_folds = len(fold_feature_lists[model_name])

    # ┌─────────────────────────────────────────────────────────────────────────┐
    # │ Stability thresholds adjusted per Liu et al. (2025)                     │
    # └─────────────────────────────────────────────────────────────────────────┘
    highly_stable = [f for f, count in feature_frequency.items()
                     if count >= n_folds * STABILITY_THRESHOLD]  # 70% (was 80%)
    moderately_stable = [f for f, count in feature_frequency.items()
                         if count >= n_folds * 0.5]  # 50%

    print(f"\n  Feature Stability Across {n_folds} Iterations:")
    print(f"    Selected in 100%:    {sum(1 for c in feature_frequency.values() if c == n_folds)}")
    print(f"    Selected in ≥{STABILITY_THRESHOLD*100:.0f}%:   {len(highly_stable)}")
    print(f"    Selected in ≥50%:    {len(moderately_stable)}")

    # Calculate Kuncheva Stability Index (optional, for reporting)
    # This measures consistency of feature selection across folds
    def calculate_kuncheva_index(feature_lists, total_features):
        """
        Calculate Kuncheva stability index for feature selection.

        Reference: Kuncheva LI (2007) A stability index for feature selection.
        """
        k = len(feature_lists)
        if k < 2:
            return np.nan

        # Get all pairwise similarities
        similarities = []
        lists = list(feature_lists.values())
        for i in range(k):
            for j in range(i + 1, k):
                set_i = set(lists[i])
                set_j = set(lists[j])
                intersection = len(set_i & set_j)
                n_i = len(set_i)
                n_j = len(set_j)

                if n_i == 0 or n_j == 0:
                    continue

                # Kuncheva formula
                expected = (n_i * n_j) / total_features
                max_val = min(n_i, n_j) - expected
                if max_val <= 0:
                    continue

                ksi = (intersection - expected) / max_val
                similarities.append(ksi)

        return np.mean(similarities) if similarities else np.nan

    kuncheva_idx = calculate_kuncheva_index(
        fold_feature_lists[model_name],
        expr_log_full.shape[1]
    )
    print(f"    Kuncheva Stability Index: {kuncheva_idx:.4f}")

    # Store comprehensive results
    results[model_name] = {
        "oof_auc": float(auc),
        "oof_auc_mean": float(repeat_auc_mean),
        "oof_auc_std": float(repeat_auc_std),
        "oof_auc_ci": [ci_lower, ci_upper],
        "oof_accuracy": float(acc),
        "oof_balanced_accuracy": float(bal_acc),
        "precision": float(precision),
        "recall_sensitivity": float(sensitivity),
        "specificity": float(specificity),
        "f1_score": float(f1),
        "confusion_matrix": cm.tolist(),
        "train_auc_mean": train_auc_mean,
        "train_auc_std": float(np.std(train_auc_folds)),
        "test_auc_mean": test_auc_mean,
        "test_auc_std": float(np.std(test_auc_folds)),
        "n_repeats": N_REPEATS,
        "n_folds": N_OUTER_SPLITS,
        "repeat_aucs": [float(x) for x in repeat_aucs],
        "kuncheva_stability_index": float(kuncheva_idx) if not np.isnan(kuncheva_idx) else None,
        "feature_stability": {
            "n_iterations": n_folds,
            "stability_threshold": STABILITY_THRESHOLD,
            "n_highly_stable": len(highly_stable),
            "highly_stable_features": highly_stable,
            "n_moderately_stable": len(moderately_stable),
            "feature_frequency": {str(k): int(v) for k, v in feature_frequency.most_common(50)}
        }
    }

    oof_predictions[model_name] = {
        "y_true": labels_full.copy(),
        "y_proba": oof_proba.copy(),
        "y_pred": oof_pred.copy()
    }

# ==============================================================================
# 6. Model Selection and Comparison
# ==============================================================================
print("\n" + "=" * 80)
print("STEP 5: Model Selection Summary")
print("=" * 80)

# Select best model by OOF AUC
best_name = max(results, key=lambda k: results[k]["oof_auc"])

print("\nAll Models (ranked by OOF AUC):")
for m_name in sorted(results, key=lambda k: results[k]["oof_auc"], reverse=True):
    res = results[m_name]
    ksi = res.get('kuncheva_stability_index', 'N/A')
    ksi_str = f"{ksi:.3f}" if isinstance(ksi, float) else ksi
    print(f"  {m_name:30s}: AUC = {res['oof_auc']:.4f} ± {res['oof_auc_std']:.4f} "
          f"(Kuncheva: {ksi_str})")

print(f"\n{'=' * 80}")
print(f"BEST MODEL: {best_name}")
print(f"{'=' * 80}")
print(f"  Out-of-fold AUC:         {results[best_name]['oof_auc']:.4f} ± {results[best_name]['oof_auc_std']:.4f}")
print(f"  95% Confidence Interval: [{results[best_name]['oof_auc_ci'][0]:.4f}, "
      f"{results[best_name]['oof_auc_ci'][1]:.4f}]")
print(f"  Accuracy:                {results[best_name]['oof_accuracy']:.4f}")
print(f"  Balanced Accuracy:       {results[best_name]['oof_balanced_accuracy']:.4f}")
print(f"  Sensitivity:             {results[best_name]['recall_sensitivity']:.4f}")
print(f"  Specificity:             {results[best_name]['specificity']:.4f}")
print(f"  Kuncheva Stability:      {results[best_name].get('kuncheva_stability_index', 'N/A')}")

# ==============================================================================
# ┌─────────────────────────────────────────────────────────────────────────────┐
# │ CHANGE 7: Dual-Path Feature Selection (Lewis 2023, Parvandeh 2020)         │
# │                                                                             │
# │ Path A: Production Model - Full-data retrained for external validation     │
# │ Path B: CV-Stable Features - For biomarker reporting (>70% of folds)       │
# └─────────────────────────────────────────────────────────────────────────────┘
# ==============================================================================
print("\n" + "=" * 80)
print("STEP 6: DUAL-PATH Feature Selection")
print("=" * 80)

# ─────────────────────────────────────────────────────────────────────────────
# PATH B: CV-Stable Features (for Biomarker Reporting)
# ─────────────────────────────────────────────────────────────────────────────
print("\n[PATH B] CV-STABLE FEATURES (Biomarker Identification)")
print("-" * 60)

# Get features stable across CV folds (>= STABILITY_THRESHOLD)
feature_freq_best = results[best_name]['feature_stability']['feature_frequency']
cv_stable_features = [
    f for f, count in feature_freq_best.items()
    if int(count) >= int(total_iterations * STABILITY_THRESHOLD)
]

print(f"  Stability threshold: {STABILITY_THRESHOLD*100:.0f}% ({int(total_iterations * STABILITY_THRESHOLD)}/{total_iterations} folds)")
print(f"  CV-stable features:  {len(cv_stable_features)}")

# Save CV-stable features for biomarker reporting
cv_stable_df = pd.DataFrame([
    {
        'Probe_ID': str(f),
        'Selection_Count': int(feature_freq_best.get(f, feature_freq_best.get(str(f), 0))),
        'Selection_Rate': f"{int(feature_freq_best.get(f, feature_freq_best.get(str(f), 0)))/total_iterations*100:.1f}%"
    }
    for f in cv_stable_features
]).sort_values('Selection_Count', ascending=False)

cv_stable_path = os.path.join(RESULT_DIR, f"cv_stable_biomarkers_{best_name}.csv")
cv_stable_df.to_csv(cv_stable_path, index=False)
print(f"  Saved: {cv_stable_path}")

print(f"\n  [RATIONALE] These {len(cv_stable_features)} features are:")
print(f"    - Selected consistently across {STABILITY_THRESHOLD*100:.0f}%+ of CV iterations")
print(f"    - Suitable for reporting as 'validated biomarkers' in manuscript")
print(f"    - NOT used for external validation (Path A features used instead)")

# ─────────────────────────────────────────────────────────────────────────────
# PATH A: Production Model (for External Validation)
# ─────────────────────────────────────────────────────────────────────────────
print("\n[PATH A] PRODUCTION MODEL (External Validation)")
print("-" * 60)
print("[INFO] Applying identical feature selection procedure to FULL dataset...")

# Perform feature selection on full dataset (same procedure as each CV fold)
expr_full_df = pd.DataFrame(
    expr_log_full.values,
    index=sample_ids_full,
    columns=expr_log_full.columns
)

final_features, final_feat_info = perform_fold_feature_selection(
    expr_full_df,
    labels_full,
    fold_idx="FULL_DATA",
    verbose=True
)

print(f"\n  Full-data selected features: {len(final_features)}")

X_final = expr_full_df[final_features].values

# Store for downstream analysis
feature_cols = final_features
X = X_final
y = labels_full

# Train final model with hyperparameter optimization
best_cfg = model_configs[best_name]
final_pipeline, final_param_grid = create_pipeline(best_cfg, len(final_features))

final_model = GridSearchCV(
    estimator=final_pipeline,
    param_grid=final_param_grid,
    scoring="roc_auc",
    cv=inner_cv,
    n_jobs=-1,
    refit=True,
    verbose=0
)

final_model.fit(X_final, labels_full)

# Store best model and parameters
best_full_model = final_model.best_estimator_
best_full_params = final_model.best_params_

results[best_name]["best_estimator"] = best_full_model
results[best_name]["best_params"] = best_full_params

print(f"\n  Final Model Configuration:")
print(f"    Model type: {best_name}")
print(f"    Features: {len(feature_cols)}")
print(f"    Hyperparameters:")
for k, v in best_full_params.items():
    print(f"      {k}: {v}")

# Save production model features
production_features_path = os.path.join(RESULT_DIR, f"production_model_features_{best_name}.csv")
production_df = pd.DataFrame({'Probe_ID': feature_cols})
production_df.to_csv(production_features_path, index=False)
print(f"\n  Saved: {production_features_path}")

print(f"\n  [RATIONALE] These {len(feature_cols)} features are:")
print(f"    - Selected using identical procedure applied to FULL n=99 dataset")
print(f"    - Used for external validation (represents 'deployed' model)")
print(f"    - Performance estimate comes from nested CV, NOT resubstitution")

# ─────────────────────────────────────────────────────────────────────────────
# DUAL-PATH SUMMARY
# ─────────────────────────────────────────────────────────────────────────────
print("\n" + "=" * 60)
print("DUAL-PATH FEATURE SELECTION SUMMARY")
print("=" * 60)

# Calculate overlap between paths
path_a_set = set([str(f) for f in feature_cols])
path_b_set = set([str(f) for f in cv_stable_features])
overlap = path_a_set & path_b_set
only_a = path_a_set - path_b_set
only_b = path_b_set - path_a_set

print(f"""
  PATH A (Production Model):     {len(feature_cols)} features
    - For: External validation, clinical deployment
    - Source: Full-data feature selection (n=99)

  PATH B (CV-Stable Biomarkers): {len(cv_stable_features)} features
    - For: Biomarker reporting in manuscript
    - Source: ≥{STABILITY_THRESHOLD*100:.0f}% selection across {total_iterations} CV iterations

  OVERLAP ANALYSIS:
    - In BOTH paths:   {len(overlap)} features
    - Only in Path A:  {len(only_a)} features
    - Only in Path B:  {len(only_b)} features

  [INTERPRETATION]
    Features in BOTH paths are the most robust biomarkers.
    They are consistently selected AND appear in the production model.
""")

# Save overlap analysis
overlap_df = pd.DataFrame({
    'Probe_ID': list(overlap),
    'In_Production_Model': True,
    'In_CV_Stable': True
})
overlap_path = os.path.join(RESULT_DIR, f"overlap_robust_biomarkers_{best_name}.csv")
overlap_df.to_csv(overlap_path, index=False)
print(f"  Overlap (most robust) saved: {overlap_path}")

# ==============================================================================
# 7. Save Comprehensive Results
# ==============================================================================
print("\n" + "=" * 80)
print("STEP 7: Save Results")
print("=" * 80)

# Save results to JSON (excluding non-serializable objects)
json_results = {}
for model_name, res in results.items():
    json_results[model_name] = {k: v for k, v in res.items()
                                 if k not in ['best_estimator']}

json_path = os.path.join(RESULT_DIR, "repeated_nested_cv_results_v3.json")
with open(json_path, "w", encoding="utf-8") as f:
    json.dump(json_results, f, indent=2, ensure_ascii=False, default=str)
print(f"[INFO] Results saved: {json_path}")

# Save feature stability for each model
for model_name in results:
    stability_df = pd.DataFrame([
        {"Probe_ID": k, "Selection_Count": v,
         "Selection_Rate": f"{v/total_iterations*100:.1f}%"}
        for k, v in results[model_name]['feature_stability']['feature_frequency'].items()
    ]).sort_values("Selection_Count", ascending=False)

    stability_path = os.path.join(RESULT_DIR, f"feature_stability_{model_name}.csv")
    stability_df.to_csv(stability_path, index=False)

print(f"[INFO] Feature stability tables saved for all models")

# ==============================================================================
# FINAL SUMMARY
# ==============================================================================
print("\n" + "=" * 80)
print("✓ NESTED CV (v3.0) COMPLETE")
print("=" * 80)

print(f"""
[SUMMARY]
  Cross-Validation:    {N_OUTER_SPLITS}-fold × {N_REPEATS} repeats = {total_iterations} iterations
  Feature Selection:   |log2FC| > {LOG2FC_THRESHOLD}, FDR q < {FDR_THRESHOLD}
  Consensus:           2-of-3 methods (unchanged)
  Stability Threshold: {STABILITY_THRESHOLD*100:.0f}% (lowered from 80%)
  Best Model:          {best_name}

[PERFORMANCE]
  OOF AUC:             {results[best_name]['oof_auc']:.4f} ± {results[best_name]['oof_auc_std']:.4f}
  95% CI:              [{results[best_name]['oof_auc_ci'][0]:.4f}, {results[best_name]['oof_auc_ci'][1]:.4f}]
  Balanced Accuracy:   {results[best_name]['oof_balanced_accuracy']:.4f}
  Sensitivity:         {results[best_name]['recall_sensitivity']:.4f}
  Specificity:         {results[best_name]['specificity']:.4f}
  Kuncheva Index:      {results[best_name].get('kuncheva_stability_index', 'N/A')}

[DUAL-PATH FEATURES]
  Path A (Production): {len(feature_cols)} features (for external validation)
  Path B (Biomarkers): {len(cv_stable_features)} features (for manuscript)
  Overlap (Robust):    {len(overlap)} features

[KEY CHANGES IN v3.0]
  ✓ SMOTE removed → class_weight='balanced' used instead
  ✓ Stability threshold: 80% → {STABILITY_THRESHOLD*100:.0f}%
  ✓ Kuncheva stability index calculated
  ✓ Dual-path feature selection implemented
  ✓ Overlap analysis for robust biomarkers

[READY FOR]
  - Cell 2-2: Balanced Subsampling comparison (sensitivity analysis)
  - Cell 3: SHAP analysis
  - Cell 4: External validation (use Path A features)
  - Manuscript: Report Path B features as biomarkers
""")

print("=" * 80)


Cell 2: Nested Cross-Validation with REPEATED Stratified K-Fold

CONFIGURATION:
  Outer CV:            5-fold × 10 repeats = 50 iterations
  Inner CV:            3-fold (hyperparameter tuning)

  Feature Selection:
    |log2FC| threshold:  > 1.0
    FDR q-value:         < 0.05
    Consensus:           2-of-3 methods (unchanged)

This provides more stable performance estimates with variance.


STEP 1: Load Expression Data from Cell 1
[INFO] Applying log2(x + 1) transformation to raw expression values


  result = func(self.values, **kwargs)


[INFO] Expression matrix loaded: (99, 15739)
  Samples: 99
  Probes: 15739
  Healthy controls: 11
  CRC patients: 88
  Class imbalance ratio: 1:8.0

[VERIFICATION] Checking for data leakage risks:
  Max value: 19.58
  Min value: -16.45
  Global mean: 0.88
  Global std: 1.62
  ✓ Data is NOT globally normalized. Safe for fold-wise processing.

STEP 2: Configure Cross-Validation Strategy
[INFO] Outer CV: 5-fold × 10 repeats = 50 iterations
[INFO] Inner CV: 3-fold stratified
[INFO] Random seed: 42
[INFO] Results will be saved to: /content/drive/MyDrive/geoexosome_results

STEP 3: Define Model Configurations
[INFO] Configured 4 model variants:
  - RandomForest_SMOTE
  - RandomForest_Weighted
  - SVM
  - LogisticRegression

STEP 4: Execute Nested Cross-Validation

--------------------------------------------------------------------------------
Model: RandomForest_SMOTE
--------------------------------------------------------------------------------

  Running 50 iterations (5-fold × 10 repea

  model = cd_fast.enet_coordinate_descent(


        LASSO: 10 features
        SVM-RFE: 50 features
        Random Forest: 150 features
      Stage 3: 51 features (≥2/3 consensus)


  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


    → Repeat 1 complete: OOF AUC = 0.9814


  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


    → Repeat 2 complete: OOF AUC = 0.9830

    Iteration 11/50 (Repeat 3, Fold 1)
      Train: 71/79 CRC | Test: 17/20 CRC

    [Fold R3F1] Feature selection on training data only
      Train samples: 79 (8 HC, 71 CRC)
      Total probes: 15739
      Stage 1: 160 probes (|log2FC| > 1.0, FDR q < 0.05)
      Stage 2: Multi-method selection (LASSO, SVM-RFE, RF)


  model = cd_fast.enet_coordinate_descent(


        LASSO: 33 features
        SVM-RFE: 53 features
        Random Forest: 150 features
      Stage 3: 60 features (≥2/3 consensus)


  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


    → Repeat 3 complete: OOF AUC = 0.9535


  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


    → Repeat 4 complete: OOF AUC = 0.9824

    Iteration 21/50 (Repeat 5, Fold 1)
      Train: 71/79 CRC | Test: 17/20 CRC

    [Fold R5F1] Feature selection on training data only
      Train samples: 79 (8 HC, 71 CRC)
      Total probes: 15739
      Stage 1: 201 probes (|log2FC| > 1.0, FDR q < 0.05)
      Stage 2: Multi-method selection (LASSO, SVM-RFE, RF)


  model = cd_fast.enet_coordinate_descent(


        LASSO: 31 features
        SVM-RFE: 67 features
        Random Forest: 150 features
      Stage 3: 69 features (≥2/3 consensus)


  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


    → Repeat 5 complete: OOF AUC = 0.9824


  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


    → Repeat 6 complete: OOF AUC = 0.9602

    Iteration 31/50 (Repeat 7, Fold 1)
      Train: 71/79 CRC | Test: 17/20 CRC

    [Fold R7F1] Feature selection on training data only
      Train samples: 79 (8 HC, 71 CRC)
      Total probes: 15739
      Stage 1: 100 probes (|log2FC| > 1.0, FDR q < 0.05)
      Stage 2: Multi-method selection (LASSO, SVM-RFE, RF)


  model = cd_fast.enet_coordinate_descent(


        LASSO: 22 features
        SVM-RFE: 33 features
        Random Forest: 100 features
      Stage 3: 39 features (≥2/3 consensus)


  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


    → Repeat 7 complete: OOF AUC = 0.9814


  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


    → Repeat 8 complete: OOF AUC = 0.9757

    Iteration 41/50 (Repeat 9, Fold 1)
      Train: 71/79 CRC | Test: 17/20 CRC

    [Fold R9F1] Feature selection on training data only
      Train samples: 79 (8 HC, 71 CRC)
      Total probes: 15739
      Stage 1: 159 probes (|log2FC| > 1.0, FDR q < 0.05)
      Stage 2: Multi-method selection (LASSO, SVM-RFE, RF)


  model = cd_fast.enet_coordinate_descent(


        LASSO: 35 features
        SVM-RFE: 53 features
        Random Forest: 150 features
      Stage 3: 59 features (≥2/3 consensus)


  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


    → Repeat 9 complete: OOF AUC = 0.9654


  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


    → Repeat 10 complete: OOF AUC = 0.9385

  OUT-OF-FOLD PERFORMANCE (Unbiased Estimate)
    ROC-AUC (aggregated):  0.9824
    ROC-AUC (mean±std):    0.9704 ± 0.0147
    95% CI:                [0.9519, 1.0000]
    Accuracy:              0.9192
    Balanced Accuracy:     0.7159
    Precision:             0.9348
    Sensitivity:           0.9773
    Specificity:           0.4545
    F1-score:              0.9556
    Mean Train AUC:        1.0000
    Mean Test AUC:         0.9772
    Confusion Matrix:
    [[ 5  6]
     [ 2 86]]

  Feature Stability Across 50 Iterations:
    Selected in 100%:  1
    Selected in ≥80%:  6
    Selected in ≥50%:  27

--------------------------------------------------------------------------------
Model: RandomForest_Weighted
--------------------------------------------------------------------------------

  Running 50 iterations (5-fold × 10 repeats)...

    Iteration 1/50 (Repeat 1, Fold 1)
      Train: 71/79 CRC | Test: 17/20 CRC

    [Fold R1F1] Feature se

  model = cd_fast.enet_coordinate_descent(


        LASSO: 10 features
        SVM-RFE: 50 features
        Random Forest: 150 features
      Stage 3: 51 features (≥2/3 consensus)


  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


    → Repeat 1 complete: OOF AUC = 0.9902


  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


    → Repeat 2 complete: OOF AUC = 0.9861

    Iteration 11/50 (Repeat 3, Fold 1)
      Train: 71/79 CRC | Test: 17/20 CRC

    [Fold R3F1] Feature selection on training data only
      Train samples: 79 (8 HC, 71 CRC)
      Total probes: 15739
      Stage 1: 160 probes (|log2FC| > 1.0, FDR q < 0.05)
      Stage 2: Multi-method selection (LASSO, SVM-RFE, RF)


  model = cd_fast.enet_coordinate_descent(


        LASSO: 33 features
        SVM-RFE: 53 features
        Random Forest: 150 features
      Stage 3: 60 features (≥2/3 consensus)


  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


    → Repeat 3 complete: OOF AUC = 0.9659


  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


In [None]:
# =============================================================================
# Cell 2-2: Balanced Subsampling Feature Selection (Ma et al. Strategy)
# =============================================================================
"""
Purpose:
    Implement Ma et al. (2021) balanced subsampling strategy for robust
    feature selection in severely imbalanced datasets.

Reference:
    Ma J, Wang P, et al. (2021) "Bioinformatic analysis reveals an exosomal
    miRNA-mRNA network in colorectal cancer." BMC Med Genomics 14(1):60.
    DOI: 10.1186/s12920-021-00905-2

Key Differences from Cell 2 (SMOTE-based):
    1. No synthetic data generation (100% real samples)
    2. Balanced subsets: n_minority samples from each class per iteration
    3. 1000 iterations for robust frequency estimation
    4. Double threshold: frequency (>50%) AND consensus (≥2/3 methods)

VERSION 2.0 MODIFICATIONS:
    - Added FDR correction (Benjamini-Hochberg) within each iteration
    - Fixed Probe ID type consistency (string normalization)
    - Clarified consensus definitions (within-method vs cross-method)
    - Fixed Jaccard similarity calculation

Prerequisites:
    - Cell 0-2 must be executed first
    - Required variables: expr_log_full, labels_full, sample_ids_full, mapping_df

Output:
    - balanced_stable_features: Features selected by balanced subsampling
    - comparison_results: Comparison with Cell 2 SMOTE-based results
    - consensus_features: Intersection of both methods (high confidence)

Author: Jungho Sohn
Date: 2025-12-23
Version: 2.0 (FDR correction added)
"""

import os
import json
import numpy as np
import pandas as pd
import warnings
from collections import defaultdict, Counter
from scipy.stats import mannwhitneyu

from sklearn.linear_model import LassoCV
from sklearn.svm import SVC
from sklearn.feature_selection import RFE
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler

warnings.filterwarnings('ignore')

print("\n" + "=" * 80)
print("Cell 2-2: BALANCED SUBSAMPLING FEATURE SELECTION (v2.0)")
print("         (Ma et al. 2021 Strategy - No Synthetic Data)")
print("         [WITH FDR CORRECTION]")
print("=" * 80)

# =============================================================================
# STEP 0: Verify Data Availability (must be first!)
# =============================================================================

# Check required variables exist using globals()
required_vars = ['expr_log_full', 'labels_full', 'sample_ids_full', 'results', 'best_name', 'SEED']
missing_vars = [v for v in required_vars if v not in globals()]

if missing_vars:
    raise RuntimeError(f"Missing required variables: {missing_vars}\n"
                       f"Please run Cell 0-2 first!")

print(f"[INFO] ✓ All required variables verified from Cell 0-2")

# =============================================================================
# CONFIGURATION
# =============================================================================

# Ensure RESULT_DIR is available (from Cell 2)
if 'RESULT_DIR' not in globals():
    RESULT_DIR = base_save_path if 'base_save_path' in globals() else './geoexosome_results'
    os.makedirs(RESULT_DIR, exist_ok=True)

# Subsampling parameters (Ma et al. strategy)
N_ITERATIONS = 1000          # Number of balanced subsampling iterations
FREQ_THRESHOLD = 0.5         # Minimum selection frequency (>50%)
CONSENSUS_THRESHOLD = 2      # Minimum methods agreeing (≥2 of 3)

# Feature selection thresholds (same as Cell 2 for fair comparison)
BS_LOG2FC_THRESHOLD = LOG2FC_THRESHOLD if 'LOG2FC_THRESHOLD' in globals() else 1.0
BS_FDR_THRESHOLD = FDR_THRESHOLD if 'FDR_THRESHOLD' in globals() else 0.05

# Get N_OUTER_SPLITS and N_REPEATS from Cell 2 (for comparison threshold calculation)
N_OUTER_SPLITS_REF = N_OUTER_SPLITS if 'N_OUTER_SPLITS' in globals() else 5
N_REPEATS_REF = N_REPEATS if 'N_REPEATS' in globals() else 10

print(f"""
CONFIGURATION:
  Iterations:              {N_ITERATIONS}
  Frequency threshold:     >{FREQ_THRESHOLD*100:.0f}%
  Consensus threshold:     ≥{CONSENSUS_THRESHOLD}/3 methods

  Biological filter:
    |log2FC| threshold:    > {BS_LOG2FC_THRESHOLD}
    FDR q-value:           < {BS_FDR_THRESHOLD}
    FDR correction:        Benjamini-Hochberg (per iteration)  ← NEW

  Key difference from Cell 2:
    - NO SMOTE (synthetic oversampling)
    - Each iteration uses {sum(labels_full==0)} HC + {sum(labels_full==0)} CRC (balanced)
    - Total real samples per iteration: {sum(labels_full==0) * 2}
""")

# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

def normalize_probe_id(probe_id):
    """
    Convert probe ID to consistent string format.

    This ensures compatibility between Cell 2 and Cell 2-2 results,
    regardless of whether probe IDs are stored as integers or strings.
    """
    if isinstance(probe_id, (int, np.integer)):
        return str(probe_id)
    return str(probe_id).strip()


def benjamini_hochberg_correction(pvalues):
    """
    Apply Benjamini-Hochberg FDR correction to p-values.

    Parameters
    ----------
    pvalues : array-like
        Raw p-values from statistical tests

    Returns
    -------
    qvalues : np.ndarray
        FDR-adjusted q-values (same length as input)

    Reference
    ---------
    Benjamini Y, Hochberg Y (1995). Controlling the false discovery rate:
    a practical and powerful approach to multiple testing.
    J R Stat Soc B 57:289-300.
    """
    pvalues = np.asarray(pvalues)
    n_tests = len(pvalues)

    if n_tests == 0:
        return np.array([])

    # Get sort order
    sorted_idx = np.argsort(pvalues)

    # Calculate adjusted p-values
    qvalues = np.zeros(n_tests)
    for i, idx in enumerate(sorted_idx):
        rank = i + 1
        qvalues[idx] = min(1.0, pvalues[idx] * n_tests / rank)

    # Ensure monotonicity (cumulative minimum from the end)
    # This guarantees q-values are non-decreasing when sorted by p-value
    for i in range(n_tests - 2, -1, -1):
        if qvalues[sorted_idx[i]] > qvalues[sorted_idx[i + 1]]:
            qvalues[sorted_idx[i]] = qvalues[sorted_idx[i + 1]]

    return qvalues


# =============================================================================
# STEP 1: Data Summary and Class Distribution
# =============================================================================
print("\n" + "=" * 80)
print("STEP 1: Data Summary")
print("=" * 80)
print(f"  - Expression matrix: {expr_log_full.shape}")
print(f"  - Labels: {len(labels_full)} samples ({sum(labels_full==0)} HC, {sum(labels_full==1)} CRC)")
print(f"  - Class ratio: 1:{sum(labels_full==1)/sum(labels_full==0):.1f}")

# Extract class indices
hc_indices = np.where(labels_full == 0)[0]
crc_indices = np.where(labels_full == 1)[0]

n_hc = len(hc_indices)
n_crc = len(crc_indices)

print(f"\n[INFO] Balanced subsampling configuration:")
print(f"  - Each iteration: {n_hc} HC + {n_hc} CRC = {2*n_hc} samples")
print(f"  - CRC samples used per iteration: {n_hc}/{n_crc} ({100*n_hc/n_crc:.1f}%)")
print(f"  - Total iterations: {N_ITERATIONS}")
print(f"  - Expected CRC coverage: Each sample appears ~{N_ITERATIONS*n_hc/n_crc:.0f} times")


# =============================================================================
# STEP 2: Balanced Subsampling Feature Selection (WITH FDR CORRECTION)
# =============================================================================
print("\n" + "=" * 80)
print("STEP 2: Execute Balanced Subsampling ({} iterations)".format(N_ITERATIONS))
print("        [WITH Benjamini-Hochberg FDR Correction per iteration]")
print("=" * 80)

# Initialize frequency counters for each method
freq_biological = defaultdict(int)
freq_lasso = defaultdict(int)
freq_svm_rfe = defaultdict(int)
freq_rf = defaultdict(int)

# Track iteration statistics
iteration_stats = []
valid_iterations = 0

# Convert expression data to numpy for faster indexing
expr_values = expr_log_full.values
probe_names = expr_log_full.columns.tolist()

np.random.seed(SEED)

print(f"\n[INFO] Starting {N_ITERATIONS} balanced subsampling iterations...")
print(f"[INFO] Progress will be shown every 100 iterations\n")

for iteration in range(N_ITERATIONS):

    # Progress indicator
    if (iteration + 1) % 100 == 0:
        avg_bio = np.mean([s['n_bio'] for s in iteration_stats if s.get('valid', False)]) if iteration_stats else 0
        print(f"  Iteration {iteration + 1}/{N_ITERATIONS} "
              f"(valid: {valid_iterations}, bio features avg: {avg_bio:.1f})")

    # -------------------------------------------------------------------------
    # STEP 2a: Balanced Random Subsampling (Ma et al. strategy)
    # -------------------------------------------------------------------------
    # Sample n_hc CRC samples to match HC count (NO replacement within iteration)
    crc_subset_idx = np.random.choice(crc_indices, size=n_hc, replace=False)

    # Combine with all HC samples
    subset_idx = np.concatenate([hc_indices, crc_subset_idx])

    # Extract balanced subset
    X_sub = expr_values[subset_idx, :]
    y_sub = labels_full[subset_idx]

    # Verify balance
    assert sum(y_sub == 0) == sum(y_sub == 1) == n_hc, "Imbalanced subset!"

    # -------------------------------------------------------------------------
    # STEP 2b: Biological Filtering WITH FDR CORRECTION (MODIFIED v2.0)
    # -------------------------------------------------------------------------

    # First pass: collect all statistics for FDR correction
    all_log2fc = []
    all_pvals = []

    for probe_idx in range(len(probe_names)):
        hc_vals = X_sub[y_sub == 0, probe_idx]
        crc_vals = X_sub[y_sub == 1, probe_idx]

        # Log2 fold change (data is already log2 transformed)
        mean_hc = np.mean(hc_vals)
        mean_crc = np.mean(crc_vals)
        log2fc = mean_crc - mean_hc

        # Mann-Whitney U test (non-parametric)
        try:
            _, pval = mannwhitneyu(hc_vals, crc_vals, alternative='two-sided')
        except:
            pval = 1.0

        all_log2fc.append(log2fc)
        all_pvals.append(pval)

    # ─────────────────────────────────────────────────────────────────────────
    # FDR CORRECTION (Benjamini-Hochberg) - CRITICAL ADDITION
    # ─────────────────────────────────────────────────────────────────────────
    all_qvals = benjamini_hochberg_correction(all_pvals)

    # Apply thresholds with FDR-corrected q-values
    bio_selected_idx = []
    for probe_idx, probe_name in enumerate(probe_names):
        log2fc = all_log2fc[probe_idx]
        qval = all_qvals[probe_idx]  # FDR-adjusted q-value

        # Criteria: |log2FC| > threshold AND FDR q-value < threshold
        if abs(log2fc) > BS_LOG2FC_THRESHOLD and qval < BS_FDR_THRESHOLD:
            bio_selected_idx.append(probe_idx)
            freq_biological[probe_name] += 1

    # Skip iteration if too few features pass filter
    if len(bio_selected_idx) < 10:
        iteration_stats.append({
            'iteration': iteration + 1,
            'n_bio': len(bio_selected_idx),
            'valid': False
        })
        continue

    valid_iterations += 1

    # Extract filtered features
    X_bio = X_sub[:, bio_selected_idx]
    bio_probe_names = [probe_names[i] for i in bio_selected_idx]

    # Standardize for ML methods
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_bio)

    # -------------------------------------------------------------------------
    # STEP 2c: Multi-Method Feature Selection
    # -------------------------------------------------------------------------

    # Method 1: LASSO
    try:
        lasso = LassoCV(cv=3, random_state=SEED, max_iter=5000, n_jobs=-1)
        lasso.fit(X_scaled, y_sub)
        lasso_selected = [bio_probe_names[i] for i, c in enumerate(lasso.coef_) if abs(c) > 1e-6]

        for feat in lasso_selected:
            freq_lasso[feat] += 1
    except:
        lasso_selected = []

    # Method 2: SVM-RFE (with adjusted n_features_to_select)
    try:
        # More stable feature count selection
        n_select = min(50, max(20, int(0.5 * len(bio_selected_idx))))
        svc = SVC(kernel='linear', class_weight='balanced', random_state=SEED)
        rfe = RFE(svc, n_features_to_select=n_select, step=0.2)
        rfe.fit(X_scaled, y_sub)
        svm_selected = [bio_probe_names[i] for i in np.where(rfe.support_)[0]]

        for feat in svm_selected:
            freq_svm_rfe[feat] += 1
    except:
        svm_selected = []

    # Method 3: Random Forest
    try:
        rf = RandomForestClassifier(
            n_estimators=100,
            random_state=SEED,
            n_jobs=-1,
            class_weight='balanced'
        )
        rf.fit(X_bio, y_sub)  # RF doesn't need scaling
        importances = rf.feature_importances_
        threshold = np.percentile(importances, 75)  # Top 25%
        rf_selected = [bio_probe_names[i] for i, imp in enumerate(importances) if imp >= threshold]

        for feat in rf_selected:
            freq_rf[feat] += 1
    except:
        rf_selected = []

    # Record iteration statistics
    iteration_stats.append({
        'iteration': iteration + 1,
        'n_bio': len(bio_selected_idx),
        'n_lasso': len(lasso_selected),
        'n_svm': len(svm_selected),
        'n_rf': len(rf_selected),
        'valid': True
    })

print(f"\n[INFO] ✓ Balanced subsampling complete")
print(f"  Total iterations: {N_ITERATIONS}")
print(f"  Valid iterations: {valid_iterations} ({100*valid_iterations/N_ITERATIONS:.1f}%)")


# =============================================================================
# STEP 3: Apply Double Threshold (Frequency + Consensus)
# =============================================================================
print("\n" + "=" * 80)
print("STEP 3: Apply Double Threshold Selection")
print("=" * 80)

# Calculate minimum count for frequency threshold
min_count = int(valid_iterations * FREQ_THRESHOLD)
print(f"\n[INFO] Frequency threshold: >{FREQ_THRESHOLD*100:.0f}% = >{min_count}/{valid_iterations} iterations")

# Combine all features
all_features = set(freq_lasso.keys()) | set(freq_svm_rfe.keys()) | set(freq_rf.keys())
print(f"[INFO] Total unique features selected across all iterations: {len(all_features)}")

# Calculate selection statistics for each feature
feature_stats = []

for feat in all_features:
    n_bio = freq_biological.get(feat, 0)
    n_lasso = freq_lasso.get(feat, 0)
    n_svm = freq_svm_rfe.get(feat, 0)
    n_rf = freq_rf.get(feat, 0)

    # Count methods where feature passed frequency threshold
    methods_passed = sum([
        n_lasso >= min_count,
        n_svm >= min_count,
        n_rf >= min_count
    ])

    # Overall frequency (max across methods)
    max_freq = max(n_lasso, n_svm, n_rf)
    overall_freq = max_freq / valid_iterations if valid_iterations > 0 else 0

    feature_stats.append({
        'Probe_ID': feat,
        'freq_bio': n_bio / valid_iterations if valid_iterations > 0 else 0,
        'freq_lasso': n_lasso / valid_iterations if valid_iterations > 0 else 0,
        'freq_svm': n_svm / valid_iterations if valid_iterations > 0 else 0,
        'freq_rf': n_rf / valid_iterations if valid_iterations > 0 else 0,
        'n_lasso': n_lasso,
        'n_svm': n_svm,
        'n_rf': n_rf,
        'methods_passed': methods_passed,
        'overall_freq': overall_freq,
        'max_count': max_freq
    })

df_feature_stats = pd.DataFrame(feature_stats)

# Apply double threshold: frequency AND consensus
df_stable = df_feature_stats[
    (df_feature_stats['overall_freq'] >= FREQ_THRESHOLD) &
    (df_feature_stats['methods_passed'] >= CONSENSUS_THRESHOLD)
].sort_values('overall_freq', ascending=False).copy()

# Also get features passing only frequency threshold (for comparison)
df_freq_only = df_feature_stats[
    df_feature_stats['overall_freq'] >= FREQ_THRESHOLD
].sort_values('overall_freq', ascending=False).copy()

print(f"\n[SELECTION RESULTS]")
print(f"  Features passing frequency threshold (>{FREQ_THRESHOLD*100:.0f}%): {len(df_freq_only)}")
print(f"  Features passing consensus (≥{CONSENSUS_THRESHOLD}/3 methods): "
      f"{len(df_feature_stats[df_feature_stats['methods_passed'] >= CONSENSUS_THRESHOLD])}")
print(f"  Features passing BOTH (final stable): {len(df_stable)}")

# Normalize probe IDs to strings for consistency
balanced_stable_features = [normalize_probe_id(p) for p in df_stable['Probe_ID'].tolist()]

print(f"\n[INFO] ✓ Balanced Subsampling selected {len(balanced_stable_features)} stable features")


# =============================================================================
# STEP 4: Compare with Cell 2 Results (SMOTE-based)
# =============================================================================
print("\n" + "=" * 80)
print("STEP 4: Compare with Cell 2 (SMOTE-based) Results")
print("=" * 80)

# Get Cell 2 stable features
cell2_feature_freq = results[best_name]["feature_stability"]["feature_frequency"]

# Calculate threshold for Cell 2 stability (80% of total iterations)
stability_threshold_80 = int(0.8 * N_OUTER_SPLITS_REF * N_REPEATS_REF)
stability_threshold_50 = int(0.5 * N_OUTER_SPLITS_REF * N_REPEATS_REF)

cell2_stable_features = []
for k, v in cell2_feature_freq.items():
    probe_str = normalize_probe_id(k)
    if v >= stability_threshold_80:
        cell2_stable_features.append(probe_str)

# Fallback to 50% threshold if too few features
if len(cell2_stable_features) < 5:
    print(f"[WARNING] Only {len(cell2_stable_features)} features at 80% threshold")
    print(f"[WARNING] Using 50% threshold instead ({stability_threshold_50} iterations)")
    cell2_stable_features = []
    for k, v in cell2_feature_freq.items():
        probe_str = normalize_probe_id(k)
        if v >= stability_threshold_50:
            cell2_stable_features.append(probe_str)

print(f"\n[COMPARISON]")
print(f"  Cell 2 (SMOTE-based) stable features: {len(cell2_stable_features)}")
print(f"  Cell 2-2 (Balanced Subsampling) stable features: {len(balanced_stable_features)}")

# Calculate overlap with normalized IDs
cell2_set = set(cell2_stable_features)
bs_set = set(balanced_stable_features)

overlap = cell2_set & bs_set
only_cell2 = cell2_set - bs_set
only_bs = bs_set - cell2_set

print(f"\n  Overlap (both methods): {len(overlap)} features")
print(f"  Only in Cell 2 (SMOTE): {len(only_cell2)} features")
print(f"  Only in Cell 2-2 (Balanced): {len(only_bs)} features")

# Safe Jaccard calculation (FIXED)
jaccard = None
if len(cell2_set) > 0 and len(bs_set) > 0:
    union_size = len(cell2_set | bs_set)
    if union_size > 0:
        jaccard = len(overlap) / union_size
        print(f"  Jaccard similarity: {jaccard:.3f}")
else:
    print(f"  Jaccard similarity: N/A (empty set)")

# High-confidence consensus features (selected by BOTH methods)
consensus_features = list(overlap)
print(f"\n[HIGH-CONFIDENCE CONSENSUS FEATURES]")
print(f"  Total: {len(consensus_features)} features")
print(f"  These features are robust to both SMOTE and balanced subsampling")


# =============================================================================
# STEP 5: Display Top Stable Features with miRNA Names
# =============================================================================
print("\n" + "=" * 80)
print("STEP 5: Top Stable Features from Balanced Subsampling")
print("=" * 80)

# Get miRNA mapping function from Cell 1 or create one
if 'probe_to_mirna_dict' not in globals():
    if 'mapping_df' in globals() and mapping_df is not None:
        mirna_col = 'miRNA' if 'miRNA' in mapping_df.columns else 'SPOT_ID'
        probe_to_mirna_dict = dict(zip(
            mapping_df["Probe_ID"].apply(normalize_probe_id),
            mapping_df[mirna_col]
        ))
    else:
        probe_to_mirna_dict = {}

def get_mirna_name_safe(probe_id):
    """Safely retrieve miRNA name from probe ID."""
    probe_str = normalize_probe_id(probe_id)
    mirna = probe_to_mirna_dict.get(probe_str, None)

    if mirna is None or pd.isna(mirna) or str(mirna).strip() == "":
        return f"Unmapped_{probe_str}"
    return str(mirna).strip()

# Add miRNA names to stable features
df_stable = df_stable.copy()
df_stable['Probe_ID_str'] = df_stable['Probe_ID'].apply(normalize_probe_id)
df_stable['miRNA_Name'] = df_stable['Probe_ID_str'].apply(get_mirna_name_safe)
df_stable['In_Consensus'] = df_stable['Probe_ID_str'].isin(consensus_features)

# Display top features
print(f"\n{'Rank':<6} {'Probe ID':<12} {'miRNA Name':<25} {'Freq':<8} {'Methods':<10} {'Consensus'}")
print("-" * 85)

for rank, (_, row) in enumerate(df_stable.head(20).iterrows(), 1):
    consensus_mark = "★" if row['In_Consensus'] else ""
    print(f"{rank:<6} {row['Probe_ID']:<12} {row['miRNA_Name']:<25} "
          f"{row['overall_freq']:.1%}   {row['methods_passed']}/3       {consensus_mark}")

# Highlight consensus features
print(f"\n[CONSENSUS FEATURES (Both SMOTE & Balanced Subsampling)]")
print("-" * 85)

consensus_df = df_stable[df_stable['In_Consensus']].copy()
if len(consensus_df) > 0:
    for _, row in consensus_df.iterrows():
        print(f"  ★ {row['Probe_ID']:<12} {row['miRNA_Name']:<25} freq={row['overall_freq']:.1%}")
else:
    print("  No overlapping features found between methods.")
    print("  This suggests method-dependent feature selection.")


# =============================================================================
# STEP 6: Save Results
# =============================================================================
print("\n" + "=" * 80)
print("STEP 6: Save Balanced Subsampling Results")
print("=" * 80)

# Save full feature statistics
stats_path = os.path.join(RESULT_DIR, "balanced_subsampling_feature_stats.csv")
df_feature_stats.to_csv(stats_path, index=False)
print(f"[INFO] All feature statistics saved: {stats_path}")

# Save stable features
stable_path = os.path.join(RESULT_DIR, "balanced_subsampling_stable_features.csv")
df_stable.to_csv(stable_path, index=False)
print(f"[INFO] Stable features saved: {stable_path}")

# Save comparison results (FIXED Jaccard handling)
comparison_results = {
    'methodology': {
        'cell2': 'SMOTE-based oversampling with nested CV',
        'cell2_2': 'Balanced subsampling (Ma et al. 2021)',
        'fdr_correction': 'Benjamini-Hochberg per iteration',
        'primary_analysis': 'Cell 2-2 (Balanced Subsampling)',
        'sensitivity_analysis': 'Cell 2 (SMOTE)'
    },
    'cell2_smote': {
        'n_stable': len(cell2_stable_features),
        'features': cell2_stable_features,
        'stability_threshold': f'{stability_threshold_80 if len(cell2_stable_features) >= 5 else stability_threshold_50} iterations'
    },
    'cell2_2_balanced': {
        'n_stable': len(balanced_stable_features),
        'features': balanced_stable_features,
        'n_iterations': N_ITERATIONS,
        'valid_iterations': valid_iterations,
        'freq_threshold': FREQ_THRESHOLD,
        'consensus_threshold': CONSENSUS_THRESHOLD,
        'fdr_threshold': BS_FDR_THRESHOLD,
        'log2fc_threshold': BS_LOG2FC_THRESHOLD
    },
    'consensus': {
        'n_features': len(consensus_features),
        'features': consensus_features,
        'jaccard_similarity': float(jaccard) if jaccard is not None else None,
        'interpretation': 'Features robust to both synthetic and real-sample approaches'
    }
}

comparison_path = os.path.join(RESULT_DIR, "feature_selection_comparison.json")
with open(comparison_path, 'w', encoding='utf-8') as f:
    json.dump(comparison_results, f, indent=2, ensure_ascii=False)
print(f"[INFO] Comparison results saved: {comparison_path}")

# Save iteration statistics
iter_stats_df = pd.DataFrame(iteration_stats)
iter_stats_path = os.path.join(RESULT_DIR, "balanced_subsampling_iteration_stats.csv")
iter_stats_df.to_csv(iter_stats_path, index=False)
print(f"[INFO] Iteration statistics saved: {iter_stats_path}")


# =============================================================================
# STEP 7: Prepare Variables for Cell 3 (CLEAR DEFINITIONS)
# =============================================================================
print("\n" + "=" * 80)
print("STEP 7: Prepare Variables for Cell 3")
print("=" * 80)

# ─────────────────────────────────────────────────────────────────────────────
# CLEAR CONSENSUS DEFINITIONS
# ─────────────────────────────────────────────────────────────────────────────
print("""
╔═══════════════════════════════════════════════════════════════════════════════╗
║                        CONSENSUS DEFINITIONS                                   ║
╠═══════════════════════════════════════════════════════════════════════════════╣
║                                                                               ║
║  1. WITHIN-METHOD CONSENSUS (Cell 2-2 internal):                              ║
║     → Freq > 50% AND ≥2/3 methods (LASSO, SVM-RFE, RF)                        ║
║     → Result: balanced_stable_features ({:>3} features)                       ║
║                                                                               ║
║  2. CROSS-METHOD CONSENSUS (Cell 2 vs Cell 2-2):                              ║
║     → Selected by BOTH SMOTE (Cell 2) AND Balanced Subsampling (Cell 2-2)     ║
║     → Result: consensus_features ({:>3} features)                             ║
║                                                                               ║
╠═══════════════════════════════════════════════════════════════════════════════╣
║  RECOMMENDED FOR PUBLICATION:                                                 ║
║                                                                               ║
║  • PRIMARY ANALYSIS:     Cell 2-2 (Balanced Subsampling)                      ║
║                          - No synthetic data, Ma et al. validated             ║
║                          - Use: balanced_stable_features                      ║
║                                                                               ║
║  • SENSITIVITY ANALYSIS: Cell 2 (SMOTE)                                       ║
║                          - Report in Supplementary Materials                  ║
║                          - Shows robustness to methodological choices         ║
║                                                                               ║
║  • EXPERIMENTAL VALIDATION: Cross-method consensus                            ║
║                          - Highest confidence for organoid experiments        ║
║                          - Use: consensus_features                            ║
╚═══════════════════════════════════════════════════════════════════════════════╝
""".format(len(balanced_stable_features), len(consensus_features)))

# Define which features to use downstream
# PRIMARY: Balanced subsampling (for main analysis)
primary_features = balanced_stable_features.copy()

# VALIDATION: Cross-method consensus (for organoid experiments)
validation_features = consensus_features.copy() if len(consensus_features) >= 3 else balanced_stable_features.copy()

# Legacy variable for Cell 3 compatibility
bs_prioritized_features = primary_features

print(f"[INFO] Variables created for Cell 3:")
print(f"  • primary_features:       {len(primary_features)} features (main analysis)")
print(f"  • validation_features:    {len(validation_features)} features (organoid validation)")
print(f"  • bs_prioritized_features: {len(bs_prioritized_features)} features (legacy compatibility)")
print(f"  • consensus_features:     {len(consensus_features)} features (cross-method)")


# =============================================================================
# FINAL SUMMARY
# =============================================================================
print("\n" + "=" * 80)
print("✓ BALANCED SUBSAMPLING COMPLETE (v2.0 with FDR Correction)")
print("=" * 80)

print(f"""
[SUMMARY]
  Method: Ma et al. (2021) Balanced Subsampling
  Iterations: {N_ITERATIONS} ({valid_iterations} valid)
  Sample size per iteration: {2 * n_hc} ({n_hc} HC + {n_hc} CRC)
  FDR correction: Benjamini-Hochberg (per iteration)  ← NEW

[RESULTS]
  Balanced Subsampling stable features: {len(balanced_stable_features)}
  Cell 2 (SMOTE) stable features: {len(cell2_stable_features)}
  Cross-method consensus: {len(consensus_features)}
  Jaccard similarity: {f'{jaccard:.3f}' if jaccard is not None else 'N/A'}

[INTERPRETATION]
""")

if len(consensus_features) >= 5:
    print(f"  ✓ GOOD: {len(consensus_features)} features are robust to both methods")
    print(f"    → These are HIGH-CONFIDENCE biomarker candidates")
    print(f"    → Recommended for organoid validation")
elif len(consensus_features) >= 1:
    print(f"  ○ MODERATE: Only {len(consensus_features)} features overlap")
    print(f"    → Feature selection is somewhat method-dependent")
    print(f"    → Consider using balanced subsampling results for novelty")
else:
    print(f"  △ DIVERGENT: No overlapping features between methods")
    print(f"    → SMOTE may be introducing artifacts")
    print(f"    → Recommend using balanced subsampling results")
    print(f"    → Be transparent about this in manuscript")

print(f"""
[FILES GENERATED]
  - balanced_subsampling_feature_stats.csv
  - balanced_subsampling_stable_features.csv
  - balanced_subsampling_iteration_stats.csv
  - feature_selection_comparison.json

[MANUSCRIPT NOTES]
  Methods section should include:
  - "FDR correction (Benjamini-Hochberg) was applied within each
    balanced subsampling iteration to control for multiple testing
    across {len(probe_names):,} probes."
  - "Features were considered stable if selected in >{FREQ_THRESHOLD*100:.0f}% of
    {N_ITERATIONS} iterations by at least {CONSENSUS_THRESHOLD} of 3 methods."

[NEXT STEPS]
  1. Compare primary_features with Ma et al. (2021) hub miRNAs
  2. Proceed to Cell 3 for miRNA → mRNA target prediction
  3. Use validation_features for organoid validation priority
""")

print("=" * 80)

In [None]:
# =============================================================================
# Cell 3: SHAP Analysis and Biomarker Discovery (v3.0 Compatible)
# =============================================================================
"""
VERSION 3.0 MODIFICATIONS:

1. Stability Threshold: 80% → 70% (matches Cell 2 v3.0)
2. Adaptive threshold detection improved
3. Dual-path awareness (uses CV-stable features from Cell 2)
4. Better handling of repeated CV structure

Author: Jungho Sohn
Date: 2025-12-28
Version: 3.0
"""

import os
import numpy as np
import pandas as pd
import warnings

warnings.filterwarnings('ignore')

print("\n" + "=" * 80)
print("Cell 3: SHAP ANALYSIS & BIOMARKER DISCOVERY (v3.0)")
print("=" * 80)

# =============================================================================
# STEP 0: Verify Required Variables
# =============================================================================
print("\n[STEP 0] Verify required variables from Cell 2...")

required_vars = ['results', 'best_name', 'oof_predictions', 'feature_cols',
                 'df_expression', 'mapping_df', 'base_save_path', 'SEED']
missing = [v for v in required_vars if v not in dir()]

if missing:
    raise NameError(f"Missing variables: {missing}\nPlease run Cell 0-2 first!")

print(f"[OK] All required variables verified")
print(f"[INFO] Best model: {best_name}")
print(f"[INFO] Number of final features: {len(feature_cols)}")

# =============================================================================
# ┌─────────────────────────────────────────────────────────────────────────────┐
# │ CHANGE 1: Detect CV Structure and Use Consistent Thresholds               │
# │ Now matches Cell 2 v3.0 STABILITY_THRESHOLD = 0.70                          │
# └─────────────────────────────────────────────────────────────────────────────┘
# =============================================================================
print("\n" + "=" * 80)
print("STEP 1: DETECT CV STRUCTURE AND ADAPT THRESHOLDS")
print("=" * 80)

# Get feature frequency from results
feature_freq_raw = results[best_name]['feature_stability']['feature_frequency']

# Convert string keys to int if necessary
if isinstance(list(feature_freq_raw.keys())[0], str):
    feature_freq = {int(k): v for k, v in feature_freq_raw.items()}
    print(f"[INFO] Converted {len(feature_freq)} feature keys from str to int")
else:
    feature_freq = feature_freq_raw.copy()

# Detect CV structure by checking n_iterations or max frequency
n_iterations = results[best_name]['feature_stability'].get('n_iterations', None)
stability_threshold = results[best_name]['feature_stability'].get('stability_threshold', 0.70)

if n_iterations is None:
    # Fallback: infer from max frequency
    max_freq = max(feature_freq.values()) if feature_freq else 0
    if max_freq <= 5:
        n_iterations = 5
        print(f"[INFO] Detected 5-fold CV structure (max_freq={max_freq})")
    else:
        n_iterations = 50  # Assume repeated CV
        print(f"[INFO] Detected Repeated CV structure (max_freq={max_freq}, assuming 50 iterations)")
else:
    print(f"[INFO] CV iterations from results: {n_iterations}")

# ┌─────────────────────────────────────────────────────────────────────────────┐
# │ CHANGE 2: Set adaptive thresholds based on CV structure                    │
# │ 70% is now the primary threshold (was 80%)                                  │
# └─────────────────────────────────────────────────────────────────────────────┘
THRESHOLD_CORE = n_iterations                     # 100% selection
THRESHOLD_HIGH = int(n_iterations * stability_threshold)  # ≥70% selection (was 80%)
THRESHOLD_MODERATE = int(n_iterations * 0.5)      # ≥50% selection

print(f"\n[INFO] Adaptive Stability Thresholds (based on {n_iterations} iterations):")
print(f"  Core stable:        {THRESHOLD_CORE}/{n_iterations} (100%)")
print(f"  Highly stable:      ≥{THRESHOLD_HIGH}/{n_iterations} (≥{stability_threshold*100:.0f}%)")
print(f"  Moderately stable:  ≥{THRESHOLD_MODERATE}/{n_iterations} (≥50%)")

# =============================================================================
# STEP 2: Extract Stable miRNA Biomarkers from ML Model
# =============================================================================
print("\n" + "=" * 80)
print("STEP 2: EXTRACT STABLE miRNA BIOMARKERS")
print("=" * 80)

# Categorize by stability with ADAPTIVE thresholds
core_stable = [pid for pid, freq in feature_freq.items() if freq >= THRESHOLD_CORE]
highly_stable = [pid for pid, freq in feature_freq.items()
                 if THRESHOLD_HIGH <= freq < THRESHOLD_CORE]
moderately_stable = [pid for pid, freq in feature_freq.items()
                     if THRESHOLD_MODERATE <= freq < THRESHOLD_HIGH]

print(f"\n[INFO] Feature Stability Summary ({best_name}):")
print(f"  Core stable (100%, {THRESHOLD_CORE}/{n_iterations}):           {len(core_stable)} features")
print(f"  Highly stable (≥{stability_threshold*100:.0f}%, ≥{THRESHOLD_HIGH}/{n_iterations}):   {len(highly_stable)} features")
print(f"  Moderately stable (≥50%, ≥{THRESHOLD_MODERATE}/{n_iterations}): {len(moderately_stable)} features")

# Show actual top features if categories are empty
if len(core_stable) == 0 and len(highly_stable) == 0:
    print(f"\n[WARNING] No features meet ≥{stability_threshold*100:.0f}% threshold!")
    print(f"[INFO] Top 10 features by selection frequency:")
    sorted_features = sorted(feature_freq.items(), key=lambda x: x[1], reverse=True)[:10]
    for pid, freq in sorted_features:
        pct = freq / n_iterations * 100
        print(f"  Probe {pid}: {freq}/{n_iterations} ({pct:.1f}%)")

    # Use top features that meet minimum threshold
    min_threshold_features = [pid for pid, freq in feature_freq.items()
                              if freq >= THRESHOLD_MODERATE]

    if len(min_threshold_features) > 0:
        print(f"\n[INFO] Using {len(min_threshold_features)} features with ≥50% selection rate")
        highly_stable = min_threshold_features[:20]  # Top 20 at most
    else:
        # Fallback: use top N features by frequency
        top_n = min(20, len(sorted_features))
        highly_stable = [pid for pid, freq in sorted_features[:top_n]]
        print(f"\n[INFO] Fallback: Using top {top_n} features by frequency")

# Combine core + highly stable for comprehensive analysis
prioritized_features = list(set(core_stable + highly_stable))
print(f"\n[INFO] Total prioritized features for target prediction: {len(prioritized_features)}")

# =============================================================================
# STEP 3: Check for CV-Stable Features from Cell 2 v3.0
# =============================================================================
print("\n" + "=" * 80)
print("STEP 3: CHECK DUAL-PATH FEATURE LISTS (Cell 2 v3.0)")
print("=" * 80)

# Check if CV-stable biomarkers file exists from Cell 2 v3.0
cv_stable_path = os.path.join(base_save_path, f"cv_stable_biomarkers_{best_name}.csv")
production_path = os.path.join(base_save_path, f"production_model_features_{best_name}.csv")
overlap_path = os.path.join(base_save_path, f"overlap_robust_biomarkers_{best_name}.csv")

if os.path.exists(cv_stable_path):
    cv_stable_df = pd.read_csv(cv_stable_path)
    print(f"[OK] Found CV-stable biomarkers from Cell 2 v3.0: {len(cv_stable_df)} features")
    print(f"     File: {cv_stable_path}")

    # Use these as the primary biomarker list
    cv_stable_probes = cv_stable_df['Probe_ID'].tolist()
    prioritized_features = [int(p) if str(p).isdigit() else p for p in cv_stable_probes]
    print(f"[INFO] Using Cell 2 v3.0 CV-stable features as primary biomarker list")
else:
    print(f"[INFO] CV-stable biomarkers file not found (Cell 2 v3.0 format)")
    print(f"[INFO] Using features from stability analysis above")

if os.path.exists(overlap_path):
    overlap_df = pd.read_csv(overlap_path)
    print(f"[OK] Found overlap (most robust) biomarkers: {len(overlap_df)} features")
else:
    print(f"[INFO] Overlap biomarkers file not found")

# =============================================================================
# STEP 4: Verification - Check if probe IDs exist in data
# =============================================================================
print(f"\n[VERIFICATION] Checking probe ID existence:")

if len(prioritized_features) == 0:
    print("[ERROR] No prioritized features available!")
    print("[INFO] Checking feature_freq content:")
    print(f"  Total features in freq dict: {len(feature_freq)}")
    print(f"  Max frequency: {max(feature_freq.values()) if feature_freq else 0}")
    raise RuntimeError("No features to analyze. Check Cell 2 results.")

sample_probe = prioritized_features[0]
print(f"  Sample probe: {sample_probe} (type: {type(sample_probe)})")

# Check in df_expression columns
df_cols = df_expression.columns.tolist()
# Handle potential type mismatch
if isinstance(df_cols[0], str) and isinstance(sample_probe, int):
    sample_probe_check = str(sample_probe)
elif isinstance(df_cols[0], int) and isinstance(sample_probe, str):
    sample_probe_check = int(sample_probe)
else:
    sample_probe_check = sample_probe

in_expression = sample_probe_check in df_cols or sample_probe in df_cols
print(f"  In df_expression.columns: {in_expression}")

# Check in mapping_df
if mapping_df is not None:
    mapping_probes = mapping_df['Probe_ID'].tolist()
    in_mapping = sample_probe in mapping_probes or int(sample_probe) in mapping_probes
    print(f"  In mapping_df Probe_ID: {in_mapping}")
else:
    print(f"  mapping_df not available")

print("[INFO] ✓ Probe ID verification complete")

# =============================================================================
# STEP 5: Map Probe IDs to miRNA Names
# =============================================================================
print("\n" + "=" * 80)
print("STEP 5: PROBE ID → miRNA NAME MAPPING")
print("=" * 80)

if mapping_df is None:
    print("[WARNING] mapping_df not available! Using probe IDs as identifiers.")
    probe_to_mirna_dict = {}
else:
    print(f"\n[INFO] mapping_df structure:")
    print(f"  Shape: {mapping_df.shape}")
    print(f"  Columns: {mapping_df.columns.tolist()}")

    # Determine miRNA column name
    mirna_col_candidates = ['miRNA', 'miRNA_ID', 'miRNA_name', 'NAME']
    mirna_col = None
    for col in mirna_col_candidates:
        if col in mapping_df.columns:
            mirna_col = col
            break

    if mirna_col is None:
        mirna_col = mapping_df.columns[1]  # Fallback to second column
        print(f"[WARNING] Using fallback miRNA column: {mirna_col}")
    else:
        print(f"[INFO] Using miRNA column: {mirna_col}")

    # Create probe-to-miRNA lookup dictionary
    probe_to_mirna_dict = dict(zip(mapping_df["Probe_ID"], mapping_df[mirna_col]))
    print(f"\n[INFO] Probe-to-miRNA dictionary created:")
    print(f"  Total probes: {len(probe_to_mirna_dict)}")


def get_mirna_name(probe_id):
    """Safely retrieve miRNA name from probe ID."""
    # Try both int and string versions
    mirna = probe_to_mirna_dict.get(probe_id, None)
    if mirna is None:
        mirna = probe_to_mirna_dict.get(int(probe_id) if isinstance(probe_id, str) else str(probe_id), None)

    if mirna is None or pd.isna(mirna) or str(mirna).strip() == "":
        return f"Unmapped_Probe_{probe_id}"

    return str(mirna).strip()


# Create prioritized miRNA biomarker list
mirna_biomarker_data = []

for probe_id in prioritized_features:
    mirna_name = get_mirna_name(probe_id)
    freq = feature_freq.get(probe_id, feature_freq.get(int(probe_id), 0))
    pct = freq / n_iterations * 100

    if freq >= THRESHOLD_CORE:
        stability_cat = f"Core ({freq}/{n_iterations}, {pct:.0f}%)"
        priority = 1
    elif freq >= THRESHOLD_HIGH:
        stability_cat = f"Highly Stable ({freq}/{n_iterations}, {pct:.0f}%)"
        priority = 2
    else:
        stability_cat = f"Stable ({freq}/{n_iterations}, {pct:.0f}%)"
        priority = 3

    mirna_biomarker_data.append({
        'Probe_ID': probe_id,
        'miRNA_Name': mirna_name,
        'Selection_Frequency': freq,
        'Selection_Percentage': pct,
        'Stability_Category': stability_cat,
        'Priority_Rank': priority
    })

df_mirna_biomarkers = pd.DataFrame(mirna_biomarker_data)
df_mirna_biomarkers = df_mirna_biomarkers.sort_values(
    ['Priority_Rank', 'Selection_Frequency'],
    ascending=[True, False]
)

# Display results
print("\n[INFO] Prioritized miRNA Biomarkers:")
print("-" * 80)
print(f"{'Probe ID':<12} {'miRNA Name':<30} {'Freq':<8} {'%':<8} {'Category'}")
print("-" * 80)

for _, row in df_mirna_biomarkers.head(20).iterrows():
    print(f"{row['Probe_ID']:<12} {row['miRNA_Name']:<30} "
          f"{row['Selection_Frequency']:<8} {row['Selection_Percentage']:<8.1f} "
          f"{row['Stability_Category']}")

if len(df_mirna_biomarkers) > 20:
    print(f"... and {len(df_mirna_biomarkers) - 20} more features")

# Save biomarker list
biomarker_path = os.path.join(base_save_path, 'mirna_biomarkers_prioritized.csv')
df_mirna_biomarkers.to_csv(biomarker_path, index=False)
print(f"\n[INFO] Biomarker list saved: {biomarker_path}")

# =============================================================================
# STEP 6: Summary Statistics
# =============================================================================
print("\n" + "=" * 80)
print("STEP 6: BIOMARKER DISCOVERY SUMMARY")
print("=" * 80)

n_mapped = sum(1 for _, row in df_mirna_biomarkers.iterrows()
               if not row['miRNA_Name'].startswith('Unmapped'))
n_unmapped = len(df_mirna_biomarkers) - n_mapped

# Get Kuncheva index from Cell 2 if available
kuncheva = results[best_name].get('kuncheva_stability_index', 'N/A')
kuncheva_str = f"{kuncheva:.4f}" if isinstance(kuncheva, float) else str(kuncheva)

print(f"""
[SUMMARY]
  CV Structure:           {n_iterations} iterations (5-fold × {n_iterations//5} repeats)
  Stability Threshold:    {stability_threshold*100:.0f}% (per Cell 2 v3.0)
  Kuncheva Index:         {kuncheva_str}

  Total prioritized:      {len(prioritized_features)} features
  Successfully mapped:    {n_mapped} miRNAs
  Unmapped probes:        {n_unmapped}

  Stability Distribution:
    Core (100%):              {len(core_stable)}
    Highly stable (≥{stability_threshold*100:.0f}%):   {len(highly_stable)}
    Moderately (≥50%):        {len(moderately_stable)}

[READY FOR]
  - Target mRNA prediction (TargetScan, miRDB)
  - Pathway enrichment analysis
  - Organoid RNA-seq validation matching
""")

# Store for downstream analysis
prioritized_mirnas = df_mirna_biomarkers['miRNA_Name'].tolist()
prioritized_probes = df_mirna_biomarkers['Probe_ID'].tolist()

print("=" * 80)
print("✓ BIOMARKER EXTRACTION COMPLETE (v3.0)")
print("=" * 80)

# ==============================================================================
# STEP 6: Summary Statistics for Methods Section
# ==============================================================================
"""
Generate statistics needed for manuscript Methods and Results sections.
"""

print("\n" + "-" * 80)
print("STEP 6: Summary for Manuscript")
print("-" * 80)

# -----------------------------------------------------------------------------
# HOTFIX: Variable name alignment
# Original code used 'df_importance' which was never defined
# Solution: Use df_mirna_biomarkers and map column names appropriately
# -----------------------------------------------------------------------------
df_importance = df_mirna_biomarkers.copy()

# Ensure 'Rank' column exists (mapped from 'SHAP_Rank')
if 'SHAP_Rank' in df_importance.columns:
    df_importance['Rank'] = df_importance['SHAP_Rank'].fillna(999).astype(int)
else:
    df_importance['Rank'] = 999  # Default if SHAP analysis failed

# Count features by stability in top 20
top20 = df_importance.head(20)
n_core_in_top20 = sum(top20["Selection_Frequency"] == 5)
n_high_in_top20 = sum(top20["Selection_Frequency"] == 4)

print("\n[MANUSCRIPT STATISTICS]")
print(f"  Total features in final model: {len(feature_cols)}")
print(f"  Core stable features (5/5):    {len(core_stable)}")
print(f"  Highly stable features (4/5):  {len(highly_stable)}")
print(f"  Core stable in top 20 SHAP:    {n_core_in_top20}")
print(f"  Highly stable in top 20 SHAP:  {n_high_in_top20}")

# List core stable features for biological interpretation
print("\n[CORE BIOMARKER CANDIDATES - For Literature Review]")
print("-" * 80)
print(f"{'miRNA Name':<30} {'Probe ID':<12} {'SHAP Rank':<12} {'Status'}")
print("-" * 80)

if len(core_stable) > 0:
    for probe_id in core_stable:
        mirna_name = get_mirna_name(probe_id) if mapping_df is not None else str(probe_id)

        # Check if this probe is in the final model
        matching_rows = df_importance[df_importance["Probe_ID"] == probe_id]

        if len(matching_rows) > 0:
            # Feature is in final model
            shap_rank = matching_rows["Rank"].values[0]
            status = "✓ In model"
            print(f"{mirna_name:<30} {probe_id:<12} #{shap_rank:<11} {status}")
        else:
            # Feature NOT in final model
            status = "⚠ Not selected"
            print(f"{mirna_name:<30} {probe_id:<12} {'N/A':<12} {status}")

    # Summary statistics
    core_in_final = sum(1 for pid in core_stable
                        if len(df_importance[df_importance["Probe_ID"] == pid]) > 0)

    print("\n" + "-" * 80)
    print(f"[SUMMARY]")
    print(f"  Core stable in final model:    {core_in_final}/{len(core_stable)}")
    print(f"  Core stable NOT in final:      {len(core_stable) - core_in_final}/{len(core_stable)}")

    if core_in_final < len(core_stable):
        print(f"\n[INTERPRETATION]")
        print(f"  Some CV-stable features were not selected in the final model.")
        print(f"  This occurs when:")
        print(f"    - Fold-specific patterns don't generalize to full dataset")
        print(f"    - Feature selection is stochastic (different splits)")
        print(f"    - Final model uses slightly different feature set")
        print(f"  Recommendation: Focus on features present in BOTH CV and final model")

else:
    print("  No features selected in all 5 folds.")
    print("  Consider using highly stable (4/5) features for interpretation.")

# Optional: List highly stable features that ARE in final model
print("\n[HIGHLY STABLE FEATURES IN FINAL MODEL (4/5 folds)]")
print("-" * 80)

high_in_final = []
for probe_id in highly_stable:
    if len(df_importance[df_importance["Probe_ID"] == probe_id]) > 0:
        high_in_final.append(probe_id)

if len(high_in_final) > 0:
    print(f"Found {len(high_in_final)} highly stable features in final model")
    print("\nTop 10 by SHAP importance:")
    print(f"{'miRNA Name':<30} {'Probe ID':<12} {'SHAP Rank'}")
    print("-" * 60)

    # Sort by SHAP rank
    high_with_rank = []
    for probe_id in high_in_final:
        mirna_name = get_mirna_name(probe_id) if mapping_df is not None else str(probe_id)
        shap_rank = df_importance[df_importance["Probe_ID"] == probe_id]["Rank"].values[0]
        high_with_rank.append((mirna_name, probe_id, shap_rank))

    high_with_rank.sort(key=lambda x: x[2])  # Sort by SHAP rank

    for mirna_name, probe_id, shap_rank in high_with_rank[:10]:
        print(f"{mirna_name:<30} {probe_id:<12} #{shap_rank}")
else:
    print("  No highly stable features in final model")

# ==============================================================================
# STEP 7: Generate Summary Report for Professor
# ==============================================================================
"""
Create comprehensive summary document ready for organoid validation planning.
"""

print("\n" + "=" * 80)
print("STEP 7: GENERATE SUMMARY REPORT")
print("=" * 80)
best_model_name = best_name
report_path = os.path.join(RESULT_DIR, "VALIDATION_SUMMARY_REPORT.txt")

with open(report_path, 'w', encoding='utf-8') as f:
    f.write("=" * 80 + "\n")
    f.write("miRNA BIOMARKER DISCOVERY → ORGANOID VALIDATION SUMMARY\n")
    f.write("=" * 80 + "\n\n")

    f.write(f"Analysis Date: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"Dataset: GSE39833 (n=99, HC=11, CRC=88)\n")
    f.write(f"Best Model: {best_model_name}\n")
    f.write(f"OOF AUC: {results[best_model_name]['oof_auc']:.4f}\n")
    f.write(f"95% CI: [{results[best_model_name]['oof_auc_ci'][0]:.4f}, "
            f"{results[best_model_name]['oof_auc_ci'][1]:.4f}]\n\n")

    f.write("-" * 80 + "\n")
    f.write("SECTION 1: STABLE miRNA BIOMARKERS\n")
    f.write("-" * 80 + "\n\n")

    f.write(f"Core stable (5/5 folds):       {len(core_stable)} miRNAs\n")
    f.write(f"Highly stable (4/5 folds):     {len(highly_stable)} miRNAs\n")
    f.write(f"Total prioritized:             {len(prioritized_features)} miRNAs\n\n")

    f.write("Top 10 miRNAs for Experimental Validation:\n")
    f.write("-" * 60 + "\n")
    for i, (_, row) in enumerate(df_mirna_biomarkers.head(10).iterrows(), 1):
        f.write(f"{i:2d}. {row['miRNA_Name']:<25} "
                f"(Stability: {row['Stability_Category']}, "
                f"SHAP Rank: #{row.get('SHAP_Rank', 'N/A')})\n")

    f.write("\n" + "-" * 80 + "\n")
    f.write("SECTION 2: PREDICTED mRNA TARGETS\n")
    f.write("-" * 80 + "\n\n")

    if TARGETS_AVAILABLE:
        f.write(f"Total miRNA-mRNA interactions: {len(df_mirna_targets)}\n")
        f.write(f"Unique target genes:           {df_mirna_targets['Target_Gene'].nunique()}\n")
        f.write(f"High-priority candidates:      {len(df_filtered_mrnas)}\n")
        f.write(f"Known CRC genes:               {df_filtered_mrnas['Is_Known_CRC_Gene'].sum()}\n\n")

        f.write("Top 20 mRNA Candidates for Organoid Matching:\n")
        f.write("-" * 80 + "\n")
        f.write(f"{'Rank':<6} {'Gene':<12} {'#miRNAs':<10} {'CRC Gene':<10} {'Score'}\n")
        f.write("-" * 80 + "\n")

        for _, row in df_filtered_mrnas.head(20).iterrows():
            known = "Yes" if row['Is_Known_CRC_Gene'] else "No"
            f.write(f"{row['Priority_Rank']:<6} {row.name:<12} {row['miRNA_Count']:<10} "
                    f"{known:<10} {row['Priority_Score']:.3f}\n")
    else:
        f.write("Target prediction not completed.\n")
        f.write("Complete Step 4 (manual database search) to generate mRNA candidates.\n")

    f.write("\n" + "-" * 80 + "\n")
    f.write("SECTION 3: NEXT STEPS FOR ORGANOID VALIDATION\n")
    f.write("-" * 80 + "\n\n")

    f.write("1. Experimental Design:\n")
    f.write("   - Compare predicted mRNA expression in organoid RNA-seq data\n")
    f.write("   - Focus on top 20 high-priority candidates (priority score ≥0.5)\n")
    f.write("   - Validate known CRC genes first (established biological plausibility)\n\n")

    f.write("2. Statistical Analysis Plan:\n")
    f.write("   - Differential expression: CRC organoid vs normal tissue\n")
    f.write("   - Correlation: miRNA abundance ↔ target mRNA expression\n")
    f.write("   - Pathway enrichment: Confirm CRC pathway activation\n\n")

    f.write("3. Expected Outcomes:\n")
    f.write("   - Validation rate: 50-70% of predicted targets (realistic expectation)\n")
    f.write("   - High priority = higher validation probability\n")
    f.write("   - Multi-miRNA targets = stronger signal\n\n")

    f.write("4. Publication Strategy:\n")
    f.write("   - Figure 1: ML model performance (ROC, confusion matrix)\n")
    f.write("   - Figure 2: SHAP feature importance\n")
    f.write("   - Figure 3: miRNA-mRNA network\n")
    f.write("   - Figure 4: Organoid validation results (your data!)\n")
    f.write("   - Figure 5: Pathway enrichment heatmap\n\n")

    f.write("-" * 80 + "\n")
    f.write("OUTPUT FILES FOR COLLABORATION\n")
    f.write("-" * 80 + "\n\n")

    f.write("1. stable_mirnas_for_validation.csv\n")
    f.write("   → Core miRNA biomarker list (share with wet lab team)\n\n")

    f.write("2. mirna_target_mrnas_predicted.csv\n")
    f.write("   → All predicted miRNA-mRNA interactions\n\n")

    f.write("3. crc_relevant_mrnas_filtered.csv\n")
    f.write("   → HIGH PRIORITY: mRNA candidates for organoid matching\n")
    f.write("   → Use this file for RNA-seq comparison!\n\n")

    f.write("4. mirna_mrna_network.csv\n")
    f.write("   → Network edge list for pathway visualization\n\n")

    f.write("5. kegg_pathway_enrichment.csv (if generated)\n")
    f.write("   → Enriched pathways for biological interpretation\n\n")

    f.write("=" * 80 + "\n")
    f.write("END OF REPORT\n")
    f.write("=" * 80 + "\n")

print(f"[INFO] ✓ Summary report saved: {report_path}")

# ==============================================================================
# FINAL COMPLETION MESSAGE
# ==============================================================================
print("\n" + "=" * 80)
print("✓ CELL 3 COMPLETE: BIOMARKER DISCOVERY → TARGET PREDICTION PIPELINE")
print("=" * 80)

print(f"""
[SUMMARY]

miRNA Biomarkers Identified:  {len(prioritized_features)}
  - Core stable (5/5 folds):   {len(core_stable)}
  - Highly stable (4/5 folds): {len(highly_stable)}

mRNA Target Prediction:        {'COMPLETED' if TARGETS_AVAILABLE else 'PENDING'}
  - Total interactions:        {len(df_mirna_targets) if TARGETS_AVAILABLE else 'N/A'}
  - High-priority candidates:  {len(df_filtered_mrnas) if TARGETS_AVAILABLE else 'N/A'}

[CRITICAL FILES FOR PROFESSOR]

Ready for Organoid Validation:
  ✓ {biomarker_path}
  {'✓' if TARGETS_AVAILABLE else '⚠'} {os.path.join(base_save_path, 'crc_relevant_mrnas_filtered.csv')}
  ✓ {report_path}

{'' if TARGETS_AVAILABLE else '''
[ACTION REQUIRED]
Complete Step 4 (miRNA target prediction) by:
  1. Manually searching each miRNA in TargetScan/miRDB
  2. Saving results as: ''' + manual_targets_path + '''
  3. Re-running this cell to generate filtered mRNA candidates
'''}

[NEXT STEPS]

1. Review summary report: {os.path.basename(report_path)}
2. Share mRNA candidate list with organoid team
3. Plan RNA-seq comparison analysis
4. Prepare for manuscript writing (you have the data!)

""")

print("=" * 80 + "\n")

In [None]:
# =============================================================================
# Cell 4: EXTERNAL VALIDATION ON INDEPENDENT DATASETS
# =============================================================================
"""
Purpose:
    Validate the trained model on completely independent cohorts to assess
    true generalization performance. This is the MOST CRITICAL step for
    publication-ready results.

Datasets:
    - GSE39814: Independent serum exosome miRNA cohort (Patient samples)
    - GSE39832: Cell line data (HT-29) - NOT suitable for patient validation

Strategy:
    1. Load external datasets from GEO
    2. Match probe IDs with training features
    3. Apply identical preprocessing (log2 transformation)
    4. Predict using the trained model (NO retraining!)
    5. Evaluate performance metrics

Version: 2.1 - Fixed variable naming conflicts
Author: Jungho Sohn
Date: 2025-12-22
"""

import os
import numpy as np
import pandas as pd
import GEOparse
import json
import warnings

from sklearn.metrics import (
    roc_auc_score,
    accuracy_score,
    balanced_accuracy_score,
    precision_recall_fscore_support,
    confusion_matrix,
    roc_curve
)
import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore')

print("=" * 80)
print("CELL 4: EXTERNAL VALIDATION ON INDEPENDENT DATASETS")
print("=" * 80)
print("""
[CRITICAL] This cell validates the model trained on GSE39833 using
           completely independent cohorts.

           NO learning or parameter tuning occurs on external data.
           This provides an unbiased estimate of true generalization.
""")

# =============================================================================
# STEP 0: Verify Required Variables from Previous Cells
# =============================================================================
print("\n" + "=" * 80)
print("STEP 0: Verify Required Variables")
print("=" * 80)

required_vars = {
    'best_full_model': 'Trained classifier pipeline',
    'feature_cols': 'List of selected probe IDs',
    'best_name': 'Name of best model',
    'SEED': 'Random seed for reproducibility',
    'base_save_path': 'Output directory path',
    'results': 'Nested CV results dictionary',
    'oof_predictions': 'Out-of-fold predictions dictionary'
}

missing_vars = []
for var_name, description in required_vars.items():
    if var_name not in dir():
        missing_vars.append(f"  - {var_name}: {description}")
    else:
        print(f"[OK] {var_name} found")

if missing_vars:
    raise NameError(
        f"[ERROR] Missing required variables from previous cells:\n"
        + "\n".join(missing_vars) +
        "\n\nPlease run Cell 2 (Nested CV) first!"
    )

# ┌─────────────────────────────────────────────────────────────────────────────┐
# │ CRITICAL FIX: Save Cell 2 results to avoid variable name conflicts         │
# └─────────────────────────────────────────────────────────────────────────────┘
nested_cv_results = results.copy()  # Preserve Cell 2's results dictionary

print(f"\n[INFO] Model to validate: {best_name}")
print(f"[INFO] Number of features: {len(feature_cols)}")

# Get training performance from nested CV
train_auc = nested_cv_results[best_name]['oof_auc']
train_bal_acc = nested_cv_results[best_name]['oof_balanced_accuracy']
train_sensitivity = nested_cv_results[best_name]['recall_sensitivity']
train_specificity = nested_cv_results[best_name]['specificity']

print(f"\n[INFO] Training Performance (from Nested CV):")
print(f"  - OOF AUC:          {train_auc:.4f}")
print(f"  - Balanced Accuracy: {train_bal_acc:.4f}")
print(f"  - Sensitivity:       {train_sensitivity:.4f}")
print(f"  - Specificity:       {train_specificity:.4f}")

# =============================================================================
# STEP 1: Define External Validation Datasets
# =============================================================================
print("\n" + "=" * 80)
print("STEP 1: Define External Validation Datasets")
print("=" * 80)

EXTERNAL_DATASETS = {
    'GSE39814': {
        'description': 'Serum exosome miRNA - Independent patient cohort',
        'expected_platform': 'GPL16016',
        'is_patient_data': True,  # Suitable for validation
        'notes': 'Patient serum samples'
    },
    'GSE39832': {
        'description': 'Serum exosome miRNA - Cell line data',
        'expected_platform': 'GPL16016',
        'is_patient_data': False,  # NOT suitable for patient validation
        'notes': 'HT-29 colorectal cancer cell line (NOT patient samples)'
    }
}

print("[INFO] External datasets:")
for gse_id, info in EXTERNAL_DATASETS.items():
    status = "✓ Patient data" if info['is_patient_data'] else "⚠️ Cell line (not for patient validation)"
    print(f"  - {gse_id}: {info['description']}")
    print(f"    Status: {status}")

# =============================================================================
# STEP 2: Define Helper Functions
# =============================================================================
print("\n" + "=" * 80)
print("STEP 2: Define Helper Functions")
print("=" * 80)


def bootstrap_auc_ci(y_true, y_proba, n_bootstrap=1000, alpha=0.05, random_state=42):
    """
    Compute bootstrap confidence interval for ROC-AUC.
    """
    rng = np.random.RandomState(random_state)
    y_true = np.asarray(y_true)
    y_proba = np.asarray(y_proba)
    n = len(y_true)

    aucs = []
    for _ in range(n_bootstrap):
        indices = rng.choice(n, n, replace=True)

        # Skip if bootstrap sample doesn't contain both classes
        if len(np.unique(y_true[indices])) < 2:
            continue

        aucs.append(roc_auc_score(y_true[indices], y_proba[indices]))

    if len(aucs) == 0:
        return np.nan, np.nan

    lower = np.percentile(aucs, 100 * (alpha / 2))
    upper = np.percentile(aucs, 100 * (1 - alpha / 2))

    return float(lower), float(upper)


def load_geo_dataset(gse_id, verbose=True):
    """
    Load GEO dataset and extract expression matrix with sample metadata.
    """
    if verbose:
        print(f"\n[INFO] Loading {gse_id} from GEO...")

    # Download and parse GEO dataset
    gse = GEOparse.get_GEO(geo=gse_id, destdir='./geo_cache', silent=True)

    if verbose:
        print(f"[INFO] Successfully loaded {gse_id}")
        print(f"  - Title: {gse.metadata.get('title', ['Unknown'])[0][:60]}...")
        print(f"  - Samples: {len(gse.gsms)}")

    # Extract sample IDs
    sample_ids = list(gse.gsms.keys())

    # Extract expression data
    expression_data = []
    sample_metadata = []

    for gsm_id in sample_ids:
        gsm = gse.gsms[gsm_id]

        # Get expression values
        if 'VALUE' in gsm.table.columns:
            values = gsm.table['VALUE'].values
        else:
            value_cols = [c for c in gsm.table.columns if 'value' in c.lower()]
            if value_cols:
                values = gsm.table[value_cols[0]].values
            else:
                raise KeyError(f"No expression value column found in {gsm_id}")

        expression_data.append(values)

        # Extract metadata for label assignment
        characteristics = gsm.metadata.get('characteristics_ch1', [])
        title = gsm.metadata.get('title', [''])[0]
        source = gsm.metadata.get('source_name_ch1', [''])[0]

        sample_metadata.append({
            'sample_id': gsm_id,
            'title': title,
            'source': source,
            'characteristics': '; '.join(characteristics) if characteristics else ''
        })

    # Get probe IDs from first sample
    probe_ids = gse.gsms[sample_ids[0]].table['ID_REF'].tolist()

    # Create expression DataFrame
    expr_df = pd.DataFrame(
        expression_data,
        index=sample_ids,
        columns=probe_ids
    )

    # Create metadata DataFrame
    sample_info = pd.DataFrame(sample_metadata)
    sample_info.set_index('sample_id', inplace=True)

    if verbose:
        print(f"  - Expression matrix shape: {expr_df.shape}")
        print(f"  - Probe count: {len(probe_ids)}")

    return expr_df, sample_info, gse


def assign_labels_external(sample_info, gse_id, verbose=True):
    """
    Assign binary labels (0=healthy, 1=cancer) based on sample metadata.
    """
    labels = []
    label_log = []

    # Define keywords for classification
    healthy_keywords = [
        'healthy', 'normal', 'control', 'hc', 'nc',
        'non-cancer', 'non-tumor', 'benign', 'volunteer'
    ]
    cancer_keywords = [
        'cancer', 'tumor', 'crc', 'colorectal', 'carcinoma',
        'malignant', 'adenocarcinoma', 'patient', 'case'
    ]

    # Cell line keywords (for GSE39832)
    cell_line_keywords = ['ht-29', 'ht29', 'caco', 'sw480', 'hct', 'cell line']

    for sample_id in sample_info.index:
        row = sample_info.loc[sample_id]

        # Combine all text fields for keyword search
        text_to_search = ' '.join([
            str(row.get('title', '')),
            str(row.get('source', '')),
            str(row.get('characteristics', ''))
        ]).lower()

        # Check if this is cell line data
        is_cell_line = any(kw in text_to_search for kw in cell_line_keywords)

        # Check for healthy/cancer indicators
        is_healthy = any(kw in text_to_search for kw in healthy_keywords)
        is_cancer = any(kw in text_to_search for kw in cancer_keywords)

        # Assign label with priority logic
        if is_cell_line:
            label = 1  # Cell lines are cancer-derived
            reason = "Cell line sample (cancer-derived)"
        elif is_healthy and not is_cancer:
            label = 0
            reason = "Matched healthy keywords"
        elif is_cancer and not is_healthy:
            label = 1
            reason = "Matched cancer keywords"
        elif is_healthy and is_cancer:
            if 'healthy' in text_to_search or 'normal control' in text_to_search:
                label = 0
                reason = "Matched 'healthy' specifically (ambiguous case)"
            else:
                label = 1
                reason = "Matched cancer keywords (ambiguous case)"
        else:
            label = 1  # Default to cancer if unclear
            reason = "No clear keywords - defaulting to cancer"

        labels.append(label)
        label_log.append({
            'sample_id': sample_id,
            'title': row.get('title', ''),
            'source': row.get('source', ''),
            'label': label,
            'reason': reason,
            'is_cell_line': is_cell_line
        })

    labels = np.array(labels)

    if verbose:
        n_healthy = (labels == 0).sum()
        n_cancer = (labels == 1).sum()
        n_cell_line = sum(1 for log in label_log if log.get('is_cell_line', False))

        print(f"\n[INFO] Label assignment for {gse_id}:")
        print(f"  - Healthy controls: {n_healthy}")
        print(f"  - Cancer patients: {n_cancer}")

        if n_cell_line > 0:
            print(f"  ⚠️  WARNING: {n_cell_line} samples are cell line data!")
            print(f"     Cell line data is NOT suitable for patient-level validation.")

        if n_healthy == 0 or n_cancer == 0:
            print(f"  ⚠️  WARNING: Only one class detected!")
            print(f"  Sample titles for review:")
            for log_entry in label_log[:5]:
                print(f"    - {log_entry['title'][:50]}... → Label: {log_entry['label']}")

    return labels, label_log


def validate_on_external_dataset(
    model,
    feature_list,
    external_expr,
    external_labels,
    dataset_name,
    seed=42,
    verbose=True
):
    """
    Validate trained model on external dataset.

    Returns validation_result dict (renamed from 'results' to avoid conflicts)
    """
    if verbose:
        print(f"\n{'─' * 60}")
        print(f"Validating on {dataset_name}")
        print(f"{'─' * 60}")

    # Step 1: Match features (probe IDs)
    available_features = set(external_expr.columns)
    required_features = set(feature_list)

    matched_features = list(required_features & available_features)
    missing_features = list(required_features - available_features)

    match_rate = len(matched_features) / len(required_features) * 100

    if verbose:
        print(f"[INFO] Feature matching:")
        print(f"  - Required features: {len(required_features)}")
        print(f"  - Matched features: {len(matched_features)} ({match_rate:.1f}%)")
        print(f"  - Missing features: {len(missing_features)}")

    if match_rate < 50:
        print(f"[ERROR] Less than 50% feature match. Validation may be unreliable.")

    # Step 2: Prepare feature matrix in SAME ORDER as training
    X_external = np.zeros((len(external_expr), len(feature_list)))

    for i, feat in enumerate(feature_list):
        if feat in external_expr.columns:
            X_external[:, i] = external_expr[feat].values
        else:
            X_external[:, i] = 0.0

    # Step 3: Apply log2 transformation if needed
    if external_expr.values.max() > 20:
        if verbose:
            print(f"[INFO] Applying log2(x + 1) transformation")
        X_external = np.log2(X_external + 1.0)
    else:
        if verbose:
            print(f"[INFO] Data appears log2-transformed. Using as-is.")

    # Handle any infinite or NaN values
    X_external = np.nan_to_num(X_external, nan=0.0, posinf=0.0, neginf=0.0)

    # Step 4: Predict using trained model
    try:
        y_proba = model.predict_proba(X_external)[:, 1]
        y_pred = (y_proba >= 0.5).astype(int)
    except Exception as e:
        print(f"[ERROR] Prediction failed: {str(e)}")
        return None

    # Step 5: Calculate performance metrics
    unique_labels = np.unique(external_labels)
    if len(unique_labels) < 2:
        print(f"[WARNING] Only one class present. AUC cannot be computed.")
        auc = np.nan
        ci_lower, ci_upper = np.nan, np.nan
    else:
        auc = roc_auc_score(external_labels, y_proba)
        ci_lower, ci_upper = bootstrap_auc_ci(
            external_labels, y_proba, n_bootstrap=1000, random_state=seed
        )

    acc = accuracy_score(external_labels, y_pred)
    bal_acc = balanced_accuracy_score(external_labels, y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(
        external_labels, y_pred, average='binary', zero_division=0
    )

    cm = confusion_matrix(external_labels, y_pred)

    # Calculate specificity and sensitivity
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    else:
        specificity = np.nan
        sensitivity = recall

    # Compile results
    validation_result = {
        'dataset': dataset_name,
        'n_samples': len(external_labels),
        'n_healthy': int((external_labels == 0).sum()),
        'n_cancer': int((external_labels == 1).sum()),
        'feature_match_rate': float(match_rate),
        'n_matched_features': len(matched_features),
        'n_missing_features': len(missing_features),
        'auc': float(auc) if not np.isnan(auc) else None,
        'auc_ci_lower': float(ci_lower) if not np.isnan(ci_lower) else None,
        'auc_ci_upper': float(ci_upper) if not np.isnan(ci_upper) else None,
        'accuracy': float(acc),
        'balanced_accuracy': float(bal_acc),
        'sensitivity': float(sensitivity) if not np.isnan(sensitivity) else None,
        'specificity': float(specificity) if not np.isnan(specificity) else None,
        'precision': float(precision),
        'f1_score': float(f1),
        'confusion_matrix': cm.tolist(),
        'y_true': external_labels.tolist(),
        'y_proba': y_proba.tolist(),
        'y_pred': y_pred.tolist()
    }

    # Print results
    if verbose:
        print(f"\n[RESULTS] {dataset_name}")
        print(f"  Samples: {validation_result['n_samples']} "
              f"({validation_result['n_healthy']} HC, {validation_result['n_cancer']} CRC)")
        print(f"  Feature match: {validation_result['feature_match_rate']:.1f}%")
        print(f"  ────────────────────────────────────")
        if validation_result['auc'] is not None:
            print(f"  AUC:              {validation_result['auc']:.4f} "
                  f"(95% CI: [{validation_result['auc_ci_lower']:.4f}, "
                  f"{validation_result['auc_ci_upper']:.4f}])")
        else:
            print(f"  AUC:              N/A (single class)")
        print(f"  Accuracy:         {validation_result['accuracy']:.4f}")
        print(f"  Balanced Acc:     {validation_result['balanced_accuracy']:.4f}")
        if validation_result['sensitivity'] is not None:
            print(f"  Sensitivity:      {validation_result['sensitivity']:.4f}")
        if validation_result['specificity'] is not None:
            print(f"  Specificity:      {validation_result['specificity']:.4f}")
        print(f"  F1-score:         {validation_result['f1_score']:.4f}")
        print(f"  Confusion Matrix:")
        print(f"    {cm}")

    return validation_result


# =============================================================================
# STEP 3: Load and Validate External Datasets
# =============================================================================
print("\n" + "=" * 80)
print("STEP 3: Load and Validate External Datasets")
print("=" * 80)

# Create cache directory for GEO downloads
os.makedirs('./geo_cache', exist_ok=True)

external_results = {}
all_label_logs = {}

for gse_id, info in EXTERNAL_DATASETS.items():
    print(f"\n{'═' * 80}")
    print(f"Processing {gse_id}: {info['description']}")
    if not info['is_patient_data']:
        print(f"⚠️  NOTE: This is cell line data - results for reference only")
    print(f"{'═' * 80}")

    try:
        # Load dataset
        expr_df, sample_info, gse = load_geo_dataset(gse_id, verbose=True)

        # Assign labels
        labels, label_log = assign_labels_external(sample_info, gse_id, verbose=True)
        all_label_logs[gse_id] = label_log

        # Save label log for transparency
        label_log_df = pd.DataFrame(label_log)
        label_log_path = os.path.join(base_save_path, f"external_label_log_{gse_id}.csv")
        label_log_df.to_csv(label_log_path, index=False)
        print(f"[INFO] Label log saved: {label_log_path}")

        # ┌─────────────────────────────────────────────────────────────────────┐
        # │ FIX: Use different variable name to avoid overwriting Cell 2 results│
        # └─────────────────────────────────────────────────────────────────────┘
        validation_result = validate_on_external_dataset(
            model=best_full_model,
            feature_list=feature_cols,
            external_expr=expr_df,
            external_labels=labels,
            dataset_name=gse_id,
            seed=SEED,
            verbose=True
        )

        if validation_result is not None:
            validation_result['is_patient_data'] = info['is_patient_data']
            external_results[gse_id] = validation_result

    except Exception as e:
        print(f"[ERROR] Failed to process {gse_id}: {str(e)}")
        import traceback
        traceback.print_exc()
        continue


# =============================================================================
# STEP 4: Summary and Comparison
# =============================================================================
print("\n" + "=" * 80)
print("STEP 4: External Validation Summary")
print("=" * 80)

if len(external_results) == 0:
    print("[ERROR] No external datasets were successfully validated.")
else:
    # Create summary table
    print("\n┌" + "─" * 78 + "┐")
    print("│" + " EXTERNAL VALIDATION RESULTS ".center(78) + "│")
    print("├" + "─" * 78 + "┤")
    print(f"│ {'Dataset':<12} │ {'Type':<8} │ {'N':<5} │ {'AUC':<8} │ {'95% CI':<15} │ {'Bal.Acc':<7} │ {'Sens':<5} │ {'Spec':<5} │")
    print("├" + "─" * 78 + "┤")

    for gse_id, res in external_results.items():
        auc_str = f"{res['auc']:.4f}" if res['auc'] else "N/A"
        ci_str = f"[{res['auc_ci_lower']:.3f},{res['auc_ci_upper']:.3f}]" if res['auc_ci_lower'] else "N/A"
        sens_str = f"{res['sensitivity']:.2f}" if res['sensitivity'] else "N/A"
        spec_str = f"{res['specificity']:.2f}" if res['specificity'] else "N/A"
        data_type = "Patient" if res.get('is_patient_data', True) else "Cell"

        print(f"│ {gse_id:<12} │ {data_type:<8} │ {res['n_samples']:<5} │ {auc_str:<8} │ {ci_str:<15} │ {res['balanced_accuracy']:.4f}  │ {sens_str:<5} │ {spec_str:<5} │")

    print("└" + "─" * 78 + "┘")

    # Compare with training performance
    print("\n" + "─" * 80)
    print("COMPARISON: Training vs External Validation")
    print("─" * 80)

    # Use preserved nested CV results
    print(f"\n{'Metric':<20} │ {'Training (CV)':<15} │ ", end="")
    for gse_id in external_results.keys():
        print(f"{gse_id:<15} │ ", end="")
    print()
    print("─" * (22 + 18 + 18 * len(external_results)))

    print(f"{'AUC':<20} │ {train_auc:.4f}          │ ", end="")
    for gse_id, res in external_results.items():
        auc_str = f"{res['auc']:.4f}" if res['auc'] else "N/A"
        print(f"{auc_str:<15} │ ", end="")
    print()

    print(f"{'Balanced Accuracy':<20} │ {train_bal_acc:.4f}          │ ", end="")
    for gse_id, res in external_results.items():
        print(f"{res['balanced_accuracy']:.4f}          │ ", end="")
    print()

    # Performance drop analysis (only for patient data)
    print("\n" + "─" * 80)
    print("PERFORMANCE DROP ANALYSIS (Patient Data Only)")
    print("─" * 80)

    patient_datasets = {k: v for k, v in external_results.items()
                        if v.get('is_patient_data', True)}

    for gse_id, res in patient_datasets.items():
        if res['auc'] is not None:
            auc_drop = train_auc - res['auc']
            bal_acc_drop = train_bal_acc - res['balanced_accuracy']

            print(f"\n{gse_id}:")
            print(f"  AUC drop:          {auc_drop:+.4f} ({train_auc:.4f} → {res['auc']:.4f})")
            print(f"  Balanced Acc drop: {bal_acc_drop:+.4f} ({train_bal_acc:.4f} → {res['balanced_accuracy']:.4f})")

            # Interpretation
            if auc_drop < 0.05:
                print(f"  ✓ Excellent generalization (AUC drop < 5%)")
            elif auc_drop < 0.10:
                print(f"  ○ Good generalization (AUC drop < 10%)")
            elif auc_drop < 0.20:
                print(f"  △ Moderate generalization (AUC drop < 20%)")
            else:
                print(f"  ✗ Poor generalization (AUC drop ≥ 20%) - Overfitting likely")


# =============================================================================
# STEP 5: Generate Visualization
# =============================================================================
print("\n" + "=" * 80)
print("STEP 5: Generate ROC Curves for External Validation")
print("=" * 80)

if len(external_results) > 0:
    fig, ax = plt.subplots(figsize=(10, 8))

    # Plot training ROC (from OOF predictions)
    train_y_true = oof_predictions[best_name]['y_true']
    train_y_proba = oof_predictions[best_name]['y_proba']
    fpr_train, tpr_train, _ = roc_curve(train_y_true, train_y_proba)

    ax.plot(
        fpr_train, tpr_train,
        linewidth=2.5,
        label=f'Training (GSE39833) - AUC = {train_auc:.3f}',
        color='#2C3E50'
    )

    # Plot external validation ROCs
    colors = {'GSE39814': '#E74C3C', 'GSE39832': '#95A5A6'}  # Gray for cell line
    linestyles = {'GSE39814': '--', 'GSE39832': ':'}

    for gse_id, res in external_results.items():
        if res['auc'] is not None:
            y_true = np.array(res['y_true'])
            y_proba = np.array(res['y_proba'])
            fpr, tpr, _ = roc_curve(y_true, y_proba)

            label_suffix = "" if res.get('is_patient_data', True) else " (Cell line)"

            ax.plot(
                fpr, tpr,
                linewidth=2,
                linestyle=linestyles.get(gse_id, '--'),
                label=f'{gse_id}{label_suffix} - AUC = {res["auc"]:.3f}',
                color=colors.get(gse_id, '#3498DB')
            )

    # Diagonal reference line
    ax.plot([0, 1], [0, 1], 'k--', alpha=0.3, label='Random (AUC = 0.5)')

    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('False Positive Rate (1 - Specificity)', fontsize=12)
    ax.set_ylabel('True Positive Rate (Sensitivity)', fontsize=12)
    ax.set_title('ROC Curves: Training vs External Validation', fontsize=14, fontweight='bold')
    ax.legend(loc='lower right', fontsize=10)
    ax.grid(True, alpha=0.3)

    # Save figure
    roc_path = os.path.join(base_save_path, 'external_validation_roc.png')
    plt.savefig(roc_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"[INFO] ROC curve saved: {roc_path}")


# =============================================================================
# STEP 6: Save Comprehensive Results
# =============================================================================
print("\n" + "=" * 80)
print("STEP 6: Save External Validation Results")
print("=" * 80)

# Prepare JSON-serializable results
json_output = {
    'training_dataset': 'GSE39833',
    'model_name': best_name,
    'n_features': len(feature_cols),
    'training_performance': {
        'oof_auc': float(train_auc),
        'oof_balanced_accuracy': float(train_bal_acc),
        'oof_sensitivity': float(train_sensitivity),
        'oof_specificity': float(train_specificity)
    },
    'external_validation': {}
}

for gse_id, res in external_results.items():
    # Remove large arrays from JSON
    json_res = {k: v for k, v in res.items() if k not in ['y_true', 'y_proba', 'y_pred']}
    json_output['external_validation'][gse_id] = json_res

# Save JSON
json_path = os.path.join(base_save_path, 'external_validation_results.json')
with open(json_path, 'w', encoding='utf-8') as f:
    json.dump(json_output, f, indent=2, ensure_ascii=False)
print(f"[INFO] Results saved: {json_path}")

# Save detailed CSV for each external dataset
for gse_id, res in external_results.items():
    detail_df = pd.DataFrame({
        'y_true': res['y_true'],
        'y_proba': res['y_proba'],
        'y_pred': res['y_pred']
    })
    detail_path = os.path.join(base_save_path, f'external_predictions_{gse_id}.csv')
    detail_df.to_csv(detail_path, index=False)
    print(f"[INFO] Predictions saved: {detail_path}")


# =============================================================================
# FINAL SUMMARY
# =============================================================================
print("\n" + "=" * 80)
print("✓ EXTERNAL VALIDATION COMPLETE")
print("=" * 80)

print(f"""
[SUMMARY]
  Model validated: {best_name}
  Training dataset: GSE39833 (n=99)
  External datasets: {len(external_results)}
""")

# Final interpretation (patient data only)
patient_results = {k: v for k, v in external_results.items()
                   if v.get('is_patient_data', True) and v['auc'] is not None}

if len(patient_results) > 0:
    avg_patient_auc = np.mean([res['auc'] for res in patient_results.values()])

    print(f"[PATIENT DATA VALIDATION]")
    print(f"  Training AUC:         {train_auc:.4f}")
    print(f"  Avg Patient Ext AUC:  {avg_patient_auc:.4f}")
    print(f"  Performance drop:     {train_auc - avg_patient_auc:.4f}")

    if avg_patient_auc >= 0.85:
        print(f"\n  ✓ EXCELLENT: Model shows strong generalization.")
        print(f"    → Ready for publication with strong external validation.")
    elif avg_patient_auc >= 0.75:
        print(f"\n  ○ GOOD: Model generalizes reasonably well.")
        print(f"    → Publishable with appropriate limitations discussed.")
    elif avg_patient_auc >= 0.65:
        print(f"\n  △ MODERATE: Some overfitting observed.")
        print(f"    → Consider feature reduction or regularization.")
    else:
        print(f"\n  ✗ POOR: Significant overfitting detected.")
        print(f"    → Model revision strongly recommended.")

# Note about cell line data
cell_line_results = {k: v for k, v in external_results.items()
                     if not v.get('is_patient_data', True)}
if len(cell_line_results) > 0:
    print(f"\n[CELL LINE DATA NOTE]")
    print(f"  GSE39832 contains HT-29 colorectal cancer cell line data.")
    print(f"  This is NOT suitable for patient-level clinical validation.")
    print(f"  Results are provided for reference only.")

print(f"""
[NEXT STEPS]
  1. Review label assignment logs for accuracy
  2. If performance drop > 0.15, consider:
     - Reducing feature count (use CV-stable features only)
     - Increasing regularization
     - Using simpler model
  3. Update manuscript with external validation results

[FILES GENERATED]
  - external_validation_results.json
  - external_validation_roc.png
  - external_label_log_GSE39814.csv
  - external_label_log_GSE39832.csv
  - external_predictions_GSE39814.csv
  - external_predictions_GSE39832.csv
""")

print("=" * 80)