In [9]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (accuracy_score, f1_score, confusion_matrix,
                            precision_score, recall_score, roc_curve, auc,
                            matthews_corrcoef, balanced_accuracy_score, cohen_kappa_score)
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from datetime import datetime
import time
import os

os.makedirs('visualization', exist_ok=True)
os.makedirs('models', exist_ok=True)

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

BATCH_SIZE = 16
EPOCHS = 150
LEARNING_RATE = 0.0005
WEIGHT_DECAY = 5e-4
SELECTED_FEATURES = 150
NUM_BANDS = 5

CHANNEL_COORDS = {
    1: (0.75, 0.15), 2: (0.65, 0.2), 3: (0.55, 0.25), 4: (0.45, 0.3), 5: (0.35, 0.2),
    6: (0.25, 0.2), 7: (0.4, 0.35), 8: (0.5, 0.15), 9: (0.45, 0.25), 10: (0.35, 0.15),
    11: (0.3, 0.25), 12: (0.4, 0.3), 13: (0.25, 0.25), 14: (0.35, 0.3), 15: (0.3, 0.35),
    16: (0.35, 0.4), 17: (0.2, 0.2), 18: (0.25, 0.3), 19: (0.2, 0.3), 20: (0.3, 0.4),
    21: (0.4, 0.45), 22: (0.25, 0.4), 23: (0.15, 0.35), 24: (0.2, 0.4), 25: (0.25, 0.45),
    26: (0.3, 0.5), 27: (0.2, 0.5), 28: (0.25, 0.55), 29: (0.15, 0.55), 30: (0.2, 0.6),
    31: (0.35, 0.55), 32: (0.15, 0.65), 33: (0.3, 0.65), 34: (0.45, 0.35), 35: (0.4, 0.7),
    36: (0.35, 0.6), 37: (0.4, 0.65), 38: (0.35, 0.65), 39: (0.45, 0.7), 40: (0.45, 0.6),
    41: (0.55, 0.45), 42: (0.55, 0.6), 43: (0.6, 0.65), 44: (0.65, 0.65), 45: (0.6, 0.55),
    46: (0.55, 0.5), 47: (0.65, 0.6), 48: (0.55, 0.55), 49: (0.6, 0.4), 50: (0.6, 0.35),
    51: (0.55, 0.4), 52: (0.65, 0.45), 53: (0.65, 0.35), 54: (0.55, 0.35), 55: (0.7, 0.4),
    56: (0.65, 0.3), 57: (0.6, 0.3), 58: (0.7, 0.25), 59: (0.65, 0.25), 60: (0.55, 0.3),
    61: (0.75, 0.2), 62: (0.5, 0.1), 63: (0.45, 0.05), 64: (0.15, 0.2)
}

class MetricsTracker:
    def __init__(self, fold, epochs):
        self.fold = fold
        self.epochs = epochs
        self.train_losses = []
        self.val_losses = []
        self.train_accs = []
        self.val_accs = []
        self.train_f1s = []
        self.val_f1s = []
        self.train_precisions = []
        self.val_precisions = []
        self.train_recalls = []
        self.val_recalls = []
        self.epoch_nums = []
        self.overfitting_ratios = []
        self.mccs = []
        self.balanced_accs = []
        self.kappas = []

    def update(self, epoch, train_metrics, val_metrics):
        self.epoch_nums.append(epoch)
        self.train_losses.append(train_metrics['loss'])
        self.val_losses.append(val_metrics['loss'])
        self.train_accs.append(train_metrics['accuracy'])
        self.val_accs.append(val_metrics['accuracy'])
        self.train_f1s.append(train_metrics['f1'])
        self.val_f1s.append(val_metrics['f1'])
        self.train_precisions.append(train_metrics['precision'])
        self.val_precisions.append(val_metrics['precision'])
        self.train_recalls.append(train_metrics['recall'])
        self.val_recalls.append(val_metrics['recall'])
        self.mccs.append(val_metrics['mcc'])
        self.balanced_accs.append(val_metrics['balanced_acc'])
        self.kappas.append(val_metrics['kappa'])

        ratio = val_metrics['loss'] / train_metrics['loss'] if train_metrics['loss'] > 0 else 1.0
        self.overfitting_ratios.append(ratio)

    def plot_metrics(self):
        fig, axes = plt.subplots(3, 2, figsize=(16, 18))

        axes[0, 0].plot(self.epoch_nums, self.train_losses, 'b-', label='Training Loss')
        axes[0, 0].plot(self.epoch_nums, self.val_losses, 'r-', label='Validation Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_title(f'Fold {self.fold}: Loss Curves')
        axes[0, 0].legend()
        axes[0, 0].grid(True)

        axes[0, 1].plot(self.epoch_nums, self.train_accs, 'b-', label='Training Accuracy')
        axes[0, 1].plot(self.epoch_nums, self.val_accs, 'r-', label='Validation Accuracy')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Accuracy')
        axes[0, 1].set_title(f'Fold {self.fold}: Accuracy Curves')
        axes[0, 1].legend()
        axes[0, 1].grid(True)

        axes[1, 0].plot(self.epoch_nums, self.train_f1s, 'b-', label='Training F1')
        axes[1, 0].plot(self.epoch_nums, self.val_f1s, 'r-', label='Validation F1')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('F1 Score')
        axes[1, 0].set_title(f'Fold {self.fold}: F1 Score Curves')
        axes[1, 0].legend()
        axes[1, 0].grid(True)

        axes[1, 1].plot(self.epoch_nums, self.train_precisions, 'b-', label='Training Precision')
        axes[1, 1].plot(self.epoch_nums, self.val_precisions, 'r-', label='Validation Precision')
        axes[1, 1].plot(self.epoch_nums, self.train_recalls, 'g-', label='Training Recall')
        axes[1, 1].plot(self.epoch_nums, self.val_recalls, 'y-', label='Validation Recall')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Score')
        axes[1, 1].set_title(f'Fold {self.fold}: Precision & Recall')
        axes[1, 1].legend()
        axes[1, 1].grid(True)

        axes[2, 0].plot(self.epoch_nums, self.overfitting_ratios, 'b-', label='Overfitting Ratio')
        axes[2, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal Ratio = 1.0')
        axes[2, 0].axhline(y=1.5, color='orange', linestyle='--', label='Warning Threshold = 1.5')
        axes[2, 0].set_xlabel('Epoch')
        axes[2, 0].set_ylabel('Ratio')
        axes[2, 0].set_title(f'Fold {self.fold}: Overfitting Ratio')
        axes[2, 0].legend()
        axes[2, 0].grid(True)

        axes[2, 1].plot(self.epoch_nums, self.mccs, 'b-', label='MCC')
        axes[2, 1].plot(self.epoch_nums, self.balanced_accs, 'r-', label='Balanced Accuracy')
        axes[2, 1].plot(self.epoch_nums, self.kappas, 'g-', label='Cohen\'s Kappa')
        axes[2, 1].set_xlabel('Epoch')
        axes[2, 1].set_ylabel('Score')
        axes[2, 1].set_title(f'Fold {self.fold}: Advanced Metrics')
        axes[2, 1].legend()
        axes[2, 1].grid(True)

        plt.tight_layout()
        plt.savefig(f'visualization/fold_{self.fold}_metrics.png')
        plt.close()

    def analyze_overfitting(self):
        final_ratio = self.overfitting_ratios[-1] if self.overfitting_ratios else 1.0
        avg_final_train_acc = np.mean(self.train_accs[-5:]) if len(self.train_accs) >= 5 else np.mean(self.train_accs)
        avg_final_val_acc = np.mean(self.val_accs[-5:]) if len(self.val_accs) >= 5 else np.mean(self.val_accs)
        generalization_gap = avg_final_train_acc - avg_final_val_acc

        overfitting_score = 0
        if final_ratio > 1.5:
            overfitting_score += 1
        if generalization_gap > 0.05:
            overfitting_score += 1
        if avg_final_train_acc > 0.98 and avg_final_val_acc < 0.95:
            overfitting_score += 1

        if overfitting_score == 0:
            return "No signs of overfitting detected"
        elif overfitting_score == 1:
            return "Mild signs of overfitting detected"
        elif overfitting_score == 2:
            return "Moderate overfitting detected"
        else:
            return "Severe overfitting detected"

def calculate_metrics(model, data_loader, criterion, device):
    model.eval()
    all_preds = []
    all_targets = []
    all_probs = []
    total_loss = 0.0

    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()

            probs = outputs.cpu().numpy()
            preds = (outputs > 0.5).float()

            all_probs.extend(probs)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

    avg_loss = total_loss / len(data_loader)
    accuracy = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds)
    precision = precision_score(all_targets, all_preds)
    recall = recall_score(all_targets, all_preds)
    mcc = matthews_corrcoef(all_targets, all_preds)
    balanced_acc = balanced_accuracy_score(all_targets, all_preds)
    kappa = cohen_kappa_score(all_targets, all_preds)

    metrics = {
        'loss': avg_loss,
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'mcc': mcc,
        'balanced_acc': balanced_acc,
        'kappa': kappa,
        'predictions': all_preds,
        'targets': all_targets,
        'probabilities': all_probs
    }

    return metrics

def plot_cross_val_metrics(fold_monitors):
    max_epoch = max([max(monitor.epoch_nums) for monitor in fold_monitors]) if fold_monitors else 0
    valid_epochs = []
    metrics_avg = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'train_f1': [], 'val_f1': [],
        'train_precision': [], 'val_precision': [],
        'train_recall': [], 'val_recall': [],
        'mcc': [], 'balanced_acc': [], 'kappa': [],
        'overfitting_ratio': []
    }

    counts = np.zeros(max_epoch)
    for i in range(1, max_epoch + 1):
        count = 0
        temp_metrics = {k: 0 for k in metrics_avg.keys()}

        for monitor in fold_monitors:
            if i in monitor.epoch_nums:
                idx = monitor.epoch_nums.index(i)
                temp_metrics['train_loss'] += monitor.train_losses[idx]
                temp_metrics['val_loss'] += monitor.val_losses[idx]
                temp_metrics['train_acc'] += monitor.train_accs[idx]
                temp_metrics['val_acc'] += monitor.val_accs[idx]
                temp_metrics['train_f1'] += monitor.train_f1s[idx]
                temp_metrics['val_f1'] += monitor.val_f1s[idx]
                temp_metrics['train_precision'] += monitor.train_precisions[idx]
                temp_metrics['val_precision'] += monitor.val_precisions[idx]
                temp_metrics['train_recall'] += monitor.train_recalls[idx]
                temp_metrics['val_recall'] += monitor.val_recalls[idx]
                temp_metrics['mcc'] += monitor.mccs[idx]
                temp_metrics['balanced_acc'] += monitor.balanced_accs[idx]
                temp_metrics['kappa'] += monitor.kappas[idx]
                temp_metrics['overfitting_ratio'] += monitor.overfitting_ratios[idx]
                count += 1

        if count > 0:
            valid_epochs.append(i)
            for k in metrics_avg.keys():
                metrics_avg[k].append(temp_metrics[k] / count)
            counts[i-1] = count

    if valid_epochs:
        fig, axes = plt.subplots(3, 2, figsize=(16, 18))

        axes[0, 0].plot(valid_epochs, metrics_avg['train_loss'], 'b-', label='Avg Training Loss')
        axes[0, 0].plot(valid_epochs, metrics_avg['val_loss'], 'r-', label='Avg Validation Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_title('Average Loss Across Folds')
        axes[0, 0].legend()
        axes[0, 0].grid(True)

        axes[0, 1].plot(valid_epochs, metrics_avg['train_acc'], 'b-', label='Avg Training Accuracy')
        axes[0, 1].plot(valid_epochs, metrics_avg['val_acc'], 'r-', label='Avg Validation Accuracy')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Accuracy')
        axes[0, 1].set_title('Average Accuracy Across Folds')
        axes[0, 1].legend()
        axes[0, 1].grid(True)

        axes[1, 0].plot(valid_epochs, metrics_avg['train_f1'], 'b-', label='Avg Training F1')
        axes[1, 0].plot(valid_epochs, metrics_avg['val_f1'], 'r-', label='Avg Validation F1')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('F1 Score')
        axes[1, 0].set_title('Average F1 Score Across Folds')
        axes[1, 0].legend()
        axes[1, 0].grid(True)

        axes[1, 1].plot(valid_epochs, metrics_avg['train_precision'], 'b-', label='Avg Training Precision')
        axes[1, 1].plot(valid_epochs, metrics_avg['val_precision'], 'r-', label='Avg Validation Precision')
        axes[1, 1].plot(valid_epochs, metrics_avg['train_recall'], 'g-', label='Avg Training Recall')
        axes[1, 1].plot(valid_epochs, metrics_avg['val_recall'], 'y-', label='Avg Validation Recall')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Score')
        axes[1, 1].set_title('Average Precision & Recall Across Folds')
        axes[1, 1].legend()
        axes[1, 1].grid(True)

        axes[2, 0].plot(valid_epochs, metrics_avg['overfitting_ratio'], 'b-', label='Avg Overfitting Ratio')
        axes[2, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal Ratio = 1.0')
        axes[2, 0].axhline(y=1.5, color='orange', linestyle='--', label='Warning Threshold = 1.5')
        axes[2, 0].set_xlabel('Epoch')
        axes[2, 0].set_ylabel('Ratio')
        axes[2, 0].set_title('Average Overfitting Ratio Across Folds')
        axes[2, 0].legend()
        axes[2, 0].grid(True)

        axes[2, 1].plot(valid_epochs, metrics_avg['mcc'], 'b-', label='Avg MCC')
        axes[2, 1].plot(valid_epochs, metrics_avg['balanced_acc'], 'r-', label='Avg Balanced Accuracy')
        axes[2, 1].plot(valid_epochs, metrics_avg['kappa'], 'g-', label='Avg Cohen\'s Kappa')
        axes[2, 1].set_xlabel('Epoch')
        axes[2, 1].set_ylabel('Score')
        axes[2, 1].set_title('Average Advanced Metrics Across Folds')
        axes[2, 1].legend()
        axes[2, 1].grid(True)

        plt.tight_layout()
        plt.savefig('visualization/cross_validation_metrics.png')
        plt.close()

def plot_confusion_matrices(fold_cms, final_cm, class_names=['Negative', 'Positive']):
    n_folds = len(fold_cms)
    fig, axes = plt.subplots(1, n_folds + 1, figsize=(5 * (n_folds + 1), 5))

    for i, cm in enumerate(fold_cms):
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        sns.heatmap(cm_normalized, annot=cm, fmt="d", cmap="Blues", ax=axes[i],
                   xticklabels=class_names, yticklabels=class_names, annot_kws={"size": 16})
        axes[i].set_title(f'Fold {i+1} Confusion Matrix')
        axes[i].set_xlabel('Predicted Label')
        axes[i].set_ylabel('True Label')

    final_cm_normalized = final_cm.astype('float') / final_cm.sum(axis=1)[:, np.newaxis]
    sns.heatmap(final_cm_normalized, annot=final_cm, fmt="d", cmap="Blues", ax=axes[-1],
               xticklabels=class_names, yticklabels=class_names, annot_kws={"size": 16})
    axes[-1].set_title('Ensemble Model Confusion Matrix')
    axes[-1].set_xlabel('Predicted Label')
    axes[-1].set_ylabel('True Label')

    plt.tight_layout()
    plt.savefig('visualization/confusion_matrices.png')
    plt.close()

def visualize_feature_importance(X, selected_indices, feature_scores, band_names=['Alpha', 'Beta', 'Delta', 'Theta', 'Gamma']):
    num_features = len(feature_scores)

    plt.figure(figsize=(12, 8))
    plt.bar(range(num_features), feature_scores)
    plt.xlabel('Feature Index')
    plt.ylabel('Importance Score')
    plt.title('Feature Importance (All Features)')
    plt.tight_layout()
    plt.savefig('visualization/feature_importance_all.png')
    plt.close()

    plt.figure(figsize=(12, 8))
    plt.bar(range(len(selected_indices)), feature_scores[selected_indices])
    plt.xlabel('Selected Feature Index')
    plt.ylabel('Importance Score')
    plt.title('Feature Importance (Selected Features)')
    plt.tight_layout()
    plt.savefig('visualization/feature_importance_selected.png')
    plt.close()

    selected_features_df = pd.DataFrame({
        'Feature_Index': selected_indices,
        'Feature_Score': feature_scores[selected_indices]
    })

    selected_features_df = selected_features_df.sort_values('Feature_Score', ascending=False)

    if len(selected_indices) >= NUM_BANDS:
        band_indices = np.array([i % NUM_BANDS for i in range(len(selected_indices))])

        band_importance = np.zeros(NUM_BANDS)
        for i in range(NUM_BANDS):
            band_mask = (band_indices == i)
            if np.any(band_mask):
                band_importance[i] = np.mean(feature_scores[selected_indices][band_mask])

        plt.figure(figsize=(10, 6))
        plt.bar(band_names, band_importance)
        plt.xlabel('Frequency Band')
        plt.ylabel('Average Importance')
        plt.title('Frequency Band Importance')
        plt.tight_layout()
        plt.savefig('visualization/band_importance.png')
        plt.close()

    return selected_features_df

def visualize_dataset(X, y, feature_names=None):
    n_samples, n_features = X.shape

    data_stats = pd.DataFrame({
        'Feature': feature_names if feature_names else [f'Feature_{i}' for i in range(n_features)],
        'Mean': np.mean(X, axis=0),
        'Std': np.std(X, axis=0),
        'Min': np.min(X, axis=0),
        'Max': np.max(X, axis=0)
    })

    class_distribution = pd.Series(y).value_counts().reset_index()
    class_distribution.columns = ['Class', 'Count']

    plt.figure(figsize=(10, 6))
    sns.countplot(x=y)
    plt.title('Class Distribution')
    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.savefig('visualization/class_distribution.png')
    plt.close()

    if n_features > 1:
        pca = PCA(n_components=2)
        pca_result = pca.fit_transform(X)

        plt.figure(figsize=(10, 8))
        scatter = plt.scatter(pca_result[:, 0], pca_result[:, 1], c=y, cmap='viridis', alpha=0.8)
        plt.colorbar(scatter)
        plt.title('PCA Visualization of Dataset')
        plt.xlabel('Principal Component 1')
        plt.ylabel('Principal Component 2')
        plt.savefig('visualization/pca_visualization.png')
        plt.close()

        tsne = TSNE(n_components=2, random_state=SEED)
        tsne_result = tsne.fit_transform(X)

        plt.figure(figsize=(10, 8))
        scatter = plt.scatter(tsne_result[:, 0], tsne_result[:, 1], c=y, cmap='viridis', alpha=0.8)
        plt.colorbar(scatter)
        plt.title('t-SNE Visualization of Dataset')
        plt.xlabel('t-SNE Component 1')
        plt.ylabel('t-SNE Component 2')
        plt.savefig('visualization/tsne_visualization.png')
        plt.close()

    return data_stats, class_distribution

def plot_learning_curve(train_sizes, train_scores, val_scores):
    plt.figure(figsize=(10, 6))
    plt.plot(train_sizes, train_scores, 'o-', color='blue', label='Training Accuracy')
    plt.plot(train_sizes, val_scores, 'o-', color='red', label='Validation Accuracy')
    plt.xlabel('Training Set Size')
    plt.ylabel('Accuracy')
    plt.title('Learning Curve: Accuracy vs Training Set Size')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('visualization/learning_curve.png')
    plt.close()

def plot_roc_curve(y_true, y_prob, fold=None):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(8, 8))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'Receiver Operating Characteristic{" - Fold " + str(fold) if fold else ""}')
    plt.legend(loc="lower right")
    plt.savefig(f'visualization/roc_curve{"_fold_" + str(fold) if fold else ""}.png')
    plt.close()

    return roc_auc

def visualize_electrode_map(importance_values=None, title='EEG Electrode Map'):
    fig, ax = plt.subplots(figsize=(10, 10))

    for ch_idx, (x, y) in CHANNEL_COORDS.items():
        x_scaled = x * 100
        y_scaled = y * 100

        if importance_values is not None and ch_idx-1 < len(importance_values):
            importance = importance_values[ch_idx-1]
            size = 20 + 100 * importance
            color = plt.cm.viridis(importance)
        else:
            size = 50
            color = 'blue'

        ax.scatter(x_scaled, y_scaled, s=size, color=color, alpha=0.7)
        ax.text(x_scaled, y_scaled, str(ch_idx), ha='center', va='center', fontsize=8, color='white')

    if importance_values is not None:
        sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis, norm=plt.Normalize(vmin=0, vmax=1))
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax)
        cbar.set_label('Importance')

    ax.set_xlim(0, 100)
    ax.set_ylim(0, 100)
    ax.set_title(title)
    ax.set_aspect('equal')
    ax.set_axis_off()

    plt.tight_layout()
    plt.savefig('visualization/electrode_map.png')
    plt.close()

class EEGDataset(Dataset):
    def __init__(self, features, targets):
        self.features = features
        self.targets = targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

class EEGChannelAttention(nn.Module):
    def __init__(self, num_channels):
        super(EEGChannelAttention, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(num_channels, num_channels // 2),
            nn.ReLU(),
            nn.Linear(num_channels // 2, num_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        attention_weights = self.attention(x.mean(dim=1))
        return x * attention_weights.unsqueeze(1), attention_weights

class EEGBandAttention(nn.Module):
    def __init__(self, num_bands):
        super(EEGBandAttention, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(num_bands, num_bands),
            nn.ReLU(),
            nn.Linear(num_bands, num_bands),
            nn.Sigmoid()
        )

    def forward(self, x):
        attention_weights = self.attention(x.mean(dim=2))
        return x * attention_weights.unsqueeze(2), attention_weights

class SpatialGNN(nn.Module):
    def __init__(self, num_channels, adjacency_matrix):
        super(SpatialGNN, self).__init__()
        self.adjacency_matrix = adjacency_matrix.to(device)
        self.weight = nn.Parameter(torch.Tensor(num_channels, num_channels))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, x):
        batch_size, num_bands, _ = x.shape
        output = torch.zeros_like(x)

        for b in range(num_bands):
            features = x[:, b, :]
            neighbor_agg = torch.matmul(self.adjacency_matrix, features.unsqueeze(2))
            transformed = torch.matmul(neighbor_agg.squeeze(2), self.weight)
            output[:, b, :] = transformed

        return output + x

class EEGClassifier(nn.Module):
    def __init__(self, num_bands, num_channels, adjacency_matrix):
        super(EEGClassifier, self).__init__()

        self.spatial_gnn = SpatialGNN(num_channels, adjacency_matrix)
        self.channel_attention = EEGChannelAttention(num_channels)
        self.band_attention = EEGBandAttention(num_bands)

        self.feature_extraction = nn.Sequential(
            nn.Linear(num_bands * num_channels, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.4)
        )

        self.classifier = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

        self.attention_weights = {
            'channel': None,
            'band': None
        }

    def forward(self, x):
        x = self.spatial_gnn(x)
        x, channel_weights = self.channel_attention(x)
        x, band_weights = self.band_attention(x)

        self.attention_weights['channel'] = channel_weights
        self.attention_weights['band'] = band_weights

        x = x.reshape(x.size(0), -1)
        x = self.feature_extraction(x)
        x = self.classifier(x)

        return x.squeeze(1)

    def get_attention_weights(self):
        return self.attention_weights

class SimplifiedEEGClassifier(nn.Module):
    def __init__(self, num_bands, num_channels, adjacency_matrix):
        super(SimplifiedEEGClassifier, self).__init__()

        self.spatial_gnn = SpatialGNN(num_channels, adjacency_matrix)
        self.channel_attention = EEGChannelAttention(num_channels)

        self.feature_extraction = nn.Sequential(
            nn.Linear(num_bands * num_channels, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.4)
        )

        self.classifier = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.Dropout(0.3),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

        self.attention_weights = {
            'channel': None
        }

    def forward(self, x):
        x = self.spatial_gnn(x)
        x, channel_weights = self.channel_attention(x)

        self.attention_weights['channel'] = channel_weights

        x = x.reshape(x.size(0), -1)
        x = self.feature_extraction(x)
        x = self.classifier(x)

        return x.squeeze(1)

    def get_attention_weights(self):
        return self.attention_weights

def create_adjacency_matrix(channel_coords, num_channels_per_band, threshold=0.2):
    adj_matrix = torch.eye(num_channels_per_band)

    for i in range(num_channels_per_band):
        for j in range(num_channels_per_band):
            if i != j and abs(i - j) <= 2:
                adj_matrix[i, j] = 0.5

    row_sum = adj_matrix.sum(dim=1, keepdim=True)
    row_sum[row_sum == 0] = 1.0
    adj_matrix = adj_matrix / row_sum

    return adj_matrix

def augment_data(X, y, noise_levels=[0.03, 0.05, 0.07], num_augmentations=3):
    X_aug, y_aug = X.copy(), y.copy()

    for noise_level in noise_levels:
        for _ in range(num_augmentations):
            noise = np.random.normal(0, noise_level, X.shape)
            X_noisy = X + noise
            X_aug = np.vstack((X_aug, X_noisy))
            y_aug = np.append(y_aug, y)

    for scale in [0.9, 1.1]:
        X_scaled = X * scale
        X_aug = np.vstack((X_aug, X_scaled))
        y_aug = np.append(y_aug, y)

    n_features = X.shape[1]
    for _ in range(num_augmentations):
        X_channel_drop = X.copy()
        drop_idx = np.random.choice(n_features, size=int(0.05 * n_features), replace=False)
        X_channel_drop[:, drop_idx] = 0
        X_aug = np.vstack((X_aug, X_channel_drop))
        y_aug = np.append(y_aug, y)

    features_per_band = n_features // 5
    for _ in range(num_augmentations):
        X_perm = X.copy()
        for band in range(5):
            band_start = band * features_per_band
            band_end = (band + 1) * features_per_band
            perm_indices = np.random.permutation(features_per_band)[:int(0.2 * features_per_band)]
            for idx in perm_indices:
                if band_start + idx < n_features:
                    source_idx = band_start + idx
                    target_idx = band_start + np.random.randint(0, features_per_band)
                    if target_idx < n_features:
                        X_perm[:, [source_idx, target_idx]] = X_perm[:, [target_idx, source_idx]]
        X_aug = np.vstack((X_aug, X_perm))
        y_aug = np.append(y_aug, y)

    print(f"Original dataset: {X.shape[0]} samples, Augmented: {X_aug.shape[0]} samples")
    return X_aug, y_aug

def reshape_data_to_bands_channels(X, feature_indices, num_bands=5):
    n_samples = X.shape[0]
    n_selected_features = X.shape[1]
    channels_per_band = n_selected_features // num_bands
    reshaped = np.zeros((n_samples, num_bands, channels_per_band))

    feature_idx = 0
    for b in range(num_bands):
        for c in range(channels_per_band):
            if feature_idx < n_selected_features:
                reshaped[:, b, c] = X[:, feature_idx]
                feature_idx += 1

    return reshaped

def create_pytorch_ensemble(X_scaled, y, selected_indices, adjacency_matrix, device, n_splits=5):
    print("\nCreating ensemble of PyTorch models...")
    ensemble_preds = np.zeros(len(y))
    ensemble_probs = np.zeros(len(y))

    X_reshaped = reshape_data_to_bands_channels(X_scaled, selected_indices)
    full_dataset = EEGDataset(torch.FloatTensor(X_reshaped), torch.FloatTensor(y))
    full_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=False)

    for fold in range(n_splits):
        model = EEGClassifier(
            num_bands=5,
            num_channels=len(selected_indices) // 5,
            adjacency_matrix=adjacency_matrix
        ).to(device)
        model.load_state_dict(torch.load(f'models/best_model_fold{fold+1}.pt'))
        model.eval()

        fold_preds = []
        fold_probs = []
        with torch.no_grad():
            for inputs, _ in full_loader:
                inputs = inputs.to(device)
                outputs = model(inputs)
                preds = (outputs > 0.5).float()
                fold_preds.extend(preds.cpu().numpy())
                fold_probs.extend(outputs.cpu().numpy())

        ensemble_preds += np.array(fold_preds)
        ensemble_probs += np.array(fold_probs)

    simplified_model = SimplifiedEEGClassifier(
        num_bands=5,
        num_channels=len(selected_indices) // 5,
        adjacency_matrix=adjacency_matrix
    ).to(device)

    train_dataset = EEGDataset(torch.FloatTensor(X_reshaped), torch.FloatTensor(y))
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    criterion = nn.BCELoss()
    optimizer = optim.AdamW(simplified_model.parameters(), lr=0.0005, weight_decay=5e-4)

    for epoch in range(50):
        simplified_model.train()
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = simplified_model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

    torch.save(simplified_model.state_dict(), 'models/simplified_model.pt')

    simplified_model.eval()
    simple_preds = []
    simple_probs = []
    with torch.no_grad():
        for inputs, _ in full_loader:
            inputs = inputs.to(device)
            outputs = simplified_model(inputs)
            preds = (outputs > 0.5).float()
            simple_preds.extend(preds.cpu().numpy())
            simple_probs.extend(outputs.cpu().numpy())

    ensemble_preds += np.array(simple_preds)
    ensemble_probs += np.array(simple_probs)

    ensemble_probs /= (n_splits + 1)
    ensemble_preds_binary = (ensemble_probs > 0.5).astype(float)

    accuracy = accuracy_score(y, ensemble_preds_binary)
    f1 = f1_score(y, ensemble_preds_binary)
    precision = precision_score(y, ensemble_preds_binary)
    recall = recall_score(y, ensemble_preds_binary)
    mcc = matthews_corrcoef(y, ensemble_preds_binary)
    balanced_acc = balanced_accuracy_score(y, ensemble_preds_binary)
    kappa = cohen_kappa_score(y, ensemble_preds_binary)
    conf_matrix = confusion_matrix(y, ensemble_preds_binary)

    auc_score = plot_roc_curve(y, ensemble_probs)

    print("\nEnsemble Model Results:")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"MCC: {mcc:.4f}")
    print(f"Balanced Accuracy: {balanced_acc:.4f}")
    print(f"Cohen's Kappa: {kappa:.4f}")
    print(f"AUC: {auc_score:.4f}")
    print(f"Confusion Matrix:\n{conf_matrix}")

    metrics = {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'mcc': mcc,
        'balanced_acc': balanced_acc,
        'kappa': kappa,
        'auc': auc_score
    }

    return metrics, conf_matrix, ensemble_probs

def visualize_feature_activation(model, X, fold_num=None):
    model.eval()

    with torch.no_grad():
        if isinstance(X, np.ndarray):
            X_tensor = torch.FloatTensor(X).to(device)
        else:
            X_tensor = X.to(device)

        outputs = model(X_tensor)

        attention_weights = model.get_attention_weights()

        if 'channel' in attention_weights:
            channel_weights = attention_weights['channel'].mean(dim=0).cpu().numpy()

            plt.figure(figsize=(12, 6))
            plt.bar(range(len(channel_weights)), channel_weights)
            plt.xlabel('Channel Index')
            plt.ylabel('Attention Weight')
            plt.title(f'Channel Attention Weights{" - Fold " + str(fold_num) if fold_num else ""}')
            plt.tight_layout()
            plt.savefig(f'visualization/channel_attention{"_fold_" + str(fold_num) if fold_num else ""}.png')
            plt.close()

            visualize_electrode_map(
                channel_weights,
                title=f'Electrode Importance Map{" - Fold " + str(fold_num) if fold_num else ""}'
            )

        if 'band' in attention_weights:
            band_weights = attention_weights['band'].mean(dim=0).cpu().numpy()

            plt.figure(figsize=(10, 6))
            plt.bar(['Alpha', 'Beta', 'Delta', 'Theta', 'Gamma'], band_weights)
            plt.xlabel('Frequency Band')
            plt.ylabel('Attention Weight')
            plt.title(f'Band Attention Weights{" - Fold " + str(fold_num) if fold_num else ""}')
            plt.tight_layout()
            plt.savefig(f'visualization/band_attention{"_fold_" + str(fold_num) if fold_num else ""}.png')
            plt.close()

def perform_learning_curve_analysis(X, y, selected_indices, adjacency_matrix, device, n_splits=5):
    X_reshaped = reshape_data_to_bands_channels(X, selected_indices)

    train_sizes = np.linspace(0.2, 1.0, 5)

    train_sizes_abs = []
    train_scores = []
    val_scores = []

    kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)

    for train_idx, val_idx in kf.split(X_reshaped, y):
        X_train_full, X_val = X_reshaped[train_idx], X_reshaped[val_idx]
        y_train_full, y_val = y[train_idx], y[val_idx]

        fold_train_sizes = []
        fold_train_scores = []
        fold_val_scores = []

        for train_size in train_sizes:
            n_train_samples = int(train_size * len(X_train_full))
            fold_train_sizes.append(n_train_samples)

            indices = np.random.choice(len(X_train_full), n_train_samples, replace=False)
            X_train_subset = X_train_full[indices]
            y_train_subset = y_train_full[indices]

            train_dataset = EEGDataset(torch.FloatTensor(X_train_subset), torch.FloatTensor(y_train_subset))
            val_dataset = EEGDataset(torch.FloatTensor(X_val), torch.FloatTensor(y_val))

            train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

            model = EEGClassifier(
                num_bands=5,
                num_channels=len(selected_indices) // 5,
                adjacency_matrix=adjacency_matrix
            ).to(device)

            criterion = nn.BCELoss()
            optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

            for epoch in range(30):
                model.train()
                for inputs, targets in train_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    optimizer.zero_grad()
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    loss.backward()
                    optimizer.step()

            model.eval()
            train_metrics = calculate_metrics(model, train_loader, criterion, device)
            val_metrics = calculate_metrics(model, val_loader, criterion, device)

            fold_train_scores.append(train_metrics['accuracy'])
            fold_val_scores.append(val_metrics['accuracy'])

        train_sizes_abs.append(fold_train_sizes)
        train_scores.append(fold_train_scores)
        val_scores.append(fold_val_scores)

    train_sizes_abs = np.mean(train_sizes_abs, axis=0)
    train_scores = np.mean(train_scores, axis=0)
    val_scores = np.mean(val_scores, axis=0)

    plot_learning_curve(train_sizes_abs, train_scores, val_scores)

    return train_sizes_abs, train_scores, val_scores

def train_and_evaluate():
    results_dir = f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    os.makedirs(results_dir, exist_ok=True)

    print("Loading data...")
    df = pd.read_csv('/eeg data.csv')

    print("\nDataset Information:")
    print(f"Shape: {df.shape}")
    print(f"Columns: {df.columns.tolist()}")

    data_summary = df.describe()
    data_summary.to_csv(f'{results_dir}/data_summary.csv')

    X = df.iloc[:, 1:-1].values
    y = df.iloc[:, -1].values
    y = y.astype(np.float32)

    feature_names = df.columns[1:-1].tolist()

    data_stats, class_distribution = visualize_dataset(X, y, feature_names)
    data_stats.to_csv(f'{results_dir}/feature_stats.csv')
    class_distribution.to_csv(f'{results_dir}/class_distribution.csv')

    print("\nClass Distribution:")
    print(class_distribution)

    print("\nPerforming feature selection...")
    selector = SelectKBest(f_classif, k=SELECTED_FEATURES)
    X_selected = selector.fit_transform(X, y)
    selected_indices = selector.get_support(indices=True)

    feature_scores = selector.scores_
    selected_features_df = visualize_feature_importance(X, selected_indices, feature_scores)
    selected_features_df.to_csv(f'{results_dir}/selected_features.csv')

    print(f"Top 10 selected features:")
    print(selected_features_df.head(10))

    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_selected)

    channels_per_band = len(selected_indices) // 5

    adjacency_matrix = create_adjacency_matrix(CHANNEL_COORDS, channels_per_band)

    X_aug, y_aug = augment_data(X_scaled, y)

    n_splits = 5
    kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
    fold_results = []

    fold_monitors = []
    fold_confusion_matrices = []
    all_fold_metrics = []
    all_val_probs = []
    all_val_targets = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(X_aug, y_aug)):
        print(f"\nTraining fold {fold+1}/{n_splits}")

        monitor = MetricsTracker(fold+1, EPOCHS)
        fold_monitors.append(monitor)

        X_train, X_val = X_aug[train_idx], X_aug[val_idx]
        y_train, y_val = y_aug[train_idx], y_aug[val_idx]

        X_train_reshaped = reshape_data_to_bands_channels(X_train, selected_indices)
        X_val_reshaped = reshape_data_to_bands_channels(X_val, selected_indices)

        train_dataset = EEGDataset(torch.FloatTensor(X_train_reshaped), torch.FloatTensor(y_train))
        val_dataset = EEGDataset(torch.FloatTensor(X_val_reshaped), torch.FloatTensor(y_val))

        class_counts = np.bincount(y_train.astype(int))
        class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)
        sample_weights = class_weights[y_train.astype(int)]
        sampler = WeightedRandomSampler(weights=sample_weights,
                                        num_samples=len(sample_weights),
                                        replacement=True)

        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

        model = EEGClassifier(
            num_bands=5,
            num_channels=len(selected_indices) // 5,
            adjacency_matrix=adjacency_matrix
        ).to(device)

        criterion = nn.BCELoss()
        optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=30,
            eta_min=1e-6
        )

        best_val_loss = float('inf')
        early_stop_counter = 0
        early_stop_patience = 10
        min_epochs = 30
        loss_improvement_threshold = 0.001

        for epoch in range(EPOCHS):
            model.train()
            train_loss = 0.0

            for inputs, targets in train_loader:
                inputs, targets = inputs.to(device), targets.to(device)

                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                train_loss += loss.item()

            scheduler.step()

            if (epoch + 1) % 10 == 0 or epoch == EPOCHS - 1:
                train_metrics = calculate_metrics(model, train_loader, criterion, device)
                val_metrics = calculate_metrics(model, val_loader, criterion, device)

                monitor.update(epoch+1, train_metrics, val_metrics)

                overfitting_ratio = val_metrics['loss'] / train_metrics['loss'] if train_metrics['loss'] > 0 else 1.0

                print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {train_metrics['loss']:.4f}, "
                      f"Val Loss: {val_metrics['loss']:.4f}, "
                      f"Train Acc: {train_metrics['accuracy']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}, "
                      f"Train F1: {train_metrics['f1']:.4f}, Val F1: {val_metrics['f1']:.4f}, "
                      f"Overfitting ratio: {overfitting_ratio:.4f}")

            val_loss = calculate_metrics(model, val_loader, criterion, device)['loss']

            if epoch >= min_epochs:
                if val_loss < best_val_loss - loss_improvement_threshold:
                    best_val_loss = val_loss
                    early_stop_counter = 0
                    torch.save(model.state_dict(), f'models/best_model_fold{fold+1}.pt')
                else:
                    early_stop_counter += 1
                    if early_stop_counter >= early_stop_patience:
                        print(f"Early stopping at epoch {epoch+1}")
                        break
            elif val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), f'models/best_model_fold{fold+1}.pt')

        model.load_state_dict(torch.load(f'models/best_model_fold{fold+1}.pt'))
        model.eval()

        final_val_metrics = calculate_metrics(model, val_loader, criterion, device)
        all_fold_metrics.append(final_val_metrics)
        all_val_probs.extend(final_val_metrics['probabilities'])
        all_val_targets.extend(final_val_metrics['targets'])

        final_train_metrics = calculate_metrics(model, train_loader, criterion, device)

        conf_matrix = confusion_matrix(final_val_metrics['targets'], final_val_metrics['predictions'])
        fold_confusion_matrices.append(conf_matrix)

        auc_score = plot_roc_curve(final_val_metrics['targets'], final_val_metrics['probabilities'], fold+1)

        visualize_feature_activation(model, X_val_reshaped, fold+1)

        print(f"\nFold {fold+1} Results:")
        print(f"Accuracy: {final_val_metrics['accuracy']:.4f}")
        print(f"F1 Score: {final_val_metrics['f1']:.4f}")
        print(f"Precision: {final_val_metrics['precision']:.4f}")
        print(f"Recall: {final_val_metrics['recall']:.4f}")
        print(f"MCC: {final_val_metrics['mcc']:.4f}")
        print(f"Balanced Accuracy: {final_val_metrics['balanced_acc']:.4f}")
        print(f"Cohen's Kappa: {final_val_metrics['kappa']:.4f}")
        print(f"AUC: {auc_score:.4f}")
        print(f"Confusion Matrix:\n{conf_matrix}")

        overfitting_ratio = final_val_metrics['loss'] / final_train_metrics['loss']
        print(f"Overfitting ratio (val_loss/train_loss): {overfitting_ratio:.4f}")
        print(f"Target ratio should be close to 1.0. Ratio > 1.5 indicates overfitting")

        monitor.plot_metrics()
        overfitting_status = monitor.analyze_overfitting()
        print(f"Overfitting analysis for fold {fold+1}: {overfitting_status}")

        fold_results.append({
            'fold': fold + 1,
            'accuracy': final_val_metrics['accuracy'],
            'f1': final_val_metrics['f1'],
            'precision': final_val_metrics['precision'],
            'recall': final_val_metrics['recall'],
            'mcc': final_val_metrics['mcc'],
            'balanced_accuracy': final_val_metrics['balanced_acc'],
            'kappa': final_val_metrics['kappa'],
            'auc': auc_score,
            'overfitting_ratio': overfitting_ratio,
            'overfitting_status': overfitting_status
        })

    overall_auc = plot_roc_curve(np.array(all_val_targets), np.array(all_val_probs))

    fold_results_df = pd.DataFrame(fold_results)
    fold_results_df.to_csv(f'{results_dir}/fold_results.csv', index=False)
    numeric_cols = fold_results_df.select_dtypes(include=['number']).columns

    print("\nOverall Results:")
    print(f"Mean Accuracy: {fold_results_df['accuracy'].mean():.4f} ± {fold_results_df['accuracy'].std():.4f}")
    print(f"Mean F1 Score: {fold_results_df['f1'].mean():.4f} ± {fold_results_df['f1'].std():.4f}")
    print(f"Mean AUC: {fold_results_df['auc'].mean():.4f} ± {fold_results_df['auc'].std():.4f}")

    status_counts = fold_results_df['overfitting_status'].value_counts()
    print("\nOverfitting Status Counts:")
    for status, count in status_counts.items():
        print(f"{status}: {count}")

    plot_cross_val_metrics(fold_monitors)

    print("\nPerforming learning curve analysis...")
    train_sizes, train_scores, val_scores = perform_learning_curve_analysis(
        X_scaled, y, selected_indices, adjacency_matrix, device)

    ensemble_metrics, ensemble_conf_matrix, ensemble_probs = create_pytorch_ensemble(
        X_scaled, y, selected_indices, adjacency_matrix, device, n_splits
    )

    plot_confusion_matrices(fold_confusion_matrices, ensemble_conf_matrix)

    print("\nOverfitting Analysis Summary:")
    for i, monitor in enumerate(fold_monitors):
        status = monitor.analyze_overfitting()
        print(f"Fold {i+1}: {status}")

    avg_final_train_acc = np.mean([m.train_accs[-1] for m in fold_monitors])
    avg_final_val_acc = np.mean([m.val_accs[-1] for m in fold_monitors])
    generalization_gap = avg_final_train_acc - avg_final_val_acc

    print(f"\nFinal Generalization Gap (Train Acc - Val Acc): {generalization_gap:.4f}")
    if generalization_gap > 0.05:
        print("Warning: Model shows signs of overfitting (Train accuracy significantly higher than validation)")
    elif avg_final_val_acc > 0.99 and avg_final_train_acc > 0.99:
        print("Warning: Perfect accuracy on both training and validation may indicate:")
        print("1. Model is memorizing the dataset (especially if the dataset is small)")
        print("2. Possible data leakage between training and validation sets")
        print("3. The task may be too easy or the dataset too simple")
        print("Consider testing on a completely separate dataset to verify generalization")
    else:
        print("No significant signs of overfitting detected based on the accuracy gap")

    avg_final_train_loss = np.mean([m.train_losses[-1] for m in fold_monitors])
    avg_final_val_loss = np.mean([m.val_losses[-1] for m in fold_monitors])
    loss_gap = avg_final_val_loss - avg_final_train_loss
    avg_overfitting_ratio = avg_final_val_loss / avg_final_train_loss if avg_final_train_loss > 0 else 1.0

    print(f"Final Loss Gap (Val Loss - Train Loss): {loss_gap:.4f}")
    print(f"Average Overfitting Ratio (Val Loss / Train Loss): {avg_overfitting_ratio:.4f}")

    if avg_overfitting_ratio > 1.5:
        print("Warning: Average overfitting ratio above 1.5 indicates overfitting")
    elif avg_overfitting_ratio > 1.2:
        print("Mild overfitting detected (ratio between 1.2 and 1.5)")
    else:
        print("Good generalization (ratio close to 1.0)")

    results = {
        'cross_val': {
            'accuracy': fold_results_df['accuracy'].mean(),
            'accuracy_std': fold_results_df['accuracy'].std(),
            'f1': fold_results_df['f1'].mean(),
            'f1_std': fold_results_df['f1'].std(),
            'precision': fold_results_df['precision'].mean(),
            'precision_std': fold_results_df['precision'].std(),
            'recall': fold_results_df['recall'].mean(),
            'recall_std': fold_results_df['recall'].std(),
            'mcc': fold_results_df['mcc'].mean(),
            'mcc_std': fold_results_df['mcc'].std(),
            'balanced_accuracy': fold_results_df['balanced_accuracy'].mean(),
            'balanced_accuracy_std': fold_results_df['balanced_accuracy'].std(),
            'kappa': fold_results_df['kappa'].mean(),
            'kappa_std': fold_results_df['kappa'].std(),
            'auc': fold_results_df['auc'].mean(),
            'auc_std': fold_results_df['auc'].std(),
            'overfitting_ratio': fold_results_df['overfitting_ratio'].mean(),
            'overfitting_ratio_std': fold_results_df['overfitting_ratio'].std()
        },
        'ensemble': ensemble_metrics
    }

    results_df = pd.DataFrame({
        'Metric': ['Accuracy', 'F1 Score', 'Precision', 'Recall', 'MCC', 'Balanced Accuracy', 'Cohen\'s Kappa', 'AUC'],
        'Cross-Validation (Mean)': [
            results['cross_val']['accuracy'],
            results['cross_val']['f1'],
            results['cross_val']['precision'],
            results['cross_val']['recall'],
            results['cross_val']['mcc'],
            results['cross_val']['balanced_accuracy'],
            results['cross_val']['kappa'],
            results['cross_val']['auc']
        ],
        'Cross-Validation (Std)': [
            results['cross_val']['accuracy_std'],
            results['cross_val']['f1_std'],
            results['cross_val']['precision_std'],
            results['cross_val']['recall_std'],
            results['cross_val']['mcc_std'],
            results['cross_val']['balanced_accuracy_std'],
            results['cross_val']['kappa_std'],
            results['cross_val']['auc_std']
        ],
        'Ensemble': [
            results['ensemble']['accuracy'],
            results['ensemble']['f1'],
            results['ensemble']['precision'],
            results['ensemble']['recall'],
            results['ensemble']['mcc'],
            results['ensemble']['balanced_acc'],
            results['ensemble']['kappa'],
            results['ensemble']['auc']
        ]
    })

    results_df.to_csv(f'{results_dir}/final_results.csv', index=False)

    print("\nFinal Results Summary:")
    print(results_df)

    return results

if __name__ == "__main__":
    start_time = time.time()
    results = train_and_evaluate()
    end_time = time.time()

    print(f"\nTotal execution time: {(end_time - start_time) / 60:.2f} minutes")

Using device: cuda
Loading data...

Dataset Information:
Shape: (40, 322)
Columns: ['Unnamed: 0', 'alpha1', 'alpha2', 'alpha3', 'alpha4', 'alpha5', 'alpha6', 'alpha7', 'alpha8', 'alpha9', 'alpha10', 'alpha11', 'alpha12', 'alpha13', 'alpha14', 'alpha15', 'alpha16', 'alpha17', 'alpha18', 'alpha19', 'alpha20', 'alpha21', 'alpha22', 'alpha23', 'alpha24', 'alpha25', 'alpha26', 'alpha27', 'alpha28', 'alpha29', 'alpha30', 'alpha31', 'alpha32', 'alpha33', 'alpha34', 'alpha35', 'alpha36', 'alpha37', 'alpha38', 'alpha39', 'alpha40', 'alpha41', 'alpha42', 'alpha43', 'alpha44', 'alpha45', 'alpha46', 'alpha47', 'alpha48', 'alpha49', 'alpha50', 'alpha51', 'alpha52', 'alpha53', 'alpha54', 'alpha55', 'alpha56', 'alpha57', 'alpha58', 'alpha59', 'alpha60', 'alpha61', 'alpha62', 'alpha63', 'alpha64', 'beta1', 'beta2', 'beta3', 'beta4', 'beta5', 'beta6', 'beta7', 'beta8', 'beta9', 'beta10', 'beta11', 'beta12', 'beta13', 'beta14', 'beta15', 'beta16', 'beta17', 'beta18', 'beta19', 'beta20', 'beta21', 'beta2


Recall is ill-defined and being set to 0.0 due to no true samples. Use `zero_division` parameter to control this behavior.


y_pred contains classes not in y_true


Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.


Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.


Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.




Creating ensemble of PyTorch models...

Ensemble Model Results:
Accuracy: 1.0000
F1 Score: 1.0000
Precision: 1.0000
Recall: 1.0000
MCC: 1.0000
Balanced Accuracy: 1.0000
Cohen's Kappa: 1.0000
AUC: 1.0000
Confusion Matrix:
[[20  0]
 [ 0 20]]

Overfitting Analysis Summary:
Fold 1: Mild signs of overfitting detected
Fold 2: Mild signs of overfitting detected
Fold 3: No signs of overfitting detected
Fold 4: Mild signs of overfitting detected
Fold 5: Moderate overfitting detected

Final Generalization Gap (Train Acc - Val Acc): 0.0198
No significant signs of overfitting detected based on the accuracy gap
Final Loss Gap (Val Loss - Train Loss): 0.0533
Average Overfitting Ratio (Val Loss / Train Loss): 4.2750

Final Results Summary:
              Metric  Cross-Validation (Mean)  Cross-Validation (Std)  \
0           Accuracy                 0.980556                0.017975   
1           F1 Score                 0.980341                0.018185   
2          Precision                 0.988652