In [5]:
import pandas as pd
import numpy as np
import os
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, classification_report, confusion_matrix
from sklearn.inspection import permutation_importance
from scipy.stats import mannwhitneyu
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)

def calculate_cohens_d(group1, group2):
    """Calculate Cohen's d effect size"""
    n1, n2 = len(group1), len(group2)
    pooled_std = np.sqrt(((n1 - 1) * np.var(group1, ddof=1) + (n2 - 1) * np.var(group2, ddof=1)) / (n1 + n2 - 2))
    return (np.mean(group1) - np.mean(group2)) / pooled_std

def load_and_prepare_data():
    """Load and prepare all datasets"""
    print("="*60)
    print("LOADING AND PREPARING DATA")
    print("="*60)
    
    # File paths
    clinical_path = "/Users/heweilin/Desktop/P056/7Clinical_data50.csv"
    dmr_path = "/Users/heweilin/Desktop/P056/4DNA_DMRs.csv"
    methylation_path = "/Users/heweilin/Desktop/P056_Code_2/Processed_Data/1_PD_PromoterRegion_CpGs.csv"
    output_dir = "/Users/heweilin/Desktop/P056_Code_3/Data"
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    print(f"Output directory created: {output_dir}")
    
    # Load clinical data
    print("\n1. Loading clinical data...")
    clinical_data = pd.read_csv(clinical_path)
    print(f"Clinical data shape: {clinical_data.shape}")
    print(f"B12 status distribution:")
    print(clinical_data['B12_status'].value_counts())
    
    # Load methylation data
    print("\n2. Loading methylation data...")
    methylation_data = pd.read_csv(methylation_path)
    print(f"Methylation data shape: {methylation_data.shape}")
    
    # Load DMR data for mechanism-driven features
    print("\n3. Loading DMR data...")
    dmr_data = pd.read_csv(dmr_path)
    print(f"DMR data shape: {dmr_data.shape}")
    
    # Extract DNA IDs from clinical data
    dna_ids = clinical_data['DNA_ID'].tolist()
    print(f"Number of DNA samples: {len(dna_ids)}")
    
    # Align methylation data with clinical data
    print("\n4. Aligning methylation data with clinical samples...")
    available_dna_ids = [col for col in methylation_data.columns if col in dna_ids]
    print(f"Available DNA IDs in methylation data: {len(available_dna_ids)}")
    
    # Filter clinical data to match available methylation data
    clinical_data = clinical_data[clinical_data['DNA_ID'].isin(available_dna_ids)]
    print(f"Final clinical data shape after alignment: {clinical_data.shape}")
    
    # Filter methylation data to include only relevant columns
    methylation_features = methylation_data[['Gene_Symbol'] + available_dna_ids]
    print(f"Final methylation data shape: {methylation_features.shape}")
    
    return clinical_data, methylation_features, dmr_data, output_dir

def prepare_clinical_features(clinical_data):
    """Prepare clinical features for modeling with updated covariate selection"""
    print("\n" + "="*60)
    print("PREPARING CLINICAL FEATURES")
    print("="*60)
    
    # Updated clinical features based on recommendations
    # Core covariates for main model:
    clinical_features = [
        'age',                    # Age (continuous)
        'BMI',                    # BMI (continuous, not BMI_cat)
        'parity',                 # Parity (will be dummy coded)
        'B12_supplementation',    # B12 supplementation (binary)
        'Folic_acid_use',         # Folic acid use (binary)  
        'Multivitamin_use',       # Multivitamin use (binary)
        'smoking'                 # Smoking (binary)
    ]
    
    print(f"Selected clinical features (using actual column names): {clinical_features}")
    print("\nExcluded variables:")
    print("  - BMI_cat: Removed (redundant with continuous BMI)")
    print("  - ethnicity: Removed (sample imbalance concern)")
    print("  - batch_*: Removed (already corrected in preprocessing)")
    print("  - v3n_Gender: Removed (all female cohort)")
    
    # Check which features actually exist in the data
    existing_features = [f for f in clinical_features if f in clinical_data.columns]
    missing_features = [f for f in clinical_features if f not in clinical_data.columns]
    
    if missing_features:
        print(f"\nWarning: Missing features in dataset: {missing_features}")
        print("Available clinical columns:")
        clinical_cols = [col for col in clinical_data.columns if col not in ['DNA_ID', 'B12_status']]
        print(f"  {clinical_cols}")
    
    # Use only existing features
    clinical_features = existing_features
    print(f"\nFinal selected features: {clinical_features}")
    
    # Handle missing values
    clinical_subset = clinical_data[clinical_features + ['DNA_ID', 'B12_status']].copy()
    
    print(f"\nMissing values before imputation:")
    missing_counts = clinical_subset.isnull().sum()
    for feature in clinical_features:
        if missing_counts[feature] > 0:
            print(f"  {feature}: {missing_counts[feature]} missing")
        else:
            print(f"  {feature}: no missing values")
    
    # Process each feature with appropriate handling
    for col in clinical_features:
        if col in clinical_subset.columns:
            if col == 'parity':
                # First handle missing values, then dummy coding for parity
                median_val = clinical_subset[col].median()
                clinical_subset[col].fillna(median_val, inplace=True)
                # Dummy coding for parity: 1 = multiparous, 0 = nulliparous
                clinical_subset[col] = (clinical_subset[col] > 0).astype(int)
                print(f"Parity: imputed missing values, then recoded (1=multiparous, 0=nulliparous)")
                
            elif col in ['B12supplem', 'v1p_FolicAcid', 'v1p_MultivitTab', 'smoking']:
                # Binary variables: handle missing values first, then ensure proper encoding
                # Use mode (most common value) for binary variables
                mode_val = clinical_subset[col].mode()[0] if len(clinical_subset[col].mode()) > 0 else 0
                missing_before = clinical_subset[col].isnull().sum()
                clinical_subset[col].fillna(mode_val, inplace=True)
                
                # Now safely convert to int (no more NaN values)
                clinical_subset[col] = clinical_subset[col].astype(int)
                
                if missing_before > 0:
                    print(f"{col}: imputed {missing_before} missing values with mode ({mode_val}), then binary encoding confirmed")
                else:
                    print(f"{col}: binary encoding confirmed (no missing values)")
                
            elif clinical_subset[col].dtype in ['float64', 'int64']:
                # Continuous variables (age, BMI): use median imputation
                median_val = clinical_subset[col].median()
                missing_before = clinical_subset[col].isnull().sum()
                clinical_subset[col].fillna(median_val, inplace=True)
                if missing_before > 0:
                    print(f"Imputed {col} ({missing_before} values) with median: {median_val:.2f}")
            else:
                # Handle any other categorical variables
                mode_val = clinical_subset[col].mode()[0] if len(clinical_subset[col].mode()) > 0 else 0
                missing_before = clinical_subset[col].isnull().sum()
                clinical_subset[col].fillna(mode_val, inplace=True)
                if missing_before > 0:
                    print(f"Imputed {col} ({missing_before} values) with mode: {mode_val}")
    
    print(f"\nFinal feature summary:")
    for feature in clinical_features:
        if feature in clinical_subset.columns:
            unique_vals = clinical_subset[feature].nunique()
            val_range = f"[{clinical_subset[feature].min():.2f}, {clinical_subset[feature].max():.2f}]"
            print(f"  {feature}: {unique_vals} unique values, range {val_range}")
    
    # Verify no missing values remain
    remaining_missing = clinical_subset[clinical_features].isnull().sum().sum()
    print(f"\nTotal remaining missing values: {remaining_missing}")
    
    return clinical_subset, clinical_features

def data_driven_feature_selection(methylation_data, clinical_data, train_indices, n_features=20):
    """Perform data-driven feature selection using statistical tests"""
    print("\n" + "="*50)
    print("DATA-DRIVEN FEATURE SELECTION")
    print("="*50)
    
    # Get training data
    train_clinical = clinical_data.iloc[train_indices]
    train_dna_ids = train_clinical['DNA_ID'].tolist()
    
    # Transpose methylation data for easier processing
    methylation_df = methylation_data.set_index('Gene_Symbol').T
    methylation_df = methylation_df.loc[train_dna_ids]
    
    # Align with B12 status
    train_clinical_aligned = train_clinical.set_index('DNA_ID').loc[train_dna_ids]
    
    # Split by B12 status
    lb_indices = train_clinical_aligned['B12_status'] == 'LB'
    nb_indices = train_clinical_aligned['B12_status'] == 'NB'
    
    lb_data = methylation_df[lb_indices]
    nb_data = methylation_df[nb_indices]
    
    print(f"LB samples: {lb_data.shape[0]}, NB samples: {nb_data.shape[0]}")
    
    # Calculate statistics for each CpG site
    feature_stats = []
    
    for gene in methylation_df.columns:
        lb_values = lb_data[gene].dropna()
        nb_values = nb_data[gene].dropna()
        
        if len(lb_values) > 0 and len(nb_values) > 0:
            # Mann-Whitney U test
            statistic, p_value = mannwhitneyu(lb_values, nb_values, alternative='two-sided')
            
            # Cohen's d effect size
            cohens_d = abs(calculate_cohens_d(lb_values, nb_values))
            
            # Composite score
            if p_value > 0:
                composite_score = cohens_d * (-np.log10(p_value))
            else:
                composite_score = cohens_d * 10  # Handle p_value = 0
            
            feature_stats.append({
                'gene': gene,
                'p_value': p_value,
                'cohens_d': cohens_d,
                'composite_score': composite_score,
                'lb_mean': np.mean(lb_values),
                'nb_mean': np.mean(nb_values)
            })
    
    # Convert to DataFrame and sort by composite score
    stats_df = pd.DataFrame(feature_stats)
    stats_df = stats_df.sort_values('composite_score', ascending=False)
    
    # Filter by significance and effect size
    significant_features = stats_df[
        (stats_df['p_value'] < 0.05) & 
        (stats_df['cohens_d'] > 0.3)
    ]
    
    print(f"Features with p < 0.05 and Cohen's d > 0.3: {len(significant_features)}")
    
    # Select top features
    selected_features = significant_features.head(n_features)['gene'].tolist()
    
    print(f"Selected top {len(selected_features)} features:")
    for i, feature in enumerate(selected_features[:10], 1):  # Print first 10
        stats = significant_features[significant_features['gene'] == feature].iloc[0]
        print(f"  {i:2d}. {feature}: p={stats['p_value']:.3e}, d={stats['cohens_d']:.3f}, score={stats['composite_score']:.3f}")
    
    if len(selected_features) > 10:
        print(f"  ... and {len(selected_features) - 10} more features")
    
    return selected_features, stats_df

def mechanism_driven_feature_selection():
    """Define mechanism-driven features based on known CpG-mRNA regulatory pairs"""
    print("\n" + "="*50)
    print("MECHANISM-DRIVEN FEATURE SELECTION")
    print("="*50)
    
    # Pre-defined CpG-mRNA negative regulatory pairs with known biological significance
    # Based on the 10 most significant CpG-mRNA regulatory pairs from your analysis
    mechanism_features = [
        'IQCG',         # β=-0.889, FDR=6.11×10⁻¹¹, R²=0.785
        'ZNF154',       # β=-0.923, FDR=3.97×10⁻⁸, R²=0.727
        'TSTD1',        # β=-0.771, FDR=3.97×10⁻⁸, R²=0.702
        'GTSF1',        # β=-0.804, FDR=3.97×10⁻⁸, R²=0.735
        'SEPTIN7P11',   # β=-0.852, FDR=7.70×10⁻⁸, R²=0.704
        'NUDT10',       # β=-0.920, FDR=2.82×10⁻⁷, R²=0.717
        'NAA11',        # β=-0.722, FDR=6.75×10⁻⁷, R²=0.684
        'SYCP3',        # β=-0.692, FDR=2.35×10⁻⁶, R²=0.751
        'PSMA8',        # β=-0.716, FDR=2.26×10⁻⁵, R²=0.591
        'LINC00667'     # β=-0.797, FDR=2.26×10⁻⁵, R²=0.550
    ]
    
    # Display regulatory information
    regulatory_info = [
        ('IQCG', -0.889, 6.11e-11, 0.785),
        ('ZNF154', -0.923, 3.97e-8, 0.727),
        ('TSTD1', -0.771, 3.97e-8, 0.702),
        ('GTSF1', -0.804, 3.97e-8, 0.735),
        ('SEPTIN7P11', -0.852, 7.70e-8, 0.704),
        ('NUDT10', -0.920, 2.82e-7, 0.717),
        ('NAA11', -0.722, 6.75e-7, 0.684),
        ('SYCP3', -0.692, 2.35e-6, 0.751),
        ('PSMA8', -0.716, 2.26e-5, 0.591),
        ('LINC00667', -0.797, 2.26e-5, 0.550)
    ]
    
    print(f"Pre-defined CpG-mRNA negative regulatory pairs ({len(mechanism_features)}):")
    print("-" * 70)
    print(f"{'Rank':<4} {'Gene':<12} {'Beta':<8} {'FDR p-value':<12} {'R²':<8}")
    print("-" * 70)
    
    for i, (gene, beta, fdr_p, r2) in enumerate(regulatory_info, 1):
        print(f"{i:2d}.  {gene:<12} {beta:<8.3f} {fdr_p:<12.2e} {r2:<8.3f}")
    
    print("\nAll regulatory pairs show large effect sizes and significant negative correlation.")
    
    return mechanism_features

def prepare_feature_matrix(clinical_data, methylation_data, selected_methylation_features, clinical_features):
    """Prepare the final feature matrix"""
    print("\n" + "="*50)
    print("PREPARING FEATURE MATRIX")
    print("="*50)
    
    # Get DNA IDs
    dna_ids = clinical_data['DNA_ID'].tolist()
    
    # Prepare methylation features
    methylation_df = methylation_data.set_index('Gene_Symbol').T
    methylation_subset = methylation_df.loc[dna_ids, selected_methylation_features]
    
    # Handle missing values in methylation data
    methylation_subset = methylation_subset.fillna(methylation_subset.median())
    
    # Prepare clinical features
    clinical_subset = clinical_data.set_index('DNA_ID')[clinical_features].loc[dna_ids]
    
    # Combine features
    feature_matrix = pd.concat([clinical_subset, methylation_subset], axis=1)
    
    # Prepare target
    target = clinical_data.set_index('DNA_ID')['B12_status'].loc[dna_ids]
    
    # Encode target (LB=1, NB=0)
    le = LabelEncoder()
    target_encoded = le.fit_transform(target)
    
    print(f"Final feature matrix shape: {feature_matrix.shape}")
    print(f"Clinical features: {len(clinical_features)}")
    print(f"Methylation features: {len(selected_methylation_features)}")
    print(f"Target distribution: LB={sum(target_encoded)}, NB={len(target_encoded)-sum(target_encoded)}")
    
    return feature_matrix, target_encoded, dna_ids

def train_random_forest(X_train, y_train):
    """Train Random Forest with streamlined hyperparameter optimization"""
    print("\n" + "="*40)
    print("RANDOM FOREST TRAINING")
    print("="*40)
    
    # Reduced hyperparameter grid for faster computation
    param_grid = {
        'n_estimators': [100, 200, 300],           # Reduced from 4 to 3 options
        'max_depth': [3, 5, 10, None],             # Reduced from 5 to 4 options
        'min_samples_split': [2, 5],               # Reduced from 3 to 2 options
        'min_samples_leaf': [1, 2],                # Reduced from 3 to 2 options
        'max_features': ['sqrt', 'log2']           # Reduced from 3 to 2 options
    }
    
    total_combinations = 1
    for param, values in param_grid.items():
        total_combinations *= len(values)
    
    print(f"Optimized hyperparameter search space ({total_combinations} combinations):")
    for param, values in param_grid.items():
        print(f"  {param}: {values}")
    
    # Initialize Random Forest with balanced class weights
    rf = RandomForestClassifier(
        random_state=42,
        class_weight='balanced',
        bootstrap=True,
        n_jobs=1  # Reduced parallelism to avoid resource conflicts
    )
    
    # Grid search with cross-validation
    print(f"\nPerforming grid search with 3-fold CV ({total_combinations * 3} total fits)...")
    cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
    
    # Use RandomizedSearchCV for even faster results if needed
    from sklearn.model_selection import RandomizedSearchCV
    
    # For small datasets, use full grid search but with timeout protection
    grid_search = GridSearchCV(
        rf, param_grid, cv=cv, scoring='roc_auc',
        n_jobs=1, verbose=0  # Reduced verbosity
    )
    
    print("Training Random Forest model...")
    grid_search.fit(X_train, y_train)
    
    print(f"Best parameters: {grid_search.best_params_}")
    print(f"Best cross-validation AUC: {grid_search.best_score_:.4f}")
    
    return grid_search.best_estimator_

def evaluate_model(model, X_test, y_test):
    """Evaluate model performance using 5 core metrics"""
    print("\n" + "="*40)
    print("MODEL EVALUATION")
    print("="*40)
    
    # Predictions
    y_pred = model.predict(X_test)
    y_pred_proba = model.predict_proba(X_test)[:, 1]
    
    # Calculate 5 core metrics
    accuracy = accuracy_score(y_test, y_pred)
    precision = precision_score(y_test, y_pred, zero_division=0)
    recall = recall_score(y_test, y_pred, zero_division=0)
    f1 = f1_score(y_test, y_pred, zero_division=0)
    
    # AUC as primary metric (handle case where only one class is present)
    try:
        auc = roc_auc_score(y_test, y_pred_proba)
    except ValueError:
        auc = np.nan
        print("Warning: AUC calculation failed (only one class in test set)")
    
    metrics = {
        'auc': auc,          # Primary metric
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }
    
    print("Model performance was evaluated using AUC as the primary metric,")
    print("complemented by accuracy, precision, recall, and F1 score:")
    print("-" * 50)
    for metric, value in metrics.items():
        if not np.isnan(value):
            print(f"  {metric.upper()}: {value:.4f}")
        else:
            print(f"  {metric.upper()}: N/A")
    
    # Classification report
    print(f"\nDetailed Classification Report:")
    print(classification_report(y_test, y_pred, target_names=['NB', 'LB']))
    
    # Confusion matrix
    cm = confusion_matrix(y_test, y_pred)
    print(f"\nConfusion Matrix:")
    print(f"     Pred NB  Pred LB")
    print(f"True NB   {cm[0,0]:2d}      {cm[0,1]:2d}")
    print(f"True LB   {cm[1,0]:2d}      {cm[1,1]:2d}")
    
    return metrics, y_pred, y_pred_proba

def analyze_feature_importance(model, feature_names, X_test, y_test, top_n=10):
    """Analyze and display top 10 most important features"""
    print("\n" + "="*40)
    print("FEATURE IMPORTANCE ANALYSIS")
    print("="*40)
    
    # Gini importance
    gini_importance = model.feature_importances_
    
    # Permutation importance
    print("Calculating permutation importance...")
    perm_importance = permutation_importance(
        model, X_test, y_test, n_repeats=10, random_state=42, n_jobs=-1
    )
    
    # Create importance DataFrame
    importance_df = pd.DataFrame({
        'feature': feature_names,
        'gini_importance': gini_importance,
        'perm_importance_mean': perm_importance.importances_mean,
        'perm_importance_std': perm_importance.importances_std
    })
    
    # Sort by Gini importance
    importance_df = importance_df.sort_values('gini_importance', ascending=False)
    
    print(f"\nTop {top_n} most important features:")
    print("-" * 70)
    print(f"{'Rank':<4} {'Feature':<20} {'Gini Importance':<15} {'Perm Mean':<10} {'Perm Std':<8}")
    print("-" * 70)
    
    for i in range(min(top_n, len(importance_df))):
        row = importance_df.iloc[i]
        print(f"{i+1:<4} {row['feature']:<20} {row['gini_importance']:<15.4f} "
              f"{row['perm_importance_mean']:<10.4f} {row['perm_importance_std']:<8.4f}")
    
    return importance_df

def run_multiple_splits(clinical_data, methylation_data, feature_selection_method='data_driven', n_splits=20, fast_mode=True):
    """Run multiple random train-test splits"""
    print("\n" + "="*60)
    print(f"RUNNING {n_splits} RANDOM SPLITS - {feature_selection_method.upper()}")
    if fast_mode:
        print("FAST MODE: Optimized for speed")
    print("="*60)
    
    # Prepare clinical features
    clinical_subset, clinical_features = prepare_clinical_features(clinical_data)
    
    # Storage for results
    results = {
        'auc': [], 'accuracy': [], 'precision': [], 'recall': [], 'f1': [],
        'selected_features': [], 'feature_importance': []
    }
    
    for split_idx in range(n_splits):
        print(f"\n{'='*20} SPLIT {split_idx + 1}/{n_splits} {'='*20}")
        
        # Stratified train-test split
        train_indices, test_indices = train_test_split(
            range(len(clinical_subset)), 
            test_size=0.2, 
            stratify=clinical_subset['B12_status'],
            random_state=split_idx
        )
        
        print(f"Train samples: {len(train_indices)}, Test samples: {len(test_indices)}")
        
        # Feature selection
        if feature_selection_method == 'data_driven':
            n_features = 15 if fast_mode else 20  # Reduce features in fast mode
            selected_features, _ = data_driven_feature_selection(
                methylation_data, clinical_subset, train_indices, n_features=n_features
            )
        else:  # mechanism_driven
            selected_features = mechanism_driven_feature_selection()
        
        # Filter features that exist in methylation data
        available_features = [f for f in selected_features if f in methylation_data['Gene_Symbol'].values]
        if len(available_features) != len(selected_features):
            missing_count = len(selected_features) - len(available_features)
            print(f"Note: {missing_count} features not found in methylation data")
        
        selected_features = available_features
        results['selected_features'].append(selected_features)
        
        # Prepare feature matrix
        X, y, sample_ids = prepare_feature_matrix(
            clinical_subset, methylation_data, selected_features, clinical_features
        )
        
        # Split data
        X_train, X_test = X.iloc[train_indices], X.iloc[test_indices]
        y_train, y_test = y[train_indices], y[test_indices]
        
        print(f"Training set: {X_train.shape}, Test set: {X_test.shape}")
        
        # Train model
        if fast_mode:
            # Use simplified RF model in fast mode
            model = RandomForestClassifier(
                n_estimators=200, max_depth=5, min_samples_split=2,
                min_samples_leaf=1, max_features='sqrt', 
                random_state=42, class_weight='balanced', n_jobs=1
            )
            print("Using fast mode: fixed parameters (n_estimators=200, max_depth=5)")
            model.fit(X_train, y_train)
        else:
            model = train_random_forest(X_train, y_train)
        
        # Evaluate model
        metrics, y_pred, y_pred_proba = evaluate_model(model, X_test, y_test)
        
        # Store results (AUC as primary metric first)
        for metric in ['auc', 'accuracy', 'precision', 'recall', 'f1']:
            results[metric].append(metrics[metric])
        
        # Feature importance
        importance_df = analyze_feature_importance(
            model, X.columns.tolist(), X_test, y_test, top_n=10
        )
        results['feature_importance'].append(importance_df)
        
        # Progress indicator
        auc_str = f"{metrics['auc']:.4f}" if not np.isnan(metrics['auc']) else 'N/A'
        print(f"Split {split_idx + 1} completed - Primary metric (AUC): {auc_str}")
    
    return results

def analyze_stability_and_performance(results, method_name, output_dir):
    """Analyze stability and performance across multiple splits with top 10 features output"""
    print("\n" + "="*60)
    print(f"STABILITY AND PERFORMANCE ANALYSIS - {method_name}")
    print("="*60)
    
    # Define original regulatory effects for mechanism-driven analysis
    original_regulatory_effects = {
        'IQCG': -0.889,
        'ZNF154': -0.923,
        'TSTD1': -0.771,
        'GTSF1': -0.804,
        'SEPTIN7P11': -0.852,
        'NUDT10': -0.920,
        'NAA11': -0.722,
        'SYCP3': -0.692,
        'PSMA8': -0.716,
        'LINC00667': -0.797
    }
    
    # Performance statistics (AUC as primary metric first)
    print("\nModel performance was evaluated using AUC as the primary metric,")
    print("complemented by accuracy, precision, recall, and F1 score.")
    print("\nPerformance Statistics Across 20 Splits:")
    print("-" * 60)
    print(f"{'Metric':<12} {'Mean':<8} {'Std':<8} {'Min':<8} {'Max':<8}")
    print("-" * 60)
    
    performance_stats = {}
    for metric in ['auc', 'accuracy', 'precision', 'recall', 'f1']:
        values = [v for v in results[metric] if not np.isnan(v)]
        if values:
            mean_val = np.mean(values)
            std_val = np.std(values)
            min_val = np.min(values)
            max_val = np.max(values)
            
            performance_stats[metric] = {
                'mean': mean_val, 'std': std_val, 'min': min_val, 'max': max_val
            }
            
            print(f"{metric.upper():<12} {mean_val:<8.4f} {std_val:<8.4f} "
                  f"{min_val:<8.4f} {max_val:<8.4f}")
    
    # Feature selection stability
    print("\nFeature Selection Stability:")
    print("-" * 40)
    
    all_features = []
    for features in results['selected_features']:
        all_features.extend(features)
    
    feature_counts = pd.Series(all_features).value_counts()
    selection_frequency = (feature_counts / len(results['selected_features']) * 100).round(2)
    
    # High stability features (≥70%)
    high_stability = selection_frequency[selection_frequency >= 70]
    medium_stability = selection_frequency[(selection_frequency >= 50) & (selection_frequency < 70)]
    
    print(f"High Stability Features (≥70% selection frequency): {len(high_stability)}")
    for feature, freq in high_stability.head(10).items():
        print(f"  {feature}: {freq}%")
    
    if len(high_stability) > 10:
        print(f"  ... and {len(high_stability) - 10} more")
    
    print(f"\nMedium Stability Features (50-69% selection frequency): {len(medium_stability)}")
    for feature, freq in medium_stability.head(5).items():
        print(f"  {feature}: {freq}%")
    
    # Average feature importance analysis with TOP 10 OUTPUT
    print("\nAverage Feature Importance Across All Splits:")
    print("-" * 50)
    
    # Collect all importance scores
    all_importance = defaultdict(list)
    for importance_df in results['feature_importance']:
        for _, row in importance_df.iterrows():
            all_importance[row['feature']].append(row['gini_importance'])
    
    # Calculate average importance
    avg_importance = {}
    for feature, importances in all_importance.items():
        avg_importance[feature] = {
            'mean': np.mean(importances),
            'std': np.std(importances),
            'frequency': len(importances)
        }
    
    # Sort by average importance
    sorted_importance = sorted(avg_importance.items(), key=lambda x: x[1]['mean'], reverse=True)
    
    # TOP 10 MOST IMPORTANT FEATURES OUTPUT
    print(f"\nTOP 10 MOST IMPORTANT FEATURES - {method_name}:")
    print("=" * 60)
    print(f"{'Rank':<4} {'Feature':<20} {'Avg Importance':<15} {'Std':<8} {'Frequency':<10}")
    print("-" * 60)
    for i, (feature, stats) in enumerate(sorted_importance[:10], 1):
        print(f"{i:<4} {feature:<20} {stats['mean']:<15.4f} {stats['std']:<8.4f} "
              f"{stats['frequency']:<10}")
    
    print(f"\nComplete Feature Importance Ranking:")
    print(f"{'Feature':<20} {'Avg Importance':<15} {'Std':<8} {'Frequency':<10}")
    print("-" * 50)
    for feature, stats in sorted_importance[:15]:
        print(f"{feature:<20} {stats['mean']:<15.4f} {stats['std']:<8.4f} "
              f"{stats['frequency']:<10}")
    
    # Special analysis for mechanism-driven method
    if method_name.lower() == 'mechanismdriven':
        print("\nMechanism-Driven Specific Analysis:")
        print("-" * 50)
        print("Original Regulatory Effect vs Model Importance Consistency:")
        print(f"{'Gene':<12} {'Original Beta':<12} {'RF Importance':<12} {'Rank':<6}")
        print("-" * 50)
        
        mechanism_genes = list(original_regulatory_effects.keys())
        for gene in mechanism_genes:
            if gene in avg_importance:
                orig_beta = original_regulatory_effects[gene]
                rf_importance = avg_importance[gene]['mean']
                
                # Find rank in sorted importance
                rank = next((i+1 for i, (f, _) in enumerate(sorted_importance) if f == gene), 'N/A')
                
                print(f"{gene:<12} {orig_beta:<12.3f} {rf_importance:<12.4f} {rank:<6}")
        
        # Calculate consistency metrics
        available_mechanism_genes = [g for g in mechanism_genes if g in avg_importance]
        print(f"\nMechanism Gene Analysis Summary:")
        print(f"  Total mechanism genes: {len(mechanism_genes)}")
        print(f"  Available in model: {len(available_mechanism_genes)}")
        
        if available_mechanism_genes:
            # Check if high-importance genes match strong regulatory effects
            top_rf_genes = [f for f, _ in sorted_importance[:5] if f in mechanism_genes]
            strong_regulatory_genes = [g for g, beta in original_regulatory_effects.items() 
                                     if abs(beta) > 0.8 and g in available_mechanism_genes]
            
            print(f"  Top 5 RF importance (mechanism genes): {top_rf_genes}")
            print(f"  Strong regulatory effects (|β|>0.8): {strong_regulatory_genes}")
            
            overlap = set(top_rf_genes) & set(strong_regulatory_genes)
            print(f"  Overlap between top RF and strong regulatory: {list(overlap)}")

    
    # Save results with TOP 10 features
    results_summary = {
        'method': method_name,
        'performance_stats': performance_stats,
        'feature_stability': selection_frequency.to_dict(),
        'average_importance': {k: v for k, v in sorted_importance},
        'top_10_features': sorted_importance[:10]
    }
    
    # Save to file
    output_file = os.path.join(output_dir, f"RF_{method_name}_results_summary.txt")
    with open(output_file, 'w') as f:
        f.write(f"Random Forest Results Summary - {method_name}\n")
        f.write("=" * 60 + "\n\n")
        
        f.write("Model performance was evaluated using AUC as the primary metric,\n")
        f.write("complemented by accuracy, precision, recall, and F1 score.\n\n")
        
        f.write("Performance Statistics:\n")
        f.write("-" * 30 + "\n")
        for metric in ['auc', 'accuracy', 'precision', 'recall', 'f1']:
            if metric in performance_stats:
                stats = performance_stats[metric]
                f.write(f"{metric.upper()}: {stats['mean']:.4f} ± {stats['std']:.4f}\n")
        
        f.write(f"\nTOP 10 MOST IMPORTANT FEATURES - {method_name}:\n")
        f.write("=" * 50 + "\n")
        for i, (feature, stats) in enumerate(sorted_importance[:10], 1):
            f.write(f"{i:2d}. {feature}: {stats['mean']:.4f} ± {stats['std']:.4f}\n")
        
        f.write(f"\nHigh Stability Features (≥70%):\n")
        f.write("-" * 30 + "\n")
        for feature, freq in high_stability.items():
            f.write(f"{feature}: {freq}%\n")
    
    print(f"\nResults saved to: {output_file}")
    
    return results_summary

def main():
    """Main execution function"""
    import time
    start_time = time.time()
    
    print("RANDOM FOREST MODEL FOR B12 DEFICIENCY PREDICTION")
    print("=" * 60)
    print("Running in optimized mode for faster execution")
    
    # Load data
    clinical_data, methylation_data, dmr_data, output_dir = load_and_prepare_data()
    
    # Run data-driven approach
    print("\n" + "="*20)
    print("RUNNING DATA-DRIVEN APPROACH")
    print("="*20)
    
    dd_start = time.time()
    data_driven_results = run_multiple_splits(
        clinical_data, methylation_data, 
        feature_selection_method='data_driven', 
        n_splits=20, fast_mode=True
    )
    dd_time = time.time() - dd_start
    print(f"\nData-driven analysis completed in {dd_time/60:.1f} minutes")
    
    # Analyze data-driven results
    data_driven_summary = analyze_stability_and_performance(
        data_driven_results, 'DataDriven', output_dir
    )
    
    # Run mechanism-driven approach
    print("\n" + "="*20)
    print("RUNNING MECHANISM-DRIVEN APPROACH")
    print("="*20)
    
    md_start = time.time()
    mechanism_driven_results = run_multiple_splits(
        clinical_data, methylation_data, 
        feature_selection_method='mechanism_driven', 
        n_splits=20, fast_mode=True
    )
    md_time = time.time() - md_start
    print(f"\nMechanism-driven analysis completed in {md_time/60:.1f} minutes")
    
    # Analyze mechanism-driven results
    mechanism_driven_summary = analyze_stability_and_performance(
        mechanism_driven_results, 'MechanismDriven', output_dir
    )
    
    # Final comparison with TOP 10 features for both methods
    print("\n" + "="*20)
    print("FINAL COMPARISON")
    print("="*20)
    
    print("\nModel performance was evaluated using AUC as the primary metric,")
    print("complemented by accuracy, precision, recall, and F1 score.")
    print("\nData-Driven vs Mechanism-Driven Performance:")
    print("-" * 60)
    print(f"{'Metric':<12} {'Data-Driven':<15} {'Mechanism-Driven':<15}")
    print("-" * 60)
    
    for metric in ['auc', 'accuracy', 'precision', 'recall', 'f1']:
        if metric in data_driven_summary['performance_stats'] and metric in mechanism_driven_summary['performance_stats']:
            dd_mean = data_driven_summary['performance_stats'][metric]['mean']
            md_mean = mechanism_driven_summary['performance_stats'][metric]['mean']
            print(f"{metric.upper():<12} {dd_mean:<15.4f} {md_mean:<15.4f}")
    
    # Display TOP 10 features comparison
    print(f"\nTOP 10 MOST IMPORTANT FEATURES COMPARISON:")
    print("=" * 80)
    print(f"{'Rank':<4} {'Data-Driven Method':<25} {'Importance':<12} {'Mechanism-Driven Method':<25} {'Importance':<12}")
    print("-" * 80)
    
    dd_top10 = data_driven_summary['top_10_features']
    md_top10 = mechanism_driven_summary['top_10_features']
    
    for i in range(10):
        dd_feature, dd_stats = dd_top10[i] if i < len(dd_top10) else ("", {"mean": 0})
        md_feature, md_stats = md_top10[i] if i < len(md_top10) else ("", {"mean": 0})
        
        print(f"{i+1:<4} {dd_feature:<25} {dd_stats['mean']:<12.4f} "
              f"{md_feature:<25} {md_stats['mean']:<12.4f}")
    
    total_time = time.time() - start_time
    print(f"\nTotal execution time: {total_time/60:.1f} minutes")
    print(f"All results saved to: {output_dir}")
    print("\nAnalysis complete!")
    
    # Save execution summary
    summary_file = os.path.join(output_dir, "RF_execution_summary.txt")
    with open(summary_file, 'w') as f:
        f.write("Random Forest B12 Prediction Model - Execution Summary\n")
        f.write("=" * 60 + "\n\n")
        f.write("Model performance was evaluated using AUC as the primary metric,\n")
        f.write("complemented by accuracy, precision, recall, and F1 score.\n\n")
        f.write(f"Total execution time: {total_time/60:.1f} minutes\n")
        f.write(f"Data-driven analysis time: {dd_time/60:.1f} minutes\n")
        f.write(f"Mechanism-driven analysis time: {md_time/60:.1f} minutes\n\n")
        
        f.write("Performance Comparison:\n")
        f.write("-" * 30 + "\n")
        for metric in ['auc', 'accuracy', 'precision', 'recall', 'f1']:
            if metric in data_driven_summary['performance_stats'] and metric in mechanism_driven_summary['performance_stats']:
                dd_mean = data_driven_summary['performance_stats'][metric]['mean']
                md_mean = mechanism_driven_summary['performance_stats'][metric]['mean']
                f.write(f"{metric.upper()}: DD={dd_mean:.4f}, MD={md_mean:.4f}\n")
        
        f.write(f"\nTOP 10 FEATURES - DATA-DRIVEN METHOD:\n")
        f.write("-" * 40 + "\n")
        for i, (feature, stats) in enumerate(dd_top10, 1):
            f.write(f"{i:2d}. {feature}: {stats['mean']:.4f}\n")
        
        f.write(f"\nTOP 10 FEATURES - MECHANISM-DRIVEN METHOD:\n")
        f.write("-" * 40 + "\n")
        for i, (feature, stats) in enumerate(md_top10, 1):
            f.write(f"{i:2d}. {feature}: {stats['mean']:.4f}\n")
    
    print(f"Execution summary saved to: {summary_file}")

if __name__ == "__main__":
    main()

RANDOM FOREST MODEL FOR B12 DEFICIENCY PREDICTION
Running in optimized mode for faster execution
LOADING AND PREPARING DATA
Output directory created: /Users/heweilin/Desktop/P056_Code_3/Data

1. Loading clinical data...
Clinical data shape: (50, 21)
B12 status distribution:
B12_status
NB    25
LB    25
Name: count, dtype: int64

2. Loading methylation data...
Methylation data shape: (17584, 51)

3. Loading DMR data...
DMR data shape: (493648, 25)
Number of DNA samples: 50

4. Aligning methylation data with clinical samples...
Available DNA IDs in methylation data: 50
Final clinical data shape after alignment: (50, 21)
Final methylation data shape: (17584, 51)

RUNNING DATA-DRIVEN APPROACH

RUNNING 20 RANDOM SPLITS - DATA_DRIVEN
FAST MODE: Optimized for speed

PREPARING CLINICAL FEATURES
Selected clinical features (using actual column names): ['age', 'BMI', 'parity', 'B12_supplementation', 'Folic_acid_use', 'Multivitamin_use', 'smoking']

Excluded variables:
  - BMI_cat: Removed (redund

In [7]:
import pandas as pd
import numpy as np
import os
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, classification_report, confusion_matrix
from sklearn.inspection import permutation_importance
from scipy.stats import mannwhitneyu, ttest_rel
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)

def calculate_cohens_d(group1, group2):
    """Calculate Cohen's d effect size"""
    n1, n2 = len(group1), len(group2)
    pooled_std = np.sqrt(((n1 - 1) * np.var(group1, ddof=1) + (n2 - 1) * np.var(group2, ddof=1)) / (n1 + n2 - 2))
    return (np.mean(group1) - np.mean(group2)) / pooled_std

def load_enhanced_data():
    """Load all datasets including miRNA data for enhanced modeling"""
    print("="*60)
    print("LOADING ENHANCED DATASETS WITH miRNA")
    print("="*60)
    
    # File paths
    clinical_path = "/Users/heweilin/Desktop/P056/7Clinical_data50.csv"
    methylation_path = "/Users/heweilin/Desktop/P056_Code_2/Processed_Data/1_PD_PromoterRegion_CpGs.csv"
    mirna_path = "/Users/heweilin/Desktop/P056_Code_2/Processed_Data/3_PD_miRNA_normalized.csv"
    demirs_path = "/Users/heweilin/Desktop/P056/2miRNA_DEmirs.csv"
    output_dir = "/Users/heweilin/Desktop/P056_Code_3/Data"
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    print(f"Output directory created: {output_dir}")
    
    # Load clinical data
    print("\n1. Loading clinical data...")
    clinical_data = pd.read_csv(clinical_path)
    print(f"Clinical data shape: {clinical_data.shape}")
    print(f"B12 status distribution:")
    print(clinical_data['B12_status'].value_counts())
    
    # Load methylation data
    print("\n2. Loading methylation data...")
    methylation_data = pd.read_csv(methylation_path)
    print(f"Methylation data shape: {methylation_data.shape}")
    
    # Load miRNA expression data
    print("\n3. Loading miRNA expression data...")
    try:
        mirna_data = pd.read_csv(mirna_path)
        print(f"miRNA expression data shape: {mirna_data.shape}")
        print(f"miRNA columns: {list(mirna_data.columns[:5])}...")
    except FileNotFoundError:
        print("Warning: miRNA expression data not found")
        mirna_data = None
    
    # Load DEmiRs results
    print("\n4. Loading differential expression miRNA results...")
    try:
        demirs_data = pd.read_csv(demirs_path)
        print(f"DEmiRs data shape: {demirs_data.shape}")
        
        # Show top significant DEmiRs
        print("\nTop 10 significant DEmiRs:")
        print("-" * 80)
        print(f"{'miRNA':<20} {'log2FC':<10} {'pvalue':<12} {'padj':<12}")
        print("-" * 80)
        for i, row in demirs_data.head(10).iterrows():
            # Handle different column naming conventions
            if 'Unnamed: 0' in demirs_data.columns:
                mirna_name = row['Unnamed: 0']
            elif 'miRNA_ID' in demirs_data.columns:
                mirna_name = row['miRNA_ID']
            else:
                mirna_name = row.iloc[0]
                
            log2fc = row.get('log2FoldChange', row.get('logFC', 'N/A'))
            pval = row.get('pvalue', row.get('P.Value', 'N/A'))
            padj = row.get('padj', row.get('adj.P.Val', 'N/A'))
            
            print(f"{str(mirna_name):<20} {str(log2fc):<10} {str(pval):<12} {str(padj):<12}")
            
    except FileNotFoundError:
        print("Warning: DEmiRs file not found")
        demirs_data = None
    
    # Sample alignment - handle both DNA and miRNA IDs
    print("\n5. Aligning datasets...")
    if 'sRNA_ID' in clinical_data.columns and mirna_data is not None:
        # Full alignment with both DNA and miRNA data
        dna_ids = clinical_data['DNA_ID'].tolist()
        mirna_ids = clinical_data['sRNA_ID'].tolist()
        
        available_dna_ids = [col for col in methylation_data.columns if col in dna_ids]
        available_mirna_ids = [col for col in mirna_data.columns if col in mirna_ids]
        
        # Find samples with both DNA and miRNA data
        clinical_with_dna = clinical_data[clinical_data['DNA_ID'].isin(available_dna_ids)]
        clinical_aligned = clinical_with_dna[clinical_with_dna['sRNA_ID'].isin(available_mirna_ids)]
        
        aligned_dna_ids = clinical_aligned['DNA_ID'].tolist()
        aligned_mirna_ids = clinical_aligned['sRNA_ID'].tolist()
        
        methylation_aligned = methylation_data[['Gene_Symbol'] + aligned_dna_ids]
        mirna_aligned = mirna_data[['miRNA_ID'] + aligned_mirna_ids]
        
        print(f"Final aligned samples (both DNA & miRNA): {len(clinical_aligned)}")
        print(f"Aligned methylation shape: {methylation_aligned.shape}")
        print(f"Aligned miRNA shape: {mirna_aligned.shape}")
    else:
        # DNA-only alignment
        dna_ids = clinical_data['DNA_ID'].tolist()
        available_dna_ids = [col for col in methylation_data.columns if col in dna_ids]
        clinical_aligned = clinical_data[clinical_data['DNA_ID'].isin(available_dna_ids)]
        methylation_aligned = methylation_data[['Gene_Symbol'] + available_dna_ids]
        mirna_aligned = None
        
        print(f"Final aligned samples (DNA only): {len(clinical_aligned)}")
        print(f"Aligned methylation shape: {methylation_aligned.shape}")
        print("No miRNA data available or no sRNA_ID column")
    
    return clinical_aligned, methylation_aligned, mirna_aligned, demirs_data, output_dir

def prepare_enhanced_clinical_features(clinical_data):
    """Prepare enhanced clinical features with proper missing value handling"""
    print("\n" + "="*60)
    print("PREPARING ENHANCED CLINICAL FEATURES")
    print("="*60)
    
    # Updated clinical features based on actual column names
    clinical_features = [
        'age', 'BMI', 'parity', 
        'B12supplem', 'v1p_MultivitTab', 'v1p_FolicAcid',
        'smoking'
    ]
    
    # Check which features actually exist
    existing_features = [f for f in clinical_features if f in clinical_data.columns]
    missing_features = [f for f in clinical_features if f not in clinical_data.columns]
    
    if missing_features:
        print(f"Warning: Missing clinical features: {missing_features}")
        print("Available clinical columns:")
        clinical_cols = [col for col in clinical_data.columns 
                        if col not in ['DNA_ID', 'sRNA_ID', 'B12_status']]
        print(f"  {clinical_cols}")
    
    clinical_features = existing_features
    print(f"Using clinical features: {clinical_features}")
    
    # Create subset with safety checks
    required_cols = clinical_features + ['DNA_ID', 'B12_status']
    if 'sRNA_ID' in clinical_data.columns:
        required_cols.append('sRNA_ID')
    
    clinical_subset = clinical_data[required_cols].copy()
    
    print(f"\nMissing values before imputation:")
    missing_counts = clinical_subset.isnull().sum()
    for feature in clinical_features:
        if missing_counts[feature] > 0:
            print(f"  {feature}: {missing_counts[feature]} missing")
    
    # Enhanced missing value handling with type checking
    for col in clinical_features:
        if col in clinical_subset.columns:
            missing_count = clinical_subset[col].isnull().sum()
            
            if missing_count > 0:
                if col == 'parity':
                    # For parity, use median then convert to binary
                    median_val = clinical_subset[col].median()
                    clinical_subset[col].fillna(median_val, inplace=True)
                    clinical_subset[col] = (clinical_subset[col] > 0).astype(int)
                    print(f"Parity: imputed {missing_count} values with median, then binary coded")
                    
                elif col in ['B12supplem', 'v1p_MultivitTab', 'v1p_FolicAcid', 'smoking']:
                    # Binary variables: use mode
                    mode_val = clinical_subset[col].mode()[0] if len(clinical_subset[col].mode()) > 0 else 0
                    clinical_subset[col].fillna(mode_val, inplace=True)
                    clinical_subset[col] = clinical_subset[col].astype(int)
                    print(f"{col}: imputed {missing_count} values with mode ({mode_val})")
                    
                elif clinical_subset[col].dtype in ['float64', 'int64']:
                    # Continuous variables
                    median_val = clinical_subset[col].median()
                    clinical_subset[col].fillna(median_val, inplace=True)
                    print(f"{col}: imputed {missing_count} values with median ({median_val:.2f})")
                else:
                    # Other categorical
                    mode_val = clinical_subset[col].mode()[0] if len(clinical_subset[col].mode()) > 0 else 0
                    clinical_subset[col].fillna(mode_val, inplace=True)
                    print(f"{col}: imputed {missing_count} values with mode ({mode_val})")
            else:
                # No missing values, but ensure proper type for binary variables
                if col in ['B12supplem', 'v1p_MultivitTab', 'v1p_FolicAcid', 'smoking']:
                    clinical_subset[col] = clinical_subset[col].astype(int)
                elif col == 'parity':
                    clinical_subset[col] = (clinical_subset[col] > 0).astype(int)
    
    # Final verification
    remaining_missing = clinical_subset[clinical_features].isnull().sum().sum()
    print(f"\nFinal verification: {remaining_missing} missing values remaining")
    
    return clinical_subset, clinical_features

def mechanism_driven_base_features():
    """Define mechanism-driven base methylation features"""
    return [
        'IQCG', 'ZNF154', 'TSTD1', 'GTSF1', 'SEPTIN7P11',
        'NUDT10', 'NAA11', 'SYCP3', 'PSMA8', 'LINC00667'
    ]

def select_significant_demirs(demirs_data, fdr_threshold=0.05, log2fc_threshold=0.5, max_features=15):
    """Select significant DEmiRs with flexible criteria"""
    if demirs_data is None:
        print("No DEmiRs data available")
        return []
    
    print(f"\nSelecting significant DEmiRs:")
    print(f"Criteria: FDR < {fdr_threshold}, |log2FC| > {log2fc_threshold}")
    
    # Handle column naming variations
    if 'Unnamed: 0' in demirs_data.columns:
        mirna_col = 'Unnamed: 0'
    elif 'miRNA_ID' in demirs_data.columns:
        mirna_col = 'miRNA_ID'
    else:
        mirna_col = demirs_data.columns[0]
    
    print(f"Using miRNA column: {mirna_col}")
    
    # Try strict criteria first
    if 'padj' in demirs_data.columns and 'log2FoldChange' in demirs_data.columns:
        significant = demirs_data[
            (demirs_data['padj'] < fdr_threshold) & 
            (abs(demirs_data['log2FoldChange']) > log2fc_threshold)
        ]
        print(f"Strict criteria: {len(significant)} miRNAs")
        
        if len(significant) == 0:
            # Relax FDR threshold
            significant = demirs_data[
                (demirs_data['padj'] < 0.1) & 
                (abs(demirs_data['log2FoldChange']) > log2fc_threshold/2)
            ]
            print(f"Relaxed criteria (FDR<0.1, |FC|>{log2fc_threshold/2}): {len(significant)} miRNAs")
            
            if len(significant) == 0:
                # Use p-value instead
                significant = demirs_data[demirs_data['pvalue'] < 0.01].head(max_features)
                print(f"P-value criteria (p<0.01): {len(significant)} miRNAs")
    
    elif 'pvalue' in demirs_data.columns:
        # Use p-value only
        significant = demirs_data[demirs_data['pvalue'] < 0.01].head(max_features)
        print(f"P-value only criteria: {len(significant)} miRNAs")
    else:
        # Take top features by some ranking
        significant = demirs_data.head(max_features)
        print(f"Taking top {max_features} miRNAs (no p-values available)")
    
    if len(significant) == 0:
        print("No significant miRNAs found, returning empty list")
        return []
    
    selected_mirnas = significant[mirna_col].tolist()
    print(f"Final selection: {len(selected_mirnas)} miRNAs")
    
    # Display selected miRNAs
    for i, mirna in enumerate(selected_mirnas[:10], 1):
        print(f"  {i:2d}. {mirna}")
    if len(selected_mirnas) > 10:
        print(f"  ... and {len(selected_mirnas) - 10} more")
    
    return selected_mirnas

def prepare_feature_matrix_base(clinical_data, methylation_data, base_methylation_features, clinical_features):
    """Prepare base feature matrix (clinical + methylation only)"""
    print(f"\nPreparing BASE feature matrix...")
    
    dna_ids = clinical_data['DNA_ID'].tolist()
    
    # Methylation features
    methylation_df = methylation_data.set_index('Gene_Symbol').T
    available_methylation = [f for f in base_methylation_features if f in methylation_df.columns]
    missing_methylation = [f for f in base_methylation_features if f not in methylation_df.columns]
    
    if missing_methylation:
        print(f"Missing methylation features: {missing_methylation}")
    
    print(f"Available methylation features: {len(available_methylation)}")
    
    methylation_subset = methylation_df.loc[dna_ids, available_methylation]
    methylation_subset = methylation_subset.fillna(methylation_subset.median())
    
    # Clinical features
    clinical_subset = clinical_data.set_index('DNA_ID')[clinical_features].loc[dna_ids]
    
    # Combine
    feature_matrix = pd.concat([clinical_subset, methylation_subset], axis=1)
    
    # Target
    target = clinical_data.set_index('DNA_ID')['B12_status'].loc[dna_ids]
    le = LabelEncoder()
    target_encoded = le.fit_transform(target)
    
    print(f"Base matrix shape: {feature_matrix.shape}")
    print(f"Target distribution: LB={sum(target_encoded)}, NB={len(target_encoded)-sum(target_encoded)}")
    
    return feature_matrix, target_encoded, dna_ids

def prepare_enhanced_feature_matrix(clinical_data, methylation_data, mirna_data, 
                                   base_methylation_features, selected_mirnas, clinical_features):
    """Prepare enhanced feature matrix (clinical + methylation + miRNA)"""
    
    # If no miRNA data or features, fall back to base
    if mirna_data is None or len(selected_mirnas) == 0:
        print("Falling back to base feature matrix (no miRNA data)")
        return prepare_feature_matrix_base(clinical_data, methylation_data, base_methylation_features, clinical_features)
    
    print(f"\nPreparing ENHANCED feature matrix...")
    
    dna_ids = clinical_data['DNA_ID'].tolist()
    mirna_ids = clinical_data['sRNA_ID'].tolist()
    
    # Methylation features
    methylation_df = methylation_data.set_index('Gene_Symbol').T
    available_methylation = [f for f in base_methylation_features if f in methylation_df.columns]
    methylation_subset = methylation_df.loc[dna_ids, available_methylation]
    methylation_subset = methylation_subset.fillna(methylation_subset.median())
    
    # miRNA features
    mirna_df = mirna_data.set_index('miRNA_ID').T
    available_mirnas = [f for f in selected_mirnas if f in mirna_df.columns]
    missing_mirnas = [f for f in selected_mirnas if f not in mirna_df.columns]
    
    if missing_mirnas:
        print(f"Missing miRNA features: {len(missing_mirnas)}")
    
    print(f"Available miRNA features: {len(available_mirnas)}")
    
    if len(available_mirnas) == 0:
        print("No miRNA features available, using base features")
        return prepare_feature_matrix_base(clinical_data, methylation_data, base_methylation_features, clinical_features)
    
    mirna_subset = mirna_df.loc[mirna_ids, available_mirnas]
    mirna_subset = mirna_subset.fillna(mirna_subset.median())
    
    # Clinical features
    clinical_subset = clinical_data.set_index('DNA_ID')[clinical_features].loc[dna_ids]
    
    # Combine all features
    feature_matrix = pd.concat([clinical_subset, methylation_subset, mirna_subset], axis=1)
    
    # Target
    target = clinical_data.set_index('DNA_ID')['B12_status'].loc[dna_ids]
    le = LabelEncoder()
    target_encoded = le.fit_transform(target)
    
    print(f"Enhanced matrix shape: {feature_matrix.shape}")
    print(f"Clinical features: {len(clinical_features)}")
    print(f"Methylation features: {len(available_methylation)}")
    print(f"miRNA features: {len(available_mirnas)}")
    print(f"Target distribution: LB={sum(target_encoded)}, NB={len(target_encoded)-sum(target_encoded)}")
    
    return feature_matrix, target_encoded, dna_ids, available_mirnas

def train_enhanced_rf_model(X_train, y_train, model_type="enhanced"):
    """Train Random Forest model with adaptive parameters"""
    n_features = X_train.shape[1]
    n_samples = X_train.shape[0]
    
    print(f"Training {model_type} RF model: {n_samples} samples, {n_features} features")
    
    # Adaptive parameters based on feature count and sample size
    if n_features > 25 or n_samples < 30:
        # More complex model for larger feature sets or smaller samples
        model = RandomForestClassifier(
            n_estimators=300, 
            max_depth=7, 
            min_samples_split=max(2, n_samples // 10),
            min_samples_leaf=2, 
            max_features='sqrt',
            random_state=42, 
            class_weight='balanced', 
            n_jobs=1
        )
        print("Using complex model parameters")
    else:
        # Simpler model for smaller feature sets
        model = RandomForestClassifier(
            n_estimators=200, 
            max_depth=5, 
            min_samples_split=2,
            min_samples_leaf=1, 
            max_features='sqrt',
            random_state=42, 
            class_weight='balanced', 
            n_jobs=1
        )
        print("Using standard model parameters")
    
    model.fit(X_train, y_train)
    return model

def evaluate_enhanced_model(model, X_test, y_test, model_type="enhanced"):
    """Evaluate model performance with 5 core metrics"""
    y_pred = model.predict(X_test)
    y_pred_proba = model.predict_proba(X_test)[:, 1]
    
    # Core 5 metrics (AUC as primary)
    metrics = {
        'accuracy': accuracy_score(y_test, y_pred),
        'precision': precision_score(y_test, y_pred, zero_division=0),
        'recall': recall_score(y_test, y_pred, zero_division=0),
        'f1': f1_score(y_test, y_pred, zero_division=0)
    }
    
    # AUC as primary metric
    try:
        metrics['auc'] = roc_auc_score(y_test, y_pred_proba)
    except ValueError:
        metrics['auc'] = np.nan
        print(f"Warning: AUC calculation failed for {model_type} model")
    
    print(f"{model_type.capitalize()} model performance:")
    for metric in ['auc', 'accuracy', 'precision', 'recall', 'f1']:
        value = metrics[metric]
        if not np.isnan(value):
            print(f"  {metric.upper()}: {value:.4f}")
    
    return metrics, y_pred, y_pred_proba

def analyze_enhanced_feature_importance(model, feature_names, clinical_features, base_methylation_features, mirna_features=None):
    """Analyze feature importance with category breakdown"""
    importances = model.feature_importances_
    
    if mirna_features is None:
        mirna_features = []
    
    importance_data = []
    for i, (feature, importance) in enumerate(zip(feature_names, importances)):
        if feature in clinical_features:
            category = 'Clinical'
        elif feature in base_methylation_features:
            category = 'Methylation'
        elif feature in mirna_features:
            category = 'miRNA'
        else:
            category = 'Other'
        
        importance_data.append({
            'feature': feature,
            'importance': importance,
            'category': category
        })
    
    importance_df = pd.DataFrame(importance_data).sort_values('importance', ascending=False)
    
    # Category summary
    category_summary = importance_df.groupby('category')['importance'].agg(['sum', 'mean', 'count'])
    print(f"\nFeature importance by category:")
    print("-" * 50)
    print(f"{'Category':<12} {'Total':<8} {'Mean':<8} {'Count':<6}")
    print("-" * 50)
    for category, stats in category_summary.iterrows():
        print(f"{category:<12} {stats['sum']:<8.4f} {stats['mean']:<8.4f} {stats['count']:<6}")
    
    return importance_df

def run_enhanced_modeling_comparison(clinical_data, methylation_data, mirna_data, demirs_data, n_splits=20):
    """Run comprehensive base vs enhanced model comparison"""
    print("\n" + "="*60)
    print(f"ENHANCED MODELING COMPARISON - {n_splits} SPLITS")
    print("="*60)
    
    # Prepare features
    clinical_subset, clinical_features = prepare_enhanced_clinical_features(clinical_data)
    base_methylation_features = mechanism_driven_base_features()
    selected_mirnas = select_significant_demirs(demirs_data)
    
    # Results storage
    base_results = {
        'auc': [], 'accuracy': [], 'precision': [], 'recall': [], 'f1': [], 
        'feature_importance': []
    }
    enhanced_results = {
        'auc': [], 'accuracy': [], 'precision': [], 'recall': [], 'f1': [], 
        'feature_importance': []
    }
    
    for split_idx in range(n_splits):
        print(f"\n{'='*20} SPLIT {split_idx + 1}/{n_splits} {'='*20}")
        
        # Stratified split
        train_indices, test_indices = train_test_split(
            range(len(clinical_subset)), 
            test_size=0.2,
            stratify=clinical_subset['B12_status'], 
            random_state=split_idx
        )
        
        print(f"Train samples: {len(train_indices)}, Test samples: {len(test_indices)}")
        
        # Prepare feature matrices
        print("Preparing base feature matrix...")
        base_X, y, _ = prepare_feature_matrix_base(
            clinical_subset, methylation_data, base_methylation_features, clinical_features
        )
        
        print("Preparing enhanced feature matrix...")
        enhanced_result = prepare_enhanced_feature_matrix(
            clinical_subset, methylation_data, mirna_data,
            base_methylation_features, selected_mirnas, clinical_features
        )
        
        if len(enhanced_result) == 4:
            enhanced_X, _, _, available_mirnas = enhanced_result
        else:
            enhanced_X, _, _ = enhanced_result
            available_mirnas = []
        
        # Train/test splits
        base_X_train, base_X_test = base_X.iloc[train_indices], base_X.iloc[test_indices]
        enhanced_X_train, enhanced_X_test = enhanced_X.iloc[train_indices], enhanced_X.iloc[test_indices]
        y_train, y_test = y[train_indices], y[test_indices]
        
        print(f"Feature comparison: Base={base_X_train.shape[1]}, Enhanced={enhanced_X_train.shape[1]}")
        
        # Train models
        base_model = train_enhanced_rf_model(base_X_train, y_train, "base")
        enhanced_model = train_enhanced_rf_model(enhanced_X_train, y_train, "enhanced")
        
        # Evaluate models
        base_metrics, _, _ = evaluate_enhanced_model(base_model, base_X_test, y_test, "base")
        enhanced_metrics, _, _ = evaluate_enhanced_model(enhanced_model, enhanced_X_test, y_test, "enhanced")
        
        # Store results (AUC first as primary metric)
        for metric in ['auc', 'accuracy', 'precision', 'recall', 'f1']:
            base_results[metric].append(base_metrics[metric])
            enhanced_results[metric].append(enhanced_metrics[metric])
        
        # Feature importance analysis
        base_importance = analyze_enhanced_feature_importance(
            base_model, base_X.columns.tolist(), clinical_features, base_methylation_features
        )
        enhanced_importance = analyze_enhanced_feature_importance(
            enhanced_model, enhanced_X.columns.tolist(), clinical_features, base_methylation_features, available_mirnas
        )
        
        base_results['feature_importance'].append(base_importance)
        enhanced_results['feature_importance'].append(enhanced_importance)
        
        # Progress summary
        base_auc = f"{base_metrics['auc']:.4f}" if not np.isnan(base_metrics['auc']) else 'N/A'
        enhanced_auc = f"{enhanced_metrics['auc']:.4f}" if not np.isnan(enhanced_metrics['auc']) else 'N/A'
        print(f"Split {split_idx + 1} AUC - Base: {base_auc}, Enhanced: {enhanced_auc}")
    
    return base_results, enhanced_results, selected_mirnas

def analyze_model_comparison(base_results, enhanced_results, selected_mirnas, output_dir):
    """Comprehensive analysis of model comparison results"""
    print("\n" + "="*60)
    print("COMPREHENSIVE MODEL COMPARISON ANALYSIS")
    print("="*60)
    
    print("\nModel performance was evaluated using AUC as the primary metric,")
    print("complemented by accuracy, precision, recall, and F1 score.")
    
    print(f"\nPerformance Comparison:")
    print("-" * 70)
    print(f"{'Metric':<12} {'Base Model':<12} {'Enhanced Model':<15} {'Improvement':<12} {'P-value':<10}")
    print("-" * 70)
    
    comparison_stats = {}
    for metric in ['auc', 'accuracy', 'precision', 'recall', 'f1']:
        base_vals = [v for v in base_results[metric] if not np.isnan(v)]
        enhanced_vals = [v for v in enhanced_results[metric] if not np.isnan(v)]
        
        if base_vals and enhanced_vals:
            base_mean = np.mean(base_vals)
            base_std = np.std(base_vals)
            enhanced_mean = np.mean(enhanced_vals)
            enhanced_std = np.std(enhanced_vals)
            improvement = enhanced_mean - base_mean
            
            # Paired t-test for significance
            if len(base_vals) == len(enhanced_vals):
                try:
                    _, p_value = ttest_rel(enhanced_vals, base_vals)
                except:
                    p_value = np.nan
            else:
                p_value = np.nan
            
            comparison_stats[metric] = {
                'base_mean': base_mean,
                'base_std': base_std,
                'enhanced_mean': enhanced_mean,
                'enhanced_std': enhanced_std,
                'improvement': improvement,
                'p_value': p_value
            }
            
            p_val_str = f"{p_value:.3f}" if not np.isnan(p_value) else 'N/A'
            print(f"{metric.upper():<12} {base_mean:<12.4f} {enhanced_mean:<15.4f} {improvement:+.4f} {p_val_str:<10}")
    
    # Analyze top features for both models
    print(f"\nTOP 10 MOST IMPORTANT FEATURES ANALYSIS:")
    print("="*70)
    
    # Base model top features
    print(f"\nTOP 10 FEATURES - BASE MODEL:")
    print("-" * 50)
    base_all_importance = defaultdict(list)
    for importance_df in base_results['feature_importance']:
        for _, row in importance_df.iterrows():
            base_all_importance[row['feature']].append(row['importance'])
    
    base_avg_importance = {}
    for feature, importances in base_all_importance.items():
        base_avg_importance[feature] = {
            'mean': np.mean(importances),
            'std': np.std(importances),
            'frequency': len(importances)
        }
    
    base_sorted = sorted(base_avg_importance.items(), key=lambda x: x[1]['mean'], reverse=True)
    print(f"{'Rank':<4} {'Feature':<20} {'Avg Importance':<15} {'Std':<8} {'Category':<12}")
    print("-" * 60)
    
    clinical_features = ['age', 'BMI', 'parity', 'B12supplem', 'v1p_MultivitTab', 'v1p_FolicAcid', 'smoking']
    base_methylation_features = mechanism_driven_base_features()
    
    for i, (feature, stats) in enumerate(base_sorted[:10], 1):
        if feature in clinical_features:
            category = 'Clinical'
        elif feature in base_methylation_features:
            category = 'Methylation'
        else:
            category = 'Other'
        print(f"{i:<4} {feature:<20} {stats['mean']:<15.4f} {stats['std']:<8.4f} {category:<12}")
    
    # Enhanced model top features
    print(f"\nTOP 10 FEATURES - ENHANCED MODEL:")
    print("-" * 50)
    enhanced_all_importance = defaultdict(list)
    for importance_df in enhanced_results['feature_importance']:
        for _, row in importance_df.iterrows():
            enhanced_all_importance[row['feature']].append(row['importance'])
    
    enhanced_avg_importance = {}
    for feature, importances in enhanced_all_importance.items():
        enhanced_avg_importance[feature] = {
            'mean': np.mean(importances),
            'std': np.std(importances),
            'frequency': len(importances)
        }
    
    enhanced_sorted = sorted(enhanced_avg_importance.items(), key=lambda x: x[1]['mean'], reverse=True)
    print(f"{'Rank':<4} {'Feature':<20} {'Avg Importance':<15} {'Std':<8} {'Category':<12}")
    print("-" * 60)
    
    for i, (feature, stats) in enumerate(enhanced_sorted[:10], 1):
        if feature in clinical_features:
            category = 'Clinical'
        elif feature in base_methylation_features:
            category = 'Methylation'
        elif feature in selected_mirnas:
            category = 'miRNA'
        else:
            category = 'Other'
        print(f"{i:<4} {feature:<20} {stats['mean']:<15.4f} {stats['std']:<8.4f} {category:<12}")
    
    # Side-by-side comparison of top 10 features
    print(f"\nTOP 10 FEATURES COMPARISON:")
    print("="*80)
    print(f"{'Rank':<4} {'Base Model':<25} {'Importance':<12} {'Enhanced Model':<25} {'Importance':<12}")
    print("-" * 80)
    
    for i in range(10):
        base_feature, base_stats = base_sorted[i] if i < len(base_sorted) else ("", {"mean": 0})
        enhanced_feature, enhanced_stats = enhanced_sorted[i] if i < len(enhanced_sorted) else ("", {"mean": 0})
        
        print(f"{i+1:<4} {base_feature:<25} {base_stats['mean']:<12.4f} "
              f"{enhanced_feature:<25} {enhanced_stats['mean']:<12.4f}")
    
    # Statistical significance summary
    print(f"\nStatistical Significance Summary:")
    print("-" * 40)
    significant_improvements = 0
    for metric, stats in comparison_stats.items():
        if not np.isnan(stats['p_value']) and stats['p_value'] < 0.05:
            significant_improvements += 1
            significance = "***" if stats['p_value'] < 0.001 else "**" if stats['p_value'] < 0.01 else "*"
            print(f"{metric.upper()}: {stats['improvement']:+.4f} {significance}")
        else:
            print(f"{metric.upper()}: {stats['improvement']:+.4f} (ns)")
    
    # Save comprehensive results
    output_file = os.path.join(output_dir, "RF_Enhanced_Comparison_Results.txt")
    with open(output_file, 'w') as f:
        f.write("Enhanced Random Forest Model Comparison Results\n")
        f.write("=" * 60 + "\n\n")
        
        f.write("Model performance was evaluated using AUC as the primary metric,\n")
        f.write("complemented by accuracy, precision, recall, and F1 score.\n\n")
        
        f.write("Performance Comparison:\n")
        f.write("-" * 50 + "\n")
        for metric in ['auc', 'accuracy', 'precision', 'recall', 'f1']:
            if metric in comparison_stats:
                stats = comparison_stats[metric]
                p_val_str = f"(p={stats['p_value']:.3f})" if not np.isnan(stats['p_value']) else ""
                f.write(f"{metric.upper()}: Base={stats['base_mean']:.4f}, Enhanced={stats['enhanced_mean']:.4f}, "
                       f"Improvement={stats['improvement']:+.4f} {p_val_str}\n")
        
        f.write(f"\nTOP 10 FEATURES - BASE MODEL:\n")
        f.write("-" * 40 + "\n")
        for i, (feature, stats) in enumerate(base_sorted[:10], 1):
            f.write(f"{i:2d}. {feature}: {stats['mean']:.4f} ± {stats['std']:.4f}\n")
        
        f.write(f"\nTOP 10 FEATURES - ENHANCED MODEL:\n")
        f.write("-" * 40 + "\n")
        for i, (feature, stats) in enumerate(enhanced_sorted[:10], 1):
            f.write(f"{i:2d}. {feature}: {stats['mean']:.4f} ± {stats['std']:.4f}\n")
        
        f.write(f"\nSelected miRNA Features ({len(selected_mirnas)}):\n")
        for i, mirna in enumerate(selected_mirnas, 1):
            f.write(f"{i:2d}. {mirna}\n")
    
    print(f"\nResults saved to: {output_file}")
    
    # Return results with top features
    return {
        'comparison_stats': comparison_stats,
        'base_top10': base_sorted[:10],
        'enhanced_top10': enhanced_sorted[:10],
        'selected_mirnas': selected_mirnas,
        'significant_improvements': significant_improvements
    }

def main():
    """Main execution function"""
    import time
    start_time = time.time()
    
    print("ENHANCED RANDOM FOREST MODEL WITH miRNA FEATURES")
    print("=" * 60)
    print("Running comprehensive base vs enhanced model comparison")
    
    # Load enhanced datasets
    clinical_data, methylation_data, mirna_data, demirs_data, output_dir = load_enhanced_data()
    
    # Run model comparison
    print(f"\nRunning enhanced modeling comparison...")
    base_results, enhanced_results, selected_mirnas = run_enhanced_modeling_comparison(
        clinical_data, methylation_data, mirna_data, demirs_data, n_splits=20
    )
    
    # Comprehensive analysis
    analysis_results = analyze_model_comparison(base_results, enhanced_results, selected_mirnas, output_dir)
    
    # Final summary
    print(f"\nFINAL SUMMARY:")
    print("=" * 40)
    
    auc_improvement = analysis_results['comparison_stats'].get('auc', {}).get('improvement', 0)
    significant_improvements = analysis_results['significant_improvements']
    
    print(f"AUC improvement: {auc_improvement:+.4f}")
    print(f"Significant improvements: {significant_improvements}/5 metrics")
    
    if auc_improvement > 0.02 and significant_improvements >= 2:
        conclusion = "Substantial improvement with miRNA features"
    elif auc_improvement > 0.005 and significant_improvements >= 1:
        conclusion = "Moderate improvement with miRNA features"
    else:
        conclusion = "Minimal improvement with miRNA features"
    
    print(f"Conclusion: {conclusion}")
    
    # miRNA contribution analysis
    mirna_in_top10 = sum(1 for feature, _ in analysis_results['enhanced_top10'] 
                        if feature in selected_mirnas)
    print(f"miRNA features in top 10: {mirna_in_top10}/10")
    
    total_time = time.time() - start_time
    print(f"\nTotal execution time: {total_time/60:.1f} minutes")
    print(f"All results saved to: {output_dir}")
    print("\nAnalysis complete!")

if __name__ == "__main__":
    main()

ENHANCED RANDOM FOREST MODEL WITH miRNA FEATURES
Running comprehensive base vs enhanced model comparison
LOADING ENHANCED DATASETS WITH miRNA
Output directory created: /Users/heweilin/Desktop/P056_Code_3/Data

1. Loading clinical data...
Clinical data shape: (50, 21)
B12 status distribution:
B12_status
NB    25
LB    25
Name: count, dtype: int64

2. Loading methylation data...
Methylation data shape: (17584, 51)

3. Loading miRNA expression data...

4. Loading differential expression miRNA results...
DEmiRs data shape: (2201, 7)

Top 10 significant DEmiRs:
--------------------------------------------------------------------------------
miRNA                log2FC     pvalue       padj        
--------------------------------------------------------------------------------
hsa-miR-6877-3p      30.0       8.66e-14     1.91e-10    
hsa-miR-1269a        3.199437988 1.32e-05     0.014516677 
novel_594            -15.87845051 7.79e-05     0.057098973 
hsa-miR-223-3p       0.58831944 0.006138

In [17]:
import pandas as pd
import numpy as np
import os
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, classification_report, confusion_matrix
from sklearn.inspection import permutation_importance
from scipy.stats import mannwhitneyu, ttest_rel
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

np.random.seed(42)

def calculate_cohens_d(group1, group2):
    """Calculate Cohen's d effect size"""
    n1, n2 = len(group1), len(group2)
    pooled_std = np.sqrt(((n1 - 1) * np.var(group1, ddof=1) + (n2 - 1) * np.var(group2, ddof=1)) / (n1 + n2 - 2))
    return (np.mean(group1) - np.mean(group2)) / pooled_std

def load_multi_omics_data():
    """Load all datasets for multi-omics modeling"""
    print("="*60)
    print("LOADING MULTI-OMICS DATASETS (DNA + miRNA + mRNA)")
    print("="*60)
    
    clinical_path = "/Users/heweilin/Desktop/P056/7Clinical_data50.csv"
    methylation_path = "/Users/heweilin/Desktop/P056_Code_2/Processed_Data/1_PD_PromoterRegion_CpGs.csv"
    mirna_path = "/Users/heweilin/Desktop/P056_Code_2/Processed_Data/3_PD_miRNA_normalized.csv"
    demirs_path = "/Users/heweilin/Desktop/P056/2miRNA_DEmirs.csv"
    mrna_tpm_path = "/Users/heweilin/Desktop/P056/5mRNA_TPM.csv"
    mrna_degs_path = "/Users/heweilin/Desktop/P056/1mRNA_DEGs_proteincoding.csv"
    output_dir = "/Users/heweilin/Desktop/P056_Code_3/Data"
    
    os.makedirs(output_dir, exist_ok=True)
    print(f"Output directory created: {output_dir}")
    
    print("\n1. Loading clinical data...")
    clinical_data = pd.read_csv(clinical_path)
    print(f"Clinical data shape: {clinical_data.shape}")
    print(f"B12 status distribution:")
    print(clinical_data['B12_status'].value_counts())
    
    print("\n2. Loading methylation data...")
    methylation_data = pd.read_csv(methylation_path)
    print(f"Methylation data shape: {methylation_data.shape}")
    
    print("\n3. Loading miRNA data...")
    try:
        mirna_data = pd.read_csv(mirna_path)
        demirs_data = pd.read_csv(demirs_path)
        print(f"miRNA expression data shape: {mirna_data.shape}")
        print(f"DEmiRs data shape: {demirs_data.shape}")
    except FileNotFoundError:
        print("Warning: miRNA data not found")
        mirna_data = None
        demirs_data = None
    
    print("\n4. Loading mRNA expression data...")
    try:
        mrna_tpm_data = pd.read_csv(mrna_tpm_path)
        mrna_degs_data = pd.read_csv(mrna_degs_path)
        print(f"mRNA TPM data shape: {mrna_tpm_data.shape}")
        print(f"mRNA DEGs data shape: {mrna_degs_data.shape}")
    except FileNotFoundError:
        print("Warning: mRNA data not found")
        mrna_tpm_data = None
        mrna_degs_data = None
    
    print("\n5. Multi-omics data alignment...")
    has_mirna = 'sRNA_ID' in clinical_data.columns and mirna_data is not None
    has_mrna = 'mRNA_ID' in clinical_data.columns and mrna_tpm_data is not None
    
    dna_ids = clinical_data['DNA_ID'].tolist()
    mirna_ids = clinical_data['sRNA_ID'].tolist() if has_mirna else []
    mrna_ids = clinical_data['mRNA_ID'].tolist() if has_mrna else []
    
    print(f"Sample counts:")
    print(f"  DNA samples: {len(dna_ids)}")
    print(f"  miRNA samples: {len(mirna_ids) if has_mirna else 0}")
    print(f"  mRNA samples: {len(mrna_ids) if has_mrna else 0}")
    
    available_dna_ids = [col for col in methylation_data.columns if col in dna_ids]
    available_mirna_ids = [col for col in mirna_data.columns if col in mirna_ids] if has_mirna else []
    available_mrna_ids = [col for col in mrna_tpm_data.columns if col in mrna_ids] if has_mrna else []
    
    clinical_aligned = clinical_data[clinical_data['DNA_ID'].isin(available_dna_ids)]
    
    if has_mirna and available_mirna_ids:
        clinical_aligned = clinical_aligned[clinical_aligned['sRNA_ID'].isin(available_mirna_ids)]
    
    if has_mrna and available_mrna_ids:
        clinical_aligned = clinical_aligned[clinical_aligned['mRNA_ID'].isin(available_mrna_ids)]
    
    print(f"Final aligned samples: {len(clinical_aligned)}")
    
    aligned_dna_ids = clinical_aligned['DNA_ID'].tolist()
    methylation_aligned = methylation_data[['Gene_Symbol'] + aligned_dna_ids]
    
    if has_mirna and len(clinical_aligned) > 0:
        aligned_mirna_ids = clinical_aligned['sRNA_ID'].tolist()
        mirna_aligned = mirna_data[['miRNA_ID'] + aligned_mirna_ids]
    else:
        mirna_aligned = None
    
    if has_mrna and len(clinical_aligned) > 0:
        aligned_mrna_ids = clinical_aligned['mRNA_ID'].tolist()
        mrna_aligned = mrna_tpm_data[['Unnamed: 0'] + aligned_mrna_ids]
        mrna_aligned = mrna_aligned.rename(columns={'Unnamed: 0': 'Gene_ID'})
    else:
        mrna_aligned = None
    
    return clinical_aligned, methylation_aligned, mirna_aligned, mrna_aligned, demirs_data, mrna_degs_data, output_dir

def prepare_clinical_features(clinical_data):
    """Prepare clinical features with proper handling"""
    print("\n" + "="*60)
    print("PREPARING CLINICAL FEATURES")
    print("="*60)
    
    clinical_features = [
        'age', 'BMI', 'parity', 'B12supplem', 
        'v1p_FolicAcid', 'v1p_MultivitTab', 'smoking'
    ]
    
    existing_features = [f for f in clinical_features if f in clinical_data.columns]
    missing_features = [f for f in clinical_features if f not in clinical_data.columns]
    
    if missing_features:
        print(f"Warning: Missing features: {missing_features}")
        print("Available columns:")
        clinical_cols = [col for col in clinical_data.columns 
                        if col not in ['DNA_ID', 'sRNA_ID', 'mRNA_ID', 'B12_status']]
        print(f"  {clinical_cols}")
    
    clinical_features = existing_features
    print(f"Using clinical features: {clinical_features}")
    
    required_cols = clinical_features + ['DNA_ID', 'B12_status']
    if 'sRNA_ID' in clinical_data.columns:
        required_cols.append('sRNA_ID')
    if 'mRNA_ID' in clinical_data.columns:
        required_cols.append('mRNA_ID')
    
    clinical_subset = clinical_data[required_cols].copy()
    
    for col in clinical_features:
        if col in clinical_subset.columns:
            missing_count = clinical_subset[col].isnull().sum()
            
            if missing_count > 0:
                if col == 'parity':
                    median_val = clinical_subset[col].median()
                    clinical_subset[col].fillna(median_val, inplace=True)
                    clinical_subset[col] = (clinical_subset[col] > 0).astype(int)
                    print(f"Parity: imputed {missing_count} values, binary coded")
                elif col in ['B12supplem', 'v1p_FolicAcid', 'v1p_MultivitTab', 'smoking']:
                    mode_val = clinical_subset[col].mode()[0] if len(clinical_subset[col].mode()) > 0 else 0
                    clinical_subset[col].fillna(mode_val, inplace=True)
                    clinical_subset[col] = clinical_subset[col].astype(int)
                    print(f"{col}: imputed {missing_count} values with mode ({mode_val})")
                elif clinical_subset[col].dtype in ['float64', 'int64']:
                    median_val = clinical_subset[col].median()
                    clinical_subset[col].fillna(median_val, inplace=True)
                    print(f"{col}: imputed {missing_count} values with median ({median_val:.2f})")
            else:
                if col in ['B12supplem', 'v1p_FolicAcid', 'v1p_MultivitTab', 'smoking']:
                    clinical_subset[col] = clinical_subset[col].astype(int)
                elif col == 'parity':
                    clinical_subset[col] = (clinical_subset[col] > 0).astype(int)
    
    remaining_missing = clinical_subset[clinical_features].isnull().sum().sum()
    print(f"Total remaining missing values: {remaining_missing}")
    
    return clinical_subset, clinical_features

def data_driven_feature_selection(methylation_data, clinical_data, train_indices, n_features=15):
    """Data-driven feature selection using statistical tests"""
    print("\n" + "="*50)
    print("DATA-DRIVEN FEATURE SELECTION")
    print("="*50)
    
    train_clinical = clinical_data.iloc[train_indices]
    train_dna_ids = train_clinical['DNA_ID'].tolist()
    
    methylation_df = methylation_data.set_index('Gene_Symbol').T
    methylation_df = methylation_df.loc[train_dna_ids]
    
    train_clinical_aligned = train_clinical.set_index('DNA_ID').loc[train_dna_ids]
    
    lb_indices = train_clinical_aligned['B12_status'] == 'LB'
    nb_indices = train_clinical_aligned['B12_status'] == 'NB'
    
    lb_data = methylation_df[lb_indices]
    nb_data = methylation_df[nb_indices]
    
    print(f"LB samples: {lb_data.shape[0]}, NB samples: {nb_data.shape[0]}")
    
    feature_stats = []
    
    for gene in methylation_df.columns:
        lb_values = lb_data[gene].dropna()
        nb_values = nb_data[gene].dropna()
        
        if len(lb_values) > 0 and len(nb_values) > 0:
            statistic, p_value = mannwhitneyu(lb_values, nb_values, alternative='two-sided')
            cohens_d = abs(calculate_cohens_d(lb_values, nb_values))
            
            if p_value > 0:
                composite_score = cohens_d * (-np.log10(p_value))
            else:
                composite_score = cohens_d * 10
            
            feature_stats.append({
                'gene': gene,
                'p_value': p_value,
                'cohens_d': cohens_d,
                'composite_score': composite_score
            })
    
    stats_df = pd.DataFrame(feature_stats)
    stats_df = stats_df.sort_values('composite_score', ascending=False)
    
    significant_features = stats_df[
        (stats_df['p_value'] < 0.05) & 
        (stats_df['cohens_d'] > 0.3)
    ]
    
    selected_features = significant_features.head(n_features)['gene'].tolist()
    
    print(f"Selected top {len(selected_features)} features")
    
    return selected_features, stats_df

def mechanism_driven_feature_selection():
    """Mechanism-driven feature selection"""
    print("\n" + "="*50)
    print("MECHANISM-DRIVEN FEATURE SELECTION")
    print("="*50)
    
    mechanism_features = [
        'IQCG', 'ZNF154', 'TSTD1', 'GTSF1', 'SEPTIN7P11',
        'NUDT10', 'NAA11', 'SYCP3', 'PSMA8', 'LINC00667'
    ]
    
    print(f"Selected {len(mechanism_features)} mechanism-driven features")
    
    return mechanism_features

def select_significant_demirs(demirs_data, fdr_threshold=0.05, log2fc_threshold=0.5, max_features=15):
    """Select significant DEmiRs"""
    if demirs_data is None:
        print("No DEmiRs data available")
        return []
    
    if 'Unnamed: 0' in demirs_data.columns:
        mirna_col = 'Unnamed: 0'
    else:
        mirna_col = demirs_data.columns[0]
    
    print(f"\nSelecting significant DEmiRs")
    
    if 'padj' in demirs_data.columns and 'log2FoldChange' in demirs_data.columns:
        significant = demirs_data[
            (demirs_data['padj'] < fdr_threshold) & 
            (abs(demirs_data['log2FoldChange']) > log2fc_threshold)
        ]
        
        if len(significant) == 0:
            significant = demirs_data[demirs_data['pvalue'] < 0.01].head(max_features)
        elif len(significant) > max_features:
            significant = significant.sort_values('padj').head(max_features)
    else:
        significant = demirs_data[demirs_data['pvalue'] < 0.01].head(max_features)
    
    if len(significant) == 0:
        return []
    
    selected_mirnas = significant[mirna_col].tolist()
    print(f"Selected {len(selected_mirnas)} significant DEmiRs")
    
    return selected_mirnas

def select_significant_degs(mrna_degs_data, fdr_threshold=0.05, log2fc_threshold=0.5, max_features=20):
    """Select significant DEGs"""
    if mrna_degs_data is None:
        print("No mRNA DEGs data available")
        return []
    
    print(f"\nSelecting significant mRNA DEGs")
    
    if 'padj' in mrna_degs_data.columns and 'log2FoldChange' in mrna_degs_data.columns:
        significant_degs = mrna_degs_data[
            (mrna_degs_data['padj'] < fdr_threshold) & 
            (abs(mrna_degs_data['log2FoldChange']) > log2fc_threshold)
        ]
        
        if len(significant_degs) == 0:
            significant_degs = mrna_degs_data[mrna_degs_data['pvalue'] < 0.001].head(max_features)
        elif len(significant_degs) > max_features:
            significant_degs = significant_degs.sort_values('padj').head(max_features)
    else:
        significant_degs = mrna_degs_data[mrna_degs_data['pvalue'] < 0.001].head(max_features)
    
    if len(significant_degs) == 0:
        return []
    
    selected_genes = significant_degs['Row.names'].tolist()
    print(f"Selected {len(selected_genes)} significant mRNA DEGs")
    
    return selected_genes

def prepare_feature_matrix(clinical_data, methylation_data, selected_methylation_features, clinical_features):
    """Prepare feature matrix"""
    print(f"\nPreparing feature matrix...")
    
    dna_ids = clinical_data['DNA_ID'].tolist()
    
    methylation_df = methylation_data.set_index('Gene_Symbol').T
    available_features = [f for f in selected_methylation_features if f in methylation_df.columns]
    methylation_subset = methylation_df.loc[dna_ids, available_features]
    methylation_subset = methylation_subset.fillna(methylation_subset.median())
    
    clinical_subset = clinical_data.set_index('DNA_ID')[clinical_features].loc[dna_ids]
    
    feature_matrix = pd.concat([clinical_subset, methylation_subset], axis=1)
    
    target = clinical_data.set_index('DNA_ID')['B12_status'].loc[dna_ids]
    le = LabelEncoder()
    target_encoded = le.fit_transform(target)
    
    print(f"Feature matrix shape: {feature_matrix.shape}")
    print(f"Target distribution: LB={sum(target_encoded)}, NB={len(target_encoded)-sum(target_encoded)}")
    
    return feature_matrix, target_encoded, dna_ids

def prepare_multi_omics_matrix(clinical_data, methylation_data, mirna_data, mrna_data,
                              base_features, selected_mirnas, selected_mrnas, clinical_features):
    """Prepare multi-omics feature matrix"""
    print(f"\nPreparing multi-omics feature matrix...")
    
    dna_ids = clinical_data['DNA_ID'].tolist()
    mirna_ids = clinical_data['sRNA_ID'].tolist() if 'sRNA_ID' in clinical_data.columns else []
    mrna_ids = clinical_data['mRNA_ID'].tolist() if 'mRNA_ID' in clinical_data.columns else []
    
    clinical_subset = clinical_data.set_index('DNA_ID')[clinical_features].loc[dna_ids]
    
    methylation_df = methylation_data.set_index('Gene_Symbol').T
    available_methylation = [f for f in base_features if f in methylation_df.columns]
    methylation_subset = methylation_df.loc[dna_ids, available_methylation]
    methylation_subset = methylation_subset.fillna(methylation_subset.median())
    
    feature_components = [clinical_subset, methylation_subset]
    feature_counts = [len(clinical_features), len(available_methylation)]
    feature_types = ['Clinical', 'Methylation']
    
    if mirna_data is not None and len(selected_mirnas) > 0 and len(mirna_ids) > 0:
        mirna_df = mirna_data.set_index('miRNA_ID').T
        available_mirnas = [f for f in selected_mirnas if f in mirna_df.columns]
        
        if len(available_mirnas) > 0:
            mirna_subset = mirna_df.loc[mirna_ids, available_mirnas]
            mirna_subset = mirna_subset.fillna(mirna_subset.median())
            feature_components.append(mirna_subset)
            feature_counts.append(len(available_mirnas))
            feature_types.append('miRNA')
    
    if mrna_data is not None and len(selected_mrnas) > 0 and len(mrna_ids) > 0:
        mrna_df = mrna_data.set_index('Gene_ID').T
        available_mrnas = [f for f in selected_mrnas if f in mrna_df.columns]
        
        if len(available_mrnas) > 0:
            mrna_subset = mrna_df.loc[mrna_ids, available_mrnas]
            mrna_subset = np.log2(mrna_subset + 1)
            mrna_subset = mrna_subset.fillna(mrna_subset.median())
            feature_components.append(mrna_subset)
            feature_counts.append(len(available_mrnas))
            feature_types.append('mRNA')
    
    feature_matrix = pd.concat(feature_components, axis=1)
    
    target = clinical_data.set_index('DNA_ID')['B12_status'].loc[dna_ids]
    le = LabelEncoder()
    target_encoded = le.fit_transform(target)
    
    print(f"Multi-omics matrix shape: {feature_matrix.shape}")
    for ftype, count in zip(feature_types, feature_counts):
        print(f"  {ftype} features: {count}")
    
    return feature_matrix, target_encoded, dna_ids

def train_random_forest(X_train, y_train):
    """Train Random Forest model"""
    print("\nTraining Random Forest model...")
    
    model = RandomForestClassifier(
        n_estimators=200, max_depth=5, min_samples_split=2,
        min_samples_leaf=1, max_features='sqrt', 
        random_state=42, class_weight='balanced', n_jobs=1
    )
    
    model.fit(X_train, y_train)
    return model

def evaluate_model(model, X_test, y_test, model_name="model"):
    """Evaluate model performance using 5 core metrics with AUC as primary"""
    print(f"\nEvaluating {model_name} model...")
    
    y_pred = model.predict(X_test)
    y_pred_proba = model.predict_proba(X_test)[:, 1]
    
    accuracy = accuracy_score(y_test, y_pred)
    precision = precision_score(y_test, y_pred, zero_division=0)
    recall = recall_score(y_test, y_pred, zero_division=0)
    f1 = f1_score(y_test, y_pred, zero_division=0)
    
    try:
        auc = roc_auc_score(y_test, y_pred_proba)
    except ValueError:
        auc = np.nan
        print("Warning: AUC calculation failed")
    
    metrics = {
        'auc': auc,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }
    
    print("Model performance was evaluated using AUC as the primary metric,")
    print("complemented by accuracy, precision, recall, and F1 score:")
    for metric, value in metrics.items():
        if not np.isnan(value):
            print(f"  {metric.upper()}: {value:.4f}")
        else:
            print(f"  {metric.upper()}: N/A")
    
    return metrics, y_pred, y_pred_proba

def analyze_feature_importance(model, feature_names, clinical_features, methylation_features, 
                              mirna_features=None, mrna_features=None, top_n=10):
    """Analyze feature importance with TOP 10 output"""
    print(f"\nAnalyzing feature importance...")
    
    importances = model.feature_importances_
    
    if mirna_features is None:
        mirna_features = []
    if mrna_features is None:
        mrna_features = []
    
    def get_category(feature):
        if feature in clinical_features:
            return 'Clinical'
        elif feature in methylation_features:
            return 'Methylation'
        elif feature in mirna_features:
            return 'miRNA'
        elif feature in mrna_features:
            return 'mRNA'
        else:
            return 'Other'
    
    importance_df = pd.DataFrame({
        'feature': feature_names,
        'importance': importances,
        'category': [get_category(f) for f in feature_names]
    })
    
    importance_df = importance_df.sort_values('importance', ascending=False)
    
    print(f"\nTOP {top_n} MOST IMPORTANT FEATURES:")
    print("="*60)
    print(f"{'Rank':<4} {'Feature':<20} {'Importance':<12} {'Category':<12}")
    print("-"*60)
    
    for i in range(min(top_n, len(importance_df))):
        row = importance_df.iloc[i]
        print(f"{i+1:<4} {row['feature']:<20} {row['importance']:<12.4f} {row['category']:<12}")
    
    return importance_df

def run_model_comparison(clinical_data, methylation_data, mirna_data, mrna_data, 
                        demirs_data, mrna_degs_data, n_splits=20):
    """Run comprehensive model comparison"""
    print("\n" + "="*60)
    print(f"RUNNING MODEL COMPARISON - {n_splits} SPLITS")
    print("="*60)
    
    clinical_subset, clinical_features = prepare_clinical_features(clinical_data)
    base_features = mechanism_driven_feature_selection()
    selected_mirnas = select_significant_demirs(demirs_data)
    selected_mrnas = select_significant_degs(mrna_degs_data)
    
    print(f"\nModel configuration:")
    print(f"  Base model: {len(clinical_features)} clinical + {len(base_features)} methylation")
    print(f"  Enhanced model: Base + {len(selected_mirnas)} miRNA")
    print(f"  Multi-omics model: Enhanced + {len(selected_mrnas)} mRNA")
    
    base_results = {'auc': [], 'accuracy': [], 'precision': [], 'recall': [], 'f1': [], 'feature_importance': []}
    enhanced_results = {'auc': [], 'accuracy': [], 'precision': [], 'recall': [], 'f1': [], 'feature_importance': []}
    multi_omics_results = {'auc': [], 'accuracy': [], 'precision': [], 'recall': [], 'f1': [], 'feature_importance': []}
    
    for split_idx in range(n_splits):
        print(f"\nSplit {split_idx + 1}/{n_splits}")
        
        train_indices, test_indices = train_test_split(
            range(len(clinical_subset)), test_size=0.2,
            stratify=clinical_subset['B12_status'], random_state=split_idx
        )
        
        # Base model
        base_X, y, _ = prepare_feature_matrix(
            clinical_subset, methylation_data, base_features, clinical_features
        )
        
        base_X_train, base_X_test = base_X.iloc[train_indices], base_X.iloc[test_indices]
        y_train, y_test = y[train_indices], y[test_indices]
        
        base_model = train_random_forest(base_X_train, y_train)
        base_metrics, _, _ = evaluate_model(base_model, base_X_test, y_test, "base")
        
        for metric in ['auc', 'accuracy', 'precision', 'recall', 'f1']:
            base_results[metric].append(base_metrics[metric])
        
        base_importance = analyze_feature_importance(
            base_model, base_X.columns.tolist(), clinical_features, base_features
        )
        base_results['feature_importance'].append(base_importance)
        
        # Enhanced model (Base + miRNA)
        enhanced_X, _, _ = prepare_multi_omics_matrix(
            clinical_subset, methylation_data, mirna_data, None,
            base_features, selected_mirnas, [], clinical_features
        )
        
        enhanced_X_train, enhanced_X_test = enhanced_X.iloc[train_indices], enhanced_X.iloc[test_indices]
        
        enhanced_model = train_random_forest(enhanced_X_train, y_train)
        enhanced_metrics, _, _ = evaluate_model(enhanced_model, enhanced_X_test, y_test, "enhanced")
        
        for metric in ['auc', 'accuracy', 'precision', 'recall', 'f1']:
            enhanced_results[metric].append(enhanced_metrics[metric])
        
        enhanced_importance = analyze_feature_importance(
            enhanced_model, enhanced_X.columns.tolist(), 
            clinical_features, base_features, selected_mirnas
        )
        enhanced_results['feature_importance'].append(enhanced_importance)
        
        # Multi-omics model (Enhanced + mRNA)
        multi_omics_X, _, _ = prepare_multi_omics_matrix(
            clinical_subset, methylation_data, mirna_data, mrna_data,
            base_features, selected_mirnas, selected_mrnas, clinical_features
        )
        
        multi_X_train, multi_X_test = multi_omics_X.iloc[train_indices], multi_omics_X.iloc[test_indices]
        
        multi_model = train_random_forest(multi_X_train, y_train)
        multi_metrics, _, _ = evaluate_model(multi_model, multi_X_test, y_test, "multi-omics")
        
        for metric in ['auc', 'accuracy', 'precision', 'recall', 'f1']:
            multi_omics_results[metric].append(multi_metrics[metric])
        
        multi_importance = analyze_feature_importance(
            multi_model, multi_omics_X.columns.tolist(),
            clinical_features, base_features, selected_mirnas, selected_mrnas
        )
        multi_omics_results['feature_importance'].append(multi_importance)
        
        # Progress report
        base_auc = f"{base_metrics['auc']:.4f}" if not np.isnan(base_metrics['auc']) else 'N/A'
        enhanced_auc = f"{enhanced_metrics['auc']:.4f}" if not np.isnan(enhanced_metrics['auc']) else 'N/A'
        multi_auc = f"{multi_metrics['auc']:.4f}" if not np.isnan(multi_metrics['auc']) else 'N/A'
        print(f"AUC - Base: {base_auc}, Enhanced: {enhanced_auc}, Multi-omics: {multi_auc}")
    
    return base_results, enhanced_results, multi_omics_results, selected_mirnas, selected_mrnas

def analyze_results(base_results, enhanced_results, multi_omics_results,
                   selected_mirnas, selected_mrnas, output_dir):
    """Analyze comprehensive results with TOP 10 features"""
    print("\n" + "="*60)
    print("COMPREHENSIVE RESULTS ANALYSIS")
    print("="*60)
    
    print("\nModel performance was evaluated using AUC as the primary metric,")
    print("complemented by accuracy, precision, recall, and F1 score.")
    
    models = [
        ('Base', base_results),
        ('Enhanced', enhanced_results), 
        ('Multi-omics', multi_omics_results)
    ]
    
    print(f"\nPerformance Statistics Across 20 Splits:")
    print("-"*70)
    print(f"{'Model':<12} {'AUC':<8} {'Accuracy':<9} {'Precision':<10} {'Recall':<8} {'F1':<8}")
    print("-"*70)
    
    comparison_stats = {}
    for model_name, results in models:
        model_stats = {}
        performance_line = f"{model_name:<12}"
        
        for metric in ['auc', 'accuracy', 'precision', 'recall', 'f1']:
            values = [v for v in results[metric] if not np.isnan(v)]
            
            if values:
                mean_val = np.mean(values)
                std_val = np.std(values)
                model_stats[metric] = {
                    'mean': mean_val, 'std': std_val, 'values': values
                }
                performance_line += f" {mean_val:<7.4f}"
            else:
                model_stats[metric] = {'mean': np.nan, 'std': np.nan, 'values': []}
                performance_line += f" {'N/A':<7}"
        
        comparison_stats[model_name.lower().replace('-', '_')] = model_stats
        print(performance_line)
    
    # TOP 10 FEATURES ANALYSIS for each model
    print(f"\nTOP 10 MOST IMPORTANT FEATURES ANALYSIS:")
    print("="*80)
    
    clinical_features = ['age', 'BMI', 'parity', 'B12supplem', 'v1p_FolicAcid', 'v1p_MultivitTab', 'smoking']
    base_methylation_features = mechanism_driven_feature_selection()
    
    model_top10_features = {}
    
    for model_name, results in models:
        if len(results['feature_importance']) > 0:
            all_importance = defaultdict(list)
            for importance_df in results['feature_importance']:
                for _, row in importance_df.iterrows():
                    all_importance[row['feature']].append(row['importance'])
            
            avg_importance = {}
            for feature, importances in all_importance.items():
                avg_importance[feature] = {
                    'mean': np.mean(importances),
                    'std': np.std(importances)
                }
            
            sorted_importance = sorted(avg_importance.items(), key=lambda x: x[1]['mean'], reverse=True)
            model_top10_features[model_name] = sorted_importance[:10]
            
            print(f"\nTOP 10 FEATURES - {model_name.upper()} MODEL:")
            print("-"*60)
            print(f"{'Rank':<4} {'Feature':<20} {'Importance':<12} {'Std':<8} {'Category':<10}")
            print("-"*60)
            
            for i, (feature, stats) in enumerate(sorted_importance[:10], 1):
                if feature in clinical_features:
                    category = 'Clinical'
                elif feature in base_methylation_features:
                    category = 'Methylation'
                elif feature in selected_mirnas:
                    category = 'miRNA'
                elif feature in selected_mrnas:
                    category = 'mRNA'
                else:
                    category = 'Other'
                
                print(f"{i:<4} {feature:<20} {stats['mean']:<12.4f} {stats['std']:<8.4f} {category:<10}")
    
    # Side-by-side comparison of TOP 10 features
    if len(model_top10_features) >= 2:
        print(f"\nTOP 10 MOST IMPORTANT FEATURES COMPARISON:")
        print("="*100)
        
        if 'Multi-omics' in model_top10_features:
            print(f"{'Rank':<4} {'Base Model':<25} {'Importance':<12} {'Enhanced Model':<25} {'Importance':<12} {'Multi-omics Model':<25} {'Importance':<12}")
            print("-"*100)
            
            base_top10 = model_top10_features.get('Base', [])
            enhanced_top10 = model_top10_features.get('Enhanced', [])
            multi_top10 = model_top10_features.get('Multi-omics', [])
            
            for i in range(10):
                base_item = base_top10[i] if i < len(base_top10) else ("", {"mean": 0})
                enhanced_item = enhanced_top10[i] if i < len(enhanced_top10) else ("", {"mean": 0})
                multi_item = multi_top10[i] if i < len(multi_top10) else ("", {"mean": 0})
                
                base_feature, base_stats = base_item
                enhanced_feature, enhanced_stats = enhanced_item
                multi_feature, multi_stats = multi_item
                
                print(f"{i+1:<4} {base_feature:<25} {base_stats['mean']:<12.4f} "
                      f"{enhanced_feature:<25} {enhanced_stats['mean']:<12.4f} "
                      f"{multi_feature:<25} {multi_stats['mean']:<12.4f}")
        else:
            print(f"{'Rank':<4} {'Base Model':<25} {'Importance':<12} {'Enhanced Model':<25} {'Importance':<12}")
            print("-"*70)
            
            base_top10 = model_top10_features.get('Base', [])
            enhanced_top10 = model_top10_features.get('Enhanced', [])
            
            for i in range(10):
                base_item = base_top10[i] if i < len(base_top10) else ("", {"mean": 0})
                enhanced_item = enhanced_top10[i] if i < len(enhanced_top10) else ("", {"mean": 0})
                
                base_feature, base_stats = base_item
                enhanced_feature, enhanced_stats = enhanced_item
                
                print(f"{i+1:<4} {base_feature:<25} {base_stats['mean']:<12.4f} "
                      f"{enhanced_feature:<25} {enhanced_stats['mean']:<12.4f}")
    
    # Performance improvement analysis
    print(f"\nPerformance Improvement Analysis:")
    print("-"*50)
    
    if 'base' in comparison_stats and 'enhanced' in comparison_stats:
        print("Enhanced vs Base:")
        for metric in ['auc', 'accuracy', 'precision', 'recall', 'f1']:
            base_mean = comparison_stats['base'][metric]['mean']
            enhanced_mean = comparison_stats['enhanced'][metric]['mean']
            
            if not np.isnan(base_mean) and not np.isnan(enhanced_mean):
                improvement = enhanced_mean - base_mean
                print(f"  {metric.upper()}: {improvement:+.4f}")
    
    if 'enhanced' in comparison_stats and 'multi_omics' in comparison_stats:
        print("\nMulti-omics vs Enhanced:")
        for metric in ['auc', 'accuracy', 'precision', 'recall', 'f1']:
            enhanced_mean = comparison_stats['enhanced'][metric]['mean']
            multi_mean = comparison_stats['multi_omics'][metric]['mean']
            
            if not np.isnan(enhanced_mean) and not np.isnan(multi_mean):
                improvement = multi_mean - enhanced_mean
                print(f"  {metric.upper()}: {improvement:+.4f}")
    
    # Save comprehensive results
    output_file = os.path.join(output_dir, "RF_Multi_Omics_Comprehensive_Results.txt")
    with open(output_file, 'w') as f:
        f.write("Multi-omics Random Forest B12 Prediction - Comprehensive Results\n")
        f.write("="*60 + "\n\n")
        
        f.write("Model performance was evaluated using AUC as the primary metric,\n")
        f.write("complemented by accuracy, precision, recall, and F1 score.\n\n")
        
        f.write("Performance Statistics:\n")
        f.write("-"*30 + "\n")
        for model_name, results in models:
            f.write(f"\n{model_name} Model:\n")
            model_key = model_name.lower().replace('-', '_')
            if model_key in comparison_stats:
                for metric in ['auc', 'accuracy', 'precision', 'recall', 'f1']:
                    stats = comparison_stats[model_key][metric]
                    if not np.isnan(stats['mean']):
                        f.write(f"  {metric.upper()}: {stats['mean']:.4f} ± {stats['std']:.4f}\n")
        
        for model_name in model_top10_features:
            f.write(f"\nTOP 10 FEATURES - {model_name.upper()} MODEL:\n")
            f.write("-"*40 + "\n")
            for i, (feature, stats) in enumerate(model_top10_features[model_name], 1):
                f.write(f"{i:2d}. {feature}: {stats['mean']:.4f} ± {stats['std']:.4f}\n")
        
        f.write(f"\nSelected Features Summary:\n")
        f.write(f"miRNA features ({len(selected_mirnas)}): {', '.join(selected_mirnas[:5])}{'...' if len(selected_mirnas) > 5 else ''}\n")
        f.write(f"mRNA features ({len(selected_mrnas)}): {', '.join(selected_mrnas[:5])}{'...' if len(selected_mrnas) > 5 else ''}\n")
    
    print(f"\nComprehensive results saved to: {output_file}")
    
    return comparison_stats, model_top10_features

def main():
    """Main execution function"""
    import time
    start_time = time.time()
    
    print("MULTI-OMICS RANDOM FOREST MODEL FOR B12 DEFICIENCY PREDICTION")
    print("="*60)
    print("Comprehensive comparison: Base + Enhanced + Multi-omics models")
    
    try:
        clinical_data, methylation_data, mirna_data, mrna_data, demirs_data, mrna_degs_data, output_dir = load_multi_omics_data()
    except Exception as e:
        print(f"Error loading data: {e}")
        return
    
    try:
        base_results, enhanced_results, multi_omics_results, selected_mirnas, selected_mrnas = run_model_comparison(
            clinical_data, methylation_data, mirna_data, mrna_data, 
            demirs_data, mrna_degs_data, n_splits=20
        )
        
        comparison_stats, model_top10_features = analyze_results(
            base_results, enhanced_results, multi_omics_results,
            selected_mirnas, selected_mrnas, output_dir
        )
        
        # Final assessment
        print("\n" + "="*60)
        print("FINAL MULTI-OMICS INTEGRATION ASSESSMENT")
        print("="*60)
        
        base_auc = comparison_stats.get('base', {}).get('auc', {}).get('mean', 0)
        enhanced_auc = comparison_stats.get('enhanced', {}).get('auc', {}).get('mean', 0)
        multi_auc = comparison_stats.get('multi_omics', {}).get('auc', {}).get('mean', 0)
        
        base_auc = 0 if np.isnan(base_auc) else base_auc
        enhanced_auc = 0 if np.isnan(enhanced_auc) else enhanced_auc
        multi_auc = 0 if np.isnan(multi_auc) else multi_auc
        
        mirna_improvement = enhanced_auc - base_auc
        mrna_improvement = multi_auc - enhanced_auc
        total_improvement = multi_auc - base_auc
        
        print(f"Performance Evolution:")
        print(f"  Base model: AUC = {base_auc:.4f}")
        print(f"  + miRNA features: AUC = {enhanced_auc:.4f} (Δ = {mirna_improvement:+.4f})")
        print(f"  + mRNA features: AUC = {multi_auc:.4f} (Δ = {mrna_improvement:+.4f})")
        print(f"  Total improvement: {total_improvement:+.4f}")
        
        if total_improvement > 0.03:
            assessment = "Substantial benefit from multi-omics integration"
        elif total_improvement > 0.01:
            assessment = "Moderate benefit from multi-omics integration"
        elif total_improvement > -0.01:
            assessment = "Minimal impact from multi-omics integration"
        else:
            assessment = "Multi-omics integration decreased performance"
        
        print(f"\nOverall Assessment: {assessment}")
        
        aucs = [base_auc, enhanced_auc, multi_auc]
        model_names = ['Base', 'Enhanced (miRNA)', 'Multi-omics (miRNA + mRNA)']
        
        if max(aucs) > 0:
            best_idx = aucs.index(max(aucs))
            print(f"Best performing model: {model_names[best_idx]} (AUC = {aucs[best_idx]:.4f})")
        
        total_time = time.time() - start_time
        print(f"\nTotal execution time: {total_time/60:.1f} minutes")
        print(f"All results saved to: {output_dir}")
        
        # Save final summary
        summary_file = os.path.join(output_dir, "RF_Multi_Omics_Final_Summary.txt")
        with open(summary_file, 'w') as f:
            f.write("Multi-omics Random Forest B12 Prediction - Final Summary\n")
            f.write("="*60 + "\n\n")
            f.write("Model performance was evaluated using AUC as the primary metric,\n")
            f.write("complemented by accuracy, precision, recall, and F1 score.\n\n")
            f.write(f"Execution time: {total_time/60:.1f} minutes\n")
            f.write(f"Assessment: {assessment}\n\n")
            f.write("Performance Summary:\n")
            f.write(f"Base → Enhanced: {mirna_improvement:+.4f}\n")
            f.write(f"Enhanced → Multi-omics: {mrna_improvement:+.4f}\n")
            f.write(f"Base → Multi-omics: {total_improvement:+.4f}\n")
            
            if model_top10_features:
                for model_name in ['Base', 'Enhanced', 'Multi-omics']:
                    if model_name in model_top10_features:
                        f.write(f"\nTOP 10 FEATURES - {model_name.upper()}:\n")
                        for i, (feature, stats) in enumerate(model_top10_features[model_name], 1):
                            f.write(f"{i:2d}. {feature}: {stats['mean']:.4f}\n")
        
        print(f"Final summary saved to: {summary_file}")
        print("\nMulti-omics modeling analysis complete!")
        
    except Exception as e:
        print(f"Error during analysis: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()

MULTI-OMICS RANDOM FOREST MODEL FOR B12 DEFICIENCY PREDICTION
Comprehensive comparison: Base + Enhanced + Multi-omics models
LOADING MULTI-OMICS DATASETS (DNA + miRNA + mRNA)
Output directory created: /Users/heweilin/Desktop/P056_Code_3/Data

1. Loading clinical data...
Clinical data shape: (50, 21)
B12 status distribution:
B12_status
NB    25
LB    25
Name: count, dtype: int64

2. Loading methylation data...
Methylation data shape: (17584, 51)

3. Loading miRNA data...

4. Loading mRNA expression data...
mRNA TPM data shape: (58735, 51)
mRNA DEGs data shape: (19853, 13)

5. Multi-omics data alignment...
Sample counts:
  DNA samples: 50
  miRNA samples: 0
  mRNA samples: 50
Final aligned samples: 50

RUNNING MODEL COMPARISON - 20 SPLITS

PREPARING CLINICAL FEATURES
Using clinical features: ['age', 'BMI', 'parity', 'B12supplem', 'v1p_FolicAcid', 'v1p_MultivitTab', 'smoking']
B12supplem: imputed 2 values with mode (1.0)
v1p_FolicAcid: imputed 2 values with mode (1.0)
v1p_MultivitTab: imp