In [2]:
import gc
import json
import logging
import os
import sys
import warnings
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any, Union
from dataclasses import dataclass, field, asdict
from enum import Enum
import pickle

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import (
    accuracy_score, confusion_matrix, roc_auc_score, 
    classification_report, roc_curve
)
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import PCA

# DEAP imports for genetic algorithm
from deap import base, creator, tools, algorithms
import random

# Matplotlib imports for window plotting
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.gridspec import GridSpec

# Suppress warnings
warnings.filterwarnings('ignore')

# Conditional import for TabNet
try:
    from pytorch_tabnet.tab_model import TabNetClassifier
    _HAS_TABNET = True
except ImportError:
    TabNetClassifier = None
    _HAS_TABNET = False

# =============================================================================
# ENUMS AND CONSTANTS
# =============================================================================

plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

DEFAULT_DROP_COLUMNS = ["id", "ID", "Id", "patient_id", "Unnamed: 0"]
TARGET_CANDIDATES = ["diagnosis", "Diagnosis", "target", "Target", "class", "Class"]

# =============================================================================
# GENETIC ALGORITHM HYPERPARAMETER BOUNDS
# =============================================================================

HYPERPARAMETER_BOUNDS = {
    'n_d': (8, 64),           # Network width for decision layer
    'n_a': (8, 64),           # Network width for attention layer
    'n_steps': (3, 10),       # Number of decision steps
    'gamma': (1.0, 2.0),      # Relaxation parameter
    'lambda_sparse': (0.0001, 0.01),  # Sparsity regularization
    'lr': (0.005, 0.05),      # Learning rate
    'momentum': (0.01, 0.3),  # Momentum for batch normalization
    'batch_size': (128, 1024), # Batch size (must be power of 2 or multiple of 64)
    'n_independent': (1, 4),  # Number of independent GLU layers
    'n_shared': (1, 4),       # Number of shared GLU layers
}

# =============================================================================
# CONFIGURATION CLASSES
# =============================================================================

@dataclass
class TabNetConfig:
    """Configuration for TabNet model and training with validation."""
    
    # Model architecture
    n_d: int = 16
    n_a: int = 16
    n_steps: int = 5
    gamma: float = 1.5
    lambda_sparse: float = 0.001
    lr: float = 0.02
    momentum: float = 0.02
    n_independent: int = 2
    n_shared: int = 2
    
    # Training parameters
    max_epochs: int = 50
    patience: int = 10
    batch_size: int = 512
    virtual_batch_size: int = 64
    
    # Cross-validation parameters
    use_cross_validation: bool = True
    n_folds: int = 5
    cv_random_state: int = 42
    
    # Genetic Algorithm parameters
    use_ga_optimization: bool = True
    ga_population_size: int = 20
    ga_generations: int = 15
    ga_mutation_rate: float = 0.3
    ga_crossover_rate: float = 0.7
    ga_tournament_size: int = 3
    
    # Visualization
    display_plots: bool = True
    save_plots: bool = True
    plot_dpi: int = 100
    block_on_plot: bool = False
    
    # Environment
    random_seed: int = 42
    deterministic: bool = True
    
    # Paths
    dataset_path: str = "data.csv"
    results_dir: str = "tabnet_results"
    
    def __post_init__(self):
        """Validate configuration parameters."""
        self._validate()
    
    def _validate(self):
        """Validate all configuration parameters."""
        if self.n_d <= 0 or self.n_a <= 0:
            raise ValueError(f"n_d and n_a must be positive (got n_d={self.n_d}, n_a={self.n_a})")
        
        if self.n_steps <= 0:
            raise ValueError(f"n_steps must be positive (got {self.n_steps})")
        
        if not 0 <= self.gamma <= 10:
            raise ValueError(f"gamma should be in [0, 10] (got {self.gamma})")
        
        if not 0 <= self.lambda_sparse <= 1:
            raise ValueError(f"lambda_sparse should be in [0, 1] (got {self.lambda_sparse})")
        
        if self.lr <= 0:
            raise ValueError(f"lr must be positive (got {self.lr})")
        
        if self.max_epochs <= 0:
            raise ValueError(f"max_epochs must be positive (got {self.max_epochs})")
        
        if self.patience <= 0:
            raise ValueError(f"patience must be positive (got {self.patience})")
        
        if self.batch_size <= 0 or self.virtual_batch_size <= 0:
            raise ValueError("batch_size and virtual_batch_size must be positive")
        
        if self.virtual_batch_size > self.batch_size:
            raise ValueError("virtual_batch_size must be <= batch_size")
        
        if self.use_cross_validation and self.n_folds < 2:
            raise ValueError(f"n_folds must be >= 2 (got {self.n_folds})")
        
        if not Path(self.dataset_path).exists():
            raise FileNotFoundError(f"Dataset not found: {self.dataset_path}")

@dataclass
class DataConfig:
    """Configuration for data loading and preprocessing with validation."""
    
    target_column: Optional[str] = None
    drop_columns: List[str] = field(default_factory=lambda: DEFAULT_DROP_COLUMNS.copy())
    test_size: float = 0.2
    validation_size: float = 0.1
    scale_features: bool = True
    handle_missing: str = "mean"
    
    def __post_init__(self):
        self._validate()
    
    def _validate(self):
        if not 0 < self.test_size < 1:
            raise ValueError(f"test_size must be in (0, 1) (got {self.test_size})")
        
        if not 0 < self.validation_size < 1:
            raise ValueError(f"validation_size must be in (0, 1) (got {self.validation_size})")
        
        if self.test_size + self.validation_size >= 1:
            raise ValueError("test_size + validation_size must be < 1")
        
        if self.handle_missing not in ["mean", "median", "drop"]:
            raise ValueError(f"handle_missing must be 'mean', 'median', or 'drop' (got {self.handle_missing})")

# =============================================================================
# CROSS-VALIDATION
# =============================================================================

@dataclass
class CVResults:
    """Stores cross-validation results."""
    fold_metrics: List[Dict[str, float]] = field(default_factory=list)
    fold_histories: List[Dict] = field(default_factory=list)
    best_fold: int = -1
    mean_metrics: Dict[str, float] = field(default_factory=dict)
    std_metrics: Dict[str, float] = field(default_factory=dict)
    
    def add_fold_result(self, fold_idx: int, metrics: Dict, history: Dict):
        """Add results from a single fold."""
        self.fold_metrics.append(metrics)
        self.fold_histories.append(history)
    
    def compute_statistics(self):
        """Compute mean and std across folds."""
        if not self.fold_metrics:
            return
        
        metric_keys = ['accuracy', 'auc', 'sensitivity', 'specificity', 'precision', 'f1', 'npv']
        
        for key in metric_keys:
            values = [fold[key] for fold in self.fold_metrics if key in fold]
            if values:
                self.mean_metrics[key] = np.mean(values)
                self.std_metrics[key] = np.std(values)
        
        if 'auc' in self.mean_metrics:
            auc_scores = [fold.get('auc', 0) for fold in self.fold_metrics]
            self.best_fold = int(np.argmax(auc_scores))
    
    def get_overfitting_analysis(self) -> Dict[str, Any]:
        """Analyze overfitting by comparing training and validation performance."""
        analysis = {
            'variance': {},
            'consistency': {},
            'overfitting_indicators': []
        }
        
        for metric, mean_val in self.mean_metrics.items():
            if metric in self.std_metrics and mean_val > 0:
                cv = self.std_metrics[metric] / mean_val
                analysis['variance'][metric] = cv
                
                if cv > 0.1:
                    analysis['overfitting_indicators'].append(
                        f"High variance in {metric}: CV={cv:.3f}"
                    )
        
        for metric in ['accuracy', 'auc', 'f1']:
            if metric in self.std_metrics:
                std = self.std_metrics[metric]
                analysis['consistency'][metric] = 'Good' if std < 0.05 else 'Moderate' if std < 0.10 else 'Poor'
        
        return analysis

# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

class Logger:
    """Enhanced logger with context management."""
    
    def __init__(self, name: str, level: int = logging.INFO):
        self.logger = logging.getLogger(name)
        if not self.logger.handlers:
            handler = logging.StreamHandler(sys.stdout)
            formatter = logging.Formatter(
                "%(asctime)s [%(levelname)s] %(name)s: %(message)s",
                datefmt="%Y-%m-%d %H:%M:%S"
            )
            handler.setFormatter(formatter)
            self.logger.addHandler(handler)
            self.logger.setLevel(level)
    
    @contextmanager
    def log_section(self, title: str):
        """Context manager for logging sections."""
        self.logger.info("=" * 70)
        self.logger.info(f" {title}")
        self.logger.info("=" * 70)
        try:
            yield
        finally:
            self.logger.info("")
    
    def info(self, msg: str):
        self.logger.info(msg)
    
    def warning(self, msg: str):
        self.logger.warning(msg)
    
    def error(self, msg: str):
        self.logger.error(msg)
    
    def debug(self, msg: str):
        self.logger.debug(msg)

@contextmanager
def timer(logger: Logger, operation: str):
    """Context manager for timing operations."""
    start_time = datetime.now()
    logger.info(f"Starting: {operation}")
    try:
        yield
    finally:
        elapsed = (datetime.now() - start_time).total_seconds()
        logger.info(f"âœ“ Completed: {operation} ({elapsed:.2f}s)")

def set_random_seeds(seed: int):
    """Set random seeds for reproducibility."""
    import random
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def clear_memory():
    """Aggressively clear memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        try:
            torch.cuda.synchronize()
        except Exception:
            pass

# =============================================================================
# DATA PROCESSING CLASS
# =============================================================================

class DataProcessor:
    """Handles all data loading and preprocessing operations."""
    
    def __init__(self, config: DataConfig, logger: Logger):
        self.config = config
        self.logger = logger
        self.scaler = None
        self.feature_names = []
    
    def load_dataset(self, filepath: Path) -> pd.DataFrame:
        """Load dataset from file with error handling."""
        if not filepath.exists():
            raise FileNotFoundError(f"Dataset not found: {filepath}")
        
        try:
            with timer(self.logger, f"Loading dataset from {filepath}"):
                df = pd.read_csv(filepath)
                self.logger.info(f"   Shape: {df.shape}")
                return df
        except Exception as e:
            raise RuntimeError(f"Failed to load dataset: {e}")
    
    def identify_target_column(self, df: pd.DataFrame) -> str:
        """Identify the target column in the dataset."""
        if self.config.target_column is not None:
            if self.config.target_column in df.columns:
                return self.config.target_column
            else:
                raise ValueError(f"Specified target column '{self.config.target_column}' not found")
        
        for candidate in TARGET_CANDIDATES:
            if candidate in df.columns:
                self.logger.info(f"Auto-detected target column: '{candidate}'")
                return candidate
        
        target_col = df.columns[-1]
        self.logger.warning(f"Using last column as target: '{target_col}'")
        return target_col
    
    def preprocess_target(self, target_series: pd.Series) -> np.ndarray:
        """Convert target to binary format with robust handling."""
        target_series = target_series.copy()
        
        if target_series.dtype == 'object' or target_series.dtype.name == 'category':
            return self._preprocess_categorical_target(target_series)
        
        return self._preprocess_numeric_target(target_series)
    
    def _preprocess_categorical_target(self, target_series: pd.Series) -> np.ndarray:
        """Preprocess categorical target values."""
        normalized = target_series.astype(str).str.strip().str.lower()
        
        mapping = {
            'm': 1, 'malignant': 1, '1': 1, '4': 1, 'positive': 1, 'yes': 1,
            'b': 0, 'benign': 0, '0': 0, '2': 0, 'negative': 0, 'no': 0
        }
        
        mapped = normalized.map(mapping)
        
        if not mapped.isna().any():
            return mapped.astype(int).values
        
        numeric_vals = pd.to_numeric(normalized, errors='coerce')
        if numeric_vals.notna().all():
            return self._preprocess_numeric_target(numeric_vals)
        
        self.logger.warning("Using LabelEncoder for target conversion")
        encoder = LabelEncoder()
        encoded = encoder.fit_transform(normalized)
        
        if len(encoder.classes_) != 2:
            raise ValueError(f"Expected binary target, got {len(encoder.classes_)} classes")
        
        return encoded.astype(int)
    
    def _preprocess_numeric_target(self, target_series: pd.Series) -> np.ndarray:
        """Preprocess numeric target values."""
        numeric_vals = pd.to_numeric(target_series, errors='coerce')
        
        if numeric_vals.isna().any():
            raise ValueError("Target contains non-numeric values that cannot be converted")
        
        unique_vals = set(numeric_vals.dropna().unique())
        
        if unique_vals.issubset({0, 1}):
            return numeric_vals.astype(int).values
        elif unique_vals.issubset({2, 4}):
            return (numeric_vals == 4).astype(int).values
        elif len(unique_vals) == 2:
            sorted_vals = sorted(unique_vals)
            return (numeric_vals == sorted_vals[1]).astype(int).values
        else:
            raise ValueError(f"Cannot interpret target values: {unique_vals}")
    
    def prepare_features(self, df: pd.DataFrame, target_col: str) -> pd.DataFrame:
        """Prepare feature matrix by dropping unnecessary columns."""
        drop_cols = [col for col in self.config.drop_columns 
                    if col in df.columns and col != target_col]
        
        if target_col not in df.columns:
            raise ValueError(f"Target column '{target_col}' not found in dataframe")
        
        X_df = df.drop(columns=[target_col] + drop_cols, errors='ignore')
        
        for col in X_df.columns:
            X_df[col] = pd.to_numeric(X_df[col], errors='coerce')
        
        X_df = X_df.dropna(axis=1, how='all')
        
        if X_df.shape[1] == 0:
            raise ValueError("No valid features remaining after preprocessing")
        
        self.feature_names = list(X_df.columns)
        return X_df
    
    def handle_missing_values(self, X_df: pd.DataFrame) -> pd.DataFrame:
        """Handle missing values according to configuration."""
        missing_cols = X_df.columns[X_df.isna().any()].tolist()
        
        if not missing_cols:
            return X_df
        
        self.logger.warning(f"âš  Missing values detected in {len(missing_cols)} columns")
        
        if self.config.handle_missing == "mean":
            X_df = X_df.fillna(X_df.mean())
        elif self.config.handle_missing == "median":
            X_df = X_df.fillna(X_df.median())
        elif self.config.handle_missing == "drop":
            X_df = X_df.dropna()
        
        return X_df
    
    def split_data(self, X: np.ndarray, y: np.ndarray, 
                   random_state: int) -> Tuple[np.ndarray, ...]:
        """Split data into train, validation, and test sets."""
        X_temp, X_test, y_temp, y_test = train_test_split(
            X, y, test_size=self.config.test_size, 
            stratify=y, random_state=random_state
        )
        
        val_size = self.config.validation_size / (1 - self.config.test_size)
        X_train, X_val, y_train, y_val = train_test_split(
            X_temp, y_temp, test_size=val_size,
            stratify=y_temp, random_state=random_state
        )
        
        self.logger.info(f"   Train set: {X_train.shape[0]} samples")
        self.logger.info(f"   Validation set: {X_val.shape[0]} samples")
        self.logger.info(f"   Test set: {X_test.shape[0]} samples")
        
        return X_train, X_val, X_test, y_train, y_val, y_test
    
    def scale_features(self, X_train: np.ndarray, X_val: np.ndarray, 
                      X_test: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Scale features using StandardScaler."""
        if not self.config.scale_features:
            return X_train, X_val, X_test
        
        with timer(self.logger, "Feature scaling"):
            self.scaler = StandardScaler()
            X_train = self.scaler.fit_transform(X_train).astype(np.float32)
            X_val = self.scaler.transform(X_val).astype(np.float32)
            X_test = self.scaler.transform(X_test).astype(np.float32)
        
        return X_train, X_val, X_test
    
    def load_and_preprocess(self, filepath: Path, random_state: int) -> Tuple[np.ndarray, ...]:
        """Complete data loading and preprocessing pipeline."""
        df = self.load_dataset(filepath)
        target_col = self.identify_target_column(df)
        X_df = self.prepare_features(df, target_col)
        X_df = self.handle_missing_values(X_df)
        y = self.preprocess_target(df[target_col])
        
        if X_df.shape[0] != len(y):
            raise ValueError(f"Shape mismatch: X has {X_df.shape[0]} samples but y has {len(y)}")
        
        benign_count = np.sum(y == 0)
        malignant_count = np.sum(y == 1)
        self.logger.info(f"   Class distribution: Benign={benign_count} ({benign_count/len(y)*100:.1f}%), "
                        f"Malignant={malignant_count} ({malignant_count/len(y)*100:.1f}%)")
        self.logger.info(f"   Features: {len(self.feature_names)}")
        
        X = X_df.values.astype(np.float32)
        X_train, X_val, X_test, y_train, y_val, y_test = self.split_data(X, y, random_state)
        X_train, X_val, X_test = self.scale_features(X_train, X_val, X_test)
        
        clear_memory()
        return X_train, X_val, X_test, y_train, y_val, y_test

# =============================================================================
# METRICS COMPUTATION 
# =============================================================================

class MetricsComputer:
    """Computes and formats model evaluation metrics."""
    
    @staticmethod
    def compute_all_metrics(y_true: np.ndarray, y_pred: np.ndarray, 
                           y_proba: np.ndarray) -> Dict[str, Any]:
        """Compute comprehensive evaluation metrics."""
        accuracy = accuracy_score(y_true, y_pred)
        auc = roc_auc_score(y_true, y_proba)
        cm = confusion_matrix(y_true, y_pred)
        
        tn, fp, fn, tp = cm.ravel()
        
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
        f1 = 2 * (precision * sensitivity) / (precision + sensitivity) if (precision + sensitivity) > 0 else 0.0
        
        fpr, tpr, thresholds = roc_curve(y_true, y_proba)
        
        return {
            "accuracy": float(accuracy),
            "auc": float(auc),
            "sensitivity": float(sensitivity),
            "specificity": float(specificity),
            "precision": float(precision),
            "npv": float(npv),
            "f1": float(f1),
            "confusion_matrix": cm,
            "roc_curve": (fpr, tpr),
            "tn": int(tn), 
            "fp": int(fp), 
            "fn": int(fn), 
            "tp": int(tp)
        }
    
    @staticmethod
    def format_metrics_summary(metrics: Dict[str, Any], prefix: str = "") -> str:
        """Format metrics as a readable summary."""
        lines = [
            f"\n {prefix}MODEL PERFORMANCE METRICS:",
            f"   Accuracy:    {metrics['accuracy']:.4f}",
            f"   AUC-ROC:     {metrics['auc']:.4f}",
            f"   Sensitivity: {metrics['sensitivity']:.4f}",
            f"   Specificity: {metrics['specificity']:.4f}",
            f"   Precision:   {metrics['precision']:.4f}",
            f"   NPV:         {metrics['npv']:.4f}",
            f"   F1 Score:    {metrics['f1']:.4f}",
            "\n CONFUSION MATRIX:",
            f"   TN: {metrics['tn']:4d}  FP: {metrics['fp']:4d}",
            f"   FN: {metrics['fn']:4d}  TP: {metrics['tp']:4d}"
        ]
        return "\n".join(lines)

# =============================================================================
# GENETIC ALGORITHM FOR HYPERPARAMETER OPTIMIZATION
# =============================================================================

class GeneticOptimizer:
    """Genetic algorithm optimizer for TabNet hyperparameters using DEAP."""
    
    def __init__(self, config: TabNetConfig, data_config: DataConfig, 
                 X_train: np.ndarray, X_val: np.ndarray,
                 y_train: np.ndarray, y_val: np.ndarray, logger: Logger):
        self.config = config
        self.data_config = data_config
        self.X_train = X_train
        self.X_val = X_val
        self.y_train = y_train
        self.y_val = y_val
        self.logger = logger
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Evolution tracking
        self.generation_stats = []
        self.best_individuals = []
        self.evaluation_cache = {}
        
        # Setup DEAP
        self._setup_deap()
    
    def _setup_deap(self):
        """Setup DEAP framework for genetic algorithm."""
        # Create fitness and individual classes
        if hasattr(creator, "FitnessMax"):
            del creator.FitnessMax
        if hasattr(creator, "Individual"):
            del creator.Individual
            
        creator.create("FitnessMax", base.Fitness, weights=(1.0,))  # Maximize fitness
        creator.create("Individual", list, fitness=creator.FitnessMax)
        
        # Create toolbox
        self.toolbox = base.Toolbox()
        
        # Register attribute generators
        self.toolbox.register("attr_n_d", random.randint, *HYPERPARAMETER_BOUNDS['n_d'])
        self.toolbox.register("attr_n_a", random.randint, *HYPERPARAMETER_BOUNDS['n_a'])
        self.toolbox.register("attr_n_steps", random.randint, *HYPERPARAMETER_BOUNDS['n_steps'])
        self.toolbox.register("attr_gamma", random.uniform, *HYPERPARAMETER_BOUNDS['gamma'])
        self.toolbox.register("attr_lambda_sparse", random.uniform, *HYPERPARAMETER_BOUNDS['lambda_sparse'])
        self.toolbox.register("attr_lr", random.uniform, *HYPERPARAMETER_BOUNDS['lr'])
        self.toolbox.register("attr_momentum", random.uniform, *HYPERPARAMETER_BOUNDS['momentum'])
        self.toolbox.register("attr_batch_size", self._generate_batch_size)
        self.toolbox.register("attr_n_independent", random.randint, *HYPERPARAMETER_BOUNDS['n_independent'])
        self.toolbox.register("attr_n_shared", random.randint, *HYPERPARAMETER_BOUNDS['n_shared'])
        
        # Register individual and population
        self.toolbox.register("individual", tools.initCycle, creator.Individual,
                             (self.toolbox.attr_n_d, self.toolbox.attr_n_a, 
                              self.toolbox.attr_n_steps, self.toolbox.attr_gamma,
                              self.toolbox.attr_lambda_sparse, self.toolbox.attr_lr,
                              self.toolbox.attr_momentum, self.toolbox.attr_batch_size,
                              self.toolbox.attr_n_independent, self.toolbox.attr_n_shared), n=1)
        
        self.toolbox.register("population", tools.initRepeat, list, self.toolbox.individual)
        
        # Register genetic operators
        self.toolbox.register("evaluate", self._evaluate_individual)
        self.toolbox.register("mate", self._crossover)
        self.toolbox.register("mutate", self._mutate)
        self.toolbox.register("select", tools.selTournament, tournsize=self.config.ga_tournament_size)
        
        self.logger.info("DEAP genetic algorithm framework initialized")
    
    def _generate_batch_size(self) -> int:
        """Generate valid batch size (power of 2 or multiple of 64)."""
        min_bs, max_bs = HYPERPARAMETER_BOUNDS['batch_size']
        # Generate multiples of 64
        multiples = [64 * i for i in range(min_bs // 64, (max_bs // 64) + 1)]
        return random.choice(multiples)
    
    def _individual_to_hyperparams(self, individual: List) -> Dict:
        """Convert DEAP individual to hyperparameter dictionary."""
        return {
            'n_d': int(individual[0]),
            'n_a': int(individual[1]),
            'n_steps': int(individual[2]),
            'gamma': float(individual[3]),
            'lambda_sparse': float(individual[4]),
            'lr': float(individual[5]),
            'momentum': float(individual[6]),
            'batch_size': int(individual[7]),
            'n_independent': int(individual[8]),
            'n_shared': int(individual[9])
        }
    
    def _hyperparams_to_tuple(self, hyperparams: Dict) -> Tuple:
        """Convert hyperparameters to hashable tuple for caching."""
        return tuple(sorted(hyperparams.items()))
    
    def _evaluate_individual(self, individual: List) -> Tuple[float]:
        """Evaluate fitness of an individual (hyperparameter set)."""
        hyperparams = self._individual_to_hyperparams(individual)
        
        # Check cache
        cache_key = self._hyperparams_to_tuple(hyperparams)
        if cache_key in self.evaluation_cache:
            return self.evaluation_cache[cache_key]
        
        try:
            # Create TabNet model with these hyperparameters
            virtual_batch_size = min(hyperparams['batch_size'] // 8, 128)
            
            model = TabNetClassifier(
                n_d=hyperparams['n_d'],
                n_a=hyperparams['n_a'],
                n_steps=hyperparams['n_steps'],
                gamma=hyperparams['gamma'],
                lambda_sparse=hyperparams['lambda_sparse'],
                momentum=hyperparams['momentum'],
                n_independent=hyperparams['n_independent'],
                n_shared=hyperparams['n_shared'],
                optimizer_fn=torch.optim.Adam,
                optimizer_params={"lr": hyperparams['lr'], "weight_decay": 1e-5},
                scheduler_fn=torch.optim.lr_scheduler.StepLR,
                scheduler_params={"step_size": 10, "gamma": 0.9},
                device_name=str(self.device),
                verbose=0,
                seed=self.config.random_seed
            )
            
            # Train model with early stopping
            model.fit(
                self.X_train, self.y_train,
                eval_set=[(self.X_val, self.y_val)],
                eval_name=["val"],
                eval_metric=["auc"],
                max_epochs=min(30, self.config.max_epochs),  # Reduced epochs for GA
                patience=8,
                batch_size=hyperparams['batch_size'],
                virtual_batch_size=virtual_batch_size,
                drop_last=False,
                num_workers=0,
            )
            
            # Evaluate on validation set
            y_pred = model.predict(self.X_val)
            y_proba = model.predict_proba(self.X_val)[:, 1]
            
            # Compute fitness as weighted combination of metrics
            accuracy = accuracy_score(self.y_val, y_pred)
            auc = roc_auc_score(self.y_val, y_proba)
            
            cm = confusion_matrix(self.y_val, y_pred)
            tn, fp, fn, tp = cm.ravel()
            sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
            specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
            f1 = 2 * (precision * sensitivity) / (precision + sensitivity) if (precision + sensitivity) > 0 else 0.0
            
            # Weighted fitness function emphasizing AUC and F1
            fitness = (0.30 * auc + 0.25 * f1 + 0.20 * accuracy + 
                      0.15 * sensitivity + 0.10 * specificity)
            
            # Cache result
            self.evaluation_cache[cache_key] = (fitness,)
            
            # Clear memory
            del model
            clear_memory()
            
            return (fitness,)
            
        except Exception as e:
            self.logger.warning(f"   Evaluation failed for individual: {e}")
            return (0.0,)  # Return poor fitness for failed evaluations
    
    def _crossover(self, ind1: List, ind2: List) -> Tuple[List, List]:
        """Custom crossover operation for hyperparameters."""
        # Blend crossover for continuous variables, uniform for discrete
        offspring1 = ind1[:]
        offspring2 = ind2[:]
        
        alpha = 0.5  # Blend factor
        
        for i in range(len(ind1)):
            if random.random() < 0.5:
                if i in [0, 1, 2, 7, 8, 9]:  # Discrete parameters
                    offspring1[i], offspring2[i] = ind2[i], ind1[i]
                else:  # Continuous parameters
                    offspring1[i] = alpha * ind1[i] + (1 - alpha) * ind2[i]
                    offspring2[i] = alpha * ind2[i] + (1 - alpha) * ind1[i]
        
        return offspring1, offspring2
    
    def _mutate(self, individual: List) -> Tuple[List]:
        """Custom mutation operation for hyperparameters."""
        mutated = individual[:]
        
        for i in range(len(mutated)):
            if random.random() < self.config.ga_mutation_rate:
                if i == 0:  # n_d
                    mutated[i] = random.randint(*HYPERPARAMETER_BOUNDS['n_d'])
                elif i == 1:  # n_a
                    mutated[i] = random.randint(*HYPERPARAMETER_BOUNDS['n_a'])
                elif i == 2:  # n_steps
                    mutated[i] = random.randint(*HYPERPARAMETER_BOUNDS['n_steps'])
                elif i == 3:  # gamma
                    mutated[i] = random.uniform(*HYPERPARAMETER_BOUNDS['gamma'])
                elif i == 4:  # lambda_sparse
                    mutated[i] = random.uniform(*HYPERPARAMETER_BOUNDS['lambda_sparse'])
                elif i == 5:  # lr
                    mutated[i] = random.uniform(*HYPERPARAMETER_BOUNDS['lr'])
                elif i == 6:  # momentum
                    mutated[i] = random.uniform(*HYPERPARAMETER_BOUNDS['momentum'])
                elif i == 7:  # batch_size
                    mutated[i] = self._generate_batch_size()
                elif i == 8:  # n_independent
                    mutated[i] = random.randint(*HYPERPARAMETER_BOUNDS['n_independent'])
                elif i == 9:  # n_shared
                    mutated[i] = random.randint(*HYPERPARAMETER_BOUNDS['n_shared'])
        
        return (mutated,)
    
    def optimize(self) -> Dict:
        """Run genetic algorithm optimization."""
        with self.logger.log_section("ðŸ§¬ GENETIC ALGORITHM OPTIMIZATION"):
            self.logger.info(f"Population Size: {self.config.ga_population_size}")
            self.logger.info(f"Generations: {self.config.ga_generations}")
            self.logger.info(f"Mutation Rate: {self.config.ga_mutation_rate}")
            self.logger.info(f"Crossover Rate: {self.config.ga_crossover_rate}")
            
            # Create initial population
            population = self.toolbox.population(n=self.config.ga_population_size)
            
            # Statistics
            stats = tools.Statistics(lambda ind: ind.fitness.values)
            stats.register("avg", np.mean)
            stats.register("std", np.std)
            stats.register("min", np.min)
            stats.register("max", np.max)
            
            # Hall of fame
            hof = tools.HallOfFame(5)
            
            # Run evolution
            self.logger.info("\nðŸ”„ Starting evolution...")
            
            for gen in range(self.config.ga_generations):
                self.logger.info(f"\n{'='*60}")
                self.logger.info(f"  GENERATION {gen + 1}/{self.config.ga_generations}")
                self.logger.info(f"{'='*60}")
                
                # Evaluate population
                fitnesses = list(map(self.toolbox.evaluate, population))
                for ind, fit in zip(population, fitnesses):
                    ind.fitness.values = fit
                
                # Update hall of fame
                hof.update(population)
                
                # Record statistics
                record = stats.compile(population)
                self.generation_stats.append(record)
                
                self.logger.info(f"   Fitness - Max: {record['max']:.6f}, "
                               f"Avg: {record['avg']:.6f}, "
                               f"Min: {record['min']:.6f}, "
                               f"Std: {record['std']:.6f}")
                
                # Store best individual of this generation
                best_ind = tools.selBest(population, 1)[0]
                self.best_individuals.append((gen, best_ind[:], best_ind.fitness.values[0]))
                
                # Selection
                offspring = self.toolbox.select(population, len(population))
                offspring = list(map(self.toolbox.clone, offspring))
                
                # Crossover
                for child1, child2 in zip(offspring[::2], offspring[1::2]):
                    if random.random() < self.config.ga_crossover_rate:
                        self.toolbox.mate(child1, child2)
                        del child1.fitness.values
                        del child2.fitness.values
                
                # Mutation
                for mutant in offspring:
                    if random.random() < self.config.ga_mutation_rate:
                        self.toolbox.mutate(mutant)
                        del mutant.fitness.values
                
                # Evaluate offspring with invalid fitness
                invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
                fitnesses = map(self.toolbox.evaluate, invalid_ind)
                for ind, fit in zip(invalid_ind, fitnesses):
                    ind.fitness.values = fit
                
                # Replace population
                population[:] = offspring
                
                # Clear some memory periodically
                if (gen + 1) % 5 == 0:
                    clear_memory()
            
            # Get best hyperparameters
            best_individual = hof[0]
            best_hyperparams = self._individual_to_hyperparams(best_individual)
            best_fitness = best_individual.fitness.values[0]
            
            self.logger.info("\n" + "="*70)
            self.logger.info("   OPTIMIZATION COMPLETE")
            self.logger.info("="*70)
            self.logger.info(f"\n Best Fitness Achieved: {best_fitness:.6f}")
            self.logger.info(f"\n Optimal Hyperparameters:")
            for param, value in best_hyperparams.items():
                self.logger.info(f"   {param:15s}: {value}")
            
            return {
                'best_hyperparams': best_hyperparams,
                'best_fitness': best_fitness,
                'best_individual': best_individual,
                'hall_of_fame': [self._individual_to_hyperparams(ind) for ind in hof],
                'generation_stats': self.generation_stats,
                'evolution_history': self.best_individuals
            }
    
    def plot_evolution_progress(self, save_dir: Path):
        """Plot evolution progress over generations."""
        if not self.generation_stats:
            return
        
        try:
            fig, axes = plt.subplots(2, 2, figsize=(16, 12))
            fig.suptitle('Genetic Algorithm Evolution Progress', fontsize=16, fontweight='bold')
            
            generations = list(range(1, len(self.generation_stats) + 1))
            
            # Max fitness
            ax = axes[0, 0]
            max_fitness = [stat['max'] for stat in self.generation_stats]
            ax.plot(generations, max_fitness, 'o-', linewidth=2, markersize=6, color='green', label='Max Fitness')
            ax.set_xlabel('Generation', fontsize=11, fontweight='bold')
            ax.set_ylabel('Maximum Fitness', fontsize=11, fontweight='bold')
            ax.set_title('Best Fitness per Generation', fontsize=12, fontweight='bold')
            ax.grid(True, alpha=0.3)
            ax.legend()
            
            # Average fitness
            ax = axes[0, 1]
            avg_fitness = [stat['avg'] for stat in self.generation_stats]
            std_fitness = [stat['std'] for stat in self.generation_stats]
            ax.plot(generations, avg_fitness, 'o-', linewidth=2, markersize=6, color='blue', label='Avg Fitness')
            ax.fill_between(generations, 
                           np.array(avg_fitness) - np.array(std_fitness),
                           np.array(avg_fitness) + np.array(std_fitness),
                           alpha=0.2, color='blue')
            ax.set_xlabel('Generation', fontsize=11, fontweight='bold')
            ax.set_ylabel('Average Fitness', fontsize=11, fontweight='bold')
            ax.set_title('Average Fitness with Â±1 Std Dev', fontsize=12, fontweight='bold')
            ax.grid(True, alpha=0.3)
            ax.legend()
            
            # Fitness range
            ax = axes[1, 0]
            min_fitness = [stat['min'] for stat in self.generation_stats]
            ax.fill_between(generations, min_fitness, max_fitness, alpha=0.3, color='purple')
            ax.plot(generations, max_fitness, 'o-', linewidth=2, markersize=4, color='green', label='Max')
            ax.plot(generations, avg_fitness, 's-', linewidth=2, markersize=4, color='blue', label='Avg')
            ax.plot(generations, min_fitness, '^-', linewidth=2, markersize=4, color='red', label='Min')
            ax.set_xlabel('Generation', fontsize=11, fontweight='bold')
            ax.set_ylabel('Fitness', fontsize=11, fontweight='bold')
            ax.set_title('Fitness Range Evolution', fontsize=12, fontweight='bold')
            ax.grid(True, alpha=0.3)
            ax.legend()
            
            # Standard deviation
            ax = axes[1, 1]
            ax.plot(generations, std_fitness, 'o-', linewidth=2, markersize=6, color='orange')
            ax.set_xlabel('Generation', fontsize=11, fontweight='bold')
            ax.set_ylabel('Standard Deviation', fontsize=11, fontweight='bold')
            ax.set_title('Population Diversity (Std Dev of Fitness)', fontsize=12, fontweight='bold')
            ax.grid(True, alpha=0.3)
            
            plt.tight_layout()
            
            # Save plot
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            save_path = save_dir / f"ga_evolution_{timestamp}.png"
            plt.savefig(str(save_path), dpi=100, bbox_inches='tight')
            self.logger.info(f" Evolution plot saved: {save_path.name}")
            
            plt.show(block=False)
            
        except Exception as e:
            self.logger.error(f"Failed to create evolution plot: {e}")

# =============================================================================
# MATPLOTLIB WINDOW PLOT (keeping original visualization code)
# =============================================================================

class WindowPlotManager:
    """Manages creation of matplotlib window-based plots including CV results."""
    
    def __init__(self, config: TabNetConfig, logger: Logger):
        self.config = config
        self.logger = logger
        self.figures = []
    
    def create_training_history_plot(self, history: Dict) -> plt.Figure:
        """Create training history plot in a window."""
        if not history:
            return None
        
        try:
            fig = plt.figure(figsize=(16, 10))
            fig.canvas.manager.set_window_title('TabNet Training History')
            gs = GridSpec(2, 2, figure=fig, hspace=0.3, wspace=0.3)
            
            epochs = list(range(1, len(history.get('loss', [])) + 1))
            
            # Loss plot
            ax1 = fig.add_subplot(gs[0, 0])
            if 'loss' in history:
                ax1.plot(epochs, history['loss'], 'o-', label='Training Loss', linewidth=2, markersize=4)
            if 'val_loss' in history:
                ax1.plot(epochs, history['val_loss'], 's-', label='Validation Loss', linewidth=2, markersize=4)
            ax1.set_xlabel('Epoch', fontsize=11)
            ax1.set_ylabel('Loss', fontsize=11)
            ax1.set_title('Training & Validation Loss', fontsize=13, fontweight='bold')
            ax1.legend()
            ax1.grid(True, alpha=0.3)
            
            # Accuracy plot
            ax2 = fig.add_subplot(gs[0, 1])
            if 'accuracy' in history:
                ax2.plot(epochs, history['accuracy'], 'o-', label='Training Accuracy', linewidth=2, markersize=4)
            if 'val_accuracy' in history:
                ax2.plot(epochs, history['val_accuracy'], 's-', label='Validation Accuracy', linewidth=2, markersize=4)
            ax2.set_xlabel('Epoch', fontsize=11)
            ax2.set_ylabel('Accuracy', fontsize=11)
            ax2.set_title('Training & Validation Accuracy', fontsize=13, fontweight='bold')
            ax2.legend()
            ax2.grid(True, alpha=0.3)
            
            # AUC plot
            ax3 = fig.add_subplot(gs[1, 0])
            if 'auc' in history:
                ax3.plot(epochs, history['auc'], 'o-', label='Training AUC', linewidth=2, markersize=4)
            if 'val_auc' in history:
                ax3.plot(epochs, history['val_auc'], 's-', label='Validation AUC', linewidth=2, markersize=4)
            ax3.set_xlabel('Epoch', fontsize=11)
            ax3.set_ylabel('AUC', fontsize=11)
            ax3.set_title('Training & Validation AUC', fontsize=13, fontweight='bold')
            ax3.legend()
            ax3.grid(True, alpha=0.3)
            
            # Learning rate plot
            ax4 = fig.add_subplot(gs[1, 1])
            if 'lr' in history:
                ax4.plot(epochs, history['lr'], 'o-', label='Learning Rate', linewidth=2, markersize=4, color='purple')
            ax4.set_xlabel('Epoch', fontsize=11)
            ax4.set_ylabel('Learning Rate', fontsize=11)
            ax4.set_title('Learning Rate Schedule', fontsize=13, fontweight='bold')
            ax4.legend()
            ax4.grid(True, alpha=0.3)
            
            plt.tight_layout()
            self.figures.append(fig)
            return fig
            
        except Exception as e:
            self.logger.error(f"Failed to create training history plot: {e}")
            return None
    
    def create_metrics_bar_chart(self, metrics: Dict) -> plt.Figure:
        """Create bar chart for model metrics."""
        try:
            categories = ['Accuracy', 'AUC', 'Sensitivity', 'Specificity', 'Precision', 'F1 Score']
            values = [
                metrics.get('accuracy', 0),
                metrics.get('auc', 0),
                metrics.get('sensitivity', 0),
                metrics.get('specificity', 0),
                metrics.get('precision', 0),
                metrics.get('f1', 0)
            ]
            
            fig, ax = plt.subplots(figsize=(12, 8))
            fig.canvas.manager.set_window_title('Performance Metrics')
            
            colors = plt.cm.RdYlGn(np.array(values))
            bars = ax.bar(categories, values, color=colors, edgecolor='black', linewidth=1.5, alpha=0.8)
            
            ax.set_ylabel('Score', fontsize=13, fontweight='bold')
            ax.set_xlabel('Metric', fontsize=13, fontweight='bold')
            ax.set_title('Model Performance Metrics', fontsize=15, fontweight='bold', pad=20)
            ax.set_ylim(0, 1.05)
            ax.grid(True, alpha=0.3, axis='y', linestyle='--')
            ax.set_axisbelow(True)
            
            for bar, value in zip(bars, values):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                       f'{value:.4f}',
                       ha='center', va='bottom', fontsize=11, fontweight='bold')
            
            ax.axhline(y=0.8, color='gray', linestyle='--', linewidth=1, alpha=0.5, label='0.8 threshold')
            ax.axhline(y=0.9, color='gray', linestyle='--', linewidth=1, alpha=0.5, label='0.9 threshold')
            
            plt.setp(ax.get_xticklabels(), rotation=0, ha="center", fontsize=11)
            ax.legend(loc='lower right', fontsize=10)
            
            plt.tight_layout()
            self.figures.append(fig)
            return fig
            
        except Exception as e:
            self.logger.error(f"Failed to create metrics bar chart: {e}")
            return None
    
    def save_and_display_plot(self, fig: plt.Figure, filepath: Path, plot_name: str) -> None:
        """Save plot and display in window."""
        if fig is None:
            return
        
        try:
            if self.config.save_plots:
                png_path = filepath.with_suffix('.png')
                fig.savefig(str(png_path), dpi=self.config.plot_dpi, bbox_inches='tight')
                self.logger.debug(f"   Saved: {png_path.name}")
            
            if self.config.display_plots:
                self.logger.debug(f"   Prepared: {plot_name}")
                
        except Exception as e:
            self.logger.warning(f"Failed to save/display plot {filepath.name}: {e}")
    
    def show_all_plots(self):
        """Display all created plots in separate windows."""
        if self.config.display_plots and self.figures:
            self.logger.info(f"\n Opening {len(self.figures)} plot windows...")
            self.logger.info("   Close plot windows to continue...")
            plt.show(block=self.config.block_on_plot)

# =============================================================================
# MAIN TRAINER CLASS WITH GA OPTIMIZATION
# =============================================================================

class TabNetTrainer:
    """Comprehensive TabNet trainer with GA optimization and cross-validation."""
    
    def __init__(self, model_config: TabNetConfig, data_config: DataConfig):
        self.config = model_config
        self.data_config = data_config
        self.logger = Logger("TabNetTrainer")
        self.device = None
        self.model = None
        self.feature_names = []
        self.results = {}
        self.cv_results = None
        self.ga_results = None
        
        self.data_processor = DataProcessor(data_config, self.logger)
        self.metrics_computer = MetricsComputer()
        self.plot_manager = WindowPlotManager(model_config, self.logger)
        
        self._setup_environment()
    
    def _setup_environment(self):
        """Setup device and computational environment."""
        set_random_seeds(self.config.random_seed)
        self.logger.info(f" Random seeds set to {self.config.random_seed}")
        
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            torch.backends.cudnn.deterministic = self.config.deterministic
            torch.backends.cudnn.benchmark = not self.config.deterministic
            gpu_name = torch.cuda.get_device_name(0)
            gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
            self.logger.info(f" Using GPU: {gpu_name} ({gpu_memory:.1f} GB)")
        else:
            self.device = torch.device("cpu")
            self.logger.info(" Using CPU")
            
            self.config.batch_size = min(self.config.batch_size, 64)
            self.config.virtual_batch_size = min(self.config.virtual_batch_size, 32)
            self.logger.info(f"   Adjusted batch sizes for CPU: batch={self.config.batch_size}, virtual={self.config.virtual_batch_size}")
    
    def _create_model(self, hyperparams: Optional[Dict] = None) -> TabNetClassifier:
        """Create and configure TabNet model."""
        if not _HAS_TABNET:
            raise RuntimeError("pytorch-tabnet is not available. Install with: pip install pytorch-tabnet")
        
        # Use provided hyperparams or config defaults
        if hyperparams is None:
            hyperparams = {
                'n_d': self.config.n_d,
                'n_a': self.config.n_a,
                'n_steps': self.config.n_steps,
                'gamma': self.config.gamma,
                'lambda_sparse': self.config.lambda_sparse,
                'lr': self.config.lr,
                'momentum': self.config.momentum,
                'batch_size': self.config.batch_size,
                'n_independent': self.config.n_independent,
                'n_shared': self.config.n_shared
            }
        
        model = TabNetClassifier(
            n_d=hyperparams['n_d'],
            n_a=hyperparams['n_a'],
            n_steps=hyperparams['n_steps'],
            gamma=hyperparams['gamma'],
            lambda_sparse=hyperparams['lambda_sparse'],
            momentum=hyperparams['momentum'],
            n_independent=hyperparams['n_independent'],
            n_shared=hyperparams['n_shared'],
            optimizer_fn=torch.optim.Adam,
            optimizer_params={"lr": hyperparams['lr'], "weight_decay": 1e-5},
            scheduler_fn=torch.optim.lr_scheduler.StepLR,
            scheduler_params={"step_size": 10, "gamma": 0.9},
            device_name=str(self.device),
            verbose=0,
            seed=self.config.random_seed
        )
        
        return model
    
    def train(self, X_train: np.ndarray, X_val: np.ndarray, 
             y_train: np.ndarray, y_val: np.ndarray,
             hyperparams: Optional[Dict] = None) -> Dict[str, Any]:
        """Train the TabNet model with optional optimized hyperparameters."""
        with self.logger.log_section("MODEL TRAINING"):
            
            # Use optimized hyperparameters if available
            if hyperparams is not None:
                self.logger.info("Using optimized hyperparameters from genetic algorithm")
                for param, value in hyperparams.items():
                    self.logger.info(f"   {param:15s}: {value}")
                self.model = self._create_model(hyperparams)
                batch_size = hyperparams['batch_size']
                virtual_batch_size = min(batch_size // 8, 128)
            else:
                self.logger.info("Using default configuration hyperparameters")
                self.model = self._create_model()
                batch_size = self.config.batch_size
                virtual_batch_size = self.config.virtual_batch_size
            
            self.logger.info(f"\n Training Configuration:")
            self.logger.info(f"   Max epochs: {self.config.max_epochs}")
            self.logger.info(f"   Patience: {self.config.patience}")
            self.logger.info(f"   Batch size: {batch_size}")
            self.logger.info(f"   Virtual batch size: {virtual_batch_size}")
            
            try:
                with timer(self.logger, "Model training"):
                    self.model.fit(
                        X_train, y_train,
                        eval_set=[(X_val, y_val)],
                        eval_name=["val"],
                        eval_metric=["auc", "accuracy"],
                        max_epochs=self.config.max_epochs,
                        patience=self.config.patience,
                        batch_size=batch_size,
                        virtual_batch_size=virtual_batch_size,
                        drop_last=False,
                        num_workers=0,
                    )
                
                self.logger.info(f"âœ“ Training completed at epoch {self.model.best_epoch}")
                return self.model.history or {}
                
            except Exception as e:
                self.logger.error(f"Training failed: {e}")
                raise
    
    def evaluate(self, X_test: np.ndarray, y_test: np.ndarray) -> Dict[str, Any]:
        """Evaluate the trained model."""
        if self.model is None:
            raise RuntimeError("Model must be trained before evaluation")
        
        with self.logger.log_section(" MODEL EVALUATION ON TEST SET"):
            with timer(self.logger, "Model evaluation"):
                y_pred = self.model.predict(X_test)
                y_proba = self.model.predict_proba(X_test)[:, 1]
                metrics = self.metrics_computer.compute_all_metrics(y_test, y_pred, y_proba)
            
            self.logger.info(self.metrics_computer.format_metrics_summary(metrics, "TEST SET "))
            
            self.logger.info("\n CLASSIFICATION REPORT:")
            print(classification_report(y_test, y_pred, 
                                      target_names=['Benign', 'Malignant'], 
                                      digits=4))
            
            return metrics
    
    def create_visualizations(self, metrics: Dict, history: Dict, 
                            X_train: np.ndarray, y_train: np.ndarray, 
                            save_dir: Path) -> None:
        """Create all visualizations."""
        with self.logger.log_section("GENERATING VISUALIZATIONS"):
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            
            # Training History
            history_fig = self.plot_manager.create_training_history_plot(history)
            if history_fig:
                self.plot_manager.save_and_display_plot(
                    history_fig, 
                    save_dir / f"training_history_{timestamp}",
                    "Training History"
                )
            
            # Metrics Bar Chart
            metrics_fig = self.plot_manager.create_metrics_bar_chart(metrics)
            if metrics_fig:
                self.plot_manager.save_and_display_plot(
                    metrics_fig, 
                    save_dir / f"metrics_barchart_{timestamp}",
                    "Metrics Bar Chart"
                )
            
            # Show all plots
            self.plot_manager.show_all_plots()
            
            self.logger.info("âœ“ All visualizations generated and displayed")
    
    def save_experiment(self, metrics: Dict, history: Dict, save_dir: Path) -> None:
        """Save complete experiment results including GA results."""
        save_dir.mkdir(parents=True, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        experiment_data = {
            'timestamp': datetime.now().isoformat(),
            'config': {
                'model': asdict(self.config),
                'data': asdict(self.data_config)
            },
            'metrics': {k: v for k, v in metrics.items() 
                       if k not in ['confusion_matrix', 'roc_curve']},
            'confusion_matrix': metrics.get('confusion_matrix', np.array([])).tolist(),
            'feature_names': self.feature_names,
            'training_info': {
                'best_epoch': getattr(self.model, 'best_epoch', -1) if self.model else -1,
                'total_params': sum(p.numel() for p in self.model.network.parameters()) if self.model else 0,
                'trainable_params': sum(p.numel() for p in self.model.network.parameters() 
                                      if p.requires_grad) if self.model else 0
            }
        }
        
        # Add GA results if available
        if self.ga_results is not None:
            experiment_data['genetic_algorithm'] = {
                'best_hyperparams': self.ga_results['best_hyperparams'],
                'best_fitness': self.ga_results['best_fitness'],
                'hall_of_fame': self.ga_results['hall_of_fame'],
                'final_generation_stats': self.ga_results['generation_stats'][-1] if self.ga_results['generation_stats'] else {}
            }
        
        results_path = save_dir / f"experiment_results_{timestamp}.json"
        with open(results_path, 'w') as f:
            json.dump(experiment_data, f, indent=2)
        
        # Save GA results separately if available
        if self.ga_results is not None:
            ga_path = save_dir / f"ga_optimization_{timestamp}.pkl"
            with open(ga_path, 'wb') as f:
                pickle.dump(self.ga_results, f)
            self.logger.info(f"GA results saved: {ga_path.name}")
        
        if self.model:
            model_path = save_dir / f"tabnet_model_{timestamp}.pth"
            torch.save(self.model.network.state_dict(), model_path)
            self.logger.info(f"Model saved: {model_path.name}")
        
        self.logger.info(f"Results saved: {results_path.name}")
    
    def run_experiment(self) -> Dict[str, Any]:
        """Run complete training experiment with GA optimization."""
        with self.logger.log_section("TABNET TRAINING EXPERIMENT WITH GA OPTIMIZATION"):
            try:
                # 1. Load and preprocess data
                with self.logger.log_section("DATA LOADING & PREPROCESSING"):
                    X_train, X_val, X_test, y_train, y_val, y_test = \
                        self.data_processor.load_and_preprocess(
                            Path(self.config.dataset_path),
                            self.config.random_seed
                        )
                    self.feature_names = self.data_processor.feature_names
                
                # 2. Run genetic algorithm optimization (if enabled)
                optimized_hyperparams = None
                if self.config.use_ga_optimization:
                    ga_optimizer = GeneticOptimizer(
                        self.config, self.data_config,
                        X_train, X_val, y_train, y_val,
                        self.logger
                    )
                    
                    self.ga_results = ga_optimizer.optimize()
                    optimized_hyperparams = self.ga_results['best_hyperparams']
                    
                    # Plot evolution progress
                    results_dir = Path(self.config.results_dir)
                    results_dir.mkdir(parents=True, exist_ok=True)
                    ga_optimizer.plot_evolution_progress(results_dir)
                
                # 3. Train final model with optimized hyperparameters
                history = self.train(X_train, X_val, y_train, y_val, optimized_hyperparams)
                
                # 4. Evaluate model on test set
                metrics = self.evaluate(X_test, y_test)
                
                # 5. Create visualizations
                results_dir = Path(self.config.results_dir)
                self.create_visualizations(metrics, history, X_train, y_train, results_dir)
                
                # 6. Save experiment
                self.save_experiment(metrics, history, results_dir)
                
                self.results = {
                    'metrics': metrics,
                    'history': history,
                    'feature_names': self.feature_names,
                    'ga_results': self.ga_results,
                    'optimized_hyperparams': optimized_hyperparams
                }
                
                with self.logger.log_section("EXPERIMENT COMPLETED SUCCESSFULLY"):
                    if self.ga_results:
                        self.logger.info("Genetic Algorithm Results:")
                        self.logger.info(f"   Best Fitness: {self.ga_results['best_fitness']:.6f}")
                        self.logger.info(f"   Generations: {len(self.ga_results['generation_stats'])}")
                    
                    self.logger.info(f"\n Final Test Performance:")
                    self.logger.info(f"   Accuracy:    {metrics['accuracy']:.4f}")
                    self.logger.info(f"   AUC-ROC:     {metrics['auc']:.4f}")
                    self.logger.info(f"   Sensitivity: {metrics['sensitivity']:.4f}")
                    self.logger.info(f"   Specificity: {metrics['specificity']:.4f}")
                    self.logger.info(f"   F1 Score:    {metrics['f1']:.4f}")
                    self.logger.info(f"\n Results saved in: {results_dir}/")
                
                return self.results
                
            except Exception as e:
                self.logger.error(f" Experiment failed: {e}")
                raise
            finally:
                clear_memory()

# =============================================================================
# MAIN EXECUTION
# =============================================================================

def run_tabnet_training(dataset_path: str = "C:/Users/awwal/Desktop/MLEA_experiments/data.csv",
                       results_dir: str = "tabnet_results_ga",
                       use_ga_optimization: bool = True,
                       ga_population_size: int = 20,
                       ga_generations: int = 15,
                       display_plots: bool = True,
                       block_on_plot: bool = False) -> Dict[str, Any]:
    """
    Run TabNet training experiment with genetic algorithm optimization.
    
    Args:
        dataset_path: Path to CSV dataset
        results_dir: Directory to save results
        use_ga_optimization: Whether to use genetic algorithm for hyperparameter tuning
        ga_population_size: Population size for GA
        ga_generations: Number of generations for GA
        display_plots: Whether to display plots in windows
        block_on_plot: Whether to pause execution until windows are closed
        
    Returns:
        Dictionary containing metrics, history, feature names, and GA results
    """
    print("\n" + "=" * 70)
    print("   TABNET TRAINING WITH GENETIC ALGORITHM OPTIMIZATION")
    print("=" * 70)
    
    if not _HAS_TABNET:
        print("\n pytorch-tabnet is required but not installed.")
        print("   Install with: pip install pytorch-tabnet")
        raise ImportError("pytorch-tabnet not available")
    
    # Configuration
    model_config = TabNetConfig(
        dataset_path=dataset_path,
        results_dir=results_dir,
        max_epochs=50,
        patience=10,
        use_ga_optimization=use_ga_optimization,
        ga_population_size=ga_population_size,
        ga_generations=ga_generations,
        ga_mutation_rate=0.3,
        ga_crossover_rate=0.7,
        ga_tournament_size=3,
        display_plots=display_plots,
        save_plots=True,
        plot_dpi=100,
        block_on_plot=block_on_plot,
        use_cross_validation=False  # GA already provides validation
    )
    
    data_config = DataConfig(
        target_column="diagnosis",
        test_size=0.2,
        validation_size=0.15,  # Larger validation set for GA
        scale_features=True,
        handle_missing="mean"
    )
    
    # Create and run trainer
    trainer = TabNetTrainer(model_config, data_config)
    results = trainer.run_experiment()
    
    return results


def main():
    """Main function for command-line execution."""
    try:
        results = run_tabnet_training()
        return 0
    except Exception as e:
        print(f"\n Experiment failed: {e}")
        import traceback
        traceback.print_exc()
        return 1


if __name__ == "__main__":
    try:
        get_ipython()
        print("\n Running in Jupyter/IPython environment")
        results = run_tabnet_training(block_on_plot=False)
        print("\n Experiment completed! Results stored in 'results' variable")
        
        # Display summary
        if results.get('ga_results'):
            print("\n GA Optimization Summary:")
            print(f"   Best Fitness: {results['ga_results']['best_fitness']:.6f}")
            print(f"\n Optimized Hyperparameters:")
            for param, value in results['optimized_hyperparams'].items():
                print(f"   {param:15s}: {value}")
    except NameError:
        sys.exit(main())


 Running in Jupyter/IPython environment

   TABNET TRAINING WITH GENETIC ALGORITHM OPTIMIZATION
2025-11-13 18:46:53 [INFO] TabNetTrainer:  Random seeds set to 42
2025-11-13 18:46:53 [INFO] TabNetTrainer:  Using CPU
2025-11-13 18:46:53 [INFO] TabNetTrainer:    Adjusted batch sizes for CPU: batch=64, virtual=32
2025-11-13 18:46:53 [INFO] TabNetTrainer:  TABNET TRAINING EXPERIMENT WITH GA OPTIMIZATION
2025-11-13 18:46:53 [INFO] TabNetTrainer:  DATA LOADING & PREPROCESSING
2025-11-13 18:46:54 [INFO] TabNetTrainer: Starting: Loading dataset from C:\Users\awwal\Desktop\MLEA_experiments\data.csv
2025-11-13 18:46:54 [INFO] TabNetTrainer:    Shape: (569, 33)
2025-11-13 18:46:54 [INFO] TabNetTrainer: âœ“ Completed: Loading dataset from C:\Users\awwal\Desktop\MLEA_experiments\data.csv (0.01s)
2025-11-13 18:46:54 [INFO] TabNetTrainer:    Class distribution: Benign=357 (62.7%), Malignant=212 (37.3%)
2025-11-13 18:46:54 [INFO] TabNetTrainer:    Features: 30
2025-11-13 18:46:54 [INFO] TabNetTrainer: