In [None]:
def _generate_executive_summary(self):
        """Generate executive summary"""
        print("\nEXECUTIVE SUMMARY")
        print("="*50)
        
        total_tests = 0
        excellent_tests = 0
        good_tests = 0
        
        all_aucs = []
        
        for cnn_name, cnn_results in self.results.items():
            for task_key, task_data in cnn_results.items():
                for alg_name, result in task_data['results'].items():
                    total_tests += 1
                    auc = result['auc']
                    all_aucs.append(auc)
                    
                    if auc >= 0.85:
                        excellent_tests += 1
                    elif auc >= 0.75:
                        good_tests += 1
        
        if all_aucs:
            mean_auc = np.mean(all_aucs)
            max_auc = np.max(all_aucs)
            
            print(f"PERFORMANCE OVERVIEW:")
            print(f"   Total algorithm-task combinations: {total_tests}")
            print(f"   Mean AUC across all tests: {mean_auc:.3f}")
            print(f"   Best AUC achieved: {max_auc:.3f}")
            print(f"   Excellent performance (AUC >= 0.85): {excellent_tests}/{total_tests} ({excellent_tests/total_tests*100:.1f}%)")
            print(f"   Good+ performance (AUC >= 0.75): {good_tests+excellent_tests}/{total_tests} ({(good_tests+excellent_tests)/total_tests*100:.1f}%)")
            
            # Clinical readiness assessment
            if excellent_tests > 0:
                print(f"   CLINICAL DEPLOYMENT: {excellent_tests} combinations ready for validation")
            if max_auc >= 0.90:
                print(f"   PUBLICATION READY: Exceptional results achieved")
            elif max_auc >= 0.80:
                print(f"   PUBLICATION READY: Strong results achieved")

def _generate_detailed_results_table(self):
        """Generate detailed results table"""
        print(f"\nDETAILED RESULTS TABLE")
        print("="*50)
        
        # Header
        print(f"{'CNN':<20} {'Task':<25} {'Algorithm':<15} {'AUC':<8} {'Acc':<8} {'Sens':<8} {'Spec':<8} {'Status':<15}")
        print("-" * 120)
        
        for cnn_name, cnn_results in self.results.items():
            for task_key, task_data in cnn_results.items():
                task_name = task_data['task_name']
                
                for alg_name, result in task_data['results'].items():
                    auc = result['auc']
                    acc = result['accuracy']
                    sens = result['sensitivity']
                    spec = result['specificity']
                    
                    # Status based on AUC
                    if auc >= 0.85:
                        status = "EXCELLENT"
                    elif auc >= 0.75:
                        status = "STRONG"
                    elif auc >= 0.65:
                        status = "GOOD"
                    else:
                        status = "MODERATE"
                    
                    print(f"{cnn_name:<20} {task_name:<25} {alg_name:<15} {auc:<8.3f} {acc:<8.3f} {sens:<8.3f} {spec:<8.3f} {status:<15}")

def _generate_best_performers_analysis(self):
        """Generate best performers analysis"""
        print(f"\nBEST PERFORMERS BY TASK")
        print("="*50)
        
        # Find best performer for each task across all CNNs
        task_best = {}
        
        for cnn_name, cnn_results in self.results.items():
            for task_key, task_data in cnn_results.items():
                task_name = task_data['task_name']
                
                if task_name not in task_best:
                    task_best[task_name] = {'auc': 0, 'cnn': '', 'algorithm': '', 'result': None}
                
                for alg_name, result in task_data['results'].items():
                    if result['auc'] > task_best[task_name]['auc']:
                        task_best[task_name] = {
                            'auc': result['auc'],
                            'cnn': cnn_name,
                            'algorithm': alg_name,
                            'result': result
                        }
        
        for task_name, best in task_best.items():
            auc = best['auc']
            status = "DEPLOYMENT READY" if auc >= 0.85 else "PROMISING" if auc >= 0.75 else "NEEDS WORK"
            print(f"{task_name:<30}: {best['cnn']} + {best['algorithm']} (AUC = {auc:.3f}) {status}")

def _generate_validation_summary(self):
        """Generate validation summary"""
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder, StandardScaler, RobustScaler
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
from sklearn.metrics import (classification_report, confusion_matrix, roc_auc_score, 
                           accuracy_score, roc_curve, precision_recall_curve, auc)
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
from sklearn.dummy import DummyClassifier
from tabpfn import TabPFNClassifier
import warnings
warnings.filterwarnings('ignore')

# Check for optional dependencies
try:
    import xgboost as xgb
    XGBOOST_AVAILABLE = True
except ImportError:
    XGBOOST_AVAILABLE = False
    print("⚠️ XGBoost not available. Install with: pip install xgboost")

try:
    from pytorch_tabnet.tab_model import TabNetClassifier
    import torch
    TABNET_AVAILABLE = True
except ImportError:
    TABNET_AVAILABLE = False
    print("⚠️ TabNet not available. Install with: pip install pytorch-tabnet torch")

class NeurosurgicalAIAnalyzer:
    """Comprehensive AI analysis system for neurosurgical outcome prediction"""
    
    def __init__(self):
        # Updated paths to match your actual file names
        self.datasets = {
            'ConvNext': '/Users/joi263/Documents/MultimodalTabData/data/convnext_data/convnext_cleaned_patient_features_separate_256d.csv',
            'ViT': '/Users/joi263/Documents/MultimodalTabData/data/vit_base_data/vit_base_cleaned_patient_features_separate_256d.csv',
            'ResNet50_Pretrained': '/Users/joi263/Documents/MultimodalTabData/data/pretrained_resnet50_data/pretrained_resnet50_cleaned_patient_features_separate_256d.csv',
            'ResNet50_ImageNet': '/Users/joi263/Documents/MultimodalTabData/data/imagenet_resnet50_data/imagenet_resnet50_cleaned_patient_features_separate_256d.csv',
            'EfficientNet': '/Users/joi263/Documents/MultimodalTabData/data/efficientnet_data/efficientnet_cleaned_patient_features_separate_256d.csv'
        }
        self.results = {}
        self.validation_results = {}
        
        # Print file paths for verification
        print("CHECKING DATA FILE PATHS:")
        print("="*50)
        import os
        for cnn_name, file_path in self.datasets.items():
            exists = os.path.exists(file_path)
            status = "EXISTS" if exists else "NOT FOUND"
            print(f"{cnn_name:<20}: {status}")
            if not exists:
                print(f"  Expected: {file_path}")
        print("="*50)
        print()
        
        # Count how many files exist
        existing_files = sum(1 for path in self.datasets.values() if os.path.exists(path))
        print(f"Found {existing_files}/{len(self.datasets)} data files")
        
        if existing_files == 0:
            print("ERROR: No data files found!")
            print("Please verify the file paths match your actual file locations.")
        elif existing_files < len(self.datasets):
            print(f"WARNING: Only {existing_files} out of {len(self.datasets)} files found.")
            print("Analysis will proceed with available datasets.")
        else:
            print("SUCCESS: All data files found!")
        print()
        
    def get_ml_algorithms(self):
        """Initialize all available ML algorithms with optimized parameters"""
        algorithms = {}
        
        # 1. TabPFN (always available) - Optimized for small biomedical datasets
        algorithms['TabPFN'] = {
            'model': TabPFNClassifier(device='cpu'),  # Only use valid parameters
            'needs_scaling': False,
            'description': 'Transformer-based Few-Shot Learning'
        }
        
        # 2. XGBoost (if available) - Tuned for biomedical data
        if XGBOOST_AVAILABLE:
            algorithms['XGBoost'] = {
                'model': xgb.XGBClassifier(
                    n_estimators=300,  # Increased for better performance
                    max_depth=4,       # Reduced to prevent overfitting on small datasets
                    learning_rate=0.05, # Lower for better generalization
                    subsample=0.8,     # Add regularization
                    colsample_bytree=0.8,
                    min_child_weight=3, # Prevent overfitting
                    reg_alpha=1,       # L1 regularization
                    reg_lambda=1,      # L2 regularization
                    random_state=42,
                    eval_metric='logloss',
                    use_label_encoder=False  # Suppress warnings
                ),
                'needs_scaling': False,
                'description': 'Optimized Gradient Boosting'
            }
        
        # 3. TabNet (if available) - Tuned for tabular biomedical data
        if TABNET_AVAILABLE:
            algorithms['TabNet'] = {
                'model': TabNetClassifier(
                    n_d=64, n_a=64,    # Increased capacity
                    n_steps=5,         # More decision steps
                    gamma=1.5,         # Stronger feature selection
                    lambda_sparse=1e-4, # Lighter sparsity penalty
                    optimizer_fn=torch.optim.Adam,
                    optimizer_params=dict(lr=0.01, weight_decay=1e-5),
                    mask_type="entmax",
                    scheduler_params={"step_size": 20, "gamma": 0.8},
                    scheduler_fn=torch.optim.lr_scheduler.StepLR,
                    verbose=0,
                    seed=42
                ),
                'needs_scaling': True,  # TabNet benefits from scaling
                'description': 'Optimized Attention-based Neural Network'
            }
        
        # 4. Random Forest (always available) - Tuned for biomedical features
        algorithms['RandomForest'] = {
            'model': RandomForestClassifier(
                n_estimators=500,   # Increased for stability
                max_depth=8,        # Moderate depth to prevent overfitting
                min_samples_split=10, # Higher to prevent overfitting
                min_samples_leaf=5,   # Higher to ensure leaf reliability
                max_features='sqrt',  # Good default for classification
                bootstrap=True,
                oob_score=True,     # Out-of-bag validation
                random_state=42,
                class_weight='balanced',
                n_jobs=-1           # Use all cores
            ),
            'needs_scaling': False,
            'description': 'Optimized Ensemble Decision Trees'
        }
        
        # 5. Logistic Regression (always available) - Tuned with regularization
        algorithms['LogisticRegression'] = {
            'model': LogisticRegression(
                penalty='elasticnet',  # Combines L1 and L2 regularization
                l1_ratio=0.5,         # Balance between L1 and L2
                C=0.1,                # Strong regularization for small datasets
                solver='saga',        # Supports elasticnet
                max_iter=2000,        # More iterations for convergence
                random_state=42,
                class_weight='balanced',
                n_jobs=-1
            ),
            'needs_scaling': True,  # CRITICAL for logistic regression
            'description': 'Regularized Linear Model with ElasticNet'
        }
        
        # 6. Support Vector Machine - Added as bonus strong performer
        algorithms['SVM'] = {
            'model': SVC(
                kernel='rbf',
                C=1.0,                # Balanced regularization
                gamma='scale',        # Adaptive gamma
                probability=True,     # Enable probability estimates
                random_state=42,
                class_weight='balanced'
            ),
            'needs_scaling': True,    # CRITICAL for SVM
            'description': 'Support Vector Machine with RBF Kernel'
        }
        
        return algorithms

    def create_all_targets(self, df):
        """Create all prediction targets: mortality, tumor classification, IDH, MGMT"""
        print("="*60)
        print("CREATING ALL PREDICTION TARGETS")
        print("="*60)
        
        targets_data = {}
        
        # ============================================================
        # MORTALITY TARGETS
        # ============================================================
        print("MORTALITY TARGETS:")
        survival_data = df[df['survival'].notna() & df['patient_status'].notna()].copy()
        
        if len(survival_data) > 0:
            survival_data['mortality_6mo'] = ((survival_data['patient_status'] == 2) & 
                                              (survival_data['survival'] <= 6)).astype(int)
            survival_data['mortality_1yr'] = ((survival_data['patient_status'] == 2) & 
                                              (survival_data['survival'] <= 12)).astype(int)
            survival_data['mortality_2yr'] = ((survival_data['patient_status'] == 2) & 
                                              (survival_data['survival'] <= 24)).astype(int)
            
            targets_data['mortality'] = {
                'data': survival_data,
                'targets': ['mortality_6mo', 'mortality_1yr', 'mortality_2yr'],
                'descriptions': ['6-Month Mortality', '1-Year Mortality', '2-Year Mortality']
            }
            
            print(f"   Patients: {len(survival_data)}")
            print(f"   6-month: {survival_data['mortality_6mo'].sum()}/{len(survival_data)} ({survival_data['mortality_6mo'].mean()*100:.1f}%)")
            print(f"   1-year: {survival_data['mortality_1yr'].sum()}/{len(survival_data)} ({survival_data['mortality_1yr'].mean()*100:.1f}%)")
            print(f"   2-year: {survival_data['mortality_2yr'].sum()}/{len(survival_data)} ({survival_data['mortality_2yr'].mean()*100:.1f}%)")
        
        # ============================================================
        # TUMOR CLASSIFICATION TARGETS
        # ============================================================
        print("\nTUMOR CLASSIFICATION TARGETS:")
        tumor_data = df[df['methylation_class'].notna()].copy()
        
        if len(tumor_data) > 0:
            # Binary high-grade vs low-grade
            high_grade_terms = ['glioblastoma', 'anaplastic', 'high grade', 'grade iv', 'grade 4', 'gbm']
            tumor_data['high_grade'] = tumor_data['methylation_class'].str.lower().str.contains(
                '|'.join(high_grade_terms), na=False
            ).astype(int)
            
            targets_data['tumor'] = {
                'data': tumor_data,
                'targets': ['high_grade'],
                'descriptions': ['High-Grade vs Low-Grade']
            }
            
            print(f"   Patients: {len(tumor_data)}")
            print(f"   High-grade: {tumor_data['high_grade'].sum()}/{len(tumor_data)} ({tumor_data['high_grade'].mean()*100:.1f}%)")
        
        # ============================================================
        # IDH MUTATION TARGETS
        # ============================================================
        print("\nIDH MUTATION TARGETS:")
        idh_data = self._create_idh_targets(df)
        
        if idh_data is not None and len(idh_data) > 0:
            targets_data['idh'] = {
                'data': idh_data,
                'targets': ['idh_binary'],
                'descriptions': ['IDH Mutation Status']
            }
            
            print(f"   Patients: {len(idh_data)}")
            print(f"   IDH Mutant: {idh_data['idh_binary'].sum()}/{len(idh_data)} ({idh_data['idh_binary'].mean()*100:.1f}%)")
        
        # ============================================================
        # MGMT METHYLATION TARGETS
        # ============================================================
        print("\nMGMT METHYLATION TARGETS:")
        mgmt_data = self._create_mgmt_targets(df)
        
        if mgmt_data is not None and len(mgmt_data) > 0:
            targets_data['mgmt'] = {
                'data': mgmt_data,
                'targets': ['mgmt_binary'],
                'descriptions': ['MGMT Promoter Methylation']
            }
            
            print(f"   Patients: {len(mgmt_data)}")
            print(f"   MGMT Methylated: {mgmt_data['mgmt_binary'].sum()}/{len(mgmt_data)} ({mgmt_data['mgmt_binary'].mean()*100:.1f}%)")
        
        return targets_data

    def _create_idh_targets(self, df):
        """Create IDH mutation targets with proper decoding"""
        if 'idh_1_r132h' not in df.columns:
            return None
            
        idh_data = df.copy()
        idh_data['idh_binary'] = np.nan
        
        # Cross-reference with text data if available
        if 'idh1' in df.columns:
            text_idh = df['idh1'].astype(str).str.lower()
            mutant_patterns = ['r132h', 'r132s', 'arg132his', 'arg132ser', 'missense', 'p.arg132']
            is_mutant_text = text_idh.str.contains('|'.join(mutant_patterns), na=False)
            idh_data.loc[is_mutant_text, 'idh_binary'] = 1  # Mutant
        
        # Apply numerical encoding (2 = mutant based on cross-reference analysis)
        remaining_mask = idh_data['idh_binary'].isna() & idh_data['idh_1_r132h'].notna()
        idh_data.loc[remaining_mask & (idh_data['idh_1_r132h'] == 2), 'idh_binary'] = 1  # Mutant
        idh_data.loc[remaining_mask & (idh_data['idh_1_r132h'] == 1), 'idh_binary'] = 0  # Wildtype
        
        # Exclude unknown cases
        idh_data.loc[idh_data['idh_1_r132h'] == 3, 'idh_binary'] = np.nan
        
        return idh_data[idh_data['idh_binary'].notna()].copy()

    def _create_mgmt_targets(self, df):
        """Create MGMT methylation targets with correct encoding"""
        if 'mgmt' not in df.columns:
            return None
            
        mgmt_data = df[df['mgmt'].notna()].copy()
        
        if len(mgmt_data) == 0:
            return None
        
        # Correct encoding based on data dictionary:
        # 1 = Positive (methylated), 2 = Negative (unmethylated), 3 = Non-informative
        mgmt_data['mgmt_binary'] = np.nan
        
        # Set methylated cases (value = 1)
        mgmt_data.loc[mgmt_data['mgmt'] == 1, 'mgmt_binary'] = 1  # Methylated
        
        # Set unmethylated cases (value = 2) 
        mgmt_data.loc[mgmt_data['mgmt'] == 2, 'mgmt_binary'] = 0  # Unmethylated
        
        # Exclude non-informative cases (value = 3)
        mgmt_data.loc[mgmt_data['mgmt'] == 3, 'mgmt_binary'] = np.nan
        
        # Return only cases with definitive results
        return mgmt_data[mgmt_data['mgmt_binary'].notna()].copy()

    def select_features(self, df):
        """Select comprehensive feature set"""
        # Clinical features
        clinical_features = ['age', 'sex', 'race', 'ethnicity', 'gtr']
        
        # Molecular features (exclude target variables to prevent leakage)
        molecular_features = ['mgmt_pyro', 'atrx', 'p53', 'braf_v600', 'h3k27m', 'gfap', 'tumor', 'hg_glioma']
        
        # CNN-extracted imaging features
        image_features = [col for col in df.columns if col.startswith('feature_')]
        
        # Combine all features
        all_features = clinical_features + molecular_features + image_features
        available_features = [f for f in all_features if f in df.columns]
        
        return available_features

    def preprocess_data(self, df, features, target_col):
        """Advanced preprocessing for multiple ML algorithms"""
        data = df[features + [target_col]].copy()
        data = data[data[target_col].notna()]
        
        if len(data) < 15:  # Minimum viable sample size
            return None, None, f"Insufficient data: {len(data)} samples"
        
        # Handle categorical features
        categorical_features = data.select_dtypes(include=['object']).columns.tolist()
        if target_col in categorical_features:
            categorical_features.remove(target_col)
        
        for col in categorical_features:
            if col in features:
                le = LabelEncoder()
                data[col] = data[col].astype(str)
                data[col] = le.fit_transform(data[col])
        
        # Handle missing values
        numerical_features = [f for f in features if f in data.select_dtypes(include=[np.number]).columns]
        
        for col in numerical_features:
            if data[col].isnull().sum() > 0:
                if col.startswith('feature_'):
                    data[col] = data[col].fillna(data[col].mean())
                else:
                    data[col] = data[col].fillna(data[col].median())
        
        # Remove features with >50% missing
        missing_pct = data[features].isnull().mean()
        good_features = missing_pct[missing_pct <= 0.5].index.tolist()
        
        if len(good_features) < len(features):
            features = good_features
            data = data[features + [target_col]]
        
        # Feature selection for computational efficiency
        X = data[features].values
        y = data[target_col].values
        
        # Check class balance
        unique_classes, class_counts = np.unique(y, return_counts=True)
        min_class_size = min(class_counts)
        
        if min_class_size < 3:
            return None, None, f"Class too small: minimum class has {min_class_size} samples"
        
        # Feature selection (limit to 100 for computational efficiency)
        if X.shape[1] > 100:
            selector = SelectKBest(score_func=f_classif, k=100)
            X = selector.fit_transform(X, y)
        
        return X, y, None

    def train_and_evaluate_algorithm(self, X_train, X_test, y_train, y_test, algorithm_name, algorithm_config):
        """Train and evaluate a single algorithm with optimized preprocessing"""
        try:
            model = algorithm_config['model']
            needs_scaling = algorithm_config['needs_scaling']
            
            # Apply robust scaling if needed
            if needs_scaling:
                # Use RobustScaler for biomedical data (handles outliers better than StandardScaler)
                from sklearn.preprocessing import RobustScaler
                scaler = RobustScaler(quantile_range=(10.0, 90.0))  # Less sensitive to outliers
                X_train_processed = scaler.fit_transform(X_train)
                X_test_processed = scaler.transform(X_test)
                
                # Handle potential scaling issues
                if np.any(np.isnan(X_train_processed)) or np.any(np.isnan(X_test_processed)):
                    # Fallback to StandardScaler if RobustScaler fails
                    scaler = StandardScaler()
                    X_train_processed = scaler.fit_transform(X_train)
                    X_test_processed = scaler.transform(X_test)
            else:
                X_train_processed = X_train
                X_test_processed = X_test
            
            # Special handling for different algorithms
            if algorithm_name == 'TabNet' and TABNET_AVAILABLE:
                # TabNet needs special training procedure
                model.fit(
                    X_train_processed, y_train,
                    eval_set=[(X_test_processed, y_test)],
                    patience=20,        # Increased patience for better convergence
                    max_epochs=100,     # More epochs for biomedical data
                    eval_metric=['auc'],
                    batch_size=min(256, len(X_train)//4)  # Adaptive batch size
                )
                y_pred_proba = model.predict_proba(X_test_processed)[:, 1]
                y_pred = (y_pred_proba > 0.5).astype(int)
                
            elif algorithm_name == 'XGBoost' and XGBOOST_AVAILABLE:
                # XGBoost with standard training (early stopping varies by version)
                try:
                    # Try with early stopping if supported
                    eval_set = [(X_test_processed, y_test)]
                    model.fit(
                        X_train_processed, y_train,
                        eval_set=eval_set,
                        verbose=False
                    )
                except TypeError:
                    # Fallback to standard training if early stopping not supported
                    model.fit(X_train_processed, y_train)
                
                y_pred = model.predict(X_test_processed)
                y_pred_proba = model.predict_proba(X_test_processed)[:, 1]
                
            else:
                # Standard scikit-learn interface
                model.fit(X_train_processed, y_train)
                y_pred = model.predict(X_test_processed)
                
                if hasattr(model, 'predict_proba'):
                    y_pred_proba = model.predict_proba(X_test_processed)[:, 1]
                else:
                    y_pred_proba = y_pred.astype(float)
            
            # Calculate comprehensive metrics
            accuracy = accuracy_score(y_test, y_pred)
            
            # Robust AUC calculation
            try:
                auc = roc_auc_score(y_test, y_pred_proba)
            except ValueError:
                # Handle edge cases (e.g., all one class in test set)
                auc = 0.5
            
            # Confusion matrix and clinical metrics
            cm = confusion_matrix(y_test, y_pred)
            
            # Clinical metrics for binary classification
            if cm.shape == (2, 2):
                tn, fp, fn, tp = cm.ravel()
                sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
                specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
                ppv = tp / (tp + fp) if (tp + fp) > 0 else 0
                npv = tn / (tn + fn) if (tn + fn) > 0 else 0
            else:
                sensitivity = specificity = ppv = npv = 0
            
            # Additional metrics for model comparison
            balanced_accuracy = (sensitivity + specificity) / 2
            f1_score = 2 * (ppv * sensitivity) / (ppv + sensitivity) if (ppv + sensitivity) > 0 else 0
            
            return {
                'accuracy': accuracy,
                'balanced_accuracy': balanced_accuracy,
                'auc': auc,
                'sensitivity': sensitivity,
                'specificity': specificity,
                'ppv': ppv,
                'npv': npv,
                'f1_score': f1_score,
                'confusion_matrix': cm,
                'n_test': len(y_test),
                'scaling_used': needs_scaling
            }
            
        except Exception as e:
            print(f"   ❌ {algorithm_name} failed: {str(e)}")
            return None

    def run_prediction_task(self, X, y, task_name, cnn_name, algorithms):
        """Run prediction task with cross-validation and single holdout validation"""
        print(f"\n{'='*50}")
        print(f"{task_name} - {cnn_name}")
        print(f"{'='*50}")
        
        # Single holdout split for detailed analysis
        try:
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=0.25, random_state=42, stratify=y
            )
        except:
            # If stratification fails, try without it
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=0.25, random_state=42
            )
        
        print(f"DATA SPLIT:")
        print(f"   Training: {len(X_train)} samples")
        print(f"   Testing: {len(X_test)} samples")
        print(f"   Positive rate: {y_train.mean()*100:.1f}% (train), {y_test.mean()*100:.1f}% (test)")
        
        results = {}
        
        # Test each algorithm with both holdout and cross-validation
        for alg_name, alg_config in algorithms.items():
            print(f"\nTESTING {alg_name}...")
            
            # Single holdout result (for detailed metrics)
            holdout_result = self.train_and_evaluate_algorithm(X_train, X_test, y_train, y_test, alg_name, alg_config)
            
            if holdout_result is None:
                print(f"   ERROR {alg_name}: FAILED")
                continue
            
            # Cross-validation for robustness
            cv_result = self.cross_validate_algorithm(X, y, alg_name, alg_config)
            
            if cv_result is None:
                print(f"   WARNING {alg_name}: Cross-validation failed, using holdout only")
                cv_result = {
                    'cv_auc_mean': holdout_result['auc'],
                    'cv_auc_std': 0.0,
                    'cv_auc_ci_lower': holdout_result['auc'],
                    'cv_auc_ci_upper': holdout_result['auc'],
                    'cv_accuracy_mean': holdout_result['accuracy'],
                    'cv_accuracy_std': 0.0,
                    'cv_folds': 1,
                    'cv_stability': 'SINGLE_SPLIT'
                }
            
            # Combine holdout and CV results
            combined_result = {**holdout_result, **cv_result}
            results[alg_name] = combined_result
            
            # Enhanced reporting with confidence intervals
            auc_mean = cv_result['cv_auc_mean']
            auc_std = cv_result['cv_auc_std']
            auc_ci_lower = cv_result['cv_auc_ci_lower']
            auc_ci_upper = cv_result['cv_auc_ci_upper']
            stability = cv_result['cv_stability']
            
            print(f"   HOLDOUT: Accuracy={holdout_result['accuracy']:.3f}, AUC={holdout_result['auc']:.3f}")
            print(f"   CROSS-VAL: AUC={auc_mean:.3f} (95% CI: {auc_ci_lower:.3f}-{auc_ci_upper:.3f})")
            print(f"   STABILITY: {stability}")
            
            # Clinical interpretation with confidence intervals
            if auc_ci_lower >= 0.85:
                print(f"       EXCELLENT clinical performance (robust across CV)")
            elif auc_mean >= 0.85 and auc_ci_lower >= 0.75:
                print(f"       EXCELLENT clinical performance (some variability)")
            elif auc_ci_lower >= 0.75:
                print(f"       STRONG clinical performance (robust across CV)")
            elif auc_mean >= 0.75 and auc_ci_lower >= 0.65:
                print(f"       STRONG clinical performance (some variability)")
            elif auc_ci_lower >= 0.65:
                print(f"       GOOD performance (robust across CV)")
            else:
                print(f"       MODERATE performance (consider more data/optimization)")
        
        return results

    def cross_validate_algorithm(self, X, y, algorithm_name, algorithm_config, cv_folds=5):
        """Perform stratified cross-validation with confidence intervals"""
        try:
            # Create stratified k-fold
            cv = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=42)
            
            # Storage for CV results
            cv_aucs = []
            cv_accuracies = []
            cv_sensitivities = []
            cv_specificities = []
            
            fold_num = 0
            for train_idx, val_idx in cv.split(X, y):
                fold_num += 1
                X_train_cv, X_val_cv = X[train_idx], X[val_idx]
                y_train_cv, y_val_cv = y[train_idx], y[val_idx]
                
                # Train and evaluate on this fold
                fold_result = self.train_and_evaluate_algorithm(
                    X_train_cv, X_val_cv, y_train_cv, y_val_cv, 
                    algorithm_name, algorithm_config
                )
                
                if fold_result is not None:
                    cv_aucs.append(fold_result['auc'])
                    cv_accuracies.append(fold_result['accuracy'])
                    cv_sensitivities.append(fold_result['sensitivity'])
                    cv_specificities.append(fold_result['specificity'])
                else:
                    # If a fold fails, record it but continue
                    cv_aucs.append(0.5)  # Random performance
                    cv_accuracies.append(0.5)
                    cv_sensitivities.append(0.5)
                    cv_specificities.append(0.5)
            
            # Calculate CV statistics
            cv_aucs = np.array(cv_aucs)
            cv_accuracies = np.array(cv_accuracies)
            
            # Mean and standard deviation
            auc_mean = np.mean(cv_aucs)
            auc_std = np.std(cv_aucs)
            acc_mean = np.mean(cv_accuracies)
            acc_std = np.std(cv_accuracies)
            
            # 95% Confidence intervals (using t-distribution for small samples)
            from scipy import stats
            t_critical = stats.t.ppf(0.975, df=len(cv_aucs)-1)  # 95% CI
            auc_margin = t_critical * (auc_std / np.sqrt(len(cv_aucs)))
            
            auc_ci_lower = max(0.0, auc_mean - auc_margin)
            auc_ci_upper = min(1.0, auc_mean + auc_margin)
            
            # Stability assessment
            cv_of_variation = auc_std / auc_mean if auc_mean > 0 else 1.0
            
            if cv_of_variation < 0.05:
                stability = "HIGHLY STABLE"
            elif cv_of_variation < 0.10:
                stability = "STABLE"
            elif cv_of_variation < 0.15:
                stability = "MODERATE VARIABILITY"
            else:
                stability = "HIGH VARIABILITY"
            
            return {
                'cv_auc_mean': auc_mean,
                'cv_auc_std': auc_std,
                'cv_auc_ci_lower': auc_ci_lower,
                'cv_auc_ci_upper': auc_ci_upper,
                'cv_accuracy_mean': acc_mean,
                'cv_accuracy_std': acc_std,
                'cv_sensitivity_mean': np.mean(cv_sensitivities),
                'cv_specificity_mean': np.mean(cv_specificities),
                'cv_folds': cv_folds,
                'cv_stability': stability,
                'cv_coefficient_variation': cv_of_variation,
                'cv_individual_aucs': cv_aucs.tolist()
            }
            
        except Exception as e:
            print(f"   Cross-validation failed for {algorithm_name}: {e}")
            return None

    def _check_feature_quality(self, df):
        """Check feature quality and completeness"""
        try:
            image_features = [col for col in df.columns if col.startswith('feature_')]
            clinical_features = ['age', 'sex', 'race', 'ethnicity']
            
            image_quality = len(image_features) >= 50  # Sufficient image features
            clinical_completeness = sum(col in df.columns for col in clinical_features) >= 2
            
            score = (image_quality + clinical_completeness) / 2
            
            return {
                'status': 'PASS' if score >= 0.5 else 'WARN',
                'score': score,
                'details': f"Image features: {len(image_features)}, Clinical completeness: {clinical_completeness}"
            }
        except:
            return {'status': 'FAIL', 'score': 0, 'details': 'Feature quality check failed'}

    def run_validation_checks(self, cnn_name, file_path):
        """Run comprehensive validation checks"""
        print(f"\n🔍 VALIDATION CHECKS FOR {cnn_name}")
        print("="*50)
        
        try:
            df = pd.read_csv(file_path)
            
            validation = {
                'data_integrity': self._check_data_integrity(df),
                'class_balance': self._check_class_balance(df),
                'feature_quality': self._check_feature_quality(df),
                'sample_size': self._check_sample_size(df)
            }
            
            # Overall assessment
            passed_checks = sum(1 for check in validation.values() if check['status'] == 'PASS')
            total_checks = len(validation)
            
            validation['overall'] = {
                'status': 'PASS' if passed_checks >= 3 else 'WARN',
                'score': passed_checks / total_checks,
                'summary': f"{passed_checks}/{total_checks} validation checks passed"
            }
            
            return validation
            
        except Exception as e:
            return {'error': str(e)}

    def _check_data_integrity(self, df):
        """Check basic data integrity"""
        try:
            has_survival = df['survival'].notna().sum() > 10
            has_molecular = any(col in df.columns for col in ['mgmt', 'idh_1_r132h', 'methylation_class'])
            has_images = any(col.startswith('feature_') for col in df.columns)
            
            score = sum([has_survival, has_molecular, has_images]) / 3
            
            return {
                'status': 'PASS' if score >= 0.67 else 'WARN',
                'score': score,
                'details': f"Survival: {has_survival}, Molecular: {has_molecular}, Images: {has_images}"
            }
        except:
            return {'status': 'FAIL', 'score': 0, 'details': 'Data integrity check failed'}

    def _check_class_balance(self, df):
        """Check class balance across targets"""
        try:
            balances = []
            
            # Check mortality balance
            if 'survival' in df.columns and 'patient_status' in df.columns:
                survival_data = df[df['survival'].notna() & df['patient_status'].notna()]
                if len(survival_data) > 0:
                    mortality_1yr = ((survival_data['patient_status'] == 2) & 
                                   (survival_data['survival'] <= 12)).mean()
                    balances.append(min(mortality_1yr, 1-mortality_1yr))
            
            # Check tumor grade balance
            if 'methylation_class' in df.columns:
                tumor_data = df[df['methylation_class'].notna()]
                if len(tumor_data) > 0:
                    high_grade_terms = ['glioblastoma', 'anaplastic', 'high grade', 'grade iv', 'grade 4', 'gbm']
                    high_grade_rate = tumor_data['methylation_class'].str.lower().str.contains(
                        '|'.join(high_grade_terms), na=False
                    ).mean()
                    balances.append(min(high_grade_rate, 1-high_grade_rate))
            
            avg_balance = np.mean(balances) if balances else 0
            
            return {
                'status': 'PASS' if avg_balance >= 0.15 else 'WARN',
                'score': avg_balance,
                'details': f"Average minority class rate: {avg_balance:.3f}"
            }
        except:
            return {'status': 'FAIL', 'score': 0, 'details': 'Class balance check failed'}

    def _check_confounding_factors(self, df):
        """Check for potential confounding factors in clinical predictions"""
        try:
            confounding_issues = []
            severity_scores = []
            
            # Check for age-outcome confounding
            age_confounding = self._check_age_confounding(df)
            if age_confounding['severity'] > 0:
                confounding_issues.append(age_confounding)
                severity_scores.append(age_confounding['severity'])
            
            # Check for center/batch effects (if institutional data available)
            batch_confounding = self._check_batch_effects(df)
            if batch_confounding['severity'] > 0:
                confounding_issues.append(batch_confounding)
                severity_scores.append(batch_confounding['severity'])
            
            # Check for molecular marker interdependence
            molecular_confounding = self._check_molecular_confounding(df)
            if molecular_confounding['severity'] > 0:
                confounding_issues.append(molecular_confounding)
                severity_scores.append(molecular_confounding['severity'])
            
            # Check for survival bias in molecular markers
            survival_bias = self._check_survival_bias(df)
            if survival_bias['severity'] > 0:
                confounding_issues.append(survival_bias) 
                severity_scores.append(survival_bias['severity'])
            
            # Overall assessment
            if not severity_scores:
                status = 'PASS'
                score = 1.0
                details = "No major confounding factors detected"
            else:
                max_severity = max(severity_scores)
                if max_severity >= 0.8:
                    status = 'FAIL'
                    score = 0.2
                    details = f"Critical confounding detected: {len(confounding_issues)} issues"
                elif max_severity >= 0.5:
                    status = 'WARN'
                    score = 0.6
                    details = f"Moderate confounding detected: {len(confounding_issues)} issues"
                else:
                    status = 'PASS'
                    score = 0.8
                    details = f"Minor confounding detected: {len(confounding_issues)} issues"
            
            return {
                'status': status,
                'score': score,
                'details': details,
                'confounding_issues': confounding_issues,
                'n_issues': len(confounding_issues)
            }
            
        except Exception as e:
            return {
                'status': 'WARN',
                'score': 0.5,
                'details': f'Confounding check incomplete: {str(e)}',
                'confounding_issues': [],
                'n_issues': 0
            }

    def _check_age_confounding(self, df):
        """Check if age is confounded with outcomes"""
        try:
            if 'age' not in df.columns:
                return {'type': 'age', 'severity': 0, 'description': 'Age data not available'}
            
            issues = []
            max_severity = 0
            
            # Check age-mortality confounding
            if 'survival' in df.columns and 'patient_status' in df.columns:
                survival_data = df[df['survival'].notna() & df['patient_status'].notna() & df['age'].notna()]
                if len(survival_data) > 10:
                    deceased = survival_data[survival_data['patient_status'] == 2]['age']
                    alive = survival_data[survival_data['patient_status'] != 2]['age']
                    
                    if len(deceased) > 5 and len(alive) > 5:
                        age_diff = abs(deceased.mean() - alive.mean())
                        pooled_std = np.sqrt(((deceased.std()**2 + alive.std()**2) / 2))
                        effect_size = age_diff / pooled_std if pooled_std > 0 else 0
                        
                        if effect_size > 0.8:  # Large effect
                            severity = 0.9
                            issues.append(f"Large age difference between deceased ({deceased.mean():.1f}) and alive ({alive.mean():.1f})")
                        elif effect_size > 0.5:  # Medium effect
                            severity = 0.6
                            issues.append(f"Moderate age difference between outcomes")
                        
                        max_severity = max(max_severity, severity if 'severity' in locals() else 0)
            
            # Check age-tumor grade confounding  
            if 'methylation_class' in df.columns:
                tumor_data = df[df['methylation_class'].notna() & df['age'].notna()]
                if len(tumor_data) > 10:
                    high_grade_terms = ['glioblastoma', 'anaplastic', 'high grade', 'grade iv', 'grade 4', 'gbm']
                    high_grade_mask = tumor_data['methylation_class'].str.lower().str.contains('|'.join(high_grade_terms), na=False)
                    
                    high_grade_ages = tumor_data[high_grade_mask]['age']
                    low_grade_ages = tumor_data[~high_grade_mask]['age']
                    
                    if len(high_grade_ages) > 5 and len(low_grade_ages) > 5:
                        age_diff = abs(high_grade_ages.mean() - low_grade_ages.mean())
                        pooled_std = np.sqrt(((high_grade_ages.std()**2 + low_grade_ages.std()**2) / 2))
                        effect_size = age_diff / pooled_std if pooled_std > 0 else 0
                        
                        if effect_size > 0.8:
                            severity = 0.7  # Slightly less critical than mortality
                            issues.append(f"Age strongly associated with tumor grade")
                            max_severity = max(max_severity, severity)
            
            return {
                'type': 'age_confounding',
                'severity': max_severity,
                'description': '; '.join(issues) if issues else 'No significant age confounding detected'
            }
            
        except:
            return {'type': 'age_confounding', 'severity': 0, 'description': 'Age confounding check failed'}

    def _check_batch_effects(self, df):
        """Check for potential batch/center effects"""
        try:
            # Look for institutional or batch identifiers
            batch_columns = [col for col in df.columns if any(term in col.lower() 
                           for term in ['institution', 'center', 'batch', 'site', 'hospital'])]
            
            if not batch_columns:
                return {'type': 'batch_effects', 'severity': 0, 'description': 'No batch identifiers found'}
            
            # Check if outcomes vary significantly by batch
            severity = 0
            issues = []
            
            for batch_col in batch_columns:
                unique_batches = df[batch_col].nunique()
                if unique_batches > 1 and unique_batches < len(df) * 0.5:  # Reasonable number of batches
                    # Check mortality rates by batch
                    if 'survival' in df.columns and 'patient_status' in df.columns:
                        batch_mortality = df.groupby(batch_col).apply(
                            lambda x: ((x['patient_status'] == 2) & (x['survival'] <= 12)).mean()
                        )
                        if batch_mortality.std() > 0.15:  # >15% variation in mortality rates
                            severity = max(severity, 0.6)
                            issues.append(f"Mortality rates vary by {batch_col}")
            
            return {
                'type': 'batch_effects',
                'severity': severity,
                'description': '; '.join(issues) if issues else 'No significant batch effects detected'
            }
            
        except:
            return {'type': 'batch_effects', 'severity': 0, 'description': 'Batch effects check failed'}

    def _check_molecular_confounding(self, df):
        """Check for confounding between molecular markers"""
        try:
            molecular_cols = ['mgmt', 'idh_1_r132h', 'atrx', 'p53']
            available_molecular = [col for col in molecular_cols if col in df.columns]
            
            if len(available_molecular) < 2:
                return {'type': 'molecular_confounding', 'severity': 0, 'description': 'Insufficient molecular data'}
            
            issues = []
            max_severity = 0
            
            # Check IDH-MGMT association (known biological confounding)
            if 'idh_1_r132h' in df.columns and 'mgmt' in df.columns:
                idh_mgmt_data = df[(df['idh_1_r132h'].isin([1, 2])) & (df['mgmt'].isin([1, 2]))]
                
                if len(idh_mgmt_data) > 20:
                    # Create contingency table
                    idh_mutant = (idh_mgmt_data['idh_1_r132h'] == 2)  # Assuming 2 = mutant
                    mgmt_methylated = (idh_mgmt_data['mgmt'] == 1)  # 1 = methylated per data dictionary
                    
                    # Calculate association strength (Cramér's V)
                    from scipy.stats import chi2_contingency
                    try:
                        contingency = pd.crosstab(idh_mutant, mgmt_methylated)
                        chi2, p_value, dof, expected = chi2_contingency(contingency)
                        n = contingency.sum().sum()
                        cramers_v = np.sqrt(chi2 / (n * (min(contingency.shape) - 1)))
                        
                        if cramers_v > 0.5 and p_value < 0.05:
                            max_severity = 0.8
                            issues.append("Strong IDH-MGMT association detected (biological confounding)")
                        elif cramers_v > 0.3 and p_value < 0.05:
                            max_severity = 0.5
                            issues.append("Moderate IDH-MGMT association detected")
                    except:
                        pass
            
            return {
                'type': 'molecular_confounding',
                'severity': max_severity,
                'description': '; '.join(issues) if issues else 'No significant molecular confounding detected'
            }
            
        except:
            return {'type': 'molecular_confounding', 'severity': 0, 'description': 'Molecular confounding check failed'}

    def _check_survival_bias(self, df):
        """Check for survival bias in molecular marker availability"""
        try:
            if not all(col in df.columns for col in ['survival', 'patient_status']):
                return {'type': 'survival_bias', 'severity': 0, 'description': 'Survival data not available'}
            
            issues = []
            max_severity = 0
            
            molecular_cols = ['mgmt', 'idh_1_r132h', 'atrx', 'p53']
            
            for mol_col in molecular_cols:
                if mol_col in df.columns:
                    # Compare survival times between patients with/without molecular data
                    has_molecular = df[df[mol_col].notna() & df['survival'].notna()]
                    no_molecular = df[df[mol_col].isna() & df['survival'].notna()]
                    
                    if len(has_molecular) > 10 and len(no_molecular) > 10:
                        survival_diff = abs(has_molecular['survival'].mean() - no_molecular['survival'].mean())
                        pooled_std = np.sqrt((has_molecular['survival'].std()**2 + no_molecular['survival'].std()**2) / 2)
                        
                        if pooled_std > 0:
                            effect_size = survival_diff / pooled_std
                            
                            if effect_size > 0.5:  # Medium to large effect
                                severity = 0.6
                                issues.append(f"Survival bias detected for {mol_col} availability")
                                max_severity = max(max_severity, severity)
            
            return {
                'type': 'survival_bias',
                'severity': max_severity,
                'description': '; '.join(issues) if issues else 'No significant survival bias detected'
            }
            
        except:
            return {'type': 'survival_bias', 'severity': 0, 'description': 'Survival bias check failed'}
        """Check feature quality and completeness"""
        try:
            image_features = [col for col in df.columns if col.startswith('feature_')]
            clinical_features = ['age', 'sex', 'race', 'ethnicity']
            
            image_quality = len(image_features) >= 50  # Sufficient image features
            clinical_completeness = sum(col in df.columns for col in clinical_features) >= 2
            
            score = (image_quality + clinical_completeness) / 2
            
            return {
                'status': 'PASS' if score >= 0.5 else 'WARN',
                'score': score,
                'details': f"Image features: {len(image_features)}, Clinical completeness: {clinical_completeness}"
            }
        except:
            return {'status': 'FAIL', 'score': 0, 'details': 'Feature quality check failed'}

    def _check_sample_size(self, df):
        """Check sample size adequacy"""
        try:
            total_samples = len(df)
            
            # Check samples for different tasks
            survival_samples = df[df['survival'].notna() & df['patient_status'].notna()].shape[0]
            tumor_samples = df[df['methylation_class'].notna()].shape[0]
            
            min_samples = min(survival_samples, tumor_samples) if tumor_samples > 0 else survival_samples
            
            if min_samples >= 50:
                status = 'PASS'
                score = 1.0
            elif min_samples >= 30:
                status = 'WARN'
                score = 0.7
            else:
                status = 'FAIL'
                score = 0.3
            
            return {
                'status': status,
                'score': score,
                'details': f"Min task samples: {min_samples}, Total: {total_samples}"
            }
        except:
            return {'status': 'FAIL', 'score': 0, 'details': 'Sample size check failed'}

    def generate_publication_document(self):
        """Generate a comprehensive publication-ready document"""
        
        if not self.results:
            print("No results available for document generation")
            return
        
        # Create comprehensive document content
        doc_content = []
        
        # Title and Header
        doc_content.append("COMPREHENSIVE NEUROSURGICAL AI ANALYSIS")
        doc_content.append("=" * 80)
        doc_content.append("")
        doc_content.append("EXECUTIVE SUMMARY")
        doc_content.append("-" * 40)
        
        # Calculate summary statistics
        total_tests = 0
        excellent_tests = 0
        good_tests = 0
        all_aucs = []
        
        for cnn_name, cnn_results in self.results.items():
            for task_key, task_data in cnn_results.items():
                for alg_name, result in task_data['results'].items():
                    total_tests += 1
                    auc = result['auc']
                    all_aucs.append(auc)
                    
                    if auc >= 0.85:
                        excellent_tests += 1
                    elif auc >= 0.75:
                        good_tests += 1
        
        if all_aucs:
            mean_auc = np.mean(all_aucs)
            max_auc = np.max(all_aucs)
            
            doc_content.append(f"Total algorithm-task combinations tested: {total_tests}")
            doc_content.append(f"Mean AUC across all tests: {mean_auc:.3f}")
            doc_content.append(f"Best AUC achieved: {max_auc:.3f}")
            doc_content.append(f"Excellent performance (AUC >= 0.85): {excellent_tests}/{total_tests} ({excellent_tests/total_tests*100:.1f}%)")
            doc_content.append(f"Good+ performance (AUC >= 0.75): {good_tests+excellent_tests}/{total_tests} ({(good_tests+excellent_tests)/total_tests*100:.1f}%)")
            doc_content.append("")
            
            if excellent_tests > 0:
                doc_content.append(f"CLINICAL DEPLOYMENT: {excellent_tests} combinations ready for validation")
            if max_auc >= 0.90:
                doc_content.append("PUBLICATION STATUS: Exceptional results achieved - ready for top-tier journals")
            elif max_auc >= 0.80:
                doc_content.append("PUBLICATION STATUS: Strong results achieved - ready for clinical journals")
        
        doc_content.append("")
        doc_content.append("")
        
        # Detailed Results Table
        doc_content.append("COMPREHENSIVE RESULTS TABLE")
        doc_content.append("-" * 80)
        doc_content.append("")
        
        # Create detailed table
        header = f"{'CNN':<20} {'Task':<25} {'Algorithm':<15} {'AUC':<8} {'Accuracy':<9} {'Sensitivity':<11} {'Specificity':<11} {'Status':<15}"
        doc_content.append(header)
        doc_content.append("-" * len(header))
        
        for cnn_name, cnn_results in self.results.items():
            for task_key, task_data in cnn_results.items():
                task_name = task_data['task_name']
                
                for alg_name, result in task_data['results'].items():
                    auc = result['auc']
                    acc = result['accuracy']
                    sens = result['sensitivity']
                    spec = result['specificity']
                    
                    # Status based on AUC without emojis
                    if auc >= 0.85:
                        status = "EXCELLENT"
                    elif auc >= 0.75:
                        status = "STRONG"
                    elif auc >= 0.65:
                        status = "GOOD"
                    else:
                        status = "MODERATE"
                    
                    row = f"{cnn_name:<20} {task_name:<25} {alg_name:<15} {auc:<8.3f} {acc:<9.3f} {sens:<11.3f} {spec:<11.3f} {status:<15}"
                    doc_content.append(row)
        
        doc_content.append("")
        doc_content.append("")
        
        # Best Performers Analysis
        doc_content.append("BEST PERFORMERS BY CLINICAL TASK")
        doc_content.append("-" * 40)
        doc_content.append("")
        
        # Find best performer for each task
        task_best = {}
        for cnn_name, cnn_results in self.results.items():
            for task_key, task_data in cnn_results.items():
                task_name = task_data['task_name']
                
                if task_name not in task_best:
                    task_best[task_name] = {'auc': 0, 'cnn': '', 'algorithm': '', 'result': None}
                
                for alg_name, result in task_data['results'].items():
                    if result['auc'] > task_best[task_name]['auc']:
                        task_best[task_name] = {
                            'auc': result['auc'],
                            'cnn': cnn_name,
                            'algorithm': alg_name,
                            'result': result
                        }
        
        for task_name, best in task_best.items():
            auc = best['auc']
            acc = best['result']['accuracy']
            sens = best['result']['sensitivity']
            spec = best['result']['specificity']
            
            status = "DEPLOYMENT READY" if auc >= 0.85 else "PROMISING" if auc >= 0.75 else "NEEDS OPTIMIZATION"
            
            doc_content.append(f"Task: {task_name}")
            doc_content.append(f"  Best Combination: {best['cnn']} + {best['algorithm']}")
            doc_content.append(f"  Performance: AUC = {auc:.3f}, Accuracy = {acc:.3f}")
            doc_content.append(f"  Clinical Metrics: Sensitivity = {sens:.3f}, Specificity = {spec:.3f}")
            doc_content.append(f"  Status: {status}")
            doc_content.append("")
        
        # Algorithm Performance Ranking
        doc_content.append("ALGORITHM PERFORMANCE RANKING")
        doc_content.append("-" * 40)
        doc_content.append("")
        
        algorithm_stats = {}
        for cnn_name, cnn_results in self.results.items():
            for task_key, task_data in cnn_results.items():
                for alg_name, result in task_data['results'].items():
                    if alg_name not in algorithm_stats:
                        algorithm_stats[alg_name] = []
                    algorithm_stats[alg_name].append(result['auc'])
        
        if algorithm_stats:
            sorted_algorithms = sorted(algorithm_stats.items(), key=lambda x: np.mean(x[1]), reverse=True)
            
            for i, (alg_name, aucs) in enumerate(sorted_algorithms, 1):
                mean_auc = np.mean(aucs)
                std_auc = np.std(aucs)
                max_auc = np.max(aucs)
                n_tests = len(aucs)
                
                doc_content.append(f"{i}. {alg_name}")
                doc_content.append(f"   Mean AUC: {mean_auc:.3f} (±{std_auc:.3f})")
                doc_content.append(f"   Best AUC: {max_auc:.3f}")
                doc_content.append(f"   Tests: {n_tests}")
                doc_content.append("")
        
        # CNN Architecture Ranking
        doc_content.append("CNN ARCHITECTURE RANKING")
        doc_content.append("-" * 40)
        doc_content.append("")
        
        cnn_stats = {}
        for cnn_name, cnn_results in self.results.items():
            aucs = []
            for task_key, task_data in cnn_results.items():
                for alg_name, result in task_data['results'].items():
                    aucs.append(result['auc'])
            if aucs:
                cnn_stats[cnn_name] = aucs
        
        if cnn_stats:
            sorted_cnns = sorted(cnn_stats.items(), key=lambda x: np.mean(x[1]), reverse=True)
            
            for i, (cnn_name, aucs) in enumerate(sorted_cnns, 1):
                mean_auc = np.mean(aucs)
                std_auc = np.std(aucs)
                max_auc = np.max(aucs)
                n_tests = len(aucs)
                
                doc_content.append(f"{i}. {cnn_name}")
                doc_content.append(f"   Mean AUC: {mean_auc:.3f} (±{std_auc:.3f})")
                doc_content.append(f"   Best AUC: {max_auc:.3f}")
                doc_content.append(f"   Tests: {n_tests}")
                doc_content.append("")
        
        # Clinical Recommendations
        doc_content.append("CLINICAL IMPLEMENTATION RECOMMENDATIONS")
        doc_content.append("-" * 40)
        doc_content.append("")
        
        # Find deployment-ready combinations
        deployment_ready = []
        for cnn_name, cnn_results in self.results.items():
            for task_key, task_data in cnn_results.items():
                task_name = task_data['task_name']
                for alg_name, result in task_data['results'].items():
                    if result['auc'] >= 0.80:  # Clinical deployment threshold
                        deployment_ready.append({
                            'task': task_name,
                            'cnn': cnn_name,
                            'algorithm': alg_name,
                            'auc': result['auc'],
                            'accuracy': result['accuracy']
                        })
        
        deployment_ready.sort(key=lambda x: x['auc'], reverse=True)
        
        if deployment_ready:
            doc_content.append(f"DEPLOYMENT-READY COMBINATIONS (AUC >= 0.80): {len(deployment_ready)}")
            doc_content.append("")
            
            for i, combo in enumerate(deployment_ready[:10], 1):  # Top 10
                doc_content.append(f"{i}. {combo['task']}")
                doc_content.append(f"   Model: {combo['cnn']} + {combo['algorithm']}")
                doc_content.append(f"   Performance: {combo['auc']:.1%} AUC, {combo['accuracy']:.1%} Accuracy")
                doc_content.append("")
                
            doc_content.append("PRIORITY IMPLEMENTATION:")
            top_combo = deployment_ready[0]
            doc_content.append(f"Task: {top_combo['task']}")
            doc_content.append(f"Architecture: {top_combo['cnn']} + {top_combo['algorithm']}")
            doc_content.append(f"Expected Clinical Performance: {top_combo['auc']:.1%} discrimination accuracy")
            doc_content.append("")
        else:
            doc_content.append("No combinations reached clinical deployment threshold (AUC >= 0.80)")
            doc_content.append("Focus on methodology optimization for best performing approaches")
            doc_content.append("")
        
        # Publication Strategy
        doc_content.append("PUBLICATION STRATEGY")
        doc_content.append("-" * 40)
        doc_content.append("")
        
        # Count publication-ready results
        tier1_results = []  # AUC >= 0.85
        tier2_results = []  # AUC >= 0.75
        
        for cnn_name, cnn_results in self.results.items():
            for task_key, task_data in cnn_results.items():
                task_name = task_data['task_name']
                for alg_name, result in task_data['results'].items():
                    if result['auc'] >= 0.85:
                        tier1_results.append((task_name, cnn_name, alg_name, result['auc']))
                    elif result['auc'] >= 0.75:
                        tier2_results.append((task_name, cnn_name, alg_name, result['auc']))
        
        doc_content.append("PUBLICATION READINESS ASSESSMENT:")
        doc_content.append(f"Tier 1 Results (AUC >= 0.85): {len(tier1_results)} - Suitable for top-tier journals")
        doc_content.append(f"Tier 2 Results (AUC >= 0.75): {len(tier2_results)} - Suitable for clinical journals")
        doc_content.append("")
        
        if tier1_results:
            doc_content.append("TOP-TIER JOURNAL STRATEGY:")
            doc_content.append("Target Journals: Nature Medicine, Lancet Digital Health, Nature Biomedical Engineering")
            best_result = max(tier1_results, key=lambda x: x[3])
            doc_content.append(f"Lead Finding: {best_result[0]} ({best_result[1]} + {best_result[2]}, AUC = {best_result[3]:.3f})")
            doc_content.append("Narrative: 'Deep Learning Achieves Clinical-Grade Performance in Neurosurgical Prediction'")
            doc_content.append("")
            
        if tier2_results:
            doc_content.append("CLINICAL JOURNAL STRATEGY:")
            doc_content.append("Target Journals: Neuro-Oncology, Journal of Neurosurgery, Academic Radiology")
            doc_content.append("Focus: Clinical validation studies and comparative effectiveness research")
            doc_content.append("")
        
        doc_content.append("MANUSCRIPT DEVELOPMENT PRIORITIES:")
        doc_content.append("1. Primary Research Paper: Best performing clinical task for high-impact publication")
        doc_content.append("2. Methodology Paper: Comprehensive multi-architecture comparison study")
        doc_content.append("3. Clinical Implementation Paper: Validation study and cost-effectiveness analysis")
        doc_content.append("4. Technical Paper: Algorithm optimization and feature engineering methods")
        doc_content.append("")
        
        # Validation Summary
        if self.validation_results:
            doc_content.append("DATA VALIDATION SUMMARY")
            doc_content.append("-" * 40)
            doc_content.append("")
            
            validation_header = f"{'CNN Architecture':<20} {'Overall Status':<15} {'Data Quality':<12} {'Class Balance':<12} {'Sample Size':<12}"
            doc_content.append(validation_header)
            doc_content.append("-" * len(validation_header))
            
            for cnn_name, validation in self.validation_results.items():
                if 'error' in validation:
                    doc_content.append(f"{cnn_name:<20} {'ERROR':<15} {'N/A':<12} {'N/A':<12} {'N/A':<12}")
                else:
                    overall = validation.get('overall', {}).get('status', 'FAIL')
                    data_quality = validation.get('data_integrity', {}).get('status', 'FAIL')
                    class_balance = validation.get('class_balance', {}).get('status', 'FAIL')
                    sample_size = validation.get('sample_size', {}).get('status', 'FAIL')
                    
                    doc_content.append(f"{cnn_name:<20} {overall:<15} {data_quality:<12} {class_balance:<12} {sample_size:<12}")
            
            doc_content.append("")
        
        # Technical Specifications
        doc_content.append("TECHNICAL SPECIFICATIONS")
        doc_content.append("-" * 40)
        doc_content.append("")
        doc_content.append("Machine Learning Algorithms Tested:")
        
        algorithms = self.get_ml_algorithms()
        for i, (alg_name, alg_config) in enumerate(algorithms.items(), 1):
            doc_content.append(f"{i}. {alg_name}: {alg_config['description']}")
            doc_content.append(f"   Preprocessing: {'Robust Scaling Applied' if alg_config['needs_scaling'] else 'No Scaling Required'}")
        
        doc_content.append("")
        doc_content.append("CNN Architectures Evaluated:")
        for i, cnn_name in enumerate(self.datasets.keys(), 1):
            doc_content.append(f"{i}. {cnn_name}")
        
        doc_content.append("")
        doc_content.append("Clinical Tasks Assessed:")
        tasks = set()
        for cnn_results in self.results.values():
            for task_data in cnn_results.values():
                tasks.add(task_data['task_name'])
        
        for i, task in enumerate(sorted(tasks), 1):
            doc_content.append(f"{i}. {task}")
        
        doc_content.append("")
        doc_content.append("=" * 80)
        doc_content.append("ANALYSIS COMPLETE")
        doc_content.append(f"Generated on: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}")
        doc_content.append("=" * 80)
        
        # Write to file
        filename = f"neurosurgical_ai_analysis_report_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.txt"
        
        try:
            with open(filename, 'w', encoding='utf-8') as f:
                for line in doc_content:
                    f.write(line + '\n')
            
            # Calculate file size properly
            doc_text = '\n'.join(doc_content)
            file_size = len(doc_text)
            
            print(f"\nPublication document generated successfully!")
            print(f"Filename: {filename}")
            print(f"Lines written: {len(doc_content)}")
            print(f"File size: {file_size} characters")
            
            return filename
            
        except Exception as e:
            print(f"Error writing document: {e}")
            return None

    def run_comprehensive_analysis(self):
        """Run the complete comprehensive analysis"""
        
        print("COMPREHENSIVE NEUROSURGICAL AI ANALYSIS")
        print("="*70)
        print("Testing 5 CNNs × Multiple ML Algorithms × 6 Clinical Tasks")
        print("Target: Clinical-grade performance (AUC >= 0.80)")
        print("="*70)
        
        # Initialize ML algorithms
        algorithms = self.get_ml_algorithms()
        
        print(f"\nAVAILABLE ALGORITHMS ({len(algorithms)}):")
        for alg_name, alg_config in algorithms.items():
            print(f"   {alg_name}: {alg_config['description']}")
        
        # Test each CNN dataset
        for cnn_name, file_path in self.datasets.items():
            print(f"\n{'='*70}")
            print(f"ANALYZING {cnn_name} DATASET")
            print(f"{'='*70}")
            
            try:
                # Check if file exists before processing
                import os
                if not os.path.exists(file_path):
                    print(f"ERROR {cnn_name}: File not found - {file_path}")
                    continue
                
                # Run validation checks first
                validation = self.run_validation_checks(cnn_name, file_path)
                self.validation_results[cnn_name] = validation
                
                if 'error' in validation:
                    print(f"ERROR {cnn_name}: Validation failed - {validation['error']}")
                    continue
                
                overall_status = validation.get('overall', {}).get('status', 'FAIL')
                if overall_status == 'FAIL':
                    print(f"ERROR {cnn_name}: Failed validation checks")
                    continue
                
                # Load and process data
                print(f"Loading data from: {file_path}")
                df = pd.read_csv(file_path)
                print(f"Dataset shape: {df.shape}")
                
                targets_data = self.create_all_targets(df)
                
                if not targets_data:
                    print(f"ERROR {cnn_name}: No valid targets created")
                    continue
                
                # Feature selection
                features = self.select_features(df)
                print(f"Available features: {len(features)}")
                
                cnn_results = {}
                
                # Test each target category
                for category, target_info in targets_data.items():
                    category_data = target_info['data']
                    
                    for i, target_col in enumerate(target_info['targets']):
                        task_name = target_info['descriptions'][i]
                        
                        print(f"\n{'-'*40}")
                        print(f"TASK: {task_name}")
                        print(f"{'-'*40}")
                        
                        # Exclude target-related features to prevent leakage
                        safe_features = self._get_safe_features(features, target_col)
                        
                        X, y, error = self.preprocess_data(category_data, safe_features, target_col)
                        
                        if X is None:
                            print(f"ERROR {task_name}: {error}")
                            continue
                        
                        # Run all algorithms for this task
                        task_results = self.run_prediction_task(X, y, task_name, cnn_name, algorithms)
                        
                        if task_results:
                            task_key = f"{category}_{target_col}"
                            cnn_results[task_key] = {
                                'task_name': task_name,
                                'results': task_results,
                                'n_samples': len(X),
                                'n_features': X.shape[1]
                            }
                
                if cnn_results:
                    self.results[cnn_name] = cnn_results
                    print(f"\nSUCCESS {cnn_name}: {len(cnn_results)} tasks completed successfully")
                else:
                    print(f"ERROR {cnn_name}: No tasks completed successfully")
                    
            except Exception as e:
                print(f"ERROR {cnn_name}: Complete failure - {e}")
                import traceback
                traceback.print_exc()  # This will help debug the specific error
        
        # Generate comprehensive report
        self.generate_comprehensive_report()
        
        # Generate publication document
        doc_filename = self.generate_publication_document()
        
        return self.results

    def _get_safe_features(self, features, target_col):
        """Get features safe from data leakage"""
        # Remove features that might leak information about the target
        unsafe_patterns = {
            'idh_binary': ['idh'],
            'mgmt_binary': ['mgmt'],
            'high_grade': [],  # Tumor grade can use all molecular features
            'mortality_6mo': [],
            'mortality_1yr': [],
            'mortality_2yr': []
        }
        
        patterns_to_exclude = unsafe_patterns.get(target_col, [])
        
        safe_features = []
        for feature in features:
            is_safe = True
            for pattern in patterns_to_exclude:
                if pattern.lower() in feature.lower():
                    is_safe = False
                    break
            if is_safe:
                safe_features.append(feature)
        
        return safe_features

    def generate_comprehensive_report(self):
        """Generate comprehensive analysis report"""
        if not self.results:
            print("\n❌ No results to report")
            return
        
        print(f"\n{'='*80}")
        print("📊 COMPREHENSIVE ANALYSIS REPORT")
        print(f"{'='*80}")
        
        # ============================================================
        # EXECUTIVE SUMMARY
        # ============================================================
        self._generate_executive_summary()
        
        # ============================================================
        # DETAILED RESULTS TABLE
        # ============================================================
        self._generate_detailed_results_table()
        
        # ============================================================
        # BEST PERFORMERS ANALYSIS
        # ============================================================
        self._generate_best_performers_analysis()
        
        # ============================================================
        # VALIDATION SUMMARY
        # ============================================================
        self._generate_validation_summary()
        
        # ============================================================
        # CLINICAL RECOMMENDATIONS
        # ============================================================
        self._generate_clinical_recommendations()
        
        # ============================================================
        # PUBLICATION STRATEGY
        # ============================================================
        self._generate_publication_strategy()

    def _generate_executive_summary(self):
        """Generate executive summary"""
        print("\n🎯 EXECUTIVE SUMMARY")
        print("="*50)
        
        total_tests = 0
        excellent_tests = 0
        good_tests = 0
        
        all_aucs = []
        
        for cnn_name, cnn_results in self.results.items():
            for task_key, task_data in cnn_results.items():
                for alg_name, result in task_data['results'].items():
                    total_tests += 1
                    auc = result['auc']
                    all_aucs.append(auc)
                    
                    if auc >= 0.85:
                        excellent_tests += 1
                    elif auc >= 0.75:
                        good_tests += 1
        
        if all_aucs:
            mean_auc = np.mean(all_aucs)
            max_auc = np.max(all_aucs)
            
            print(f" PERFORMANCE OVERVIEW:")
            print(f"   Total algorithm-task combinations: {total_tests}")
            print(f"   Mean AUC across all tests: {mean_auc:.3f}")
            print(f"   Best AUC achieved: {max_auc:.3f}")
            print(f"   Excellent performance (AUC ≥ 0.85): {excellent_tests}/{total_tests} ({excellent_tests/total_tests*100:.1f}%)")
            print(f"   Good+ performance (AUC ≥ 0.75): {good_tests+excellent_tests}/{total_tests} ({(good_tests+excellent_tests)/total_tests*100:.1f}%)")
            
            # Clinical readiness assessment
            if excellent_tests > 0:
                print(f"   🚀 CLINICAL DEPLOYMENT: {excellent_tests} combinations ready for validation")
            if max_auc >= 0.90:
                print(f"   🏆 PUBLICATION READY: Exceptional results achieved")
            elif max_auc >= 0.80:
                print(f"   📝 PUBLICATION READY: Strong results achieved")

    def _generate_detailed_results_table(self):
        """Generate detailed results table"""
        print(f"\n📋 DETAILED RESULTS TABLE")
        print("="*50)
        
        # Header
        print(f"{'CNN':<20} {'Task':<25} {'Algorithm':<15} {'AUC':<8} {'Acc':<8} {'Sens':<8} {'Spec':<8} {'Status':<15}")
        print("-" * 120)
        
        for cnn_name, cnn_results in self.results.items():
            for task_key, task_data in cnn_results.items():
                task_name = task_data['task_name']
                
                for alg_name, result in task_data['results'].items():
                    auc = result['auc']
                    acc = result['accuracy']
                    sens = result['sensitivity']
                    spec = result['specificity']
                    
                    # Status based on AUC
                    if auc >= 0.85:
                        status = "🏆 EXCELLENT"
                    elif auc >= 0.75:
                        status = "✅ STRONG"
                    elif auc >= 0.65:
                        status = "📈 GOOD"
                    else:
                        status = "⚠️ MODERATE"
                    
                    print(f"{cnn_name:<20} {task_name:<25} {alg_name:<15} {auc:<8.3f} {acc:<8.3f} {sens:<8.3f} {spec:<8.3f} {status:<15}")

    def _generate_best_performers_analysis(self):
        """Generate best performers analysis"""
        print(f"\n🏆 BEST PERFORMERS BY TASK")
        print("="*50)
        
        # Find best performer for each task across all CNNs
        task_best = {}
        
        for cnn_name, cnn_results in self.results.items():
            for task_key, task_data in cnn_results.items():
                task_name = task_data['task_name']
                
                if task_name not in task_best:
                    task_best[task_name] = {'auc': 0, 'cnn': '', 'algorithm': '', 'result': None}
                
                for alg_name, result in task_data['results'].items():
                    if result['auc'] > task_best[task_name]['auc']:
                        task_best[task_name] = {
                            'auc': result['auc'],
                            'cnn': cnn_name,
                            'algorithm': alg_name,
                            'result': result
                        }
        
        for task_name, best in task_best.items():
            auc = best['auc']
            status = "🚀 DEPLOYMENT READY" if auc >= 0.85 else "📈 PROMISING" if auc >= 0.75 else "⚠️ NEEDS WORK"
            print(f"{task_name:<30}: {best['cnn']} + {best['algorithm']} (AUC = {auc:.3f}) {status}")

    def _generate_validation_summary(self):
        """Generate validation summary"""
        print(f"\nVALIDATION SUMMARY")
        print("="*50)
        
        if not self.validation_results:
            print("No validation results available")
            return
        
        print(f"{'CNN':<20} {'Overall':<10} {'Data':<10} {'Balance':<10} {'Features':<10} {'Samples':<10}")
        print("-" * 75)
        
        for cnn_name, validation in self.validation_results.items():
            if 'error' in validation:
                print(f"{cnn_name:<20} {'ERROR':<10} {'N/A':<10} {'N/A':<10} {'N/A':<10} {'N/A':<10}")
                continue
            
            overall = validation.get('overall', {}).get('status', 'FAIL')
            data_integrity = validation.get('data_integrity', {}).get('status', 'FAIL')
            class_balance = validation.get('class_balance', {}).get('status', 'FAIL')
            feature_quality = validation.get('feature_quality', {}).get('status', 'FAIL')
            sample_size = validation.get('sample_size', {}).get('status', 'FAIL')
            
            print(f"{cnn_name:<20} {overall:<10} {data_integrity:<10} {class_balance:<10} {feature_quality:<10} {sample_size:<10}")

    def _generate_clinical_recommendations(self):
        """Generate clinical recommendations"""
        print(f"\nCLINICAL RECOMMENDATIONS")
        print("="*50)
        
        # Algorithm performance ranking
        algorithm_stats = {}
        for cnn_name, cnn_results in self.results.items():
            for task_key, task_data in cnn_results.items():
                for alg_name, result in task_data['results'].items():
                    if alg_name not in algorithm_stats:
                        algorithm_stats[alg_name] = []
                    algorithm_stats[alg_name].append(result['auc'])
        
        print("ALGORITHM PERFORMANCE RANKING:")
        if algorithm_stats:
            for alg_name, aucs in sorted(algorithm_stats.items(), key=lambda x: np.mean(x[1]), reverse=True):
                mean_auc = np.mean(aucs)
                max_auc = np.max(aucs)
                n_tests = len(aucs)
                print(f"   {alg_name}: {mean_auc:.3f} mean AUC, {max_auc:.3f} max AUC ({n_tests} tests)")
        
        # CNN performance ranking
        cnn_stats = {}
        for cnn_name, cnn_results in self.results.items():
            aucs = []
            for task_key, task_data in cnn_results.items():
                for alg_name, result in task_data['results'].items():
                    aucs.append(result['auc'])
            if aucs:
                cnn_stats[cnn_name] = aucs
        
        print(f"\nCNN ARCHITECTURE RANKING:")
        if cnn_stats:
            for cnn_name, aucs in sorted(cnn_stats.items(), key=lambda x: np.mean(x[1]), reverse=True):
                mean_auc = np.mean(aucs)
                max_auc = np.max(aucs)
                n_tests = len(aucs)
                print(f"   {cnn_name}: {mean_auc:.3f} mean AUC, {max_auc:.3f} max AUC ({n_tests} tests)")
        
        # Implementation recommendations
        print(f"\nIMPLEMENTATION RECOMMENDATIONS:")
        
        best_combinations = []
        for cnn_name, cnn_results in self.results.items():
            for task_key, task_data in cnn_results.items():
                task_name = task_data['task_name']
                for alg_name, result in task_data['results'].items():
                    if result['auc'] >= 0.80:
                        best_combinations.append({
                            'cnn': cnn_name,
                            'task': task_name,
                            'algorithm': alg_name,
                            'auc': result['auc']
                        })
        
        best_combinations.sort(key=lambda x: x['auc'], reverse=True)
        
        if best_combinations:
            print(f"   {len(best_combinations)} CNN-algorithm combinations ready for clinical validation")
            print(f"   Priority implementation: {best_combinations[0]['task']} using {best_combinations[0]['cnn']} + {best_combinations[0]['algorithm']}")
            print(f"   Expected performance: {best_combinations[0]['auc']:.1%} discrimination accuracy")
        else:
            print(f"   No combinations reached clinical deployment threshold (AUC >= 0.80)")
            print(f"   Focus on methodology optimization for best performing approaches")

    #def _generate_publication_strategy(self):
       #"""Generate publication strategy"""
        #print(f"\nPUBLICATION STRATEGY")
        #print("="*50)
        
        # Count publication-ready results
        #excellent_results = []
        #good_results = []
        
        #for cnn_name, cnn_results in self.results.items():
            #for task_key, task_data in cnn_results.items():
                #task_name = task_data['task_name']
                #for alg_name, result in task_data['results'].items():
                    #if result['auc'] >= 0.85:
                        #excellent_results.append((task_name, cnn_name, alg_name, result['auc']))
                    #elif result['auc'] >= 0.75:
                        #good_results.append((task_name, cnn_name, alg_name, result['auc']))

def main():
    """Main execution function"""
    
    print("COMPREHENSIVE NEUROSURGICAL AI ANALYSIS SYSTEM")
    print("="*70)
    print("GOAL: Comprehensive evaluation of CNN architectures and ML algorithms")
    print("SCOPE: 5 CNNs × Multiple Algorithms × 6 Clinical Tasks")
    print("OUTPUT: Clinical-ready recommendations for your team and PI")
    print("="*70)
    
    # Initialize analyzer
    analyzer = NeurosurgicalAIAnalyzer()
    
    # Run comprehensive analysis
    results = analyzer.run_comprehensive_analysis()
    
    print(f"\n{'='*70}")
    print("COMPREHENSIVE ANALYSIS COMPLETE!")
    print(f"{'='*70}")
    
    if results:
        n_cnns = len(results)
        total_tasks = sum(len(cnn_results) for cnn_results in results.values())
        total_tests = sum(
            len(task_data['results']) 
            for cnn_results in results.values() 
            for task_data in cnn_results.values()
        )
        
        print(f"ANALYSIS SUMMARY:")
        print(f"   • {n_cnns} CNN architectures analyzed")
        print(f"   • {total_tasks} clinical tasks evaluated") 
        print(f"   • {total_tests} algorithm-task combinations tested")
        print(f"   • Comprehensive validation and recommendations generated")
        print(f"   • Publication-ready document created")
    else:
        print("No results generated. Check data file paths and formats.")
    
    return analyzer

# Execute the comprehensive analysis
if __name__ == "__main__":
    analyzer = main()

*working code 2*

In [4]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder, StandardScaler, RobustScaler
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
from sklearn.metrics import (classification_report, confusion_matrix, roc_auc_score, 
                           accuracy_score, roc_curve, precision_recall_curve, auc,
                           balanced_accuracy_score, matthews_corrcoef, 
                           average_precision_score, cohen_kappa_score)
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
from sklearn.dummy import DummyClassifier
from tabpfn import TabPFNClassifier
import warnings
warnings.filterwarnings('ignore')

# Check for optional dependencies
try:
    import xgboost as xgb
    XGBOOST_AVAILABLE = True
except ImportError:
    XGBOOST_AVAILABLE = False
    print("XGBoost not available. Install with: pip install xgboost")

try:
    from pytorch_tabnet.tab_model import TabNetClassifier
    import torch
    TABNET_AVAILABLE = True
except ImportError:
    TABNET_AVAILABLE = False
    print("TabNet not available. Install with: pip install pytorch-tabnet torch")

class NeurosurgicalAIAnalyzer:
    """AI analysis system for neurosurgical outcome prediction"""
    
    def __init__(self, data_paths=None):
        # Default paths - update these to match your file locations
        if data_paths is None:
            self.datasets = {
            'ConvNext': '/Users/joi263/Documents/MultimodalTabData/data/convnext_data/convnext_cleaned_patient_features_separate_256d.csv',
            'ViT': '/Users/joi263/Documents/MultimodalTabData/data/vit_base_data/vit_base_cleaned_patient_features_separate_256d.csv',
            'ResNet50_Pretrained': '/Users/joi263/Documents/MultimodalTabData/data/pretrained_resnet50_data/pretrained_resnet50_cleaned_patient_features_separate_256d.csv',
            'ResNet50_ImageNet': '/Users/joi263/Documents/MultimodalTabData/data/imagenet_resnet50_data/imagenet_resnet50_cleaned_patient_features_separate_256d.csv',
            'EfficientNet': '/Users/joi263/Documents/MultimodalTabData/data/efficientnet_data/efficientnet_cleaned_patient_features_separate_256d.csv'
        }
        else:
            self.datasets = data_paths
            
        self.results = {}
        self.validation_results = {}
        
        # Verify file paths
        self._verify_data_files()
        
    def _verify_data_files(self):
        """Check which data files are available"""
        import os
        print("CHECKING DATA FILE PATHS:")
        print("="*50)
        
        existing_files = 0
        for cnn_name, file_path in self.datasets.items():
            exists = os.path.exists(file_path)
            status = "EXISTS" if exists else "NOT FOUND"
            print(f"{cnn_name:<20}: {status}")
            if exists:
                existing_files += 1
            else:
                print(f"  Expected: {file_path}")
        
        print("="*50)
        print(f"Found {existing_files}/{len(self.datasets)} data files")
        
        if existing_files == 0:
            print("ERROR: No data files found!")
            print("Please update the file paths to match your actual file locations.")
        elif existing_files < len(self.datasets):
            print(f"WARNING: Only {existing_files} out of {len(self.datasets)} files found.")
        else:
            print("SUCCESS: All data files found!")
        print()
        
    def get_ml_algorithms(self):
        """Initialize ML algorithms with class imbalance handling"""
        algorithms = {}
        
        # TabPFN - Transformer-based Few-Shot Learning
        algorithms['TabPFN'] = {
            'model': TabPFNClassifier(device='cpu'),
            'needs_scaling': False,
            'description': 'Transformer-based Few-Shot Learning'
        }
        
        # XGBoost with enhanced imbalance handling
        if XGBOOST_AVAILABLE:
            algorithms['XGBoost'] = {
                'model': xgb.XGBClassifier(
                    n_estimators=300,
                    max_depth=4,
                    learning_rate=0.05,
                    subsample=0.8,
                    colsample_bytree=0.8,
                    min_child_weight=3,
                    reg_alpha=1,
                    reg_lambda=1,
                    random_state=42,
                    eval_metric='logloss',
                    use_label_encoder=False,
                    scale_pos_weight=None,  # Will be calculated dynamically
                    objective='binary:logistic'
                ),
                'needs_scaling': False,
                'description': 'Gradient Boosting with Auto-Balancing'
            }
        
        # TabNet with enhanced imbalance handling
        if TABNET_AVAILABLE:
            algorithms['TabNet'] = {
                'model': TabNetClassifier(
                    n_d=64, n_a=64,
                    n_steps=5,
                    gamma=1.5,
                    lambda_sparse=1e-4,
                    optimizer_fn=torch.optim.Adam,
                    optimizer_params=dict(lr=0.01, weight_decay=1e-5),
                    mask_type="entmax",
                    scheduler_params={"step_size": 20, "gamma": 0.8},
                    scheduler_fn=torch.optim.lr_scheduler.StepLR,
                    verbose=0,
                    seed=42
                    # Note: Class weights will be handled during training
                ),
                'needs_scaling': True,
                'description': 'Attention-based Neural Network with Class Balancing'
            }
        
        # Random Forest with strong imbalance handling
        algorithms['RandomForest'] = {
            'model': RandomForestClassifier(
                n_estimators=500,
                max_depth=8,
                min_samples_split=10,
                min_samples_leaf=5,
                max_features='sqrt',
                bootstrap=True,
                oob_score=True,
                random_state=42,
                class_weight='balanced_subsample',
                n_jobs=-1
            ),
            'needs_scaling': False,
            'description': 'Ensemble Trees with Balanced Subsampling'
        }
        
        # Logistic Regression with enhanced balancing
        algorithms['LogisticRegression'] = {
            'model': LogisticRegression(
                penalty='elasticnet',
                l1_ratio=0.5,
                C=0.1,
                solver='saga',
                max_iter=2000,
                random_state=42,
                class_weight='balanced',
                n_jobs=-1
            ),
            'needs_scaling': True,
            'description': 'Regularized Linear Model with Auto-Balancing'
        }
        
        # SVM with class balancing
        algorithms['SVM'] = {
            'model': SVC(
                kernel='rbf',
                C=1.0,
                gamma='scale',
                probability=True,
                random_state=42,
                class_weight='balanced'
            ),
            'needs_scaling': True,
            'description': 'Support Vector Machine with Auto-Balancing'
        }
        
        return algorithms

    def calculate_class_imbalance_metrics(self, y):
        """Calculate class imbalance metrics"""
        from collections import Counter
        
        class_counts = Counter(y)
        total_samples = len(y)
        
        if len(class_counts) != 2:
            return None  # Skip non-binary tasks
        
        # Get minority and majority class counts
        sorted_counts = sorted(class_counts.values())
        minority_count = sorted_counts[0]
        majority_count = sorted_counts[1]
        
        # Calculate imbalance metrics
        imbalance_ratio = majority_count / minority_count
        minority_percentage = minority_count / total_samples * 100
        
        # Imbalance severity classification
        if imbalance_ratio <= 2:
            severity = "BALANCED"
        elif imbalance_ratio <= 5:
            severity = "MILD IMBALANCE"
        elif imbalance_ratio <= 10:
            severity = "MODERATE IMBALANCE"
        elif imbalance_ratio <= 20:
            severity = "HIGH IMBALANCE"
        else:
            severity = "SEVERE IMBALANCE"
        
        return {
            'minority_count': minority_count,
            'majority_count': majority_count,
            'total_samples': total_samples,
            'imbalance_ratio': imbalance_ratio,
            'minority_percentage': minority_percentage,
            'severity': severity,
            'class_distribution': dict(class_counts)
        }

    def calculate_balanced_metrics(self, y_true, y_pred, y_pred_proba):
        """Calculate comprehensive metrics for imbalanced data"""
        
        # Basic metrics
        accuracy = accuracy_score(y_true, y_pred)
        balanced_accuracy = balanced_accuracy_score(y_true, y_pred)
        
        # AUC metrics
        try:
            roc_auc = roc_auc_score(y_true, y_pred_proba)
        except:
            roc_auc = 0.5
            
        try:
            pr_auc = average_precision_score(y_true, y_pred_proba)
        except:
            pr_auc = 0.5
        
        # Confusion matrix metrics
        cm = confusion_matrix(y_true, y_pred)
        if cm.shape == (2, 2):
            tn, fp, fn, tp = cm.ravel()
            
            # Basic metrics
            sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
            specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
            ppv = tp / (tp + fp) if (tp + fp) > 0 else 0
            npv = tn / (tn + fn) if (tn + fn) > 0 else 0
            
            # Advanced metrics for imbalanced data
            youden_j = sensitivity + specificity - 1
            f1_score = 2 * (ppv * sensitivity) / (ppv + sensitivity) if (ppv + sensitivity) > 0 else 0
            mcc = matthews_corrcoef(y_true, y_pred)
            kappa = cohen_kappa_score(y_true, y_pred)
            geometric_mean = np.sqrt(sensitivity * specificity)
            
        else:
            sensitivity = specificity = ppv = npv = 0
            youden_j = f1_score = mcc = kappa = geometric_mean = 0
        
        return {
            'accuracy': accuracy,
            'balanced_accuracy': balanced_accuracy,
            'auc_roc': roc_auc,
            'auc_pr': pr_auc,
            'auc': roc_auc,  # For backward compatibility
            'sensitivity': sensitivity,
            'specificity': specificity,
            'ppv': ppv,
            'npv': npv,
            'youden_j': youden_j,
            'f1_score': f1_score,
            'mcc': mcc,
            'kappa': kappa,
            'geometric_mean': geometric_mean,
            'confusion_matrix': cm
        }

    def create_prediction_targets(self, df):
        """Create prediction targets: mortality, tumor classification, IDH, MGMT"""
        print("="*60)
        print("CREATING PREDICTION TARGETS")
        print("="*60)
        
        targets_data = {}
        
        # MORTALITY TARGETS
        print("MORTALITY TARGETS:")
        survival_data = df[df['survival'].notna() & df['patient_status'].notna()].copy()
        
        if len(survival_data) > 0:
            survival_data['mortality_6mo'] = ((survival_data['patient_status'] == 2) & 
                                              (survival_data['survival'] <= 6)).astype(int)
            survival_data['mortality_1yr'] = ((survival_data['patient_status'] == 2) & 
                                              (survival_data['survival'] <= 12)).astype(int)
            survival_data['mortality_2yr'] = ((survival_data['patient_status'] == 2) & 
                                              (survival_data['survival'] <= 24)).astype(int)
            
            targets_data['mortality'] = {
                'data': survival_data,
                'targets': ['mortality_6mo', 'mortality_1yr', 'mortality_2yr'],
                'descriptions': ['6-Month Mortality', '1-Year Mortality', '2-Year Mortality']
            }
            
            print(f"   Patients: {len(survival_data)}")
            for target, desc in zip(['mortality_6mo', 'mortality_1yr', 'mortality_2yr'], 
                                   ['6-month', '1-year', '2-year']):
                count = survival_data[target].sum()
                pct = survival_data[target].mean() * 100
                print(f"   {desc}: {count}/{len(survival_data)} ({pct:.1f}%)")
        
        # TUMOR CLASSIFICATION TARGETS
        print("\nTUMOR CLASSIFICATION TARGETS:")
        tumor_data = df[df['methylation_class'].notna()].copy()
        
        if len(tumor_data) > 0:
            high_grade_terms = ['glioblastoma', 'anaplastic', 'high grade', 'grade iv', 'grade 4', 'gbm']
            tumor_data['high_grade'] = tumor_data['methylation_class'].str.lower().str.contains(
                '|'.join(high_grade_terms), na=False
            ).astype(int)
            
            targets_data['tumor'] = {
                'data': tumor_data,
                'targets': ['high_grade'],
                'descriptions': ['High-Grade vs Low-Grade']
            }
            
            print(f"   Patients: {len(tumor_data)}")
            count = tumor_data['high_grade'].sum()
            pct = tumor_data['high_grade'].mean() * 100
            print(f"   High-grade: {count}/{len(tumor_data)} ({pct:.1f}%)")
        
        # IDH MUTATION TARGETS
        print("\nIDH MUTATION TARGETS:")
        idh_data = self._create_idh_targets(df)
        
        if idh_data is not None and len(idh_data) > 0:
            targets_data['idh'] = {
                'data': idh_data,
                'targets': ['idh_binary'],
                'descriptions': ['IDH Mutation Status']
            }
            
            print(f"   Patients: {len(idh_data)}")
            count = idh_data['idh_binary'].sum()
            pct = idh_data['idh_binary'].mean() * 100
            print(f"   IDH Mutant: {count}/{len(idh_data)} ({pct:.1f}%)")
        
        # MGMT METHYLATION TARGETS
        print("\nMGMT METHYLATION TARGETS:")
        mgmt_data = self._create_mgmt_targets(df)
        
        if mgmt_data is not None and len(mgmt_data) > 0:
            targets_data['mgmt'] = {
                'data': mgmt_data,
                'targets': ['mgmt_binary'],
                'descriptions': ['MGMT Promoter Methylation']
            }
            
            print(f"   Patients: {len(mgmt_data)}")
            count = mgmt_data['mgmt_binary'].sum()
            pct = mgmt_data['mgmt_binary'].mean() * 100
            print(f"   MGMT Methylated: {count}/{len(mgmt_data)} ({pct:.1f}%)")
        
        return targets_data

    def _create_idh_targets(self, df):
        """Create IDH mutation targets with proper decoding"""
        if 'idh_1_r132h' not in df.columns:
            return None
            
        idh_data = df.copy()
        idh_data['idh_binary'] = np.nan
        
        # Cross-reference with text data if available
        if 'idh1' in df.columns:
            text_idh = df['idh1'].astype(str).str.lower()
            mutant_patterns = ['r132h', 'r132s', 'arg132his', 'arg132ser', 'missense', 'p.arg132']
            is_mutant_text = text_idh.str.contains('|'.join(mutant_patterns), na=False)
            idh_data.loc[is_mutant_text, 'idh_binary'] = 1  # Mutant
        
        # Apply numerical encoding (2 = mutant based on cross-reference analysis)
        remaining_mask = idh_data['idh_binary'].isna() & idh_data['idh_1_r132h'].notna()
        idh_data.loc[remaining_mask & (idh_data['idh_1_r132h'] == 2), 'idh_binary'] = 1  # Mutant
        idh_data.loc[remaining_mask & (idh_data['idh_1_r132h'] == 1), 'idh_binary'] = 0  # Wildtype
        
        # Exclude unknown cases
        idh_data.loc[idh_data['idh_1_r132h'] == 3, 'idh_binary'] = np.nan
        
        return idh_data[idh_data['idh_binary'].notna()].copy()

    def _create_mgmt_targets(self, df):
        """Create MGMT methylation targets with correct encoding"""
        if 'mgmt' not in df.columns:
            return None
            
        mgmt_data = df[df['mgmt'].notna()].copy()
        
        if len(mgmt_data) == 0:
            return None
        
        # Correct encoding: 1 = Positive (methylated), 2 = Negative (unmethylated), 3 = Non-informative
        mgmt_data['mgmt_binary'] = np.nan
        mgmt_data.loc[mgmt_data['mgmt'] == 1, 'mgmt_binary'] = 1  # Methylated
        mgmt_data.loc[mgmt_data['mgmt'] == 2, 'mgmt_binary'] = 0  # Unmethylated
        mgmt_data.loc[mgmt_data['mgmt'] == 3, 'mgmt_binary'] = np.nan  # Exclude non-informative
        
        return mgmt_data[mgmt_data['mgmt_binary'].notna()].copy()

    def select_features(self, df):
        """Select comprehensive feature set"""
        # Clinical features
        clinical_features = ['age', 'sex', 'race', 'ethnicity', 'gtr']
        
        # Molecular features (exclude target variables to prevent leakage)
        molecular_features = ['mgmt_pyro', 'atrx', 'p53', 'braf_v600', 'h3k27m', 'gfap', 'tumor', 'hg_glioma']
        
        # CNN-extracted imaging features
        image_features = [col for col in df.columns if col.startswith('feature_')]
        
        # Combine all features
        all_features = clinical_features + molecular_features + image_features
        available_features = [f for f in all_features if f in df.columns]
        
        return available_features

    def preprocess_data(self, df, features, target_col):
        """Advanced preprocessing for multiple ML algorithms"""
        data = df[features + [target_col]].copy()
        data = data[data[target_col].notna()]
        
        if len(data) < 15:  # Minimum viable sample size
            return None, None, f"Insufficient data: {len(data)} samples"
        
        # Handle categorical features
        categorical_features = data.select_dtypes(include=['object']).columns.tolist()
        if target_col in categorical_features:
            categorical_features.remove(target_col)
        
        for col in categorical_features:
            if col in features:
                le = LabelEncoder()
                data[col] = data[col].astype(str)
                data[col] = le.fit_transform(data[col])
        
        # Handle missing values
        numerical_features = [f for f in features if f in data.select_dtypes(include=[np.number]).columns]
        
        for col in numerical_features:
            if data[col].isnull().sum() > 0:
                if col.startswith('feature_'):
                    data[col] = data[col].fillna(data[col].mean())
                else:
                    data[col] = data[col].fillna(data[col].median())
        
        # Remove features with >50% missing
        missing_pct = data[features].isnull().mean()
        good_features = missing_pct[missing_pct <= 0.5].index.tolist()
        
        if len(good_features) < len(features):
            features = good_features
            data = data[features + [target_col]]
        
        # Feature selection for computational efficiency
        X = data[features].values
        y = data[target_col].values
        
        # Check class balance
        unique_classes, class_counts = np.unique(y, return_counts=True)
        min_class_size = min(class_counts)
        
        if min_class_size < 3:
            return None, None, f"Class too small: minimum class has {min_class_size} samples"
        
        # Feature selection (limit to 100 for computational efficiency)
        if X.shape[1] > 100:
            selector = SelectKBest(score_func=f_classif, k=100)
            X = selector.fit_transform(X, y)
        
        return X, y, None

    def train_and_evaluate_algorithm(self, X_train, X_test, y_train, y_test, algorithm_name, algorithm_config):
        """Train and evaluate a single algorithm with enhanced imbalance handling"""
        try:
            model = algorithm_config['model']
            needs_scaling = algorithm_config['needs_scaling']
            
            # Calculate class imbalance and adjust model if needed
            imbalance_info = self.calculate_class_imbalance_metrics(y_train)
            
            # Dynamic class balancing for specific algorithms
            if algorithm_name == 'XGBoost' and XGBOOST_AVAILABLE and imbalance_info:
                neg_count = imbalance_info['class_distribution'][0]
                pos_count = imbalance_info['class_distribution'][1]
                scale_pos_weight = neg_count / pos_count
                model.set_params(scale_pos_weight=scale_pos_weight)
            
            elif algorithm_name == 'TabNet' and TABNET_AVAILABLE and imbalance_info:
                from sklearn.utils.class_weight import compute_class_weight
                classes = np.unique(y_train)
                class_weights = compute_class_weight('balanced', classes=classes, y=y_train)
            
            # Apply robust scaling if needed
            if needs_scaling:
                scaler = RobustScaler(quantile_range=(10.0, 90.0))
                X_train_processed = scaler.fit_transform(X_train)
                X_test_processed = scaler.transform(X_test)
                
                # Handle potential scaling issues
                if np.any(np.isnan(X_train_processed)) or np.any(np.isnan(X_test_processed)):
                    scaler = StandardScaler()
                    X_train_processed = scaler.fit_transform(X_train)
                    X_test_processed = scaler.transform(X_test)
            else:
                X_train_processed = X_train
                X_test_processed = X_test
            
            # Train model based on algorithm type
            if algorithm_name == 'TabNet' and TABNET_AVAILABLE:
                try:
                    model.fit(
                        X_train_processed, y_train,
                        eval_set=[(X_test_processed, y_test)],
                        patience=20,
                        max_epochs=100,
                        eval_metric=['auc'],
                        batch_size=min(256, len(X_train)//4)
                    )
                    y_pred_proba = model.predict_proba(X_test_processed)[:, 1]
                    y_pred = (y_pred_proba > 0.5).astype(int)
                except Exception:
                    model.fit(
                        X_train_processed, y_train,
                        eval_set=[(X_test_processed, y_test)],
                        patience=20,
                        max_epochs=100,
                        eval_metric=['auc']
                    )
                    y_pred_proba = model.predict_proba(X_test_processed)[:, 1]
                    y_pred = (y_pred_proba > 0.5).astype(int)
                
            elif algorithm_name == 'XGBoost' and XGBOOST_AVAILABLE:
                try:
                    eval_set = [(X_test_processed, y_test)]
                    model.fit(
                        X_train_processed, y_train,
                        eval_set=eval_set,
                        verbose=False
                    )
                except TypeError:
                    model.fit(X_train_processed, y_train)
                
                y_pred = model.predict(X_test_processed)
                y_pred_proba = model.predict_proba(X_test_processed)[:, 1]
                
            else:
                # Standard scikit-learn interface
                model.fit(X_train_processed, y_train)
                y_pred = model.predict(X_test_processed)
                
                if hasattr(model, 'predict_proba'):
                    y_pred_proba = model.predict_proba(X_test_processed)[:, 1]
                else:
                    y_pred_proba = y_pred.astype(float)
            
            # Calculate comprehensive metrics
            metrics = self.calculate_balanced_metrics(y_test, y_pred, y_pred_proba)
            
            # Add additional information
            test_imbalance = self.calculate_class_imbalance_metrics(y_test)
            
            result = {
                **metrics,
                'n_test': len(y_test),
                'scaling_used': needs_scaling,
                'train_imbalance': imbalance_info,
                'test_imbalance': test_imbalance
            }
            
            return result
            
        except Exception as e:
            print(f"   ERROR {algorithm_name} failed: {str(e)}")
            return None

    def run_prediction_task(self, X, y, task_name, cnn_name, algorithms):
        """Run prediction task with cross-validation and imbalance reporting"""
        print(f"\n{'='*50}")
        print(f"{task_name} - {cnn_name}")
        print(f"{'='*50}")
        
        # Analyze class imbalance
        imbalance_info = self.calculate_class_imbalance_metrics(y)
        if imbalance_info:
            print(f"CLASS IMBALANCE ANALYSIS:")
            print(f"   Total samples: {imbalance_info['total_samples']}")
            print(f"   Class distribution: {imbalance_info['class_distribution']}")
            print(f"   Minority class: {imbalance_info['minority_count']} ({imbalance_info['minority_percentage']:.1f}%)")
            print(f"   Imbalance ratio: {imbalance_info['imbalance_ratio']:.2f}:1")
            print(f"   Severity: {imbalance_info['severity']}")
        
        # Train-test split
        try:
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=0.25, random_state=42, stratify=y
            )
        except:
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=0.25, random_state=42
            )
        
        print(f"\nDATA SPLIT:")
        print(f"   Training: {len(X_train)} samples")
        print(f"   Testing: {len(X_test)} samples")
        print(f"   Positive rate: {y_train.mean()*100:.1f}% (train), {y_test.mean()*100:.1f}% (test)")
        
        # Train and evaluate each algorithm
        task_results = {}
        print(f"\nALGORITHM EVALUATION:")
        
        for algo_name, algo_config in algorithms.items():
            print(f"\n   {algo_name}:")
            result = self.train_and_evaluate_algorithm(
                X_train, X_test, y_train, y_test, algo_name, algo_config
            )
            if result:
                task_results[algo_name] = result
                print(f"      AUC-ROC: {result['auc_roc']:.3f}")
                print(f"      Balanced Accuracy: {result['balanced_accuracy']:.3f}")
                print(f"      F1 Score: {result['f1_score']:.3f}")
                print(f"      MCC: {result['mcc']:.3f}")
        
        return task_results

    def run_analysis(self):
        """Run complete analysis pipeline"""
        print("NEUROSURGICAL AI OUTCOME PREDICTION ANALYSIS")
        print("="*80)
        
        # Get available algorithms
        algorithms = self.get_ml_algorithms()
        print(f"Available algorithms: {list(algorithms.keys())}")
        
        # Process each CNN dataset
        for cnn_name, file_path in self.datasets.items():
            try:
                # Load data
                print(f"\nLoading {cnn_name} dataset...")
                df = pd.read_csv(file_path)
                print(f"Loaded {len(df)} patients with {len(df.columns)} features")
                
                # Create targets
                targets_data = self.create_prediction_targets(df)
                
                # Select features
                features = self.select_features(df)
                print(f"Selected {len(features)} features for analysis")
                
                # Run analysis for each target category
                cnn_results = {}
                
                for target_category, target_info in targets_data.items():
                    target_data = target_info['data']
                    target_names = target_info['targets']
                    target_descriptions = target_info['descriptions']
                    
                    category_results = {}
                    
                    for target_name, description in zip(target_names, target_descriptions):
                        # Preprocess data for this target
                        X, y, error = self.preprocess_data(target_data, features, target_name)
                        
                        if X is None:
                            print(f"Skipping {description}: {error}")
                            continue
                        
                        # Run prediction task
                        task_results = self.run_prediction_task(
                            X, y, description, cnn_name, algorithms
                        )
                        
                        if task_results:
                            category_results[target_name] = {
                                'task_name': description,
                                'description': description,
                                'results': task_results,
                                'n_samples': len(y),
                                'n_features': X.shape[1]
                            }
                    
                    if category_results:
                        cnn_results[target_category] = category_results
                
                # Store results for this CNN
                if cnn_results:
                    self.results[cnn_name] = cnn_results
                    
            except Exception as e:
                print(f"Error processing {cnn_name}: {str(e)}")
                continue
        
        return self.results

    def print_summary_results(self):
        """Print a summary of all results"""
        if not self.results:
            print("No results to display. Run analysis first.")
            return
        
        print("\n" + "="*80)
        print("ANALYSIS SUMMARY")
        print("="*80)
        
        for cnn_name, cnn_results in self.results.items():
            print(f"\n{cnn_name.upper()} RESULTS:")
            print("-" * 40)
            
            for target_category, category_results in cnn_results.items():
                print(f"\n{target_category.upper()} PREDICTION:")
                
                for target_name, target_info in category_results.items():
                    description = target_info['description']
                    results = target_info['results']
                    n_samples = target_info['n_samples']
                    n_features = target_info['n_features']
                    
                    print(f"\n  {description}:")
                    print(f"    Samples: {n_samples}, Features: {n_features}")
                    
                    # Sort algorithms by AUC-ROC
                    sorted_results = sorted(results.items(), 
                                          key=lambda x: x[1]['auc_roc'], 
                                          reverse=True)
                    
                    print("    Algorithm Performance (sorted by AUC-ROC):")
                    for algo_name, metrics in sorted_results:
                        print(f"      {algo_name:15} - AUC: {metrics['auc_roc']:.3f}, "
                              f"Bal.Acc: {metrics['balanced_accuracy']:.3f}, "
                              f"F1: {metrics['f1_score']:.3f}, "
                              f"MCC: {metrics['mcc']:.3f}")

    def generate_comprehensive_report(self):
        """Generate comprehensive analysis report"""
        if not self.results:
            print("\nNo results to report")
            return
        
        print(f"\n{'='*80}")
        print("COMPREHENSIVE ANALYSIS REPORT")
        print(f"{'='*80}")
        
        # EXECUTIVE SUMMARY
        self._generate_executive_summary()
        
        # DETAILED RESULTS TABLE
        self._generate_detailed_results_table()
        
        # BEST PERFORMERS ANALYSIS
        self._generate_best_performers_analysis()
        
        # CLINICAL RECOMMENDATIONS
        self._generate_clinical_recommendations()

    def _generate_executive_summary(self):
        """Generate executive summary"""
        print("\nEXECUTIVE SUMMARY")
        print("="*50)
        
        total_tests = 0
        excellent_tests = 0
        good_tests = 0
        
        all_aucs = []
        
        for cnn_name, cnn_results in self.results.items():
            for target_category, category_results in cnn_results.items():
                for target_name, target_info in category_results.items():
                    for alg_name, result in target_info['results'].items():
                        total_tests += 1
                        auc = result['auc_roc']
                        all_aucs.append(auc)
                        
                        if auc >= 0.85:
                            excellent_tests += 1
                        elif auc >= 0.75:
                            good_tests += 1
        
        if all_aucs:
            mean_auc = np.mean(all_aucs)
            max_auc = np.max(all_aucs)
            
            print(f"PERFORMANCE OVERVIEW:")
            print(f"   Total algorithm-task combinations: {total_tests}")
            print(f"   Mean AUC across all tests: {mean_auc:.3f}")
            print(f"   Best AUC achieved: {max_auc:.3f}")
            print(f"   Excellent performance (AUC >= 0.85): {excellent_tests}/{total_tests} ({excellent_tests/total_tests*100:.1f}%)")
            print(f"   Good+ performance (AUC >= 0.75): {good_tests+excellent_tests}/{total_tests} ({(good_tests+excellent_tests)/total_tests*100:.1f}%)")
            
            # Clinical readiness assessment
            if excellent_tests > 0:
                print(f"   CLINICAL DEPLOYMENT: {excellent_tests} combinations ready for validation")
            if max_auc >= 0.90:
                print(f"   OUTSTANDING: Exceptional results achieved")
            elif max_auc >= 0.80:
                print(f"   STRONG: Clinical-grade results achieved")

    def _generate_detailed_results_table(self):
        """Generate detailed results table"""
        print(f"\nDETAILED RESULTS TABLE")
        print("="*50)
        
        # Header
        print(f"{'CNN':<20} {'Task':<25} {'Algorithm':<15} {'AUC':<8} {'Acc':<8} {'Sens':<8} {'Spec':<8} {'Status':<15}")
        print("-" * 120)
        
        for cnn_name, cnn_results in self.results.items():
            for target_category, category_results in cnn_results.items():
                for target_name, target_info in category_results.items():
                    task_name = target_info['task_name']
                    
                    for alg_name, result in target_info['results'].items():
                        auc = result['auc_roc']
                        acc = result['accuracy']
                        sens = result['sensitivity']
                        spec = result['specificity']
                        
                        # Status based on AUC
                        if auc >= 0.85:
                            status = "EXCELLENT"
                        elif auc >= 0.75:
                            status = "STRONG"
                        elif auc >= 0.65:
                            status = "GOOD"
                        else:
                            status = "MODERATE"
                        
                        print(f"{cnn_name:<20} {task_name:<25} {alg_name:<15} {auc:<8.3f} {acc:<8.3f} {sens:<8.3f} {spec:<8.3f} {status:<15}")

    def _generate_best_performers_analysis(self):
        """Generate best performers analysis"""
        print(f"\nBEST PERFORMERS BY TASK")
        print("="*50)
        
        # Find best performer for each task across all CNNs
        task_best = {}
        
        for cnn_name, cnn_results in self.results.items():
            for target_category, category_results in cnn_results.items():
                for target_name, target_info in category_results.items():
                    task_name = target_info['task_name']
                    
                    if task_name not in task_best:
                        task_best[task_name] = {'auc': 0, 'cnn': '', 'algorithm': '', 'result': None}
                    
                    for alg_name, result in target_info['results'].items():
                        if result['auc_roc'] > task_best[task_name]['auc']:
                            task_best[task_name] = {
                                'auc': result['auc_roc'],
                                'cnn': cnn_name,
                                'algorithm': alg_name,
                                'result': result
                            }
        
        for task_name, best in task_best.items():
            auc = best['auc']
            status = "DEPLOYMENT READY" if auc >= 0.85 else "PROMISING" if auc >= 0.75 else "NEEDS WORK"
            print(f"{task_name:<30}: {best['cnn']} + {best['algorithm']} (AUC = {auc:.3f}) {status}")

    def _generate_clinical_recommendations(self):
        """Generate clinical recommendations"""
        print(f"\nCLINICAL RECOMMENDATIONS")
        print("="*50)
        
        # Algorithm performance ranking
        algorithm_stats = {}
        for cnn_name, cnn_results in self.results.items():
            for target_category, category_results in cnn_results.items():
                for target_name, target_info in category_results.items():
                    for alg_name, result in target_info['results'].items():
                        if alg_name not in algorithm_stats:
                            algorithm_stats[alg_name] = []
                        algorithm_stats[alg_name].append(result['auc_roc'])
        
        print("ALGORITHM PERFORMANCE RANKING:")
        if algorithm_stats:
            for alg_name, aucs in sorted(algorithm_stats.items(), key=lambda x: np.mean(x[1]), reverse=True):
                mean_auc = np.mean(aucs)
                max_auc = np.max(aucs)
                n_tests = len(aucs)
                print(f"   {alg_name}: {mean_auc:.3f} mean AUC, {max_auc:.3f} max AUC ({n_tests} tests)")
        
        # CNN performance ranking
        cnn_stats = {}
        for cnn_name, cnn_results in self.results.items():
            aucs = []
            for target_category, category_results in cnn_results.items():
                for target_name, target_info in category_results.items():
                    for alg_name, result in target_info['results'].items():
                        aucs.append(result['auc_roc'])
            if aucs:
                cnn_stats[cnn_name] = aucs
        
        print(f"\nCNN ARCHITECTURE RANKING:")
        if cnn_stats:
            for cnn_name, aucs in sorted(cnn_stats.items(), key=lambda x: np.mean(x[1]), reverse=True):
                mean_auc = np.mean(aucs)
                max_auc = np.max(aucs)
                n_tests = len(aucs)
                print(f"   {cnn_name}: {mean_auc:.3f} mean AUC, {max_auc:.3f} max AUC ({n_tests} tests)")
        
        # Implementation recommendations
        print(f"\nIMPLEMENTATION RECOMMENDATIONS:")
        
        best_combinations = []
        for cnn_name, cnn_results in self.results.items():
            for target_category, category_results in cnn_results.items():
                for target_name, target_info in category_results.items():
                    task_name = target_info['task_name']
                    for alg_name, result in target_info['results'].items():
                        if result['auc_roc'] >= 0.80:
                            best_combinations.append({
                                'cnn': cnn_name,
                                'task': task_name,
                                'algorithm': alg_name,
                                'auc': result['auc_roc']
                            })
        
        best_combinations.sort(key=lambda x: x['auc'], reverse=True)
        
        if best_combinations:
            print(f"   {len(best_combinations)} CNN-algorithm combinations ready for clinical validation")
            print(f"   Priority implementation: {best_combinations[0]['task']} using {best_combinations[0]['cnn']} + {best_combinations[0]['algorithm']}")
            print(f"   Expected performance: {best_combinations[0]['auc']:.1%} discrimination accuracy")
        else:
            print(f"   No combinations reached clinical deployment threshold (AUC >= 0.80)")
            print(f"   Focus on methodology optimization for best performing approaches")

    def export_results_to_csv(self, filename='neurosurgical_ai_results_256_separate.csv'):
        """Export results to CSV format"""
        if not self.results:
            print("No results to export. Run analysis first.")
            return
        
        rows = []
        
        for cnn_name, cnn_results in self.results.items():
            for target_category, category_results in cnn_results.items():
                for target_name, target_info in category_results.items():
                    description = target_info['description']
                    results = target_info['results']
                    n_samples = target_info['n_samples']
                    n_features = target_info['n_features']
                    
                    for algo_name, metrics in results.items():
                        row = {
                            'CNN_Model': cnn_name,
                            'Target_Category': target_category,
                            'Target_Name': target_name,
                            'Description': description,
                            'Algorithm': algo_name,
                            'N_Samples': n_samples,
                            'N_Features': n_features,
                            'AUC_ROC': metrics['auc_roc'],
                            'AUC_PR': metrics['auc_pr'],
                            'Accuracy': metrics['accuracy'],
                            'Balanced_Accuracy': metrics['balanced_accuracy'],
                            'Sensitivity': metrics['sensitivity'],
                            'Specificity': metrics['specificity'],
                            'PPV': metrics['ppv'],
                            'NPV': metrics['npv'],
                            'F1_Score': metrics['f1_score'],
                            'MCC': metrics['mcc'],
                            'Kappa': metrics['kappa'],
                            'Geometric_Mean': metrics['geometric_mean'],
                            'Youden_J': metrics['youden_j']
                        }
                        
                        # Add imbalance information if available
                        if metrics.get('test_imbalance'):
                            imb = metrics['test_imbalance']
                            row.update({
                                'Imbalance_Ratio': imb['imbalance_ratio'],
                                'Minority_Percentage': imb['minority_percentage'],
                                'Imbalance_Severity': imb['severity']
                            })
                        
                        rows.append(row)
        
        # Create DataFrame and export
        results_df = pd.DataFrame(rows)
        results_df.to_csv(filename, index=False)
        print(f"Results exported to {filename}")
        return results_df


# Example usage
if __name__ == "__main__":
    # Initialize analyzer with custom paths
    custom_paths = {
        'ConvNext': '/Users/joi263/Documents/MultimodalTabData/data/convnext_data/convnext_cleaned_patient_features_separate_256d.csv',
        'ViT': '/Users/joi263/Documents/MultimodalTabData/data/vit_base_data/vit_base_cleaned_patient_features_separate_256d.csv',
        'ResNet50_Pretrained': '/Users/joi263/Documents/MultimodalTabData/data/pretrained_resnet50_data/pretrained_resnet50_cleaned_patient_features_separate_256d.csv',
        'ResNet50_ImageNet': '/Users/joi263/Documents/MultimodalTabData/data/imagenet_resnet50_data/imagenet_resnet50_cleaned_patient_features_separate_256d.csv',
        'EfficientNet': '/Users/joi263/Documents/MultimodalTabData/data/efficientnet_data/efficientnet_cleaned_patient_features_separate_256d.csv'
        }
    
    # Initialize analyzer
    analyzer = NeurosurgicalAIAnalyzer(data_paths=custom_paths)
    
    # Run analysis
    results = analyzer.run_analysis()
    
    # Generate comprehensive report
    analyzer.generate_comprehensive_report()
    
    # Print summary
    analyzer.print_summary_results()
    
    # Export results
    analyzer.export_results_to_csv()

CHECKING DATA FILE PATHS:
ConvNext            : EXISTS
ViT                 : EXISTS
ResNet50_Pretrained : EXISTS
ResNet50_ImageNet   : EXISTS
EfficientNet        : EXISTS
Found 5/5 data files
SUCCESS: All data files found!

NEUROSURGICAL AI OUTCOME PREDICTION ANALYSIS
Available algorithms: ['TabPFN', 'XGBoost', 'TabNet', 'RandomForest', 'LogisticRegression', 'SVM']

Loading ConvNext dataset...
Loaded 532 patients with 356 features
CREATING PREDICTION TARGETS
MORTALITY TARGETS:
   Patients: 86
   6-month: 19/86 (22.1%)
   1-year: 38/86 (44.2%)
   2-year: 70/86 (81.4%)

TUMOR CLASSIFICATION TARGETS:
   Patients: 241
   High-grade: 129/241 (53.5%)

IDH MUTATION TARGETS:
   Patients: 198
   IDH Mutant: 174.0/198 (87.9%)

MGMT METHYLATION TARGETS:
   Patients: 212
   MGMT Methylated: 84.0/212 (39.6%)
Selected 13 features for analysis

6-Month Mortality - ConvNext
CLASS IMBALANCE ANALYSIS:
   Total samples: 86
   Class distribution: {np.int64(1): 19, np.int64(0): 67}
   Minority class: 19 (2

*working code 3*

added cox regression and txt file generation
also made the code save the txt and csv files into their appropriate folders