In [None]:

import gc
import json
import logging
import os
import sys
import warnings
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any, Union
from dataclasses import dataclass, field, asdict
from enum import Enum

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import (
    accuracy_score, confusion_matrix, roc_auc_score, 
    classification_report, roc_curve
)
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import PCA

# Matplotlib imports for window plotting
import matplotlib
matplotlib.use('TkAgg')  # Use TkAgg backend for Windows GUI windows
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.gridspec import GridSpec

# Suppress warnings
warnings.filterwarnings('ignore')

# Conditional import for TabNet
try:
    from pytorch_tabnet.tab_model import TabNetClassifier
    _HAS_TABNET = True
except ImportError:
    TabNetClassifier = None
    _HAS_TABNET = False

# =============================================================================
# ENUMS AND CONSTANTS
# =============================================================================

# Set style for better-looking plots
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Default column names to drop
DEFAULT_DROP_COLUMNS = ["id", "ID", "Id", "patient_id", "Unnamed: 0"]

# Target column candidates
TARGET_CANDIDATES = ["diagnosis", "Diagnosis", "target", "Target", "class", "Class"]

# =============================================================================
# CONFIGURATION CLASSES
# =============================================================================

@dataclass
class TabNetConfig:
    """Configuration for TabNet model and training with validation."""
    
    # Model architecture
    n_d: int = 16
    n_a: int = 16
    n_steps: int = 5
    gamma: float = 1.5
    lambda_sparse: float = 0.001
    lr: float = 0.02
    momentum: float = 0.02
    n_independent: int = 2
    n_shared: int = 2
    
    # Training parameters
    max_epochs: int = 50
    patience: int = 10
    batch_size: int = 512
    virtual_batch_size: int = 64
    
    # Cross-validation parameters
    use_cross_validation: bool = True
    n_folds: int = 5
    cv_random_state: int = 42
    
    # Visualization
    display_plots: bool = True  # Display plots in windows
    save_plots: bool = True     # Save plots as PNG
    plot_dpi: int = 100         # DPI for saved plots
    block_on_plot: bool = False # Whether to block execution when showing plots
    
    # Environment
    random_seed: int = 42
    deterministic: bool = True
    
    # Paths
    dataset_path: str = "data.csv"
    results_dir: str = "tabnet_results"
    
    def __post_init__(self):
        """Validate configuration parameters."""
        self._validate()
    
    def _validate(self):
        """Validate all configuration parameters."""
        if self.n_d <= 0 or self.n_a <= 0:
            raise ValueError(f"n_d and n_a must be positive (got n_d={self.n_d}, n_a={self.n_a})")
        
        if self.n_steps <= 0:
            raise ValueError(f"n_steps must be positive (got {self.n_steps})")
        
        if not 0 <= self.gamma <= 10:
            raise ValueError(f"gamma should be in [0, 10] (got {self.gamma})")
        
        if not 0 <= self.lambda_sparse <= 1:
            raise ValueError(f"lambda_sparse should be in [0, 1] (got {self.lambda_sparse})")
        
        if self.lr <= 0:
            raise ValueError(f"lr must be positive (got {self.lr})")
        
        if self.max_epochs <= 0:
            raise ValueError(f"max_epochs must be positive (got {self.max_epochs})")
        
        if self.patience <= 0:
            raise ValueError(f"patience must be positive (got {self.patience})")
        
        if self.batch_size <= 0 or self.virtual_batch_size <= 0:
            raise ValueError("batch_size and virtual_batch_size must be positive")
        
        if self.virtual_batch_size > self.batch_size:
            raise ValueError("virtual_batch_size must be <= batch_size")
        
        if self.use_cross_validation and self.n_folds < 2:
            raise ValueError(f"n_folds must be >= 2 (got {self.n_folds})")
        
        if not Path(self.dataset_path).exists():
            raise FileNotFoundError(f"Dataset not found: {self.dataset_path}")

@dataclass
class DataConfig:
    """Configuration for data loading and preprocessing with validation."""
    
    target_column: Optional[str] = None
    drop_columns: List[str] = field(default_factory=lambda: DEFAULT_DROP_COLUMNS.copy())
    test_size: float = 0.2
    validation_size: float = 0.1
    scale_features: bool = True
    handle_missing: str = "mean"  # "mean", "median", "drop"
    
    def __post_init__(self):
        """Validate configuration parameters."""
        self._validate()
    
    def _validate(self):
        """Validate all configuration parameters."""
        if not 0 < self.test_size < 1:
            raise ValueError(f"test_size must be in (0, 1) (got {self.test_size})")
        
        if not 0 < self.validation_size < 1:
            raise ValueError(f"validation_size must be in (0, 1) (got {self.validation_size})")
        
        if self.test_size + self.validation_size >= 1:
            raise ValueError("test_size + validation_size must be < 1")
        
        if self.handle_missing not in ["mean", "median", "drop"]:
            raise ValueError(f"handle_missing must be 'mean', 'median', or 'drop' (got {self.handle_missing})")

# =============================================================================
# CROSS-VALIDATION
# =============================================================================

@dataclass
class CVResults:
    """Stores cross-validation results."""
    fold_metrics: List[Dict[str, float]] = field(default_factory=list)
    fold_histories: List[Dict] = field(default_factory=list)
    best_fold: int = -1
    mean_metrics: Dict[str, float] = field(default_factory=dict)
    std_metrics: Dict[str, float] = field(default_factory=dict)
    
    def add_fold_result(self, fold_idx: int, metrics: Dict, history: Dict):
        """Add results from a single fold."""
        self.fold_metrics.append(metrics)
        self.fold_histories.append(history)
    
    def compute_statistics(self):
        """Compute mean and std across folds."""
        if not self.fold_metrics:
            return
        
        # Extract numeric metrics
        metric_keys = ['accuracy', 'auc', 'sensitivity', 'specificity', 'precision', 'f1', 'npv']
        
        for key in metric_keys:
            values = [fold[key] for fold in self.fold_metrics if key in fold]
            if values:
                self.mean_metrics[key] = np.mean(values)
                self.std_metrics[key] = np.std(values)
        
        # Find best fold based on AUC
        if 'auc' in self.mean_metrics:
            auc_scores = [fold.get('auc', 0) for fold in self.fold_metrics]
            self.best_fold = int(np.argmax(auc_scores))
    
    def get_overfitting_analysis(self) -> Dict[str, Any]:
        """Analyze overfitting by comparing training and validation performance."""
        analysis = {
            'variance': {},
            'consistency': {},
            'overfitting_indicators': []
        }
        
        # Compute coefficient of variation (CV) for each metric
        for metric, mean_val in self.mean_metrics.items():
            if metric in self.std_metrics and mean_val > 0:
                cv = self.std_metrics[metric] / mean_val
                analysis['variance'][metric] = cv
                
                # Flag high variance (CV > 0.1 means >10% relative variation)
                if cv > 0.1:
                    analysis['overfitting_indicators'].append(
                        f"High variance in {metric}: CV={cv:.3f}"
                    )
        
        # Check consistency across folds
        for metric in ['accuracy', 'auc', 'f1']:
            if metric in self.std_metrics:
                std = self.std_metrics[metric]
                analysis['consistency'][metric] = 'Good' if std < 0.05 else 'Moderate' if std < 0.10 else 'Poor'
        
        return analysis

# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

class Logger:
    """Enhanced logger with context management."""
    
    def __init__(self, name: str, level: int = logging.INFO):
        self.logger = logging.getLogger(name)
        if not self.logger.handlers:
            handler = logging.StreamHandler(sys.stdout)
            formatter = logging.Formatter(
                "%(asctime)s [%(levelname)s] %(name)s: %(message)s",
                datefmt="%Y-%m-%d %H:%M:%S"
            )
            handler.setFormatter(formatter)
            self.logger.addHandler(handler)
            self.logger.setLevel(level)
    
    @contextmanager
    def log_section(self, title: str):
        """Context manager for logging sections."""
        self.logger.info("=" * 70)
        self.logger.info(f" {title}")
        self.logger.info("=" * 70)
        try:
            yield
        finally:
            self.logger.info("")
    
    def info(self, msg: str):
        self.logger.info(msg)
    
    def warning(self, msg: str):
        self.logger.warning(msg)
    
    def error(self, msg: str):
        self.logger.error(msg)
    
    def debug(self, msg: str):
        self.logger.debug(msg)

@contextmanager
def timer(logger: Logger, operation: str):
    """Context manager for timing operations."""
    start_time = datetime.now()
    logger.info(f"Starting: {operation}")
    try:
        yield
    finally:
        elapsed = (datetime.now() - start_time).total_seconds()
        logger.info(f"âœ“ Completed: {operation} ({elapsed:.2f}s)")

def set_random_seeds(seed: int):
    """Set random seeds for reproducibility."""
    import random
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def clear_memory():
    """Aggressively clear memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        try:
            torch.cuda.synchronize()
        except Exception:
            pass

# =============================================================================
# DATA PROCESSING CLASS
# =============================================================================

class DataProcessor:
    """Handles all data loading and preprocessing operations."""
    
    def __init__(self, config: DataConfig, logger: Logger):
        self.config = config
        self.logger = logger
        self.scaler = None
        self.feature_names = []
    
    def load_dataset(self, filepath: Path) -> pd.DataFrame:
        """Load dataset from file with error handling."""
        if not filepath.exists():
            raise FileNotFoundError(f"Dataset not found: {filepath}")
        
        try:
            with timer(self.logger, f"Loading dataset from {filepath}"):
                df = pd.read_csv(filepath)
                self.logger.info(f"   Shape: {df.shape}")
                return df
        except Exception as e:
            raise RuntimeError(f"Failed to load dataset: {e}")
    
    def identify_target_column(self, df: pd.DataFrame) -> str:
        """Identify the target column in the dataset."""
        if self.config.target_column is not None:
            if self.config.target_column in df.columns:
                return self.config.target_column
            else:
                raise ValueError(f"Specified target column '{self.config.target_column}' not found")
        
        for candidate in TARGET_CANDIDATES:
            if candidate in df.columns:
                self.logger.info(f"Auto-detected target column: '{candidate}'")
                return candidate
        
        target_col = df.columns[-1]
        self.logger.warning(f"Using last column as target: '{target_col}'")
        return target_col
    
    def preprocess_target(self, target_series: pd.Series) -> np.ndarray:
        """Convert target to binary format with robust handling."""
        target_series = target_series.copy()
        
        if target_series.dtype == 'object' or target_series.dtype.name == 'category':
            return self._preprocess_categorical_target(target_series)
        
        return self._preprocess_numeric_target(target_series)
    
    def _preprocess_categorical_target(self, target_series: pd.Series) -> np.ndarray:
        """Preprocess categorical target values."""
        normalized = target_series.astype(str).str.strip().str.lower()
        
        mapping = {
            'm': 1, 'malignant': 1, '1': 1, '4': 1, 'positive': 1, 'yes': 1,
            'b': 0, 'benign': 0, '0': 0, '2': 0, 'negative': 0, 'no': 0
        }
        
        mapped = normalized.map(mapping)
        
        if not mapped.isna().any():
            return mapped.astype(int).values
        
        numeric_vals = pd.to_numeric(normalized, errors='coerce')
        if numeric_vals.notna().all():
            return self._preprocess_numeric_target(numeric_vals)
        
        self.logger.warning("Using LabelEncoder for target conversion")
        encoder = LabelEncoder()
        encoded = encoder.fit_transform(normalized)
        
        if len(encoder.classes_) != 2:
            raise ValueError(f"Expected binary target, got {len(encoder.classes_)} classes")
        
        return encoded.astype(int)
    
    def _preprocess_numeric_target(self, target_series: pd.Series) -> np.ndarray:
        """Preprocess numeric target values."""
        numeric_vals = pd.to_numeric(target_series, errors='coerce')
        
        if numeric_vals.isna().any():
            raise ValueError("Target contains non-numeric values that cannot be converted")
        
        unique_vals = set(numeric_vals.dropna().unique())
        
        if unique_vals.issubset({0, 1}):
            return numeric_vals.astype(int).values
        elif unique_vals.issubset({2, 4}):
            return (numeric_vals == 4).astype(int).values
        elif len(unique_vals) == 2:
            sorted_vals = sorted(unique_vals)
            return (numeric_vals == sorted_vals[1]).astype(int).values
        else:
            raise ValueError(f"Cannot interpret target values: {unique_vals}")
    
    def prepare_features(self, df: pd.DataFrame, target_col: str) -> pd.DataFrame:
        """Prepare feature matrix by dropping unnecessary columns."""
        drop_cols = [col for col in self.config.drop_columns 
                    if col in df.columns and col != target_col]
        
        if target_col not in df.columns:
            raise ValueError(f"Target column '{target_col}' not found in dataframe")
        
        X_df = df.drop(columns=[target_col] + drop_cols, errors='ignore')
        
        for col in X_df.columns:
            X_df[col] = pd.to_numeric(X_df[col], errors='coerce')
        
        X_df = X_df.dropna(axis=1, how='all')
        
        if X_df.shape[1] == 0:
            raise ValueError("No valid features remaining after preprocessing")
        
        self.feature_names = list(X_df.columns)
        return X_df
    
    def handle_missing_values(self, X_df: pd.DataFrame) -> pd.DataFrame:
        """Handle missing values according to configuration."""
        missing_cols = X_df.columns[X_df.isna().any()].tolist()
        
        if not missing_cols:
            return X_df
        
        self.logger.warning(f" Missing values detected in {len(missing_cols)} columns")
        
        if self.config.handle_missing == "mean":
            X_df = X_df.fillna(X_df.mean())
        elif self.config.handle_missing == "median":
            X_df = X_df.fillna(X_df.median())
        elif self.config.handle_missing == "drop":
            X_df = X_df.dropna()
        
        return X_df
    
    def split_data(self, X: np.ndarray, y: np.ndarray, 
                   random_state: int) -> Tuple[np.ndarray, ...]:
        """Split data into train, validation, and test sets."""
        X_temp, X_test, y_temp, y_test = train_test_split(
            X, y, test_size=self.config.test_size, 
            stratify=y, random_state=random_state
        )
        
        val_size = self.config.validation_size / (1 - self.config.test_size)
        X_train, X_val, y_train, y_val = train_test_split(
            X_temp, y_temp, test_size=val_size,
            stratify=y_temp, random_state=random_state
        )
        
        self.logger.info(f"   Train set: {X_train.shape[0]} samples")
        self.logger.info(f"   Validation set: {X_val.shape[0]} samples")
        self.logger.info(f"   Test set: {X_test.shape[0]} samples")
        
        return X_train, X_val, X_test, y_train, y_val, y_test
    
    def scale_features(self, X_train: np.ndarray, X_val: np.ndarray, 
                      X_test: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Scale features using StandardScaler."""
        if not self.config.scale_features:
            return X_train, X_val, X_test
        
        with timer(self.logger, "Feature scaling"):
            self.scaler = StandardScaler()
            X_train = self.scaler.fit_transform(X_train).astype(np.float32)
            X_val = self.scaler.transform(X_val).astype(np.float32)
            X_test = self.scaler.transform(X_test).astype(np.float32)
        
        return X_train, X_val, X_test
    
    def load_and_preprocess(self, filepath: Path, random_state: int) -> Tuple[np.ndarray, ...]:
        """Complete data loading and preprocessing pipeline."""
        df = self.load_dataset(filepath)
        target_col = self.identify_target_column(df)
        X_df = self.prepare_features(df, target_col)
        X_df = self.handle_missing_values(X_df)
        y = self.preprocess_target(df[target_col])
        
        if X_df.shape[0] != len(y):
            raise ValueError(f"Shape mismatch: X has {X_df.shape[0]} samples but y has {len(y)}")
        
        benign_count = np.sum(y == 0)
        malignant_count = np.sum(y == 1)
        self.logger.info(f"   Class distribution: Benign={benign_count} ({benign_count/len(y)*100:.1f}%), "
                        f"Malignant={malignant_count} ({malignant_count/len(y)*100:.1f}%)")
        self.logger.info(f"   Features: {len(self.feature_names)}")
        
        X = X_df.values.astype(np.float32)
        X_train, X_val, X_test, y_train, y_val, y_test = self.split_data(X, y, random_state)
        X_train, X_val, X_test = self.scale_features(X_train, X_val, X_test)
        
        clear_memory()
        return X_train, X_val, X_test, y_train, y_val, y_test

# =============================================================================
# METRICS COMPUTATION 
# =============================================================================

class MetricsComputer:
    """Computes and formats model evaluation metrics."""
    
    @staticmethod
    def compute_all_metrics(y_true: np.ndarray, y_pred: np.ndarray, 
                           y_proba: np.ndarray) -> Dict[str, Any]:
        """Compute comprehensive evaluation metrics."""
        accuracy = accuracy_score(y_true, y_pred)
        auc = roc_auc_score(y_true, y_proba)
        cm = confusion_matrix(y_true, y_pred)
        
        tn, fp, fn, tp = cm.ravel()
        
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
        f1 = 2 * (precision * sensitivity) / (precision + sensitivity) if (precision + sensitivity) > 0 else 0.0
        
        fpr, tpr, thresholds = roc_curve(y_true, y_proba)
        
        return {
            "accuracy": float(accuracy),
            "auc": float(auc),
            "sensitivity": float(sensitivity),
            "specificity": float(specificity),
            "precision": float(precision),
            "npv": float(npv),
            "f1": float(f1),
            "confusion_matrix": cm,
            "roc_curve": (fpr, tpr),
            "tn": int(tn), 
            "fp": int(fp), 
            "fn": int(fn), 
            "tp": int(tp)
        }
    
    @staticmethod
    def format_metrics_summary(metrics: Dict[str, Any], prefix: str = "") -> str:
        """Format metrics as a readable summary."""
        lines = [
            f"\n {prefix}MODEL PERFORMANCE METRICS:",
            f"   Accuracy:    {metrics['accuracy']:.4f}",
            f"   AUC-ROC:     {metrics['auc']:.4f}",
            f"   Sensitivity: {metrics['sensitivity']:.4f}",
            f"   Specificity: {metrics['specificity']:.4f}",
            f"   Precision:   {metrics['precision']:.4f}",
            f"   NPV:         {metrics['npv']:.4f}",
            f"   F1 Score:    {metrics['f1']:.4f}",
            "\nðŸ“‹ CONFUSION MATRIX:",
            f"   TN: {metrics['tn']:4d}  FP: {metrics['fp']:4d}",
            f"   FN: {metrics['fn']:4d}  TP: {metrics['tp']:4d}"
        ]
        return "\n".join(lines)

# =============================================================================
# MATPLOTLIB WINDOW PLOT
# =============================================================================

class WindowPlotManager:
    """Manages creation of matplotlib window-based plots including CV results."""
    
    def __init__(self, config: TabNetConfig, logger: Logger):
        self.config = config
        self.logger = logger
        self.figures = []
    
    def create_training_history_plot(self, history: Dict) -> plt.Figure:
        """Create training history plot in a window."""
        if not history:
            return None
        
        try:
            fig = plt.figure(figsize=(16, 10))
            fig.canvas.manager.set_window_title('TabNet Training History')
            gs = GridSpec(2, 2, figure=fig, hspace=0.3, wspace=0.3)
            
            epochs = list(range(1, len(history.get('loss', [])) + 1))
            
            # Loss plot
            ax1 = fig.add_subplot(gs[0, 0])
            if 'loss' in history:
                ax1.plot(epochs, history['loss'], 'o-', label='Training Loss', linewidth=2, markersize=4)
            if 'val_loss' in history:
                ax1.plot(epochs, history['val_loss'], 's-', label='Validation Loss', linewidth=2, markersize=4)
            ax1.set_xlabel('Epoch', fontsize=11)
            ax1.set_ylabel('Loss', fontsize=11)
            ax1.set_title('Training & Validation Loss', fontsize=13, fontweight='bold')
            ax1.legend()
            ax1.grid(True, alpha=0.3)
            
            # Accuracy plot
            ax2 = fig.add_subplot(gs[0, 1])
            if 'accuracy' in history:
                ax2.plot(epochs, history['accuracy'], 'o-', label='Training Accuracy', linewidth=2, markersize=4)
            if 'val_accuracy' in history:
                ax2.plot(epochs, history['val_accuracy'], 's-', label='Validation Accuracy', linewidth=2, markersize=4)
            ax2.set_xlabel('Epoch', fontsize=11)
            ax2.set_ylabel('Accuracy', fontsize=11)
            ax2.set_title('Training & Validation Accuracy', fontsize=13, fontweight='bold')
            ax2.legend()
            ax2.grid(True, alpha=0.3)
            
            # AUC plot
            ax3 = fig.add_subplot(gs[1, 0])
            if 'auc' in history:
                ax3.plot(epochs, history['auc'], 'o-', label='Training AUC', linewidth=2, markersize=4)
            if 'val_auc' in history:
                ax3.plot(epochs, history['val_auc'], 's-', label='Validation AUC', linewidth=2, markersize=4)
            ax3.set_xlabel('Epoch', fontsize=11)
            ax3.set_ylabel('AUC', fontsize=11)
            ax3.set_title('Training & Validation AUC', fontsize=13, fontweight='bold')
            ax3.legend()
            ax3.grid(True, alpha=0.3)
            
            # Learning rate plot
            ax4 = fig.add_subplot(gs[1, 1])
            if 'lr' in history:
                ax4.plot(epochs, history['lr'], 'o-', label='Learning Rate', linewidth=2, markersize=4, color='purple')
            ax4.set_xlabel('Epoch', fontsize=11)
            ax4.set_ylabel('Learning Rate', fontsize=11)
            ax4.set_title('Learning Rate Schedule', fontsize=13, fontweight='bold')
            ax4.legend()
            ax4.grid(True, alpha=0.3)
            
            plt.tight_layout()
            self.figures.append(fig)
            return fig
            
        except Exception as e:
            self.logger.error(f"Failed to create training history plot: {e}")
            return None
    
    def create_cv_fold_comparison_plot(self, cv_results: CVResults) -> plt.Figure:
        """Create comparison plot of metrics across CV folds."""
        if not cv_results.fold_metrics:
            return None
        
        try:
            metrics_to_plot = ['accuracy', 'auc', 'sensitivity', 'specificity', 'precision', 'f1']
            n_folds = len(cv_results.fold_metrics)
            
            fig, axes = plt.subplots(2, 3, figsize=(18, 12))
            fig.canvas.manager.set_window_title('Cross-Validation Fold Comparison')
            axes = axes.flatten()
            
            fold_numbers = list(range(1, n_folds + 1))
            
            for idx, metric in enumerate(metrics_to_plot):
                ax = axes[idx]
                
                # Extract metric values for each fold
                values = [fold.get(metric, 0) for fold in cv_results.fold_metrics]
                mean_val = cv_results.mean_metrics.get(metric, 0)
                std_val = cv_results.std_metrics.get(metric, 0)
                
                # Create bar plot
                bars = ax.bar(fold_numbers, values, alpha=0.7, color='steelblue', edgecolor='black')
                
                # Highlight best fold
                if cv_results.best_fold >= 0:
                    bars[cv_results.best_fold].set_color('gold')
                    bars[cv_results.best_fold].set_edgecolor('darkgoldenrod')
                    bars[cv_results.best_fold].set_linewidth(2)
                
                # Add mean line
                ax.axhline(y=mean_val, color='red', linestyle='--', linewidth=2, 
                          label=f'Mean: {mean_val:.4f}')
                
                # Add Â±1 std band
                ax.axhspan(mean_val - std_val, mean_val + std_val, 
                          alpha=0.2, color='red', label=f'Â±1 STD: {std_val:.4f}')
                
                # Formatting
                ax.set_xlabel('Fold', fontsize=11, fontweight='bold')
                ax.set_ylabel(metric.capitalize(), fontsize=11, fontweight='bold')
                ax.set_title(f'{metric.capitalize()} Across Folds', fontsize=12, fontweight='bold')
                ax.set_xticks(fold_numbers)
                ax.legend(fontsize=9)
                ax.grid(True, alpha=0.3, axis='y')
                
                # Add value labels on bars
                for i, (bar, val) in enumerate(zip(bars, values)):
                    height = bar.get_height()
                    ax.text(bar.get_x() + bar.get_width()/2., height,
                           f'{val:.3f}', ha='center', va='bottom', fontsize=9)
            
            plt.tight_layout()
            self.figures.append(fig)
            return fig
            
        except Exception as e:
            self.logger.error(f"Failed to create CV fold comparison plot: {e}")
            return None
    
    def create_cv_boxplot(self, cv_results: CVResults) -> plt.Figure:
        """Create boxplot of CV metrics showing variance."""
        if not cv_results.fold_metrics:
            return None
        
        try:
            metrics_to_plot = ['accuracy', 'auc', 'sensitivity', 'specificity', 'precision', 'f1']
            
            # Prepare data
            data_for_boxplot = []
            labels = []
            
            for metric in metrics_to_plot:
                values = [fold.get(metric, 0) for fold in cv_results.fold_metrics]
                if values:
                    data_for_boxplot.append(values)
                    labels.append(metric.capitalize())
            
            if not data_for_boxplot:
                return None
            
            # Create figure
            fig, ax = plt.subplots(figsize=(14, 8))
            fig.canvas.manager.set_window_title('Cross-Validation Metrics Distribution')
            
            # Create boxplot
            bp = ax.boxplot(data_for_boxplot, labels=labels, patch_artist=True,
                           showmeans=True, meanline=True,
                           boxprops=dict(facecolor='lightblue', alpha=0.7),
                           medianprops=dict(color='red', linewidth=2),
                           meanprops=dict(color='green', linewidth=2, linestyle='--'),
                           whiskerprops=dict(linewidth=1.5),
                           capprops=dict(linewidth=1.5))
            
            # Customize
            ax.set_ylabel('Score', fontsize=13, fontweight='bold')
            ax.set_xlabel('Metric', fontsize=13, fontweight='bold')
            ax.set_title('Distribution of Metrics Across CV Folds\n(Red=Median, Green=Mean)', 
                        fontsize=14, fontweight='bold')
            ax.grid(True, alpha=0.3, axis='y')
            ax.set_ylim(0, 1.05)
            
            # Add legend
            from matplotlib.lines import Line2D
            legend_elements = [
                Line2D([0], [0], color='red', linewidth=2, label='Median'),
                Line2D([0], [0], color='green', linewidth=2, linestyle='--', label='Mean')
            ]
            ax.legend(handles=legend_elements, loc='lower right', fontsize=11)
            
            plt.tight_layout()
            self.figures.append(fig)
            return fig
            
        except Exception as e:
            self.logger.error(f"Failed to create CV boxplot: {e}")
            return None
    
    def create_overfitting_analysis_plot(self, cv_results: CVResults) -> plt.Figure:
        """Create visualization showing overfitting indicators."""
        if not cv_results.fold_metrics:
            return None
        
        try:
            analysis = cv_results.get_overfitting_analysis()
            
            fig = plt.figure(figsize=(16, 10))
            fig.canvas.manager.set_window_title('Overfitting Analysis')
            gs = GridSpec(2, 2, figure=fig, hspace=0.3, wspace=0.3)
            
            # 1. Coefficient of Variation plot
            ax1 = fig.add_subplot(gs[0, 0])
            if analysis['variance']:
                metrics = list(analysis['variance'].keys())
                cv_values = list(analysis['variance'].values())
                
                colors = ['red' if cv > 0.1 else 'orange' if cv > 0.05 else 'green' 
                         for cv in cv_values]
                
                bars = ax1.barh(metrics, cv_values, color=colors, alpha=0.7, edgecolor='black')
                ax1.set_xlabel('Coefficient of Variation', fontsize=11, fontweight='bold')
                ax1.set_title('Metric Variance Across Folds\n(Green=Low, Orange=Moderate, Red=High)', 
                            fontsize=12, fontweight='bold')
                ax1.axvline(x=0.05, color='orange', linestyle='--', alpha=0.5, label='5% threshold')
                ax1.axvline(x=0.1, color='red', linestyle='--', alpha=0.5, label='10% threshold')
                ax1.legend(fontsize=9)
                ax1.grid(True, alpha=0.3, axis='x')
            
            # 2. Consistency ratings
            ax2 = fig.add_subplot(gs[0, 1])
            if analysis['consistency']:
                metrics = list(analysis['consistency'].keys())
                consistency = list(analysis['consistency'].values())
                
                consistency_map = {'Good': 3, 'Moderate': 2, 'Poor': 1}
                numeric_consistency = [consistency_map.get(c, 0) for c in consistency]
                colors = ['green' if c == 'Good' else 'orange' if c == 'Moderate' else 'red' 
                         for c in consistency]
                
                bars = ax2.barh(metrics, numeric_consistency, color=colors, alpha=0.7, edgecolor='black')
                ax2.set_xlabel('Consistency Level', fontsize=11, fontweight='bold')
                ax2.set_xticks([1, 2, 3])
                ax2.set_xticklabels(['Poor', 'Moderate', 'Good'])
                ax2.set_title('Cross-Validation Consistency', fontsize=12, fontweight='bold')
                ax2.grid(True, alpha=0.3, axis='x')
            
            # 3. Mean vs Std comparison
            ax3 = fig.add_subplot(gs[1, 0])
            metrics = ['accuracy', 'auc', 'sensitivity', 'specificity', 'precision', 'f1']
            means = [cv_results.mean_metrics.get(m, 0) for m in metrics]
            stds = [cv_results.std_metrics.get(m, 0) for m in metrics]
            
            x = np.arange(len(metrics))
            width = 0.35
            
            ax3.bar(x - width/2, means, width, label='Mean', alpha=0.8, color='steelblue')
            ax3.bar(x + width/2, stds, width, label='Std Dev', alpha=0.8, color='coral')
            
            ax3.set_ylabel('Value', fontsize=11, fontweight='bold')
            ax3.set_xlabel('Metric', fontsize=11, fontweight='bold')
            ax3.set_title('Mean Performance vs Variability', fontsize=12, fontweight='bold')
            ax3.set_xticks(x)
            ax3.set_xticklabels([m.capitalize() for m in metrics], rotation=45, ha='right')
            ax3.legend(fontsize=10)
            ax3.grid(True, alpha=0.3, axis='y')
            
            # 4. Text summary of overfitting indicators
            ax4 = fig.add_subplot(gs[1, 1])
            ax4.axis('off')
            
            summary_text = " OVERFITTING ANALYSIS SUMMARY\n\n"
            
            if analysis['overfitting_indicators']:
                summary_text += " Potential Issues Detected:\n"
                for indicator in analysis['overfitting_indicators'][:5]:  # Show top 5
                    summary_text += f"  â€¢ {indicator}\n"
            else:
                summary_text += " No major overfitting indicators detected\n"
            
            summary_text += f"\n Overall Assessment:\n"
            good_consistency = sum(1 for c in analysis['consistency'].values() if c == 'Good')
            total_consistency = len(analysis['consistency'])
            
            if good_consistency == total_consistency:
                summary_text += "  â€¢ Excellent model stability\n"
                summary_text += "  â€¢ Low variance across folds\n"
                summary_text += "  â€¢ Model generalizes well\n"
            elif good_consistency >= total_consistency * 0.5:
                summary_text += "  â€¢ Moderate model stability\n"
                summary_text += "  â€¢ Some variance present\n"
                summary_text += "  â€¢ Consider regularization\n"
            else:
                summary_text += "  â€¢ High variance detected\n"
                summary_text += "  â€¢ Possible overfitting\n"
                summary_text += "  â€¢ Review model complexity\n"
            
            ax4.text(0.1, 0.9, summary_text, transform=ax4.transAxes,
                    fontsize=11, verticalalignment='top', family='monospace',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
            
            plt.tight_layout()
            self.figures.append(fig)
            return fig
            
        except Exception as e:
            self.logger.error(f"Failed to create overfitting analysis plot: {e}")
            return None
    
    def create_feature_importance_plot(self, feature_names: List[str], 
                                     importances: np.ndarray, 
                                     top_k: int = 15) -> plt.Figure:
        """Create feature importance plot in a window."""
        if len(importances) == 0:
            return None
        
        try:
            top_k = min(top_k, len(importances))
            indices = np.argsort(importances)[-top_k:]
            top_features = [feature_names[i] if i < len(feature_names) else f'Feature_{i}' 
                           for i in indices]
            top_importances = importances[indices]
            
            fig, ax = plt.subplots(figsize=(12, max(8, top_k * 0.4)))
            fig.canvas.manager.set_window_title('Feature Importance')
            
            colors = plt.cm.viridis(np.linspace(0, 1, len(top_importances)))
            bars = ax.barh(range(len(top_features)), top_importances, color=colors)
            
            ax.set_yticks(range(len(top_features)))
            ax.set_yticklabels(top_features, fontsize=10)
            ax.set_xlabel('Feature Importance', fontsize=12, fontweight='bold')
            ax.set_ylabel('Features', fontsize=12, fontweight='bold')
            ax.set_title(f'Top {top_k} Feature Importances', fontsize=14, fontweight='bold')
            ax.grid(True, alpha=0.3, axis='x')
            
            # Add value labels on bars
            for i, (bar, val) in enumerate(zip(bars, top_importances)):
                ax.text(val, i, f' {val:.4f}', va='center', fontsize=9)
            
            plt.tight_layout()
            self.figures.append(fig)
            return fig
            
        except Exception as e:
            self.logger.error(f"Failed to create feature importance plot: {e}")
            return None
    
    def create_confusion_matrix_plot(self, cm: np.ndarray) -> plt.Figure:
        """Create confusion matrix plot in a window."""
        try:
            fig, ax = plt.subplots(figsize=(10, 8))
            fig.canvas.manager.set_window_title('Confusion Matrix')
            
            # Normalize
            cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            
            # Create heatmap
            im = ax.imshow(cm_normalized, interpolation='nearest', cmap=plt.cm.Blues)
            ax.figure.colorbar(im, ax=ax)
            
            # Labels
            classes = ['Benign', 'Malignant']
            ax.set(xticks=np.arange(cm.shape[1]),
                   yticks=np.arange(cm.shape[0]),
                   xticklabels=classes, yticklabels=classes,
                   ylabel='True Label',
                   xlabel='Predicted Label')
            
            ax.set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold', pad=20)
            
            # Rotate the tick labels
            plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
            
            # Add text annotations
            thresh = cm_normalized.max() / 2.
            for i in range(cm.shape[0]):
                for j in range(cm.shape[1]):
                    ax.text(j, i, f'{cm[i, j]}\n({cm_normalized[i, j]:.2%})',
                           ha="center", va="center",
                           color="white" if cm_normalized[i, j] > thresh else "black",
                           fontsize=12, fontweight='bold')
            
            plt.tight_layout()
            self.figures.append(fig)
            return fig
            
        except Exception as e:
            self.logger.error(f"Failed to create confusion matrix plot: {e}")
            return None
    
    def create_roc_curve_plot(self, fpr: np.ndarray, tpr: np.ndarray, 
                            auc_score: float) -> plt.Figure:
        """Create ROC curve plot in a window."""
        try:
            fig, ax = plt.subplots(figsize=(10, 8))
            fig.canvas.manager.set_window_title('ROC Curve')
            
            # ROC curve
            ax.plot(fpr, tpr, color='darkorange', lw=3, 
                   label=f'ROC curve (AUC = {auc_score:.4f})')
            
            # Diagonal line
            ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', 
                   label='Random Classifier')
            
            ax.set_xlim([-0.05, 1.05])
            ax.set_ylim([-0.05, 1.05])
            ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='bold')
            ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='bold')
            ax.set_title('Receiver Operating Characteristic (ROC) Curve', 
                        fontsize=14, fontweight='bold')
            ax.legend(loc="lower right", fontsize=11)
            ax.grid(True, alpha=0.3)
            
            plt.tight_layout()
            self.figures.append(fig)
            return fig
            
        except Exception as e:
            self.logger.error(f"Failed to create ROC curve plot: {e}")
            return None
    
    def create_metrics_bar_chart(self, metrics: Dict) -> plt.Figure:
        """Create bar chart for model metrics."""
        try:
            categories = ['Accuracy', 'AUC', 'Sensitivity', 'Specificity', 'Precision', 'F1 Score']
            values = [
                metrics.get('accuracy', 0),
                metrics.get('auc', 0),
                metrics.get('sensitivity', 0),
                metrics.get('specificity', 0),
                metrics.get('precision', 0),
                metrics.get('f1', 0)
            ]
            
            # Create figure
            fig, ax = plt.subplots(figsize=(12, 8))
            fig.canvas.manager.set_window_title('Performance Metrics')
            
            # Create color map based on values
            colors = plt.cm.RdYlGn(np.array(values))  # Red-Yellow-Green colormap
            
            # Create bars
            bars = ax.bar(categories, values, color=colors, edgecolor='black', linewidth=1.5, alpha=0.8)
            
            # Customize plot
            ax.set_ylabel('Score', fontsize=13, fontweight='bold')
            ax.set_xlabel('Metric', fontsize=13, fontweight='bold')
            ax.set_title('Model Performance Metrics', fontsize=15, fontweight='bold', pad=20)
            ax.set_ylim(0, 1.05)
            
            # Add grid
            ax.grid(True, alpha=0.3, axis='y', linestyle='--')
            ax.set_axisbelow(True)
            
            # Add value labels on top of bars
            for bar, value in zip(bars, values):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                       f'{value:.4f}',
                       ha='center', va='bottom', fontsize=11, fontweight='bold')
            
            # Add horizontal reference lines
            ax.axhline(y=0.8, color='gray', linestyle='--', linewidth=1, alpha=0.5, label='0.8 threshold')
            ax.axhline(y=0.9, color='gray', linestyle='--', linewidth=1, alpha=0.5, label='0.9 threshold')
            
            # Rotate x-axis labels if needed
            plt.setp(ax.get_xticklabels(), rotation=0, ha="center", fontsize=11)
            
            # Add legend
            ax.legend(loc='lower right', fontsize=10)
            
            plt.tight_layout()
            self.figures.append(fig)
            return fig
            
        except Exception as e:
            self.logger.error(f"Failed to create metrics bar chart: {e}")
            return None
    
    def save_and_display_plot(self, fig: plt.Figure, filepath: Path, plot_name: str) -> None:
        """Save plot and display in window."""
        if fig is None:
            return
        
        try:
            # Save plot
            if self.config.save_plots:
                png_path = filepath.with_suffix('.png')
                fig.savefig(str(png_path), dpi=self.config.plot_dpi, bbox_inches='tight')
                self.logger.debug(f"   Saved: {png_path.name}")
            
            # Display plot - windows will show when show_all_plots() is called
            if self.config.display_plots:
                self.logger.debug(f"   Prepared: {plot_name}")
                
        except Exception as e:
            self.logger.warning(f"Failed to save/display plot {filepath.name}: {e}")
    
    def show_all_plots(self):
        """Display all created plots in separate windows."""
        if self.config.display_plots and self.figures:
            self.logger.info(f"\nOpening {len(self.figures)} plot windows...")
            self.logger.info("   Close plot windows to continue...")
            plt.show(block=self.config.block_on_plot)

# =============================================================================
# MAIN TRAINER CLASS WITH CROSS-VALIDATION
# =============================================================================

class TabNetTrainer:
    """Comprehensive TabNet trainer with cross-validation capabilities."""
    
    def __init__(self, model_config: TabNetConfig, data_config: DataConfig):
        self.config = model_config
        self.data_config = data_config
        self.logger = Logger("TabNetTrainer")
        self.device = None
        self.model = None
        self.feature_names = []
        self.results = {}
        self.cv_results = None
        
        self.data_processor = DataProcessor(data_config, self.logger)
        self.metrics_computer = MetricsComputer()
        self.plot_manager = WindowPlotManager(model_config, self.logger)
        
        self._setup_environment()
    
    def _setup_environment(self):
        """Setup device and computational environment."""
        set_random_seeds(self.config.random_seed)
        self.logger.info(f" Random seeds set to {self.config.random_seed}")
        
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            torch.backends.cudnn.deterministic = self.config.deterministic
            torch.backends.cudnn.benchmark = not self.config.deterministic
            gpu_name = torch.cuda.get_device_name(0)
            gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
            self.logger.info(f" Using GPU: {gpu_name} ({gpu_memory:.1f} GB)")
        else:
            self.device = torch.device("cpu")
            self.logger.info("Using CPU")
            
            self.config.batch_size = min(self.config.batch_size, 64)
            self.config.virtual_batch_size = min(self.config.virtual_batch_size, 32)
            self.logger.info(f"   Adjusted batch sizes for CPU: batch={self.config.batch_size}, virtual={self.config.virtual_batch_size}")
    
    def _create_model(self) -> TabNetClassifier:
        """Create and configure TabNet model."""
        if not _HAS_TABNET:
            raise RuntimeError("pytorch-tabnet is not available. Install with: pip install pytorch-tabnet")
        
        model = TabNetClassifier(
            n_d=self.config.n_d,
            n_a=self.config.n_a,
            n_steps=self.config.n_steps,
            gamma=self.config.gamma,
            lambda_sparse=self.config.lambda_sparse,
            momentum=self.config.momentum,
            n_independent=self.config.n_independent,
            n_shared=self.config.n_shared,
            optimizer_fn=torch.optim.Adam,
            optimizer_params={"lr": self.config.lr, "weight_decay": 1e-5},
            scheduler_fn=torch.optim.lr_scheduler.StepLR,
            scheduler_params={"step_size": 10, "gamma": 0.9},
            device_name=str(self.device),
            verbose=0,  # Reduce verbosity for CV
            seed=self.config.random_seed
        )
        
        return model
    
    def train_single_fold(self, X_train: np.ndarray, X_val: np.ndarray, 
                         y_train: np.ndarray, y_val: np.ndarray,
                         fold_idx: Optional[int] = None) -> Tuple[TabNetClassifier, Dict]:
        """Train model on a single fold."""
        model = self._create_model()
        
        fold_str = f"Fold {fold_idx + 1}" if fold_idx is not None else "Single"
        
        try:
            with timer(self.logger, f"{fold_str} training"):
                model.fit(
                    X_train, y_train,
                    eval_set=[(X_val, y_val)],
                    eval_name=["val"],
                    eval_metric=["auc", "accuracy"],
                    max_epochs=self.config.max_epochs,
                    patience=self.config.patience,
                    batch_size=self.config.batch_size,
                    virtual_batch_size=self.config.virtual_batch_size,
                    drop_last=False,
                    num_workers=0,
                )
            
            history = model.history or {}
            self.logger.info(f"   {fold_str} completed at epoch {model.best_epoch}")
            
            return model, history
            
        except Exception as e:
            self.logger.error(f"   {fold_str} training failed: {e}")
            raise
    
    def run_cross_validation(self, X: np.ndarray, y: np.ndarray) -> CVResults:
        """Perform k-fold cross-validation."""
        with self.logger.log_section(f"{self.config.n_folds}-FOLD CROSS-VALIDATION"):
            cv_results = CVResults()
            
            # Create stratified k-fold splitter
            skf = StratifiedKFold(
                n_splits=self.config.n_folds,
                shuffle=True,
                random_state=self.config.cv_random_state
            )
            
            # Iterate through folds
            for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X, y)):
                self.logger.info(f"\n{'='*60}")
                self.logger.info(f"  FOLD {fold_idx + 1}/{self.config.n_folds}")
                self.logger.info(f"{'='*60}")
                
                # Split data
                X_train_fold, X_val_fold = X[train_idx], X[val_idx]
                y_train_fold, y_val_fold = y[train_idx], y[val_idx]
                
                # Scale features
                scaler = StandardScaler()
                X_train_fold = scaler.fit_transform(X_train_fold).astype(np.float32)
                X_val_fold = scaler.transform(X_val_fold).astype(np.float32)
                
                self.logger.info(f"   Train: {len(y_train_fold)} samples")
                self.logger.info(f"   Val:   {len(y_val_fold)} samples")
                
                # Train model
                model, history = self.train_single_fold(
                    X_train_fold, X_val_fold, y_train_fold, y_val_fold, fold_idx
                )
                
                # Evaluate on validation set
                y_pred = model.predict(X_val_fold)
                y_proba = model.predict_proba(X_val_fold)[:, 1]
                
                metrics = self.metrics_computer.compute_all_metrics(
                    y_val_fold, y_pred, y_proba
                )
                
                # Log metrics
                self.logger.info(f"\n   Fold {fold_idx + 1} Results:")
                self.logger.info(f"   Accuracy: {metrics['accuracy']:.4f}")
                self.logger.info(f"   AUC:      {metrics['auc']:.4f}")
                self.logger.info(f"   F1:       {metrics['f1']:.4f}")
                
                # Store results
                cv_results.add_fold_result(fold_idx, metrics, history)
                
                # Clear memory
                del model, X_train_fold, X_val_fold
                clear_memory()
            
            # Compute statistics
            cv_results.compute_statistics()
            
            # Log summary
            self.logger.info("\n" + "="*70)
            self.logger.info("  CROSS-VALIDATION SUMMARY")
            self.logger.info("="*70)
            
            for metric, mean_val in cv_results.mean_metrics.items():
                std_val = cv_results.std_metrics.get(metric, 0)
                self.logger.info(f"   {metric.capitalize():12s}: {mean_val:.4f} Â± {std_val:.4f}")
            
            self.logger.info(f"\n   Best Fold: {cv_results.best_fold + 1} "
                           f"(AUC: {cv_results.fold_metrics[cv_results.best_fold]['auc']:.4f})")
            
            # Overfitting analysis
            analysis = cv_results.get_overfitting_analysis()
            
            self.logger.info("\n OVERFITTING ANALYSIS:")
            if analysis['overfitting_indicators']:
                self.logger.warning("Potential issues detected:")
                for indicator in analysis['overfitting_indicators']:
                    self.logger.warning(f"      â€¢ {indicator}")
            else:
                self.logger.info("No major overfitting indicators detected")
            
            self.logger.info("\n   Consistency Ratings:")
            for metric, rating in analysis['consistency'].items():
                emoji = "" if rating == "Good" else "" if rating == "Moderate" else ""
                self.logger.info(f"      {emoji} {metric.capitalize()}: {rating}")
            
            return cv_results
    
    def train(self, X_train: np.ndarray, X_val: np.ndarray, 
             y_train: np.ndarray, y_val: np.ndarray) -> Dict[str, Any]:
        """Train the TabNet model (standard or with CV)."""
        with self.logger.log_section("MODEL TRAINING"):
            
            if self.config.use_cross_validation:
                # Combine train and val for CV
                X_combined = np.vstack([X_train, X_val])
                y_combined = np.concatenate([y_train, y_val])
                
                self.logger.info("Using Cross-Validation mode")
                self.cv_results = self.run_cross_validation(X_combined, y_combined)
                
                # Train final model on all training data
                self.logger.info("\n" + "="*70)
                self.logger.info("  TRAINING FINAL MODEL ON FULL TRAINING SET")
                self.logger.info("="*70)
                
                self.model, history = self.train_single_fold(X_train, X_val, y_train, y_val)
                
                return history
            else:
                self.logger.info(" Using standard train/val split")
                self.model = self._create_model()
                
                self.logger.info("Training Configuration:")
                self.logger.info(f"   Max epochs: {self.config.max_epochs}")
                self.logger.info(f"   Patience: {self.config.patience}")
                self.logger.info(f"   Batch size: {self.config.batch_size}")
                self.logger.info(f"   Virtual batch size: {self.config.virtual_batch_size}")
                
                try:
                    with timer(self.logger, "Model training"):
                        self.model.fit(
                            X_train, y_train,
                            eval_set=[(X_val, y_val)],
                            eval_name=["val"],
                            eval_metric=["auc", "accuracy"],
                            max_epochs=self.config.max_epochs,
                            patience=self.config.patience,
                            batch_size=self.config.batch_size,
                            virtual_batch_size=self.config.virtual_batch_size,
                            drop_last=False,
                            num_workers=0,
                        )
                    
                    self.logger.info(f"Training completed at epoch {self.model.best_epoch}")
                    return self.model.history or {}
                    
                except Exception as e:
                    self.logger.error(f"Training failed: {e}")
                    raise
    
    def evaluate(self, X_test: np.ndarray, y_test: np.ndarray) -> Dict[str, Any]:
        """Evaluate the trained model."""
        if self.model is None:
            raise RuntimeError("Model must be trained before evaluation")
        
        with self.logger.log_section("MODEL EVALUATION ON TEST SET"):
            with timer(self.logger, "Model evaluation"):
                y_pred = self.model.predict(X_test)
                y_proba = self.model.predict_proba(X_test)[:, 1]
                metrics = self.metrics_computer.compute_all_metrics(y_test, y_pred, y_proba)
            
            self.logger.info(self.metrics_computer.format_metrics_summary(metrics, "TEST SET "))
            
            self.logger.info("\n CLASSIFICATION REPORT:")
            print(classification_report(y_test, y_pred, 
                                      target_names=['Benign', 'Malignant'], 
                                      digits=4))
            
            return metrics
    
    def create_visualizations(self, metrics: Dict, history: Dict, 
                            X_train: np.ndarray, y_train: np.ndarray, 
                            save_dir: Path) -> None:
        """Create all visualizations including CV plots."""
        with self.logger.log_section("GENERATING VISUALIZATIONS"):
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            
            # Cross-validation plots (if CV was used)
            if self.cv_results is not None:
                self.logger.info("Creating Cross-Validation plots...")
                
                # Fold comparison
                cv_comparison_fig = self.plot_manager.create_cv_fold_comparison_plot(self.cv_results)
                if cv_comparison_fig:
                    self.plot_manager.save_and_display_plot(
                        cv_comparison_fig,
                        save_dir / f"cv_fold_comparison_{timestamp}",
                        "CV Fold Comparison"
                    )
                
                # Boxplot
                cv_boxplot_fig = self.plot_manager.create_cv_boxplot(self.cv_results)
                if cv_boxplot_fig:
                    self.plot_manager.save_and_display_plot(
                        cv_boxplot_fig,
                        save_dir / f"cv_metrics_boxplot_{timestamp}",
                        "CV Metrics Boxplot"
                    )
                
                # Overfitting analysis
                overfitting_fig = self.plot_manager.create_overfitting_analysis_plot(self.cv_results)
                if overfitting_fig:
                    self.plot_manager.save_and_display_plot(
                        overfitting_fig,
                        save_dir / f"overfitting_analysis_{timestamp}",
                        "Overfitting Analysis"
                    )
            
            # Standard plots
            self.logger.info("Creating standard training plots...")
            
            # 1. Training History
            history_fig = self.plot_manager.create_training_history_plot(history)
            if history_fig:
                self.plot_manager.save_and_display_plot(
                    history_fig, 
                    save_dir / f"training_history_{timestamp}",
                    "Training History"
                )
            
            # 2. Feature Importance
            if self.model and hasattr(self.model, 'feature_importances_'):
                feature_fig = self.plot_manager.create_feature_importance_plot(
                    self.feature_names, self.model.feature_importances_
                )
                if feature_fig:
                    self.plot_manager.save_and_display_plot(
                        feature_fig, 
                        save_dir / f"feature_importance_{timestamp}",
                        "Feature Importance"
                    )
            
            # 3. Confusion Matrix
            if 'confusion_matrix' in metrics:
                cm_fig = self.plot_manager.create_confusion_matrix_plot(metrics['confusion_matrix'])
                if cm_fig:
                    self.plot_manager.save_and_display_plot(
                        cm_fig, 
                        save_dir / f"confusion_matrix_{timestamp}",
                        "Confusion Matrix"
                    )
            
            # 4. ROC Curve
            if 'roc_curve' in metrics:
                fpr, tpr = metrics['roc_curve']
                roc_fig = self.plot_manager.create_roc_curve_plot(fpr, tpr, metrics['auc'])
                if roc_fig:
                    self.plot_manager.save_and_display_plot(
                        roc_fig, 
                        save_dir / f"roc_curve_{timestamp}",
                        "ROC Curve"
                    )
            
            # 5. Metrics Bar Chart
            metrics_fig = self.plot_manager.create_metrics_bar_chart(metrics)
            if metrics_fig:
                self.plot_manager.save_and_display_plot(
                    metrics_fig, 
                    save_dir / f"metrics_barchart_{timestamp}",
                    "Metrics Bar Chart"
                )
            
            # Show all plots in windows
            self.plot_manager.show_all_plots()
            
            self.logger.info("All visualizations generated and displayed")
    
    def save_experiment(self, metrics: Dict, history: Dict, save_dir: Path) -> None:
        """Save complete experiment results including CV results."""
        save_dir.mkdir(parents=True, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        experiment_data = {
            'timestamp': datetime.now().isoformat(),
            'config': {
                'model': asdict(self.config),
                'data': asdict(self.data_config)
            },
            'metrics': {k: v for k, v in metrics.items() 
                       if k not in ['confusion_matrix', 'roc_curve']},
            'confusion_matrix': metrics.get('confusion_matrix', np.array([])).tolist(),
            'feature_names': self.feature_names,
            'training_info': {
                'best_epoch': getattr(self.model, 'best_epoch', -1) if self.model else -1,
                'total_params': sum(p.numel() for p in self.model.network.parameters()) if self.model else 0,
                'trainable_params': sum(p.numel() for p in self.model.network.parameters() 
                                      if p.requires_grad) if self.model else 0
            }
        }
        
        # Add CV results if available
        if self.cv_results is not None:
            experiment_data['cross_validation'] = {
                'n_folds': self.config.n_folds,
                'mean_metrics': self.cv_results.mean_metrics,
                'std_metrics': self.cv_results.std_metrics,
                'best_fold': self.cv_results.best_fold,
                'fold_metrics': [{k: v for k, v in fold.items() 
                                 if k not in ['confusion_matrix', 'roc_curve']} 
                                for fold in self.cv_results.fold_metrics],
                'overfitting_analysis': self.cv_results.get_overfitting_analysis()
            }
        
        results_path = save_dir / f"experiment_results_{timestamp}.json"
        with open(results_path, 'w') as f:
            json.dump(experiment_data, f, indent=2)
        
        if self.model:
            model_path = save_dir / f"tabnet_model_{timestamp}.pth"
            torch.save(self.model.network.state_dict(), model_path)
            self.logger.info(f" Model saved: {model_path.name}")
        
        self.logger.info(f"Results saved: {results_path.name}")
    
    def run_experiment(self) -> Dict[str, Any]:
        """Run complete training experiment with optional CV."""
        with self.logger.log_section("TABNET TRAINING EXPERIMENT"):
            try:
                # 1. Load and preprocess data
                with self.logger.log_section("DATA LOADING & PREPROCESSING"):
                    X_train, X_val, X_test, y_train, y_val, y_test = \
                        self.data_processor.load_and_preprocess(
                            Path(self.config.dataset_path),
                            self.config.random_seed
                        )
                    self.feature_names = self.data_processor.feature_names
                
                # 2. Train model (with or without CV)
                history = self.train(X_train, X_val, y_train, y_val)
                
                # 3. Evaluate model on test set
                metrics = self.evaluate(X_test, y_test)
                
                # 4. Create visualizations
                results_dir = Path(self.config.results_dir)
                self.create_visualizations(metrics, history, X_train, y_train, results_dir)
                
                # 5. Save experiment
                self.save_experiment(metrics, history, results_dir)
                
                self.results = {
                    'metrics': metrics,
                    'history': history,
                    'feature_names': self.feature_names,
                    'cv_results': self.cv_results
                }
                
                with self.logger.log_section("EXPERIMENT COMPLETED SUCCESSFULLY"):
                    if self.cv_results:
                        self.logger.info(" Cross-Validation Results:")
                        self.logger.info(f"   Mean AUC:      {self.cv_results.mean_metrics.get('auc', 0):.4f} "
                                       f"Â± {self.cv_results.std_metrics.get('auc', 0):.4f}")
                        self.logger.info(f"   Mean Accuracy: {self.cv_results.mean_metrics.get('accuracy', 0):.4f} "
                                       f"Â± {self.cv_results.std_metrics.get('accuracy', 0):.4f}")
                    
                    self.logger.info(f" Final Test Accuracy: {metrics['accuracy']:.4f}")
                    self.logger.info(f" Final Test AUC-ROC:  {metrics['auc']:.4f}")
                    self.logger.info(f" Results saved in: {results_dir}/")
                
                return self.results
                
            except Exception as e:
                self.logger.error(f" Experiment failed: {e}")
                raise
            finally:
                clear_memory()

# =============================================================================
# MAIN EXECUTION
# =============================================================================

def run_tabnet_training(dataset_path: str = "C:/Users/awwal/Desktop/MLEA_experiments/data.csv",
                       results_dir: str = "tabnet_results_cv",
                       use_cross_validation: bool = True,
                       n_folds: int = 5,
                       display_plots: bool = True,
                       block_on_plot: bool = False) -> Dict[str, Any]:
    """
    Run TabNet training experiment with optional cross-validation.
    
    Args:
        dataset_path: Path to CSV dataset
        results_dir: Directory to save results
        use_cross_validation: Whether to use k-fold cross-validation
        n_folds: Number of folds for cross-validation
        display_plots: Whether to display plots in windows
        block_on_plot: Whether to pause execution until windows are closed
        
    Returns:
        Dictionary containing metrics, history, feature names, and CV results
    """
    print("\n" + "=" * 70)
    print("  TABNET TRAINING WITH CROSS-VALIDATION & OVERFITTING DETECTION")
    print("=" * 70)
    
    if not _HAS_TABNET:
        print("\n pytorch-tabnet is required but not installed.")
        print("   Install with: pip install pytorch-tabnet")
        raise ImportError("pytorch-tabnet not available")
    
    # Configuration
    model_config = TabNetConfig(
        dataset_path=dataset_path,
        results_dir=results_dir,
        max_epochs=50,
        patience=10,
        n_d=16,
        n_a=16,
        n_steps=5,
        gamma=1.5,
        lambda_sparse=0.001,
        lr=0.02,
        use_cross_validation=use_cross_validation,
        n_folds=n_folds,
        display_plots=display_plots,
        save_plots=True,
        plot_dpi=100,
        block_on_plot=block_on_plot
    )
    
    data_config = DataConfig(
        target_column="diagnosis",
        test_size=0.2,
        validation_size=0.1,
        scale_features=True,
        handle_missing="mean"
    )
    
    # Create and run trainer
    trainer = TabNetTrainer(model_config, data_config)
    results = trainer.run_experiment()
    
    return results


def main():
    """Main function for command-line execution."""
    try:
        results = run_tabnet_training()
        return 0
    except Exception as e:
        print(f"\n Experiment failed: {e}")
        import traceback
        traceback.print_exc()
        return 1


if __name__ == "__main__":
    try:
        get_ipython()
        print("\nðŸ”¬ Running in Jupyter/IPython environment")
        results = run_tabnet_training(block_on_plot=False)
        print("\n Experiment completed! Results stored in 'results' variable")
        
        # Display CV summary if available
        if results.get('cv_results'):
            print("\nCross-Validation Summary:")
            cv = results['cv_results']
            for metric in ['accuracy', 'auc', 'f1']:
                if metric in cv.mean_metrics:
                    print(f"   {metric.capitalize()}: {cv.mean_metrics[metric]:.4f} "
                          f"Â± {cv.std_metrics[metric]:.4f}")
    except NameError:
        sys.exit(main())


ðŸ”¬ Running in Jupyter/IPython environment

  TABNET TRAINING WITH CROSS-VALIDATION & OVERFITTING DETECTION
2025-11-12 17:38:04 [INFO] TabNetTrainer:  Random seeds set to 42
2025-11-12 17:38:04 [INFO] TabNetTrainer: Using CPU
2025-11-12 17:38:04 [INFO] TabNetTrainer:    Adjusted batch sizes for CPU: batch=64, virtual=32
2025-11-12 17:38:04 [INFO] TabNetTrainer:  TABNET TRAINING EXPERIMENT
2025-11-12 17:38:04 [INFO] TabNetTrainer:  DATA LOADING & PREPROCESSING
2025-11-12 17:38:04 [INFO] TabNetTrainer: Starting: Loading dataset from C:\Users\awwal\Desktop\MLEA_experiments\data.csv
2025-11-12 17:38:04 [INFO] TabNetTrainer:    Shape: (569, 33)
2025-11-12 17:38:04 [INFO] TabNetTrainer: âœ“ Completed: Loading dataset from C:\Users\awwal\Desktop\MLEA_experiments\data.csv (0.02s)
2025-11-12 17:38:04 [INFO] TabNetTrainer:    Class distribution: Benign=357 (62.7%), Malignant=212 (37.3%)
2025-11-12 17:38:04 [INFO] TabNetTrainer:    Features: 30
2025-11-12 17:38:04 [INFO] TabNetTrainer:    Train