In [None]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import random
import time
import json
import os
from datetime import datetime
from typing import Dict, List, Any
from dataclasses import dataclass
from scipy import stats
from joblib import Parallel, delayed
import multiprocessing

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.impute import SimpleImputer
from sklearn.metrics import (
    accuracy_score, roc_auc_score, precision_score, recall_score,
    f1_score, confusion_matrix, matthews_corrcoef, cohen_kappa_score,
    average_precision_score, make_scorer
)
from sklearn.base import clone
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC

import matplotlib
# Use TkAgg backend for interactive plots in VSCode
# If TkAgg doesn't work, try 'Qt5Agg' or remove this line entirely
try:
    matplotlib.use('TkAgg')
except:
    pass  # Fall back to default backend

import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Enable interactive mode for VSCode
plt.ion()

try:
    from pytorch_tabnet.tab_model import TabNetClassifier
    import torch
    TABNET_AVAILABLE = True
except ImportError:
    TABNET_AVAILABLE = False
    print("[INFO] TabNet not available")

try:
    from xgboost import XGBClassifier
    XGBOOST_AVAILABLE = True
except ImportError:
    XGBOOST_AVAILABLE = False
    print("[INFO] XGBoost not available")

try:
    from imblearn.over_sampling import SMOTE
    SMOTE_AVAILABLE = True
except ImportError:
    SMOTE_AVAILABLE = False
    print("[INFO] SMOTE not available")

N_JOBS = max(1, multiprocessing.cpu_count() - 1)

print("\n" + "=" * 70)
print("GE-based HPO Framework - Modified Version")
print(f"Available CPU cores for parallel processing: {N_JOBS}")
print("=" * 70)


def compute_confidence_interval(data, confidence=0.95):
    if len(data) < 2:
        return (np.mean(data), np.mean(data))
    n = len(data)
    mean = np.mean(data)
    se = stats.sem(data)
    h = se * stats.t.ppf((1 + confidence) / 2, n - 1)
    return (mean - h, mean + h)


def compute_statistics(data):
    if len(data) == 0:
        return {'mean': 0, 'std': 0, 'min': 0, 'max': 0, 'median': 0, 
                'ci_lower': 0, 'ci_upper': 0}
    ci_lower, ci_upper = compute_confidence_interval(data)
    return {
        'mean': float(np.mean(data)),
        'std': float(np.std(data)),
        'min': float(np.min(data)),
        'max': float(np.max(data)),
        'median': float(np.median(data)),
        'ci_lower': float(ci_lower),
        'ci_upper': float(ci_upper)
    }


def paired_ttest(scores1, scores2):
    if len(scores1) != len(scores2) or len(scores1) < 2:
        return (np.nan, np.nan)
    t_stat, p_value = stats.ttest_rel(scores1, scores2)
    return (float(t_stat), float(p_value))


def wilcoxon_test(scores1, scores2):
    if len(scores1) != len(scores2) or len(scores1) < 2:
        return (np.nan, np.nan)
    try:
        stat, p_value = stats.wilcoxon(scores1, scores2)
        return (float(stat), float(p_value))
    except Exception:
        return (np.nan, np.nan)


def compute_sensitivity_specificity(y_true, y_pred):
    cm = confusion_matrix(y_true, y_pred)
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        ppv = tp / (tp + fp) if (tp + fp) > 0 else 0
        npv = tn / (tn + fn) if (tn + fn) > 0 else 0
    else:
        sensitivity = specificity = ppv = npv = 0
        tn = fp = fn = tp = 0
    
    return {
        'sensitivity': float(sensitivity),
        'specificity': float(specificity),
        'ppv': float(ppv),
        'npv': float(npv),
        'tp': int(tp),
        'tn': int(tn),
        'fp': int(fp),
        'fn': int(fn)
    }


class HyperparameterPruningAnalyzer:
    
    DEFAULT_HYPERPARAMS = {
        'TabNet': {
            'n_d': 8, 'n_a': 8, 'n_steps': 3, 'gamma': 1.3, 'lambda_sparse': 1e-3
        },
        'RandomForest': {
            'n_estimators': 100, 'max_depth': None, 'max_features': 'sqrt', 'min_samples_split': 2
        },
        'XGBoost': {
            'n_estimators': 100, 'learning_rate': 0.1, 'max_depth': 3, 
            'subsample': 0.8, 'colsample_bytree': 0.8
        },
        'SVM': {'C': 1.0, 'kernel': 'rbf', 'gamma': 'scale'},
        'LogisticRegression': {'C': 1.0, 'penalty': 'l2', 'max_iter': 1000},
        'GradientBoosting': {
            'n_estimators': 100, 'learning_rate': 0.1, 'max_depth': 3, 'subsample': 0.8
        }
    }
    
    SEARCH_SPACE = {
        'TabNet': {
            'n_d': [8, 16, 32, 64],
            'n_a': [8, 16, 32, 64],
            'n_steps': [3, 5, 7],
            'lambda_sparse': [1e-4, 1e-3, 1e-2]
        },
        'RandomForest': {
            'n_estimators': [100, 200, 300],
            'max_depth': [5, 10, 15, None]
        },
        'XGBoost': {
            'n_estimators': [100, 200, 300],
            'learning_rate': [0.01, 0.05, 0.1],
            'max_depth': [3, 5, 7]
        },
        'SVM': {
            'C': [0.1, 1.0, 10.0],
            'kernel': ['linear', 'rbf']
        },
        'LogisticRegression': {
            'C': [0.1, 1.0, 10.0],
            'penalty': ['l1', 'l2']
        },
        'GradientBoosting': {
            'n_estimators': [100, 200, 300],
            'learning_rate': [0.05, 0.1, 0.2],
            'max_depth': [3, 5]
        }
    }
    
    def __init__(self):
        self.pruning_results = {}
    
    def analyze_model(self, model_name, optimized_configs, best_config, 
                      default_fitness=None, optimized_fitness=None):
        if model_name not in self.SEARCH_SPACE:
            return None
        
        search_space = self.SEARCH_SPACE[model_name]
        default_params = self.DEFAULT_HYPERPARAMS.get(model_name, {})
        
        total_combinations = 1
        for param, values in search_space.items():
            total_combinations *= len(values)
        
        param_frequency = {}
        for param in search_space.keys():
            param_frequency[param] = {}
            for value in search_space[param]:
                param_frequency[param][str(value)] = 0
        
        for config in optimized_configs:
            for param, value in config.items():
                if param in param_frequency:
                    str_value = str(value)
                    if str_value in param_frequency[param]:
                        param_frequency[param][str_value] += 1
        
        effective_space = 1
        pruned_params = {}
        for param, freq_dict in param_frequency.items():
            selected_values = [v for v, count in freq_dict.items() if count > 0]
            pruned_params[param] = {
                'original_choices': search_space[param],
                'selected_values': selected_values,
                'selection_frequency': freq_dict,
                'reduction_ratio': len(selected_values) / len(search_space[param])
            }
            effective_space *= max(1, len(selected_values))
        
        space_reduction = 1 - (effective_space / total_combinations)
        
        param_changes = {}
        for param in search_space.keys():
            default_val = default_params.get(param)
            optimized_val = best_config.get(param)
            param_changes[param] = {
                'default': default_val,
                'optimized': optimized_val,
                'changed': str(default_val) != str(optimized_val)
            }
        
        result = {
            'model_name': model_name,
            'pre_prune': {
                'search_space': search_space,
                'total_combinations': total_combinations,
                'default_params': default_params,
                'default_fitness': default_fitness,
                'n_hyperparameters': len(search_space)
            },
            'post_prune': {
                'best_config': best_config,
                'optimized_fitness': optimized_fitness,
                'effective_combinations': effective_space,
                'pruned_params': pruned_params,
                'n_selected_hyperparameters': sum(1 for p in pruned_params.values() 
                                                   if len(p['selected_values']) > 0)
            },
            'reduction_metrics': {
                'space_reduction_ratio': space_reduction,
                'space_reduction_percent': space_reduction * 100,
                'original_combinations': total_combinations,
                'effective_combinations': effective_space,
                'params_changed': sum(1 for p in param_changes.values() if p['changed']),
                'total_params': len(param_changes)
            },
            'param_changes': param_changes,
            'fitness_improvement': {
                'default': default_fitness,
                'optimized': optimized_fitness,
                'absolute_gain': (optimized_fitness - default_fitness) if default_fitness and optimized_fitness else None,
                'relative_gain_percent': ((optimized_fitness - default_fitness) / default_fitness * 100) 
                                         if default_fitness and optimized_fitness and default_fitness > 0 else None
            }
        }
        
        self.pruning_results[model_name] = result
        return result
    
    def get_summary_table(self):
        rows = []
        for model_name, result in self.pruning_results.items():
            rows.append({
                'Model': model_name,
                'Pre-Prune HP': result['pre_prune']['n_hyperparameters'],
                'Post-Prune HP': result['post_prune']['n_selected_hyperparameters'],
                'Pre-Prune Combos': result['pre_prune']['total_combinations'],
                'Post-Prune Combos': result['post_prune']['effective_combinations'],
                'Reduction (%)': f"{result['reduction_metrics']['space_reduction_percent']:.1f}",
                'Default AUC': f"{result['fitness_improvement']['default']:.4f}" if result['fitness_improvement']['default'] else 'N/A',
                'Optimized AUC': f"{result['fitness_improvement']['optimized']:.4f}" if result['fitness_improvement']['optimized'] else 'N/A',
                'AUC Gain (%)': f"{result['fitness_improvement']['relative_gain_percent']:.2f}" if result['fitness_improvement']['relative_gain_percent'] else 'N/A'
            })
        return pd.DataFrame(rows)
    
    def get_all_results(self):
        return self.pruning_results


class MetricsCalculator:
    
    @staticmethod
    def compute_all_metrics(y_true, y_pred, y_proba=None):
        metrics = {
            'accuracy': float(accuracy_score(y_true, y_pred)),
            'precision': float(precision_score(y_true, y_pred, zero_division=0)),
            'recall': float(recall_score(y_true, y_pred, zero_division=0)),
            'f1': float(f1_score(y_true, y_pred, zero_division=0)),
            'mcc': float(matthews_corrcoef(y_true, y_pred)),
            'kappa': float(cohen_kappa_score(y_true, y_pred)),
        }
        
        sens_spec = compute_sensitivity_specificity(y_true, y_pred)
        metrics.update(sens_spec)
        
        if y_proba is not None:
            metrics['roc_auc'] = float(roc_auc_score(y_true, y_proba))
            metrics['avg_precision'] = float(average_precision_score(y_true, y_proba))
        else:
            metrics['roc_auc'] = np.nan
            metrics['avg_precision'] = np.nan
        
        return metrics
    
    @staticmethod
    def compute_train_val_gap(train_metrics, val_metrics):
        gap = {}
        for key in train_metrics:
            if key in val_metrics and isinstance(train_metrics[key], (int, float)):
                if not np.isnan(train_metrics[key]) and not np.isnan(val_metrics[key]):
                    gap[f'{key}_gap'] = float(train_metrics[key] - val_metrics[key])
        return gap


class HyperparameterGrammar:
    
    def __init__(self, grammar_str=None):
        self.params = {}
        self.param_order = []
        if grammar_str:
            self._parse_string(grammar_str)
    
    def _parse_string(self, grammar_str):
        for line in grammar_str.strip().split('\n'):
            line = line.strip()
            if not line or line.startswith('#'):
                continue
            if '::=' in line:
                parts = line.split('::=')
                param = parts[0].strip().strip('<>').replace('-', '_').replace(' ', '')
                choices = [c.strip() for c in parts[1].strip().split('|') if c.strip()]
                self.params[param] = choices
                self.param_order.append(param)
    
    def get_choices(self, param):
        return self.params.get(param.strip('<>').replace('-', '_'), [])
    
    def total_search_space(self):
        space = 1
        for choices in self.params.values():
            space *= len(choices)
        return space


def convert_value(value_str):
    value_str = str(value_str).strip()
    if value_str.lower() == 'true':
        return True
    if value_str.lower() == 'false':
        return False
    if value_str.lower() == 'none':
        return None
    try:
        if 'e' in value_str.lower() or '.' in value_str:
            return float(value_str)
        return int(value_str)
    except ValueError:
        return value_str


class GEMapper:
    
    def __init__(self, grammar, max_wraps=2):
        self.grammar = grammar
        self.max_wraps = max_wraps
    
    def decode(self, chromosome):
        config = {}
        codon_idx = 0
        wraps = 0
        
        for param in self.grammar.param_order:
            choices = self.grammar.get_choices(param)
            if not choices:
                continue
            if codon_idx >= len(chromosome):
                codon_idx = 0
                wraps += 1
                if wraps > self.max_wraps:
                    return config, False
            codon = chromosome[codon_idx]
            choice_idx = codon % len(choices)
            config[param] = convert_value(choices[choice_idx])
            codon_idx += 1
        return config, True


@dataclass
class GEResult:
    best_config: Dict[str, Any]
    best_fitness: float
    best_chromosome: List[int]
    history: Dict[str, List]
    generations: int
    evaluations: int
    runtime_seconds: float


class GEOptimizer:
    
    def __init__(self, grammar, fitness_fn, pop_size=20, generations=8,
                 chromosome_length=None, codon_max=255, crossover_rate=0.8,
                 mutation_rate=0.1, tournament_size=3, elitism=2,
                 maximize=True, seed=None, verbose=True):
        
        self.grammar = grammar
        self.mapper = GEMapper(grammar)
        self.fitness_fn = fitness_fn
        self.pop_size = pop_size
        self.generations = generations
        self.codon_max = codon_max
        self.crossover_rate = crossover_rate
        self.mutation_rate = mutation_rate
        self.tournament_size = tournament_size
        self.elitism = elitism
        self.maximize = maximize
        self.verbose = verbose
        self.chromosome_length = chromosome_length or max(len(grammar.param_order) * 3, 20)
        
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
        
        self.history = {'best_fitness': [], 'avg_fitness': []}
    
    def _init_population(self):
        return [[random.randint(0, self.codon_max) for _ in range(self.chromosome_length)]
                for _ in range(self.pop_size)]
    
    def _evaluate(self, chrom):
        config, valid = self.mapper.decode(chrom)
        if not valid:
            return -float('inf') if self.maximize else float('inf'), config, False
        try:
            fitness = self.fitness_fn(config)
            return fitness, config, True
        except Exception:
            return -float('inf') if self.maximize else float('inf'), config, False
    
    def _tournament(self, pop, scores):
        indices = random.sample(range(len(pop)), self.tournament_size)
        if self.maximize:
            winner = max(indices, key=lambda i: scores[i])
        else:
            winner = min(indices, key=lambda i: scores[i])
        return list(pop[winner])
    
    def _crossover(self, p1, p2):
        if random.random() > self.crossover_rate:
            return p1[:], p2[:]
        pt = random.randint(1, len(p1) - 1)
        return p1[:pt] + p2[pt:], p2[:pt] + p1[pt:]
    
    def _mutate(self, chrom):
        for i in range(len(chrom)):
            if random.random() < self.mutation_rate:
                chrom[i] = random.randint(0, self.codon_max)
        return chrom
    
    def evolve(self):
        start_time = time.time()
        population = self._init_population()
        best_ever = (None, -float('inf') if self.maximize else float('inf'), None)
        total_evals = 0
        
        for gen in range(self.generations):
            scores = []
            configs = []
            for chrom in population:
                fit, cfg, valid = self._evaluate(chrom)
                scores.append(fit)
                configs.append(cfg)
            total_evals += len(population)
            
            valid_scores = [s for s in scores if s != -float('inf') and s != float('inf')]
            if valid_scores:
                if self.maximize:
                    best_idx = np.argmax(scores)
                    if scores[best_idx] > best_ever[1]:
                        best_ever = (list(population[best_idx]), scores[best_idx], configs[best_idx])
                else:
                    best_idx = np.argmin(scores)
                    if scores[best_idx] < best_ever[1]:
                        best_ever = (list(population[best_idx]), scores[best_idx], configs[best_idx])
                
                best_fit = max(valid_scores) if self.maximize else min(valid_scores)
                avg_fit = np.mean(valid_scores)
            else:
                best_fit = avg_fit = 0
            
            self.history['best_fitness'].append(best_fit)
            self.history['avg_fitness'].append(avg_fit)
            
            if self.verbose:
                print(f"    Gen {gen+1:2d}/{self.generations}: Best={best_fit:.4f}, Avg={avg_fit:.4f}")
            
            sorted_pop = sorted(zip(population, scores), key=lambda x: x[1], reverse=self.maximize)
            new_pop = [list(sorted_pop[i][0]) for i in range(self.elitism)]
            
            while len(new_pop) < self.pop_size:
                p1 = self._tournament(population, scores)
                p2 = self._tournament(population, scores)
                c1, c2 = self._crossover(p1, p2)
                new_pop.append(self._mutate(c1))
                if len(new_pop) < self.pop_size:
                    new_pop.append(self._mutate(c2))
            
            population = new_pop
        
        runtime = time.time() - start_time
        return GEResult(
            best_config=best_ever[2],
            best_fitness=best_ever[1],
            best_chromosome=best_ever[0],
            history=self.history,
            generations=len(self.history['best_fitness']),
            evaluations=total_evals,
            runtime_seconds=runtime
        )


def create_tabnet_grammar():
    return HyperparameterGrammar(grammar_str="""
<n_d> ::= 8 | 16 | 32 | 64
<n_a> ::= 8 | 16 | 32 | 64
<n_steps> ::= 3 | 5 | 7
<lambda_sparse> ::= 1e-4 | 1e-3 | 1e-2
""")


def create_random_forest_grammar():
    return HyperparameterGrammar(grammar_str="""
<n_estimators> ::= 100 | 200 | 300
<max_depth> ::= 5 | 10 | 15 | None
""")


def create_xgboost_grammar():
    return HyperparameterGrammar(grammar_str="""
<n_estimators> ::= 100 | 200 | 300
<learning_rate> ::= 0.01 | 0.05 | 0.1
<max_depth> ::= 3 | 5 | 7
""")


def create_svm_grammar():
    return HyperparameterGrammar(grammar_str="""
<C> ::= 0.1 | 1.0 | 10.0
<kernel> ::= linear | rbf
""")


def create_logistic_regression_grammar():
    return HyperparameterGrammar(grammar_str="""
<C> ::= 0.1 | 1.0 | 10.0
<penalty> ::= l1 | l2
""")


def create_gradient_boosting_grammar():
    return HyperparameterGrammar(grammar_str="""
<n_estimators> ::= 100 | 200 | 300
<learning_rate> ::= 0.05 | 0.1 | 0.2
<max_depth> ::= 3 | 5
""")


@dataclass
class Config:
    DATA_CSV: str = r"C:\Users\awwal\Desktop\MLEA_experiments\data.csv"
    TARGET_COLUMN: str = "diagnosis"
    TEST_SIZE: float = 0.1
    VAL_SIZE: float = 0.111
    
    # =====================================================
    # MODIFIED: Increased GE generations from 10 to 30
    # =====================================================
    GE_POP_SIZE: int = 20  # Increased population for better exploration
    GE_GENERATIONS: int = 30  # CHANGED: 10 -> 30
    GE_GENERATIONS_BASELINE: int = 30  # CHANGED: 10 -> 30
    GE_CROSSOVER_RATE: float = 0.85
    GE_MUTATION_RATE: float = 0.15
    GE_TOURNAMENT_SIZE: int = 3
    GE_ELITISM: int = 2
    
    # 30 independent runs for statistical robustness
    N_INDEPENDENT_RUNS: int = 30
    
    # Parallel execution settings
    N_PARALLEL_JOBS: int = -1  # -1 uses all available cores
    
    # =====================================================
    # MODIFIED: Improved TabNet settings for better performance
    # =====================================================
    TABNET_MAX_EPOCHS: int = 100  # CHANGED: 25 -> 100 (more training time)
    TABNET_PATIENCE: int = 15  # CHANGED: 6 -> 15 (more patience for convergence)
    TABNET_BATCH_SIZE: int = 128  # CHANGED: 256 -> 128 (smaller batches often help)
    TABNET_GAMMA: float = 1.5  # CHANGED: 1.3 -> 1.5 (coefficient for feature reusage)
    
    # Additional TabNet settings for improved performance
    TABNET_N_INDEPENDENT_FINAL: int = 5  # Train multiple final models and ensemble
    TABNET_VIRTUAL_BATCH_SIZE: int = 64  # Ghost batch normalization
    TABNET_MOMENTUM: float = 0.02  # Batch normalization momentum
    TABNET_MASK_TYPE: str = 'entmax'  # 'sparsemax' or 'entmax' for attention
    
    # SMOTE for class imbalance handling
    USE_SMOTE: bool = True
    
    RESULTS_DIR: str = "ge_results_modified"
    RANDOM_SEED: int = 42
    VERBOSE: bool = False  # Disabled for parallel execution
    
    # =====================================================
    # MODIFIED: Show plots interactively instead of saving
    # =====================================================
    SAVE_PLOTS: bool = True  # Still save plots
    SHOW_PLOTS: bool = True  # CHANGED: False -> True (show plots in VSCode)
    PLOT_DPI: int = 150


class TabNetOptimizer:
    
    def __init__(self, X_train, y_train, X_val, y_val, X_test, y_test, 
                 config, feature_names=None):
        self.X_train = X_train
        self.y_train = y_train
        self.X_val = X_val
        self.y_val = y_val
        self.X_test = X_test
        self.y_test = y_test
        self.config = config
        self.feature_names = feature_names or [f"feature_{i}" for i in range(X_train.shape[1])]
    
    def _fitness_holdout(self, params, seed_offset=0):
        """Evaluate fitness using holdout validation with SMOTE."""
        try:
            X_tr = self.X_train.copy()
            y_tr = self.y_train.copy()
            
            # Apply SMOTE for class imbalance
            if self.config.USE_SMOTE and SMOTE_AVAILABLE:
                smote = SMOTE(random_state=self.config.RANDOM_SEED + seed_offset)
                X_tr, y_tr = smote.fit_resample(X_tr, y_tr)
            
            model = TabNetClassifier(
                n_d=params['n_d'],
                n_a=params['n_a'],
                n_steps=params['n_steps'],
                gamma=self.config.TABNET_GAMMA,
                lambda_sparse=params['lambda_sparse'],
                momentum=self.config.TABNET_MOMENTUM,
                mask_type=self.config.TABNET_MASK_TYPE,
                verbose=0,
                seed=self.config.RANDOM_SEED + seed_offset
            )
            
            model.fit(
                X_tr, y_tr,
                eval_set=[(self.X_val, self.y_val)],
                eval_metric=['auc'],
                max_epochs=self.config.TABNET_MAX_EPOCHS,
                patience=self.config.TABNET_PATIENCE,
                batch_size=self.config.TABNET_BATCH_SIZE,
                virtual_batch_size=self.config.TABNET_VIRTUAL_BATCH_SIZE,
                drop_last=False
            )
            
            y_pred_proba = model.predict_proba(self.X_val)[:, 1]
            return roc_auc_score(self.y_val, y_pred_proba)
        except Exception as e:
            print(f"    [Warning] TabNet training error: {e}")
            return 0.0
    
    def evaluate_default_params(self):
        default_params = HyperparameterPruningAnalyzer.DEFAULT_HYPERPARAMS.get('TabNet', {})
        return self._fitness_holdout(default_params, seed_offset=9999)
    
    def _run_single_optimization(self, run_idx):
        """Run a single GE optimization - designed for parallel execution."""
        grammar = create_tabnet_grammar()
        seed = self.config.RANDOM_SEED + run_idx * 1000
        
        fitness_fn = lambda p: self._fitness_holdout(p, seed_offset=run_idx * 100)
        
        optimizer = GEOptimizer(
            grammar=grammar,
            fitness_fn=fitness_fn,
            pop_size=self.config.GE_POP_SIZE,
            generations=self.config.GE_GENERATIONS,
            crossover_rate=self.config.GE_CROSSOVER_RATE,
            mutation_rate=self.config.GE_MUTATION_RATE,
            tournament_size=self.config.GE_TOURNAMENT_SIZE,
            elitism=self.config.GE_ELITISM,
            maximize=True,
            seed=seed,
            verbose=False  # Disable verbose for parallel
        )
        
        result = optimizer.evolve()
        result.best_config['gamma'] = self.config.TABNET_GAMMA
        
        return {
            'run_idx': run_idx,
            'best_config': result.best_config,
            'best_fitness': result.best_fitness,
            'generations': result.generations,
            'evaluations': result.evaluations,
            'runtime': result.runtime_seconds,
            'history': {
                'best_fitness': result.history['best_fitness'],
                'avg_fitness': result.history['avg_fitness']
            }
        }
    
    def optimize_parallel(self, n_jobs=-1):
        """Run all independent optimizations in parallel using joblib."""
        print(f"  Running {self.config.N_INDEPENDENT_RUNS} optimizations in parallel...")
        
        results = Parallel(n_jobs=n_jobs, verbose=10)(
            delayed(self._run_single_optimization)(run_idx) 
            for run_idx in range(self.config.N_INDEPENDENT_RUNS)
        )
        
        return results
    
    def optimize_single_run(self, run_idx):
        """Run a single optimization (for sequential execution)."""
        return self._run_single_optimization(run_idx)
    
    def train_final_model(self, params, seed=42):
        X_train_full = np.vstack([self.X_train, self.X_val])
        y_train_full = np.concatenate([self.y_train, self.y_val])
        
        if self.config.USE_SMOTE and SMOTE_AVAILABLE:
            smote = SMOTE(random_state=seed)
            X_train_full, y_train_full = smote.fit_resample(X_train_full, y_train_full)
        
        model = TabNetClassifier(
            n_d=params['n_d'],
            n_a=params['n_a'],
            n_steps=params['n_steps'],
            gamma=params.get('gamma', self.config.TABNET_GAMMA),
            lambda_sparse=params['lambda_sparse'],
            momentum=self.config.TABNET_MOMENTUM,
            mask_type=self.config.TABNET_MASK_TYPE,
            verbose=0,
            seed=seed
        )
        
        X_tr, X_es, y_tr, y_es = train_test_split(
            X_train_full, y_train_full, test_size=0.1, 
            stratify=y_train_full, random_state=seed
        )
        
        model.fit(
            X_tr, y_tr,
            eval_set=[(X_es, y_es)],
            eval_metric=['auc'],
            max_epochs=self.config.TABNET_MAX_EPOCHS,
            patience=self.config.TABNET_PATIENCE,
            batch_size=self.config.TABNET_BATCH_SIZE,
            virtual_batch_size=self.config.TABNET_VIRTUAL_BATCH_SIZE,
            drop_last=False
        )
        
        results = {}
        
        y_train_pred = model.predict(X_tr)
        y_train_proba = model.predict_proba(X_tr)[:, 1]
        results['train_metrics'] = MetricsCalculator.compute_all_metrics(y_tr, y_train_pred, y_train_proba)
        
        y_test_pred = model.predict(self.X_test)
        y_test_proba = model.predict_proba(self.X_test)[:, 1]
        results['test_metrics'] = MetricsCalculator.compute_all_metrics(self.y_test, y_test_pred, y_test_proba)
        
        results['train_val_gap'] = MetricsCalculator.compute_train_val_gap(
            results['train_metrics'], results['test_metrics']
        )
        
        results['feature_importance'] = self._get_feature_importance(model)
        
        return model, results
    
    def _get_feature_importance(self, model):
        try:
            importance = model.feature_importances_
            importance_dict = {}
            for i, imp in enumerate(importance):
                importance_dict[self.feature_names[i]] = float(imp)
            
            sorted_importance = sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)
            
            return {
                'raw_importance': importance.tolist(),
                'feature_ranking': sorted_importance,
                'top_10_features': sorted_importance[:10],
                'feature_names': self.feature_names
            }
        except Exception as e:
            return {'error': str(e)}
    
class BaselineOptimizer:
    
    MODEL_GRAMMARS = {
        'RandomForest': create_random_forest_grammar,
        'LogisticRegression': create_logistic_regression_grammar,
        'SVM': create_svm_grammar,
        'GradientBoosting': create_gradient_boosting_grammar,
        'XGBoost': create_xgboost_grammar,
    }
    
    FIXED_PARAMS = {
        'RandomForest': {'max_features': 'sqrt', 'min_samples_split': 2},
        'LogisticRegression': {'max_iter': 1000},
        'SVM': {'gamma': 'scale'},
        'GradientBoosting': {'subsample': 0.8},
        'XGBoost': {'subsample': 0.8, 'colsample_bytree': 0.8},
    }
    
    def __init__(self, X_train, y_train, X_val, y_val, X_test, y_test, 
                 config, feature_names=None):
        self.X_train = X_train
        self.y_train = y_train
        self.X_val = X_val
        self.y_val = y_val
        self.X_test = X_test
        self.y_test = y_test
        self.config = config
        self.feature_names = feature_names or [f"feature_{i}" for i in range(X_train.shape[1])]
    
    def _create_model(self, model_type, params, seed):
        full_params = {**self.FIXED_PARAMS.get(model_type, {}), **params}
        
        if model_type == 'RandomForest':
            max_depth = full_params.get('max_depth')
            if max_depth == 'None' or max_depth is None:
                max_depth = None
            return RandomForestClassifier(
                n_estimators=full_params['n_estimators'],
                max_depth=max_depth,
                max_features=full_params.get('max_features', 'sqrt'),
                random_state=seed,
                n_jobs=-1
            )
        
        elif model_type == 'LogisticRegression':
            penalty = full_params.get('penalty', 'l2')
            solver = 'liblinear' if penalty == 'l1' else 'lbfgs'
            return LogisticRegression(
                C=full_params['C'],
                penalty=penalty,
                solver=solver,
                max_iter=full_params.get('max_iter', 1000),
                random_state=seed
            )
        
        elif model_type == 'SVM':
            return SVC(
                C=full_params['C'],
                kernel=full_params['kernel'],
                gamma=full_params.get('gamma', 'scale'),
                probability=True,
                random_state=seed
            )
        
        elif model_type == 'GradientBoosting':
            return GradientBoostingClassifier(
                n_estimators=full_params['n_estimators'],
                learning_rate=full_params['learning_rate'],
                max_depth=full_params['max_depth'],
                subsample=full_params.get('subsample', 0.8),
                random_state=seed
            )
        
        elif model_type == 'XGBoost':
            if not XGBOOST_AVAILABLE:
                raise ImportError("XGBoost not available")
            return XGBClassifier(
                n_estimators=full_params['n_estimators'],
                learning_rate=full_params['learning_rate'],
                max_depth=full_params['max_depth'],
                subsample=full_params.get('subsample', 0.8),
                colsample_bytree=full_params.get('colsample_bytree', 0.8),
                random_state=seed,
                use_label_encoder=False,
                eval_metric='logloss',
                verbosity=0
            )
        else:
            raise ValueError(f"Unknown model type: {model_type}")
    
    def _fitness_holdout(self, model_type, params, seed_offset=0):
        """Evaluate fitness using holdout validation (no CV for speed)."""
        try:
            seed = self.config.RANDOM_SEED + seed_offset
            model = self._create_model(model_type, params, seed)
            
            X_tr = self.X_train.copy()
            y_tr = self.y_train.copy()
            
            if self.config.USE_SMOTE and SMOTE_AVAILABLE:
                smote = SMOTE(random_state=seed)
                X_tr, y_tr = smote.fit_resample(X_tr, y_tr)
            
            model.fit(X_tr, y_tr)
            
            if hasattr(model, 'predict_proba'):
                y_pred_proba = model.predict_proba(self.X_val)[:, 1]
                return roc_auc_score(self.y_val, y_pred_proba)
            else:
                y_pred = model.predict(self.X_val)
                return accuracy_score(self.y_val, y_pred)
        except Exception:
            return 0.0
    
    def evaluate_default_params(self, model_type):
        default_params = HyperparameterPruningAnalyzer.DEFAULT_HYPERPARAMS.get(model_type, {})
        search_params = {}
        search_space = HyperparameterPruningAnalyzer.SEARCH_SPACE.get(model_type, {})
        for param in search_space.keys():
            if param in default_params:
                search_params[param] = default_params[param]
        return self._fitness_holdout(model_type, search_params, seed_offset=9999)
    
    def train_final_model(self, model_type, params, seed=42):
        X_train_full = np.vstack([self.X_train, self.X_val])
        y_train_full = np.concatenate([self.y_train, self.y_val])
        
        if self.config.USE_SMOTE and SMOTE_AVAILABLE:
            smote = SMOTE(random_state=seed)
            X_train_full, y_train_full = smote.fit_resample(X_train_full, y_train_full)
        
        model = self._create_model(model_type, params, seed)
        model.fit(X_train_full, y_train_full)
        
        results = {}
        
        y_train_pred = model.predict(X_train_full)
        y_train_proba = model.predict_proba(X_train_full)[:, 1] if hasattr(model, 'predict_proba') else None
        results['train_metrics'] = MetricsCalculator.compute_all_metrics(X_train_full, y_train_pred, y_train_proba)
        
        y_test_pred = model.predict(self.X_test)
        y_test_proba = model.predict_proba(self.X_test)[:, 1] if hasattr(model, 'predict_proba') else None
        results['test_metrics'] = MetricsCalculator.compute_all_metrics(self.y_test, y_test_pred, y_test_proba)
        
        results['train_test_gap'] = MetricsCalculator.compute_train_val_gap(
            results['train_metrics'], results['test_metrics']
        )
        
        results['feature_importance'] = self._get_feature_importance(model, model_type)
        
        return model, results
    
    def _get_feature_importance(self, model, model_type):
        try:
            importance = None
            
            if model_type in ['RandomForest', 'GradientBoosting', 'XGBoost']:
                importance = model.feature_importances_
            elif model_type == 'LogisticRegression':
                importance = np.abs(model.coef_[0])
            elif model_type == 'SVM':
                if model.kernel == 'linear':
                    importance = np.abs(model.coef_[0])
                else:
                    return {'note': 'Feature importance not available for non-linear SVM'}
            
            if importance is not None:
                if importance.sum() > 0:
                    importance = importance / importance.sum()
                
                importance_dict = {}
                for i, imp in enumerate(importance):
                    importance_dict[self.feature_names[i]] = float(imp)
                
                sorted_importance = sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)
                
                return {
                    'raw_importance': importance.tolist(),
                    'feature_ranking': sorted_importance,
                    'top_10_features': sorted_importance[:10]
                }
            
            return {'note': 'Feature importance not available'}
        except Exception as e:
            return {'error': str(e)}
    
    def _run_single_optimization(self, model_type, run_idx):
        """Run a single GE optimization - designed for parallel execution."""
        if model_type not in self.MODEL_GRAMMARS:
            raise ValueError(f"Unknown model: {model_type}")
        
        grammar = self.MODEL_GRAMMARS[model_type]()
        seed = self.config.RANDOM_SEED + run_idx * 1000
        
        fitness_fn = lambda p: self._fitness_holdout(model_type, p, seed_offset=run_idx * 100)
        
        optimizer = GEOptimizer(
            grammar=grammar,
            fitness_fn=fitness_fn,
            pop_size=self.config.GE_POP_SIZE,
            generations=self.config.GE_GENERATIONS_BASELINE,
            crossover_rate=self.config.GE_CROSSOVER_RATE,
            mutation_rate=self.config.GE_MUTATION_RATE,
            tournament_size=self.config.GE_TOURNAMENT_SIZE,
            elitism=self.config.GE_ELITISM,
            maximize=True,
            seed=seed,
            verbose=False  # Disable verbose for parallel
        )
        
        result = optimizer.evolve()
        full_config = {**self.FIXED_PARAMS.get(model_type, {}), **result.best_config}
        
        return {
            'run_idx': run_idx,
            'best_config': full_config,
            'best_fitness': result.best_fitness,
            'generations': result.generations,
            'evaluations': result.evaluations,
            'runtime': result.runtime_seconds,
            'history': {
                'best_fitness': result.history['best_fitness'],
                'avg_fitness': result.history['avg_fitness']
            }
        }
    
    def optimize_parallel(self, model_type, n_jobs=-1):
        """Run all independent optimizations in parallel using joblib."""
        print(f"  Running {self.config.N_INDEPENDENT_RUNS} optimizations in parallel...")
        
        results = Parallel(n_jobs=n_jobs, verbose=10)(
            delayed(self._run_single_optimization)(model_type, run_idx) 
            for run_idx in range(self.config.N_INDEPENDENT_RUNS)
        )
        
        return results
    
    def optimize_single_run(self, model_type, run_idx):
        """Run a single optimization (for sequential execution)."""
        return self._run_single_optimization(model_type, run_idx)


class ResultsPlotter:
    
    def __init__(self, results, config, output_dir=None):
        self.results = results
        self.config = config
        self.output_dir = output_dir or config.RESULTS_DIR
        os.makedirs(self.output_dir, exist_ok=True)
        
        self.colors = {
            'TabNet': '#2ecc71',
            'RandomForest': '#3498db',
            'XGBoost': '#e74c3c',
            'SVM': '#9b59b6',
            'LogisticRegression': '#f39c12',
            'GradientBoosting': '#1abc9c'
        }
    
    def _get_color(self, name):
        return self.colors.get(name, '#95a5a6')
    
    def _save_and_show(self, filename):
        """Helper method to save and optionally show plots."""
        if self.config.SAVE_PLOTS:
            plt.savefig(os.path.join(self.output_dir, filename),
                       dpi=self.config.PLOT_DPI, bbox_inches='tight')
        if self.config.SHOW_PLOTS:
            plt.show()
        else:
            plt.close()
    
    def plot_all(self):
        print("\n" + "=" * 70)
        print("GENERATING PLOTS")
        print("=" * 70)
        
        plot_methods = [
            ('Test performance comparison', self.plot_test_performance_comparison),
            ('Sensitivity/Specificity', self.plot_sensitivity_specificity),
            ('Train-Test gap', self.plot_train_test_gap),
            ('Boxplot (runs)', self.plot_runs_boxplot),
            ('Convergence curves', self.plot_convergence_summary),
            ('Best vs Average fitness', self.plot_best_vs_average_fitness),
            ('Feature importance (all models)', self.plot_all_feature_importance),
            ('Feature importance heatmap', self.plot_feature_importance_heatmap),
            ('HP pruning analysis', self.plot_hyperparameter_pruning),
            ('HP count comparison', self.plot_hp_count_comparison),
            ('Pruning tables', self.plot_pruning_comparison_table),
        ]
        
        for idx, (name, method) in enumerate(plot_methods, 1):
            print(f"  [{idx}/{len(plot_methods)}] {name}...", end=" ", flush=True)
            try:
                method()
                print("done")
            except Exception as e:
                print(f"failed ({e})")
                plt.close('all')
        
        print(f"\nAll plots saved to: {self.output_dir}")
        
        if self.config.SHOW_PLOTS:
            print("\nPlots are being displayed. Close plot windows to continue...")
            plt.ioff()  # Turn off interactive mode
            plt.show()  # This will block until all windows are closed
    
    def plot_test_performance_comparison(self):
        models_data = self.results.get('models', {})
        
        model_names = []
        test_aucs = []
        test_accs = []
        test_f1s = []
        
        for name, data in models_data.items():
            if 'best_run_test_metrics' in data:
                model_names.append(name)
                metrics = data['best_run_test_metrics']
                test_aucs.append(metrics.get('roc_auc', 0))
                test_accs.append(metrics.get('accuracy', 0))
                test_f1s.append(metrics.get('f1', 0))
        
        if not model_names:
            return
        
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        fig.suptitle('Test Set Performance Comparison', fontsize=14, fontweight='bold')
        
        x = np.arange(len(model_names))
        colors = [self._get_color(name) for name in model_names]
        
        for ax, values, title, ylabel in [
            (axes[0], test_aucs, 'ROC-AUC', 'AUC'),
            (axes[1], test_accs, 'Accuracy', 'Accuracy'),
            (axes[2], test_f1s, 'F1 Score', 'F1')
        ]:
            bars = ax.bar(x, values, color=colors, edgecolor='black', linewidth=1)
            for bar, val in zip(bars, values):
                ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                       f'{val:.4f}', ha='center', va='bottom', fontsize=9)
            ax.set_xticks(x)
            ax.set_xticklabels(model_names, rotation=45, ha='right')
            ax.set_ylabel(ylabel)
            ax.set_title(title)
            ax.set_ylim(0, 1.1)
            ax.grid(True, alpha=0.3, axis='y')
        
        plt.tight_layout()
        self._save_and_show('test_performance.png')
    
    def plot_sensitivity_specificity(self):
        models_data = self.results.get('models', {})
        
        model_names = []
        sensitivities = []
        specificities = []
        
        for name, data in models_data.items():
            if 'best_run_test_metrics' in data:
                model_names.append(name)
                metrics = data['best_run_test_metrics']
                sensitivities.append(metrics.get('sensitivity', 0))
                specificities.append(metrics.get('specificity', 0))
        
        if not model_names:
            return
        
        fig, ax = plt.subplots(figsize=(12, 6))
        
        x = np.arange(len(model_names))
        width = 0.35
        
        bars1 = ax.bar(x - width/2, sensitivities, width, label='Sensitivity (TPR)',
                      color='#e74c3c', edgecolor='black')
        bars2 = ax.bar(x + width/2, specificities, width, label='Specificity (TNR)',
                      color='#3498db', edgecolor='black')
        
        for bars in [bars1, bars2]:
            for bar in bars:
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2, height + 0.01,
                       f'{height:.3f}', ha='center', va='bottom', fontsize=9)
        
        ax.set_xticks(x)
        ax.set_xticklabels(model_names, rotation=45, ha='right')
        ax.set_ylabel('Score')
        ax.set_title('Sensitivity vs Specificity (Test Set)', fontsize=12, fontweight='bold')
        ax.set_ylim(0, 1.15)
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')
        
        plt.tight_layout()
        self._save_and_show('sensitivity_specificity.png')
    
    def plot_train_test_gap(self):
        models_data = self.results.get('models', {})
        
        model_names = []
        train_aucs = []
        test_aucs = []
        gaps = []
        
        for name, data in models_data.items():
            if 'best_run_train_metrics' in data and 'best_run_test_metrics' in data:
                model_names.append(name)
                train_auc = data['best_run_train_metrics'].get('roc_auc', 0)
                test_auc = data['best_run_test_metrics'].get('roc_auc', 0)
                train_aucs.append(train_auc)
                test_aucs.append(test_auc)
                gaps.append(train_auc - test_auc)
        
        if not model_names:
            return
        
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))
        fig.suptitle('Train vs Test Performance', fontsize=14, fontweight='bold')
        
        x = np.arange(len(model_names))
        width = 0.35
        
        axes[0].bar(x - width/2, train_aucs, width, label='Train AUC', color='#2ecc71', edgecolor='black')
        axes[0].bar(x + width/2, test_aucs, width, label='Test AUC', color='#e74c3c', edgecolor='black')
        axes[0].set_xticks(x)
        axes[0].set_xticklabels(model_names, rotation=45, ha='right')
        axes[0].set_ylabel('AUC')
        axes[0].set_title('Train vs Test AUC')
        axes[0].set_ylim(0, 1.1)
        axes[0].legend()
        axes[0].grid(True, alpha=0.3, axis='y')
        
        colors = ['#e74c3c' if g > 0.05 else '#f39c12' if g > 0.02 else '#2ecc71' for g in gaps]
        bars = axes[1].bar(x, gaps, color=colors, edgecolor='black')
        
        for bar, gap in zip(bars, gaps):
            axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005,
                        f'{gap:.3f}', ha='center', va='bottom', fontsize=9)
        
        axes[1].axhline(y=0.05, color='red', linestyle='--', label='High overfit')
        axes[1].axhline(y=0.02, color='orange', linestyle='--', label='Moderate overfit')
        axes[1].set_xticks(x)
        axes[1].set_xticklabels(model_names, rotation=45, ha='right')
        axes[1].set_ylabel('AUC Gap (Train - Test)')
        axes[1].set_title('Overfitting Gap')
        axes[1].legend(loc='upper right')
        axes[1].grid(True, alpha=0.3, axis='y')
        
        plt.tight_layout()
        self._save_and_show('train_test_gap.png')
    
    def plot_runs_boxplot(self):
        models_data = self.results.get('models', {})
        
        model_names = []
        all_fitness = []
        
        for name, data in models_data.items():
            if 'all_runs' in data:
                valid_runs = [r['best_fitness'] for r in data['all_runs'] if 'error' not in r]
                if valid_runs:
                    model_names.append(name)
                    all_fitness.append(valid_runs)
        
        if not model_names:
            return
        
        fig, ax = plt.subplots(figsize=(12, 6))
        colors = [self._get_color(name) for name in model_names]
        
        bp = ax.boxplot(all_fitness, labels=model_names, patch_artist=True)
        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)
            patch.set_alpha(0.7)
        
        ax.set_ylabel('Best Fitness (Validation AUC)', fontsize=11)
        ax.set_title(f'Distribution Across {self.config.N_INDEPENDENT_RUNS} Independent Runs', 
                    fontsize=12, fontweight='bold')
        ax.tick_params(axis='x', rotation=45)
        ax.grid(True, alpha=0.3, axis='y')
        
        plt.tight_layout()
        self._save_and_show('boxplot_runs.png')
    
    def plot_convergence_summary(self):
        models_data = self.results.get('models', {})
        
        fig, ax = plt.subplots(figsize=(12, 6))
        
        has_data = False
        for name, data in models_data.items():
            if 'convergence_summary' in data:
                conv = data['convergence_summary']
                best_mean = conv.get('best_mean') or conv.get('mean', [])
                best_std = conv.get('best_std') or conv.get('std', [])
                if best_mean:
                    has_data = True
                    generations = range(len(best_mean))
                    mean = np.array(best_mean)
                    std = np.array(best_std) if best_std else np.zeros_like(mean)
                    
                    color = self._get_color(name)
                    ax.plot(generations, mean, label=name, color=color, linewidth=2)
                    ax.fill_between(generations, mean - std, mean + std, color=color, alpha=0.2)
        
        if not has_data:
            plt.close()
            return
        
        ax.set_xlabel('Generation', fontsize=11)
        ax.set_ylabel('Best Fitness (AUC)', fontsize=11)
        ax.set_title(f'Mean Convergence ({self.config.N_INDEPENDENT_RUNS} Runs) - {self.config.GE_GENERATIONS} Generations', 
                    fontsize=12, fontweight='bold')
        ax.legend(loc='lower right')
        ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        self._save_and_show('convergence.png')
    
    def plot_best_vs_average_fitness(self):
        models_data = self.results.get('models', {})
        n_models = len(models_data)
        
        if n_models == 0:
            return
        
        n_cols = min(3, n_models)
        n_rows = (n_models + n_cols - 1) // n_cols
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 5*n_rows))
        fig.suptitle('Best Individual vs Population Average Fitness', fontsize=14, fontweight='bold')
        
        if n_models == 1:
            axes = np.array([axes])
        axes = axes.flatten() if n_models > 1 else axes
        
        plot_idx = 0
        for name, data in models_data.items():
            if plot_idx >= len(axes):
                break
            
            ax = axes[plot_idx] if n_models > 1 else axes[0]
            color = self._get_color(name)
            
            conv = data.get('convergence_summary', {})
            
            best_mean = np.array(conv.get('best_mean', []))
            best_std = np.array(conv.get('best_std', []))
            avg_mean = np.array(conv.get('avg_mean', []))
            avg_std = np.array(conv.get('avg_std', []))
            
            if len(best_mean) == 0:
                ax.text(0.5, 0.5, 'No convergence data', ha='center', va='center',
                       transform=ax.transAxes, fontsize=11)
                ax.set_title(name, fontweight='bold')
                plot_idx += 1
                continue
            
            generations = np.arange(len(best_mean))
            
            ax.plot(generations, best_mean, color=color, linewidth=2.5,
                   label='Best Individual', marker='o', markersize=5,
                   markevery=max(1, len(generations)//8))
            if len(best_std) > 0:
                ax.fill_between(generations, best_mean - best_std, best_mean + best_std,
                              color=color, alpha=0.2)
            
            if len(avg_mean) > 0:
                ax.plot(generations, avg_mean, color='#7f8c8d', linewidth=2,
                       linestyle='--', label='Population Average', marker='s', 
                       markersize=4, markevery=max(1, len(generations)//8))
                if len(avg_std) > 0:
                    ax.fill_between(generations, avg_mean - avg_std, avg_mean + avg_std,
                                  color='#7f8c8d', alpha=0.15)
            
            final_best = best_mean[-1]
            final_avg = avg_mean[-1] if len(avg_mean) > 0 else 0
            gap = final_best - final_avg if len(avg_mean) > 0 else 0
            
            stats_text = f'Final Best: {final_best:.4f}'
            if len(avg_mean) > 0:
                stats_text += f'\nFinal Avg: {final_avg:.4f}\nGap: {gap:.4f}'
            
            ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, fontsize=9,
                   verticalalignment='top', family='monospace',
                   bbox=dict(boxstyle='round,pad=0.3', facecolor='white', 
                            alpha=0.9, edgecolor='gray'))
            
            ax.set_xlabel('Generation', fontsize=10)
            ax.set_ylabel('Fitness (AUC)', fontsize=10)
            ax.set_title(name, fontweight='bold', fontsize=12)
            ax.legend(loc='lower right', fontsize=9)
            ax.grid(True, alpha=0.3)
            ax.set_ylim(0.5, 1.02)
            
            plot_idx += 1
        
        for idx in range(plot_idx, len(axes)):
            axes[idx].set_visible(False)
        
        plt.tight_layout()
        self._save_and_show('best_vs_average_fitness.png')
    
    def plot_all_feature_importance(self):
        """Plot feature importance for ALL models in a grid layout."""
        models_data = self.results.get('models', {})
        
        models_with_importance = {}
        for name, data in models_data.items():
            fi = data.get('feature_importance', {})
            if 'top_10_features' in fi and fi['top_10_features']:
                models_with_importance[name] = fi['top_10_features']
        
        if not models_with_importance:
            return
        
        n_models = len(models_with_importance)
        n_cols = min(3, n_models)
        n_rows = (n_models + n_cols - 1) // n_cols
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(7*n_cols, 5*n_rows))
        fig.suptitle('Feature Importance by Model (Top 10 Features)', fontsize=14, fontweight='bold')
        
        if n_models == 1:
            axes = np.array([[axes]])
        elif n_rows == 1:
            axes = axes.reshape(1, -1)
        
        axes_flat = axes.flatten()
        
        for idx, (model_name, top_features) in enumerate(models_with_importance.items()):
            ax = axes_flat[idx]
            
            features = [f[0][:20] for f in top_features]
            importances = [f[1] for f in top_features]
            
            y_pos = np.arange(len(features))
            color = self._get_color(model_name)
            bars = ax.barh(y_pos, importances, color=color, edgecolor='black', alpha=0.8)
            
            ax.set_yticks(y_pos)
            ax.set_yticklabels(features, fontsize=9)
            ax.invert_yaxis()
            ax.set_xlabel('Importance', fontsize=10)
            ax.set_title(model_name, fontsize=12, fontweight='bold')
            ax.grid(True, alpha=0.3, axis='x')
            
            for bar, imp in zip(bars, importances):
                ax.text(bar.get_width() + 0.001, bar.get_y() + bar.get_height()/2,
                       f'{imp:.4f}', va='center', fontsize=8)
        
        for idx in range(n_models, len(axes_flat)):
            axes_flat[idx].set_visible(False)
        
        plt.tight_layout()
        self._save_and_show('feature_importance_all_models.png')
    
    def plot_feature_importance_heatmap(self):
        """Plot comparative feature importance heatmap across all models."""
        models_data = self.results.get('models', {})
        
        all_importance = {}
        all_features = set()
        
        for name, data in models_data.items():
            fi = data.get('feature_importance', {})
            if 'feature_ranking' in fi:
                all_importance[name] = dict(fi['feature_ranking'])
                all_features.update(all_importance[name].keys())
            elif 'top_10_features' in fi:
                all_importance[name] = dict(fi['top_10_features'])
                all_features.update(all_importance[name].keys())
        
        if not all_importance or len(all_importance) < 2:
            return
        
        avg_importance = {}
        for feat in all_features:
            values = [all_importance[m].get(feat, 0) for m in all_importance.keys()]
            avg_importance[feat] = np.mean(values)
        
        top_features = sorted(avg_importance.items(), key=lambda x: x[1], reverse=True)[:15]
        top_feature_names = [f[0] for f in top_features]
        
        model_names = list(all_importance.keys())
        heatmap_data = np.zeros((len(top_feature_names), len(model_names)))
        
        for j, model in enumerate(model_names):
            for i, feat in enumerate(top_feature_names):
                heatmap_data[i, j] = all_importance[model].get(feat, 0)
        
        fig, ax = plt.subplots(figsize=(12, 10))
        
        im = ax.imshow(heatmap_data, cmap='YlOrRd', aspect='auto')
        
        ax.set_xticks(np.arange(len(model_names)))
        ax.set_yticks(np.arange(len(top_feature_names)))
        ax.set_xticklabels(model_names, rotation=45, ha='right', fontsize=11)
        ax.set_yticklabels([f[:25] for f in top_feature_names], fontsize=10)
        
        for i in range(len(top_feature_names)):
            for j in range(len(model_names)):
                val = heatmap_data[i, j]
                color = 'white' if val > heatmap_data.max() * 0.5 else 'black'
                ax.text(j, i, f'{val:.3f}', ha='center', va='center',
                       color=color, fontsize=8, fontweight='bold')
        
        ax.set_title('Comparative Feature Importance Across Models', fontsize=14, fontweight='bold')
        plt.colorbar(im, ax=ax, shrink=0.8, label='Importance')
        
        plt.tight_layout()
        self._save_and_show('feature_importance_heatmap.png')
    
    def plot_hyperparameter_pruning(self):
        """Plot search space reduction analysis."""
        pruning_data = self.results.get('hyperparameter_pruning', {})
        
        if not pruning_data:
            return
        
        model_names = list(pruning_data.keys())
        pre_prune_combos = []
        post_prune_combos = []
        reduction_pct = []
        default_aucs = []
        optimized_aucs = []
        
        for name in model_names:
            data = pruning_data[name]
            pre_prune_combos.append(data['pre_prune']['total_combinations'])
            post_prune_combos.append(data['post_prune']['effective_combinations'])
            reduction_pct.append(data['reduction_metrics']['space_reduction_percent'])
            default_aucs.append(data['fitness_improvement'].get('default', 0))
            optimized_aucs.append(data['fitness_improvement'].get('optimized', 0))
        
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        fig.suptitle('GE Hyperparameter Pruning Analysis', fontsize=14, fontweight='bold')
        
        x = np.arange(len(model_names))
        width = 0.35
        
        ax1 = axes[0, 0]
        bars1 = ax1.bar(x - width/2, pre_prune_combos, width, label='Pre-Prune',
                       color='#e74c3c', edgecolor='black')
        bars2 = ax1.bar(x + width/2, post_prune_combos, width, label='Post-Prune',
                       color='#2ecc71', edgecolor='black')
        ax1.set_xticks(x)
        ax1.set_xticklabels(model_names, rotation=45, ha='right')
        ax1.set_ylabel('Number of Combinations')
        ax1.set_title('Search Space Size (log scale)')
        ax1.legend()
        ax1.grid(True, alpha=0.3, axis='y')
        ax1.set_yscale('log')
        
        ax2 = axes[0, 1]
        colors = [self._get_color(name) for name in model_names]
        bars = ax2.bar(x, reduction_pct, color=colors, edgecolor='black')
        for bar, pct in zip(bars, reduction_pct):
            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                    f'{pct:.1f}%', ha='center', va='bottom', fontsize=9, fontweight='bold')
        ax2.set_xticks(x)
        ax2.set_xticklabels(model_names, rotation=45, ha='right')
        ax2.set_ylabel('Reduction (%)')
        ax2.set_title('Search Space Reduction')
        ax2.grid(True, alpha=0.3, axis='y')
        ax2.set_ylim(0, 100)
        
        ax3 = axes[1, 0]
        valid_idx = [i for i, (d, o) in enumerate(zip(default_aucs, optimized_aucs)) 
                    if d and o and d > 0 and o > 0]
        if valid_idx:
            valid_models = [model_names[i] for i in valid_idx]
            valid_default = [default_aucs[i] for i in valid_idx]
            valid_optimized = [optimized_aucs[i] for i in valid_idx]
            
            x_valid = np.arange(len(valid_models))
            ax3.bar(x_valid - width/2, valid_default, width, label='Default HP',
                   color='#95a5a6', edgecolor='black')
            ax3.bar(x_valid + width/2, valid_optimized, width, label='GE-Optimized',
                   color='#2ecc71', edgecolor='black')
            ax3.set_xticks(x_valid)
            ax3.set_xticklabels(valid_models, rotation=45, ha='right')
            ax3.set_ylabel('Validation AUC')
            ax3.set_title('Default vs Optimized Performance')
            ax3.legend()
            ax3.grid(True, alpha=0.3, axis='y')
            ax3.set_ylim(0.5, 1.05)
        
        ax4 = axes[1, 1]
        if valid_idx:
            gains = [(optimized_aucs[i] - default_aucs[i]) * 100 / default_aucs[i] 
                    for i in valid_idx]
            colors_gain = ['#2ecc71' if g > 0 else '#e74c3c' for g in gains]
            bars = ax4.bar(x_valid, gains, color=colors_gain, edgecolor='black')
            for bar, g in zip(bars, gains):
                ax4.text(bar.get_x() + bar.get_width()/2, 
                        bar.get_height() + (0.1 if g > 0 else -0.3),
                        f'{g:.2f}%', ha='center', va='bottom', fontsize=9, fontweight='bold')
            ax4.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
            ax4.set_xticks(x_valid)
            ax4.set_xticklabels(valid_models, rotation=45, ha='right')
            ax4.set_ylabel('AUC Improvement (%)')
            ax4.set_title('Performance Gain from GE Optimization')
            ax4.grid(True, alpha=0.3, axis='y')
        
        plt.tight_layout()
        self._save_and_show('hyperparameter_pruning.png')
    
    def plot_hp_count_comparison(self):
        """Plot pre-prune vs post-prune hyperparameter COUNT comparison."""
        pruning_data = self.results.get('hyperparameter_pruning', {})
        
        if not pruning_data:
            return
        
        model_names = list(pruning_data.keys())
        pre_hp_count = []
        post_hp_count = []
        
        for name in model_names:
            data = pruning_data[name]
            pre_hp_count.append(data['pre_prune']['n_hyperparameters'])
            post_hp_count.append(data['post_prune']['n_selected_hyperparameters'])
        
        fig, ax = plt.subplots(figsize=(12, 6))
        
        x = np.arange(len(model_names))
        width = 0.35
        
        bars1 = ax.bar(x - width/2, pre_hp_count, width, label='Pre-Prune (Full Search Space)',
                      color='#e74c3c', edgecolor='black')
        bars2 = ax.bar(x + width/2, post_hp_count, width, label='Post-Prune (Selected by GE)',
                      color='#2ecc71', edgecolor='black')
        
        for bar in bars1:
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                   f'{int(bar.get_height())}', ha='center', va='bottom', fontsize=11, fontweight='bold')
        for bar in bars2:
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                   f'{int(bar.get_height())}', ha='center', va='bottom', fontsize=11, fontweight='bold')
        
        for i, (pre, post) in enumerate(zip(pre_hp_count, post_hp_count)):
            reduction = pre - post
            if reduction > 0:
                ax.annotate(f'-{reduction}', xy=(i, max(pre, post) + 0.5),
                           ha='center', fontsize=10, color='darkred', fontweight='bold')
        
        ax.set_xticks(x)
        ax.set_xticklabels(model_names, rotation=45, ha='right', fontsize=11)
        ax.set_ylabel('Number of Hyperparameters', fontsize=12)
        ax.set_title('Pre-Prune vs Post-Prune Hyperparameter Count', fontsize=14, fontweight='bold')
        ax.legend(loc='upper right', fontsize=10)
        ax.grid(True, alpha=0.3, axis='y')
        ax.set_ylim(0, max(pre_hp_count) + 2)
        
        plt.tight_layout()
        self._save_and_show('hp_count_comparison.png')
    
    def plot_pruning_comparison_table(self):
        """Create visual tables showing parameter changes for each model."""
        pruning_data = self.results.get('hyperparameter_pruning', {})
        
        if not pruning_data:
            return
        
        for model_name, data in pruning_data.items():
            param_changes = data.get('param_changes', {})
            if not param_changes:
                continue
            
            fig, ax = plt.subplots(figsize=(10, max(3, len(param_changes) * 0.8)))
            ax.axis('off')
            
            table_data = []
            for param, change in param_changes.items():
                default_val = str(change['default'])
                optimized_val = str(change['optimized'])
                changed = 'Yes' if change['changed'] else 'No'
                table_data.append([param, default_val, optimized_val, changed])
            
            metrics = data.get('reduction_metrics', {})
            fitness = data.get('fitness_improvement', {})
            
            table = ax.table(
                cellText=table_data,
                colLabels=['Parameter', 'Pre-Prune (Default)', 'Post-Prune (GE)', 'Changed'],
                cellLoc='center',
                loc='center',
                colWidths=[0.3, 0.25, 0.25, 0.2]
            )
            
            table.auto_set_font_size(False)
            table.set_fontsize(10)
            table.scale(1.2, 1.5)
            
            for i in range(4):
                table[(0, i)].set_facecolor('#3498db')
                table[(0, i)].set_text_props(color='white', fontweight='bold')
            
            for row_idx, row in enumerate(table_data):
                if row[3] == 'Yes':
                    table[(row_idx + 1, 2)].set_facecolor('#d5f5e3')
                    table[(row_idx + 1, 3)].set_facecolor('#d5f5e3')
            
            title = f'{model_name}: Hyperparameter Changes by GE'
            summary = f"Combinations: {metrics.get('original_combinations', 'N/A')} -> {metrics.get('effective_combinations', 'N/A')} "
            summary += f"(Reduction: {metrics.get('space_reduction_percent', 0):.1f}%)"
            
            if fitness.get('relative_gain_percent'):
                summary += f" | AUC Gain: {fitness['relative_gain_percent']:.2f}%"
            
            ax.set_title(title, fontsize=12, fontweight='bold', pad=20)
            fig.text(0.5, 0.02, summary, ha='center', fontsize=10, style='italic')
            
            plt.tight_layout()
            self._save_and_show(f'pruning_table_{model_name}.png')


class GEExperiment:
    
    def __init__(self, config=None):
        self.config = config or Config()
        self.scaler = StandardScaler()
        self.label_encoder = LabelEncoder()
        self.feature_names = None
        
        self.X_train = None
        self.X_val = None
        self.X_test = None
        self.y_train = None
        self.y_val = None
        self.y_test = None
        
        self.results = {}
        self.pruning_analyzer = HyperparameterPruningAnalyzer()
        
        np.random.seed(self.config.RANDOM_SEED)
        random.seed(self.config.RANDOM_SEED)
        if TABNET_AVAILABLE:
            torch.manual_seed(self.config.RANDOM_SEED)
    
    def load_data(self):
        print("\n" + "=" * 70)
        print("LOADING DATA")
        print("=" * 70)
        
        df = pd.read_csv(self.config.DATA_CSV)
        print(f"  Loaded: {df.shape[0]} samples, {df.shape[1]} columns")
        
        unnamed_cols = [col for col in df.columns if 'Unnamed' in str(col)]
        if unnamed_cols:
            df = df.drop(columns=unnamed_cols)
        
        id_cols = [col for col in df.columns if str(col).lower() == 'id']
        if id_cols:
            df = df.drop(columns=id_cols)
        
        target_col = self.config.TARGET_COLUMN
        if target_col not in df.columns:
            target_col = df.columns[-1]
        
        print(f"  Target column: {target_col}")
        
        X = df.drop(columns=[target_col])
        y = df[target_col]
        
        self.feature_names = X.columns.tolist()
        
        for col in X.columns:
            if X[col].dtype == 'object':
                X[col] = LabelEncoder().fit_transform(X[col].astype(str))
        
        if X.isnull().any().any():
            imputer = SimpleImputer(strategy='median')
            X = pd.DataFrame(imputer.fit_transform(X), columns=X.columns)
        
        if y.dtype == 'object':
            y = self.label_encoder.fit_transform(y)
        else:
            y = y.values
        
        X = X.values.astype(np.float32)
        X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
        
        print(f"  Features: {X.shape[1]}")
        print(f"  Class distribution: {np.bincount(y)}")
        
        return X, y
    
    def prepare_splits(self, X, y):
        print("\n" + "=" * 70)
        print("DATA SPLITS (80/10/10)")
        print("=" * 70)
        
        X_temp, self.X_test, y_temp, self.y_test = train_test_split(
            X, y, test_size=self.config.TEST_SIZE, stratify=y, 
            random_state=self.config.RANDOM_SEED
        )
        
        self.X_train, self.X_val, self.y_train, self.y_val = train_test_split(
            X_temp, y_temp, test_size=self.config.VAL_SIZE, stratify=y_temp,
            random_state=self.config.RANDOM_SEED
        )
        
        total = len(y)
        print(f"  Train: {len(self.y_train)} samples ({100*len(self.y_train)/total:.1f}%)")
        print(f"  Val:   {len(self.y_val)} samples ({100*len(self.y_val)/total:.1f}%)")
        print(f"  Test:  {len(self.y_test)} samples ({100*len(self.y_test)/total:.1f}%)")
        
        self.X_train = self.scaler.fit_transform(self.X_train)
        self.X_val = self.scaler.transform(self.X_val)
        self.X_test = self.scaler.transform(self.X_test)
    
    def _process_convergence_data(self, all_convergence):
        if not all_convergence:
            return {
                'best_mean': [], 'best_std': [],
                'avg_mean': [], 'avg_std': [],
                'mean': [], 'std': []
            }
        
        best_histories = [c['best_fitness'] for c in all_convergence]
        avg_histories = [c['avg_fitness'] for c in all_convergence]
        
        max_gen = max(len(h) for h in best_histories)
        
        padded_best = []
        padded_avg = []
        for best, avg in zip(best_histories, avg_histories):
            if len(best) < max_gen:
                best = best + [best[-1]] * (max_gen - len(best))
            if len(avg) < max_gen:
                avg = avg + [avg[-1]] * (max_gen - len(avg))
            padded_best.append(best)
            padded_avg.append(avg)
        
        best_array = np.array(padded_best)
        avg_array = np.array(padded_avg)
        
        return {
            'best_mean': best_array.mean(axis=0).tolist(),
            'best_std': best_array.std(axis=0).tolist(),
            'avg_mean': avg_array.mean(axis=0).tolist(),
            'avg_std': avg_array.std(axis=0).tolist(),
            'mean': best_array.mean(axis=0).tolist(),
            'std': best_array.std(axis=0).tolist()
        }
    
    def run(self):
        experiment_start = time.time()
        
        print("\n" + "=" * 70)
        print("GE-BASED HYPERPARAMETER OPTIMIZATION - MODIFIED VERSION")
        print("=" * 70)
        print(f"  Population Size:      {self.config.GE_POP_SIZE}")
        print(f"  Generations:          {self.config.GE_GENERATIONS}  <-- INCREASED")
        print(f"  Independent Runs:     {self.config.N_INDEPENDENT_RUNS}")
        print(f"  Parallel Jobs:        {self.config.N_PARALLEL_JOBS}")
        print(f"  SMOTE Enabled:        {self.config.USE_SMOTE}")
        print(f"  TabNet Max Epochs:    {self.config.TABNET_MAX_EPOCHS}  <-- INCREASED")
        print(f"  TabNet Patience:      {self.config.TABNET_PATIENCE}  <-- INCREASED")
        print(f"  TabNet Batch Size:    {self.config.TABNET_BATCH_SIZE}")
        print(f"  TabNet Gamma:         {self.config.TABNET_GAMMA}")
        print(f"  Show Plots:           {self.config.SHOW_PLOTS}  <-- ENABLED")
        print(f"  Validation:           Holdout (no CV)")
        print(f"  Data Split:           80% train / 10% val / 10% test")
        print("=" * 70)
        
        X, y = self.load_data()
        self.prepare_splits(X, y)
        
        results = {
            'config': {
                'n_independent_runs': self.config.N_INDEPENDENT_RUNS,
                'ge_pop_size': self.config.GE_POP_SIZE,
                'ge_generations': self.config.GE_GENERATIONS,
                'tabnet_max_epochs': self.config.TABNET_MAX_EPOCHS,
                'tabnet_patience': self.config.TABNET_PATIENCE,
                'tabnet_batch_size': self.config.TABNET_BATCH_SIZE,
                'validation': 'holdout',
                'timestamp': datetime.now().isoformat()
            },
            'models': {},
            'statistical_tests': {},
            'hyperparameter_pruning': {}
        }
        
        all_model_fitness = {}
        
        if TABNET_AVAILABLE:
            print("\n" + "=" * 70)
            print(f"TabNet - {self.config.N_INDEPENDENT_RUNS} Independent Runs (Parallel)")
            print(f"  -> {self.config.GE_GENERATIONS} generations, {self.config.TABNET_MAX_EPOCHS} max epochs")
            print("=" * 70)
            
            tabnet_opt = TabNetOptimizer(
                self.X_train, self.y_train, self.X_val, self.y_val,
                self.X_test, self.y_test, self.config, self.feature_names
            )
            
            print("\n  Evaluating default hyperparameters...")
            default_fitness = tabnet_opt.evaluate_default_params()
            print(f"    Default AUC: {default_fitness:.4f}")
            
            # Run parallel optimization
            run_start = time.time()
            parallel_results = tabnet_opt.optimize_parallel(n_jobs=self.config.N_PARALLEL_JOBS)
            parallel_time = time.time() - run_start
            
            # Process parallel results
            all_runs = []
            all_convergence = []
            all_configs = []
            
            for result in parallel_results:
                if 'error' not in result:
                    all_runs.append(result)
                    all_convergence.append(result['history'])
                    all_configs.append(result['best_config'])
                else:
                    all_runs.append(result)
            
            print(f"\n  Completed {len(all_configs)} runs in {parallel_time:.1f}s")
            
            fitness_values = np.array([r['best_fitness'] for r in all_runs if 'error' not in r])
            all_model_fitness['TabNet'] = fitness_values
            
            valid_runs = [r for r in all_runs if 'error' not in r]
            if valid_runs:
                best_run_idx = np.argmax([r['best_fitness'] for r in valid_runs])
                best_run = valid_runs[best_run_idx]
                
                print(f"  Best validation AUC: {best_run['best_fitness']:.4f}")
                print("\n  Training final model...")
                _, final_eval = tabnet_opt.train_final_model(best_run['best_config'], 
                                                             seed=self.config.RANDOM_SEED)
                
                convergence_summary = self._process_convergence_data(all_convergence)
                
                pruning_result = self.pruning_analyzer.analyze_model(
                    'TabNet', all_configs, best_run['best_config'],
                    default_fitness=default_fitness, optimized_fitness=best_run['best_fitness']
                )
                results['hyperparameter_pruning']['TabNet'] = pruning_result
                
                results['models']['TabNet'] = {
                    'all_runs': all_runs,
                    'runs_summary': {
                        'best_fitness': compute_statistics(fitness_values),
                        'n_successful_runs': len(fitness_values)
                    },
                    'convergence_summary': convergence_summary,
                    'best_run': best_run,
                    'best_run_train_metrics': final_eval['train_metrics'],
                    'best_run_test_metrics': final_eval['test_metrics'],
                    'feature_importance': final_eval['feature_importance'],
                    'default_fitness': default_fitness
                }
                
                print(f"\n  TabNet: Test AUC={final_eval['test_metrics']['roc_auc']:.4f}")
        
        baseline_models = ['RandomForest', 'LogisticRegression', 'SVM', 'GradientBoosting']
        if XGBOOST_AVAILABLE:
            baseline_models.append('XGBoost')
        
        baseline_opt = BaselineOptimizer(
            self.X_train, self.y_train, self.X_val, self.y_val,
            self.X_test, self.y_test, self.config, self.feature_names
        )
        
        for model_name in baseline_models:
            print("\n" + "=" * 70)
            print(f"{model_name} - {self.config.N_INDEPENDENT_RUNS} Independent Runs (Parallel)")
            print(f"  -> {self.config.GE_GENERATIONS_BASELINE} generations")
            print("=" * 70)
            
            print("\n  Evaluating default hyperparameters...")
            default_fitness = baseline_opt.evaluate_default_params(model_name)
            print(f"    Default AUC: {default_fitness:.4f}")
            
            # Run parallel optimization
            run_start = time.time()
            parallel_results = baseline_opt.optimize_parallel(model_name, n_jobs=self.config.N_PARALLEL_JOBS)
            parallel_time = time.time() - run_start
            
            # Process parallel results
            all_runs = []
            all_convergence = []
            all_configs = []
            
            for result in parallel_results:
                if 'error' not in result:
                    all_runs.append(result)
                    all_convergence.append(result['history'])
                    all_configs.append(result['best_config'])
                else:
                    all_runs.append(result)
            
            print(f"\n  Completed {len(all_configs)} runs in {parallel_time:.1f}s")
            
            fitness_values = np.array([r['best_fitness'] for r in all_runs if 'error' not in r])
            all_model_fitness[model_name] = fitness_values
            
            valid_runs = [r for r in all_runs if 'error' not in r]
            if valid_runs:
                best_run_idx = np.argmax([r['best_fitness'] for r in valid_runs])
                best_run = valid_runs[best_run_idx]
                
                print(f"  Best validation AUC: {best_run['best_fitness']:.4f}")
                print(f"\n  Training final {model_name} model...")
                _, final_eval = baseline_opt.train_final_model(model_name, best_run['best_config'],
                                                               seed=self.config.RANDOM_SEED)
                
                pruning_result = self.pruning_analyzer.analyze_model(
                    model_name, all_configs, best_run['best_config'],
                    default_fitness=default_fitness, optimized_fitness=best_run['best_fitness']
                )
                results['hyperparameter_pruning'][model_name] = pruning_result
            else:
                best_run = None
                final_eval = {'train_metrics': {}, 'test_metrics': {}, 'feature_importance': {}}
            
            convergence_summary = self._process_convergence_data(all_convergence)
            
            results['models'][model_name] = {
                'all_runs': all_runs,
                'runs_summary': {
                    'best_fitness': compute_statistics(fitness_values) if len(fitness_values) > 0 else {},
                    'n_successful_runs': len(fitness_values)
                },
                'convergence_summary': convergence_summary,
                'best_run': best_run,
                'best_run_train_metrics': final_eval.get('train_metrics', {}),
                'best_run_test_metrics': final_eval.get('test_metrics', {}),
                'feature_importance': final_eval.get('feature_importance', {}),
                'default_fitness': default_fitness
            }
            
            if final_eval.get('test_metrics'):
                print(f"  {model_name}: Test AUC={final_eval['test_metrics'].get('roc_auc', 0):.4f}")
        
        print("\n" + "=" * 70)
        print("HYPERPARAMETER PRUNING SUMMARY")
        print("=" * 70)
        
        pruning_summary = self.pruning_analyzer.get_summary_table()
        print(pruning_summary.to_string(index=False))
        
        print("\n" + "=" * 70)
        print("STATISTICAL TESTS")
        print("=" * 70)
        
        model_names = list(all_model_fitness.keys())
        for i in range(len(model_names)):
            for j in range(i+1, len(model_names)):
                m1, m2 = model_names[i], model_names[j]
                scores1, scores2 = all_model_fitness[m1], all_model_fitness[m2]
                
                if len(scores1) == len(scores2) and len(scores1) > 0:
                    w_stat, w_p = wilcoxon_test(scores1, scores2)
                    
                    results['statistical_tests'][f'{m1}_vs_{m2}'] = {
                        'wilcoxon_statistic': w_stat,
                        'wilcoxon_p': w_p,
                        'mean_diff': float(scores1.mean() - scores2.mean())
                    }
                    
                    sig = "***" if w_p < 0.001 else "**" if w_p < 0.01 else "*" if w_p < 0.05 else ""
                    print(f"  {m1} vs {m2}: p={w_p:.4f} {sig}")
        
        total_time = time.time() - experiment_start
        results['total_runtime_seconds'] = total_time
        
        print("\n" + "=" * 70)
        print("TEST SET RESULTS SUMMARY")
        print("=" * 70)
        print(f"{'Model':<20} {'AUC':>8} {'Acc':>8} {'Sens':>8} {'Spec':>8} {'F1':>8}")
        print("-" * 70)
        
        for name, data in results['models'].items():
            if 'best_run_test_metrics' in data and data['best_run_test_metrics']:
                m = data['best_run_test_metrics']
                print(f"{name:<20} {m.get('roc_auc', 0):>8.4f} {m.get('accuracy', 0):>8.4f} "
                      f"{m.get('sensitivity', 0):>8.4f} {m.get('specificity', 0):>8.4f} "
                      f"{m.get('f1', 0):>8.4f}")
        
        print(f"\nTotal Runtime: {total_time:.1f}s ({total_time/60:.1f} min)")
        
        os.makedirs(self.config.RESULTS_DIR, exist_ok=True)
        results_file = os.path.join(
            self.config.RESULTS_DIR,
            f"ge_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        )
        
        def convert(obj):
            if isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            return obj
        
        with open(results_file, 'w') as f:
            json.dump(results, f, indent=2, default=convert)
        print(f"\nResults saved: {results_file}")
        
        self.results = results
        
        plotter = ResultsPlotter(results, self.config)
        plotter.plot_all()
        
        return results


def run_experiment():
    config = Config()
    experiment = GEExperiment(config)
    return experiment.run()


if __name__ == "__main__":
    print("\n" + "=" * 70)
    print("GE-based Hyperparameter Optimization - MODIFIED VERSION")
    print("=" * 70)
    print("\nKEY MODIFICATIONS:")
    print("  [1] Generations:        10 -> 30  (3x more evolution)")
    print("  [2] TabNet Max Epochs:  25 -> 100 (4x more training)")
    print("  [3] TabNet Patience:    6  -> 15  (better convergence)")
    print("  [4] TabNet Batch Size:  256 -> 128 (finer gradients)")
    print("  [5] TabNet Gamma:       1.3 -> 1.5 (feature reusage)")
    print("  [6] Show Plots:         False -> True (interactive)")
    print("  [7] Population Size:    15 -> 20  (better exploration)")
    print("\nConfiguration Summary:")
    print(f"  Population Size:      {Config.GE_POP_SIZE}")
    print(f"  Generations:          {Config.GE_GENERATIONS}")
    print(f"  Independent Runs:     {Config.N_INDEPENDENT_RUNS}")
    print(f"  Parallel Jobs:        {Config.N_PARALLEL_JOBS} (-1 = all cores)")
    print(f"  Validation:           Holdout (no CV)")
    print(f"  SMOTE:                {Config.USE_SMOTE}")
    print(f"  TabNet Max Epochs:    {Config.TABNET_MAX_EPOCHS}")
    print(f"  TabNet Patience:      {Config.TABNET_PATIENCE}")
    print(f"  TabNet Batch Size:    {Config.TABNET_BATCH_SIZE}")
    print(f"  TabNet Gamma:         {Config.TABNET_GAMMA}")
    print(f"  Show Plots:           {Config.SHOW_PLOTS}")
    print(f"  Models:               6 (TabNet + 5 baselines)")
    print("\nExpected improvements for TabNet:")
    print("  - More generations allow better HP exploration")
    print("  - More epochs allow model to fully converge")
    print("  - Higher patience prevents premature stopping")
    print("  - Smaller batch size can improve generalization")
    print("  - entmax mask provides sharper feature selection")
    print("\nEstimated Runtime: 30-60 minutes (with parallel execution)")
    print("=" * 70)
    
    results = run_experiment()