In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from sklearn.metrics import roc_auc_score, roc_curve, auc, accuracy_score, recall_score, precision_score, confusion_matrix
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import label_binarize
from sklearn.utils import resample
from scipy.stats import sem, t
from scipy import io
from collections import defaultdict
import os
import re
import itertools
import random

In [None]:
def find_optimal_cutoff(tpr, fpr, thresholds):
    """Find the optimal cutoff point from ROC curve."""
    j_scores = tpr - fpr
    j_ordered = sorted(zip(j_scores, thresholds))
    return j_ordered[-1][1]
def bootstrap_auc(y_true, y_pred, n_bootstraps=2000, rng_seed=42):
    n_bootstraps = n_bootstraps
    rng_seed = rng_seed  
    bootstrapped_scores = []
    rng = np.random.RandomState(rng_seed)
    for i in range(n_bootstraps):
        indices = rng.randint(0, len(y_pred) - 1, len(y_pred))
        if len(np.unique(y_true[indices])) < 2:
            continue
        score = roc_auc_score(y_true[indices], y_pred[indices])
        bootstrapped_scores.append(score)
    sorted_scores = np.array(bootstrapped_scores)
    sorted_scores.sort()
    confidence_lower = sorted_scores[int(0.025 * len(sorted_scores))]
    confidence_upper = sorted_scores[int(0.975 * len(sorted_scores))]
    return confidence_lower, confidence_upper
def evaluate_model(model, data_loader):
    model.eval()
    total = 0
    correct = 0
    predicted_labels = []
    test_labels = []
    predicted_probabilities = []
    patient_ids_list = []
    with torch.no_grad():
        for inputs, label, patient_ids in data_loader:
            inputs = inputs.float().to(device)  
            label = label.long().to(device)  
            outputs = model(inputs)
            probabilities = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()
            predicted_probabilities.append(probabilities.cpu().numpy())
            predicted_labels += predicted.tolist()
            test_labels += label.tolist()
            patient_ids_list += patient_ids
    predicted_probabilities = np.concatenate(predicted_probabilities, axis=0)
    return test_labels, predicted_labels, predicted_probabilities, patient_ids_list
def find_multi_class_optimal_cutoffs(tpr, fpr, thresholds):
    i = np.arange(len(tpr))
    roc = pd.DataFrame({'tf': pd.Series(tpr-(1-fpr), index=i), 'thresholds': pd.Series(thresholds, index=i)})
    roc_t = roc.iloc[(roc.tf-0).abs().argsort()[:1]]
    return list(roc_t['thresholds'])
def get_multi_class_predicted_labels_based_on_cutoff(probabilities, cutoffs):
    return [np.argmax([p[i] > cutoffs[i] for i in range(3)]) for p in probabilities]
def plot_roc_curve(y_true, y_pred_proba, model_name):
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(3):
        fpr[i], tpr[i], _ = roc_curve(y_true == i, y_pred_proba[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    plt.figure(figsize=(8, 6))
    plt.plot(fpr[0], tpr[0], label=f'NCSE (area = {roc_auc[0]:.2f})')
    plt.plot(fpr[1], tpr[1], label=f'ME (area = {roc_auc[1]:.2f})')
    plt.plot(fpr[2], tpr[2], label=f'BI (area = {roc_auc[2]:.2f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC Curve - {model_name}')
    plt.legend(loc='lower right')
    plt.savefig(f'coh_ROC_{model_name}.eps', format='eps')
    plt.show()
def plot_confusion_matrix(y_true, y_pred, model_name):
    matrix = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(6, 5))
    sns.heatmap(matrix, annot=True, fmt='d', cmap='Blues', cbar=False, vmin=0, vmax=50)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title(f'Confusion Matrix - {model_name}')
    plt.savefig(f'coh_CM_{model_name}.eps', format='eps')
    plt.show()
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from scipy.stats import sem
def compute_corrected_metrics_from_cm(y_true, y_pred, y_pred_proba):
    cm = confusion_matrix(y_true, y_pred)
    classwise_metrics = []
    total_samples = len(y_true)
    for i in range(3):
        tp = cm[i, i]
        fn = sum(cm[i, :]) - tp
        fp = sum(cm[:, i]) - tp
        tn = total_samples - (tp + fn + fp)
        accuracy = (tp + tn) / (tp + tn + fp + fn)
        precision = tp / (tp + fp) if (tp + fp) != 0 else 0
        recall = tp / (tp + fn) if (tp + fn) != 0 else 0
        f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) != 0 else 0
        classwise_metrics.append({
            'accuracy': accuracy,
            'f1': f1,
            'precision': precision,
            'recall': recall,
            'tn': tn,
            'tp': tp,
            'fp': fp,
            'fn': fn
        })
    overall_accuracy = accuracy_score(y_true, y_pred)
    overall_f1 = f1_score(y_true, y_pred, average='macro')
    overall_precision = precision_score(y_true, y_pred, average='macro')
    overall_recall = recall_score(y_true, y_pred, average='macro')
    overall_auc = roc_auc_score(y_true, y_pred_proba, average='macro', multi_class='ovr')
    return {
        'overall': {
            'accuracy': overall_accuracy,
            'f1': overall_f1,
            'precision': overall_precision,
            'recall': overall_recall,
            'auc': overall_auc
        },
        'classwise': classwise_metrics
    }
def confidence_interval(data):
    n = len(data)
    m = mean(data)
    std_err = sem(data)
    ci = std_err * t.ppf((1 + 0.95) / 2, n - 1)
    return (m - ci, m + ci)
from sklearn.utils import resample
def bootstrap_ci(y_true, y_pred, y_pred_proba, metric_function, label=None, n_bootstrap=1000, alpha=0.05):
    """Compute the (1-alpha) confidence interval of the metric using bootstrap."""
    bootstrap_samples = np.random.choice(len(y_true), size=(n_bootstrap, len(y_true)), replace=True)
    y_true_array = np.array(y_true)
    if label is not None:  
        binary_true = (y_true_array == label).astype(int)
        stats = [metric_function(binary_true[indices], y_pred_proba[indices, label]) for indices in bootstrap_samples]
    elif y_pred is not None:  
        stats = [metric_function(y_true_array[indices], y_pred[indices]) for indices in bootstrap_samples]
    else:  
        stats = [metric_function(y_true_array[indices], y_pred_proba[indices]) for indices in bootstrap_samples]
    return (np.percentile(stats, 100 * (alpha / 2.)), np.percentile(stats, 100 * (1 - alpha / 2.)))
def bootstrap_ci_for_auc(y_true, y_pred_proba, label, n_bootstrap=1000, alpha=0.05):
    """Compute the (1-alpha) confidence interval of the AUC using bootstrap for a specific class."""
    bootstrap_samples = np.random.choice(len(y_true), size=(n_bootstrap, len(y_true)), replace=True)
    y_true_array = np.array(y_true)
    binary_true = (y_true_array == label).astype(int)
    auc_stats = [roc_auc_score(binary_true[indices], y_pred_proba[indices, label]) for indices in bootstrap_samples]
    return (np.percentile(auc_stats, 100 * (alpha / 2.)), np.percentile(auc_stats, 100 * (1 - alpha / 2.)))
def bootstrap_ci_classwise_metric(y_true, y_pred, metric_function, label, n_bootstrap=1000, alpha=0.05):
    """Compute the (1-alpha) confidence interval of the metric using bootstrap for a specific class."""
    indices = np.arange(len(y_true))  
    y_true_binary = (np.array(y_true) == label).astype(int)
    y_pred_binary = (np.array(y_pred) == label).astype(int)
    stats = []
    for _ in range(n_bootstrap):
        resampled_indices = resample(indices)
        stats.append(metric_function(y_true_binary[resampled_indices], y_pred_binary[resampled_indices]))
    return (np.percentile(stats, 100 * (alpha / 2.)), np.percentile(stats, 100 * (1 - alpha / 2.)))
def bootstrap_ci_classwise_auc(y_true, y_pred_proba, label, n_bootstrap=1000, alpha=0.05):
    """Compute the (1-alpha) confidence interval of the AUC using bootstrap for a specific class."""
    indices = np.arange(len(y_true))  
    y_true_binary = (np.array(y_true) == label).astype(int)
    auc_stats = []
    for _ in range(n_bootstrap):
        resampled_indices = resample(indices)
        auc_stats.append(roc_auc_score(y_true_binary[resampled_indices], y_pred_proba[resampled_indices, label]))
    return (np.percentile(auc_stats, 100 * (alpha / 2.)), np.percentile(auc_stats, 100 * (1 - alpha / 2.)))
def compute_classwise_auc(y_true, y_pred_proba):
    class_aucs = []
    for i in range(3):
        binary_y_true = np.where(y_true == i, 1, 0)
        class_aucs.append(roc_auc_score(binary_y_true, y_pred_proba[:, i]))
    return class_aucs

In [None]:
folder_path = "D:/LOC_matrix/" 
dataset = []
temp_data_epilepsy = defaultdict(lambda: [None]*5)
temp_data_tme = defaultdict(lambda: [None]*5)
temp_data_drug = defaultdict(lambda: [None]*5)
temp_data_control = defaultdict(lambda: [None]*5)
frequency_bands = ["delta", "theta", "alpha", "beta", "gamma"]
frequency_bands_dict = {band: i for i, band in enumerate(frequency_bands)}  
epilepsy_path = os.path.join(folder_path, "epilepsy_rev")
epilepsy_label = 0
for file in os.listdir(epilepsy_path):
    if not file.endswith(".mat"):
        continue
    match = re.search(r"(\w+)_([a-z\d]+)\.mat", file)
    if match is None:
        continue
    frequency_band, patient_number = match.group(1, 2)
    file_path = os.path.join(epilepsy_path, file)
    mat_data = io.loadmat(file_path)
    data = np.array(mat_data['data'])
    temp_data_epilepsy[patient_number][frequency_bands_dict[frequency_band]] = data
tme_path = os.path.join(folder_path, "tme_rev")
tme_label = 1
for file in os.listdir(tme_path):
    if not file.endswith(".mat"):
        continue
    match = re.search(r"(\w+)_([a-z\d]+)\.mat", file)
    if match is None:
        continue
    frequency_band, patient_number = match.group(1, 2)
    file_path = os.path.join(tme_path, file)
    mat_data = io.loadmat(file_path)
    data = np.array(mat_data['data'])
    temp_data_tme[patient_number][frequency_bands_dict[frequency_band]] = data
drug_path = os.path.join(folder_path, "drug_rev")
drug_label = 2
for file in os.listdir(drug_path):
    if not file.endswith(".mat"):
        continue
    match = re.search(r"(\w+)_([a-z\d]+)\.mat", file)
    if match is None:
        continue
    frequency_band, patient_number = match.group(1, 2)
    file_path = os.path.join(drug_path, file)
    mat_data = io.loadmat(file_path)
    data = np.array(mat_data['data'])
    temp_data_drug[patient_number][frequency_bands_dict[frequency_band]] = data
for patient_number, data in temp_data_epilepsy.items():
    for band in frequency_bands:
        for i in range(19):  
            data[frequency_bands_dict[band]][:i+1, i:] = 0  
    dataset.append((np.stack(data, axis=0), epilepsy_label, patient_number))
for patient_number, data in temp_data_tme.items():
    for band in frequency_bands:
        for i in range(19):  
            data[frequency_bands_dict[band]][:i+1, i:] = 0  
    dataset.append((np.stack(data, axis=0), tme_label, patient_number))
for patient_number, data in temp_data_drug.items():
    for band in frequency_bands:
        for i in range(19):  
            data[frequency_bands_dict[band]][:i+1, i:] = 0  
    dataset.append((np.stack(data, axis=0), drug_label, patient_number))

In [None]:
folder_path = "D:/LOC_matrix/" 
dataset_prospective = []
temp_data_epilepsy = defaultdict(lambda: [None]*5)
temp_data_tme = defaultdict(lambda: [None]*5)
temp_data_drug = defaultdict(lambda: [None]*5)
temp_data_control = defaultdict(lambda: [None]*5)
frequency_bands = ["delta", "theta", "alpha", "beta", "gamma"]
frequency_bands_dict = {band: i for i, band in enumerate(frequency_bands)}  
epilepsy_path = os.path.join(folder_path, "epilepsy_rev_pro")
epilepsy_label = 0
for file in os.listdir(epilepsy_path):
    if not file.endswith(".mat"):
        continue
    match = re.search(r"(\w+)_([a-z\d]+)\.mat", file)
    if match is None:
        continue
    frequency_band, patient_number = match.group(1, 2)
    file_path = os.path.join(epilepsy_path, file)
    mat_data = io.loadmat(file_path)
    data = np.array(mat_data['data'])
    temp_data_epilepsy[patient_number][frequency_bands_dict[frequency_band]] = data
tme_path = os.path.join(folder_path, "tme_rev_pro")
tme_label = 1
for file in os.listdir(tme_path):
    if not file.endswith(".mat"):
        continue
    match = re.search(r"(\w+)_([a-z\d]+)\.mat", file)
    if match is None:
        continue
    frequency_band, patient_number = match.group(1, 2)
    file_path = os.path.join(tme_path, file)
    mat_data = io.loadmat(file_path)
    data = np.array(mat_data['data'])
    temp_data_tme[patient_number][frequency_bands_dict[frequency_band]] = data
drug_path = os.path.join(folder_path, "drug_rev_pro")
drug_label = 2
for file in os.listdir(drug_path):
    if not file.endswith(".mat"):
        continue
    match = re.search(r"(\w+)_([a-z\d]+)\.mat", file)
    if match is None:
        continue
    frequency_band, patient_number = match.group(1, 2)
    file_path = os.path.join(drug_path, file)
    mat_data = io.loadmat(file_path)
    data = np.array(mat_data['data'])
    temp_data_drug[patient_number][frequency_bands_dict[frequency_band]] = data
for patient_number, data in temp_data_epilepsy.items():
    for band in frequency_bands:
        for i in range(19):  
            data[frequency_bands_dict[band]][:i+1, i:] = 0  
    dataset_prospective.append((np.stack(data, axis=0), epilepsy_label, patient_number))
for patient_number, data in temp_data_tme.items():
    for band in frequency_bands:
        for i in range(19):  
            data[frequency_bands_dict[band]][:i+1, i:] = 0  
    dataset_prospective.append((np.stack(data, axis=0), tme_label, patient_number))
for patient_number, data in temp_data_drug.items():
    for band in frequency_bands:
        for i in range(19):  
            data[frequency_bands_dict[band]][:i+1, i:] = 0  
    dataset_prospective.append((np.stack(data, axis=0), drug_label, patient_number))


In [None]:
class EEGDataset(Dataset):
    def __init__(self, data, labels, patient_ids):
        self.data = data
        self.labels = labels
        self.patient_ids = patient_ids
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx], self.patient_ids[idx]


In [None]:
def output_size_after_conv(input_size, kernel_size, stride=1, padding=1):
    return int((input_size - kernel_size + 2 * padding) / stride + 1)
def output_size_after_maxpool(input_size, kernel_size, stride=None):
    if stride is None:
        stride = kernel_size
    return int((input_size - kernel_size) / stride + 1)
class DynamicEEGCNN(nn.Module):
    def __init__(self, num_blocks=2, dropout_rate=0.20, kernel_size=3):
        super(DynamicEEGCNN, self).__init__()
        channels = [5, 32, 64, 128, 256]
        self.blocks = nn.ModuleList()
        input_size = 19
        for i in range(num_blocks):
            conv_size = output_size_after_conv(input_size, kernel_size)
            pool_size = output_size_after_maxpool(conv_size, 2)
            if pool_size < 1:
                break
            block = nn.Sequential(
                nn.Conv2d(channels[i], channels[i+1], kernel_size=kernel_size, stride=1, padding=1),
                nn.BatchNorm2d(channels[i+1]),
                nn.LeakyReLU(),
                nn.Conv2d(channels[i+1], channels[i+1], kernel_size=kernel_size, stride=1, padding=1),
                nn.BatchNorm2d(channels[i+1]),
                nn.LeakyReLU(),
                nn.Dropout(dropout_rate),
                nn.MaxPool2d(kernel_size=2),
            )
            self.blocks.append(block)
            input_size = pool_size
        final_size = int(input_size)
        final_channel = channels[len(self.blocks)]
        fc_input_size = final_channel * final_size * final_size
        if kernel_size == 3:
            self.fc1 = nn.Sequential(
                nn.Linear(fc_input_size, 1024),
                nn.LeakyReLU(),
                nn.Dropout(dropout_rate)
            )
            self.fc2 = nn.Sequential(
                nn.Linear(1024, 512),
                nn.LeakyReLU(),
                nn.Dropout(dropout_rate)
            )
            self.fc3 = nn.Linear(512, 3)
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


In [None]:
seed_value = 777  
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(77)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
param_grid = {
    'learning_rate': [0.001],
    'weight_decay': [1e-6],
    'dropout_rate': [0.10],
    'kernel_size': [3],
    'num_blocks': [4]
}
predictions = defaultdict(list)
external_predictions = defaultdict(list)
all_params = [dict(zip(param_grid.keys(), v)) for v in itertools.product(*param_grid.values())]
for params in all_params:
    best_results = {
        'predicted_probabilities': None,
        'test_labels': None,
        'predicted_labels': None
    }
    best_results_per_fold = defaultdict(lambda: {
        'predicted_probabilities': None,
        'test_labels': None,
        'predicted_labels': None,
        'patient_ids': None
    })
    learning_rate = params['learning_rate']
    weight_decay = params['weight_decay']
    dropout_rate = params['dropout_rate']
    kernel_size = params['kernel_size']
    num_blocks = params['num_blocks']
    print(f"Training with params: {params}")
    train_batch_size = 256 
    train_shuffle = True  
    test_batch_size = 256  
    test_shuffle = False  
    num_epochs = 1000 
    best_model_path = None
    best_accuracy = 0.0
    best_fold = None
    k = 10 
    skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=777) 
    fold_results = []  
    all_test_labels = [] 
    all_predicted_labels = [] 
    all_predicted_probabilities = []    
    data, labels, patient_ids = zip(*dataset)
    data = np.array(data)
    labels = np.array(labels)
    for fold, (train_indices, test_indices) in enumerate(skf.split(data, labels)):
        fold_patient_ids = [patient_ids[i] for i in test_indices] 
        fold_train_data, fold_train_labels = data[train_indices], labels[train_indices]
        fold_test_data, fold_test_labels = data[test_indices], labels[test_indices]
        fold_train_dataset = EEGDataset(fold_train_data, fold_train_labels, [patient_ids[i] for i in train_indices])
        fold_test_dataset = EEGDataset(fold_test_data, fold_test_labels, [patient_ids[i] for i in test_indices])
        fold_train_data_loader = DataLoader(fold_train_dataset, batch_size=train_batch_size, shuffle=train_shuffle)
        fold_test_data_loader = DataLoader(fold_test_dataset, batch_size=test_batch_size, shuffle=test_shuffle)
        model = DynamicEEGCNN(num_blocks=params['num_blocks'], dropout_rate=dropout_rate, kernel_size=kernel_size).to('cuda:0')
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        best_test_accuracy = 0.0
        for epoch in range(num_epochs):
            total = 0
            correct = 0
            model.train()  
            for inputs, label, patient_id_batch in fold_train_data_loader:
                inputs = inputs.float().to('cuda:0')  
                label = label.long().to('cuda:0')     
                outputs = model(inputs)
                loss = criterion(outputs, label)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                _, predicted = torch.max(outputs, 1)
                total += label.size(0)
                correct += (predicted == label).sum().item()
            train_accuracy = correct / total
            total = 0
            correct = 0
            model.eval()  
            fold_predicted_labels = []
            fold_test_labels = []
            fold_predicted_probabilities = []
            with torch.no_grad():
                for inputs, label, patient_id_batch in fold_test_data_loader:  
                    inputs = inputs.float().to('cuda:0')  
                    label = label.long().to('cuda:0')     
                    outputs = model(inputs)
                    probabilities = torch.softmax(outputs, dim=1) 
                    _, predicted = torch.max(outputs, 1)
                    total += label.size(0)
                    correct += (predicted == label).sum().item()
                    fold_predicted_probabilities.append(probabilities.cpu().numpy())
                    fold_predicted_labels += predicted.cpu().tolist()  
                    fold_test_labels += label.cpu().tolist()          
            fold_predicted_probabilities = np.concatenate(fold_predicted_probabilities, axis=0)
            test_accuracy = correct / total
            if test_accuracy > best_test_accuracy:
                best_test_accuracy = test_accuracy
                model_path = f'./LOC_slice_1_model_fold_{fold + 1}.pth'
                torch.save(model.state_dict(), model_path)
                if best_test_accuracy > best_accuracy:
                    best_accuracy = best_test_accuracy
                    best_model_path = model_path
                    best_fold = fold
                    best_results = {
                        'predicted_probabilities': fold_predicted_probabilities,
                        'test_labels': fold_test_labels,
                        'predicted_labels': fold_predicted_labels
                    }
                best_results_per_fold[fold] = {
                'predicted_probabilities': fold_predicted_probabilities,
                'test_labels': fold_test_labels,
                'predicted_labels': fold_predicted_labels,
                'patient_ids': fold_patient_ids
                }
                best_predicted_probabilities = fold_predicted_probabilities
                best_predicted_labels = fold_predicted_labels
                best_test_labels = fold_test_labels   
        all_predicted_probabilities.append(best_predicted_probabilities)
        all_predicted_labels += best_predicted_labels
        all_test_labels += best_test_labels 
        fold_results.append(best_test_accuracy) 
        print(f"Fold {fold + 1}, Best Test Acc: {best_test_accuracy:.4f}")
    for fold, fold_data in best_results_per_fold.items():
        fold_predicted_probabilities = fold_data['predicted_probabilities']
        fold_test_labels = fold_data['test_labels']
        fold_predicted_labels = fold_data['predicted_labels']
        fold_patient_ids = fold_data['patient_ids']
        for idx in range(len(fold_test_labels)):
            predictions['DeepLearningModel'].append({
                'ID': fold_patient_ids[idx],
                'Fold': fold + 1,
                'True_Label': fold_test_labels[idx],
                'NCSE_Prob': fold_predicted_probabilities[idx][0],
                'ME_Prob': fold_predicted_probabilities[idx][1],
                'BI_Prob': fold_predicted_probabilities[idx][2]
            })
    df = pd.DataFrame(predictions['DeepLearningModel'])
    csv_filename = f"dl_predictions_150_original_dataset.csv"
    df.to_csv(csv_filename, index=False)
    mean_accuracy = sum(fold_results) / len(fold_results)
    print(f"Mean Accuracy: {mean_accuracy:.4f}")
    n_classes = 3
    all_predicted_probabilities = np.array(all_predicted_probabilities)
    all_test_labels = np.array(all_test_labels)
    y_test = label_binarize(all_test_labels, classes=np.unique(all_test_labels))
    y_score = all_predicted_probabilities.reshape(-1, n_classes)
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    plt.rcParams['font.size'] = 12  
    for i, label in enumerate(['NCSE', 'ME', 'BI']):
        fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
        confidence_lower, confidence_upper = bootstrap_auc(y_test[:, i], y_score[:, i])
        plt.plot(fpr[i], tpr[i], label=f'ROC curve of class {label} (AUC = {roc_auc[i]:.3f}, 95% CI: [{confidence_lower:.3f}, {confidence_upper:.3f}])')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic to Multi-Class')
    plt.legend(loc="lower right")
    plt.savefig("dl_ROC.eps", format='eps')
    plt.show()
    plt.rcParams['font.size'] = 18 
    labels_class = ['NCSE', 'ME', 'BI']  
    conf_mat = confusion_matrix(all_test_labels, all_predicted_labels)
    plt.figure(figsize=(10,10))
    sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues', xticklabels=labels_class, yticklabels=labels_class, vmin=0, vmax=50)
    plt.xlabel('Predicted Label', fontsize=22)
    plt.ylabel('True Label', fontsize=22)
    plt.savefig("dl_confusion.eps", format='eps')
    plt.show()
    n_classes = 3
    optimal_thresholds = dict()
    for i, label in enumerate(['NCSE', 'ME', 'BI']):
        fpr[i], tpr[i], thresholds = roc_curve(y_test[:, i], y_score[:, i])
        optimal_thresholds[i] = find_optimal_cutoff(tpr[i], fpr[i], thresholds)
    data_prospective, labels_prospective, patient_ids_prospective = zip(*dataset_prospective)
    data_prospective = np.array(data_prospective)
    labels_prospective = np.array(labels_prospective)
    prospective_dataset = EEGDataset(data_prospective, labels_prospective, patient_ids_prospective)
    prospective_data_loader = DataLoader(prospective_dataset, batch_size=test_batch_size, shuffle=test_shuffle)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    best_model = DynamicEEGCNN(num_blocks=num_blocks, dropout_rate=dropout_rate, kernel_size=kernel_size).to(device)
    best_model.load_state_dict(torch.load(best_model_path, map_location=device))
    full_train_dataset = EEGDataset(data, labels, patient_ids)
    full_train_data_loader = DataLoader(full_train_dataset, batch_size=train_batch_size, shuffle=train_shuffle)
    criterion = nn.CrossEntropyLoss()
    additional_epochs = [100]
    best_prospective_accuracy = 0.0
    best_prospective_epoch = 0
    results = {}
    for epochs in additional_epochs:
        model = DynamicEEGCNN(num_blocks=num_blocks, dropout_rate=dropout_rate, kernel_size=kernel_size).to(device)
        model.load_state_dict(torch.load(best_model_path, map_location=device))
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        for epoch in range(epochs):
            model.train()
            for inputs, label, _ in full_train_data_loader:
                inputs = inputs.float().to(device)
                label = label.long().to(device)
                outputs = model(inputs)
                loss = criterion(outputs, label)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        test_labels, predicted_labels, predicted_probabilities, patient_ids = evaluate_model(model, prospective_data_loader)
        prospective_accuracy = np.mean(np.array(test_labels) == np.array(predicted_labels))
        if prospective_accuracy > best_prospective_accuracy:
            best_prospective_accuracy = prospective_accuracy
            best_prospective_epoch = epochs
        results[epochs] = {
            'test_labels': test_labels,
            'predicted_labels': predicted_labels,
            'predicted_probabilities': predicted_probabilities
        }
        for idx in range(len(test_labels)):
            external_predictions['DeepLearningModel'].append({
                'ID': patient_ids_prospective[idx],
                'Fold': 'ExternalValidation',
                'True_Label': test_labels[idx],
                'NCSE_Prob': predicted_probabilities[idx][0],
                'ME_Prob': predicted_probabilities[idx][1],
                'BI_Prob': predicted_probabilities[idx][2]
            })
        df = pd.DataFrame(external_predictions['DeepLearningModel'])
        csv_filename = f"dl_predictions_after_{epochs}_epochs.csv"
        df.to_csv(csv_filename, index=False)
        external_predictions['DeepLearningModel'].clear()  
    labels_class = ['NCSE', 'ME', 'BI']  
    n_classes = 3
    for epochs, epoch_results in results.items():
        predicted_probabilities = epoch_results['predicted_probabilities']
        test_labels = epoch_results['test_labels']
        y_test = label_binarize(test_labels, classes=np.unique(test_labels))
        y_score = predicted_probabilities.reshape(-1, n_classes)
        optimal_thresholds = dict()
        plt.figure(figsize=(25, 8))
        plt.subplot(1, 3, 1)  
        fpr = dict()
        tpr = dict()
        roc_auc = dict()
        for i, label in enumerate(['NCSE', 'ME', 'BI']):
            fpr[i], tpr[i], thresholds = roc_curve(y_test[:, i], y_score[:, i])
            roc_auc[i] = auc(fpr[i], tpr[i])
            confidence_lower, confidence_upper = bootstrap_auc(y_test[:, i], y_score[:, i])
            plt.plot(fpr[i], tpr[i], label=f'ROC curve of class {label} (AUC = {roc_auc[i]:.3f}, 95% CI: [{confidence_lower:.3f}, {confidence_upper:.3f}])')
            optimal_thresholds[i] = find_optimal_cutoff(tpr[i], fpr[i], thresholds)
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'Receiver Operating Characteristic after {epochs} Epochs')
        plt.legend(loc="lower right")
        optimal_predicted_labels = np.zeros_like(y_score)
        for i in range(n_classes):
            optimal_predicted_labels[:, i] = y_score[:, i] > optimal_thresholds[i]
        optimal_predicted_labels = np.argmax(optimal_predicted_labels, axis=1)
        general_predicted_labels = np.argmax(y_score, axis=1)
        labels_class = ['NCSE', 'ME', 'BI']
        conf_mat_general = confusion_matrix(test_labels, general_predicted_labels)
        plt.subplot(1, 3, 3)  
        sns.heatmap(conf_mat_general, annot=True, fmt='d', cmap='Blues', xticklabels=labels_class, yticklabels=labels_class, vmin=0, vmax=10)
        plt.xlabel('Predicted Label', fontsize=12)
        plt.ylabel('True Label', fontsize=12)
        plt.title(f'General Confusion Matrix after {epochs} Epochs', fontsize=14)
        plt.tight_layout()  
        plt.savefig(f"dl_combined_after_{epochs}_epochs.eps", format='eps')
        plt.show()


In [None]:
vmin = 0
vmax = 1.5
model = DynamicEEGCNN(num_blocks=params['num_blocks'], dropout_rate=dropout_rate, kernel_size=kernel_size).to('cuda:0')
data, labels, patient_ids = zip(*dataset)
train_batch_size = 256
train_shuffle = True
test_batch_size = 256
test_shuffle = False
data = np.array(data)
labels = np.array(labels)
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=777)
class_names = ['NCSE', 'ME', 'BI']
original_labels = labels.copy()
saliency_sums_all_folds = {(class_idx, band_idx): [] for class_idx in range(3) for band_idx in range(5)}
saliency_values_all_folds = {(class_idx, band_idx): [] for class_idx in range(3) for band_idx in range(5)}
for fold, (train_indices, test_indices) in enumerate(skf.split(data, original_labels)):
    print(f"Processing fold {fold + 1}...")
    fold_test_labels = original_labels[test_indices]
    unique, counts = np.unique(fold_test_labels, return_counts=True)
    print("Class distribution in this fold's test set:")
    for u, c in zip(unique, counts):
        print(f"Class {u}: {c} samples")
    saliency_sums = {(class_idx, band_idx): np.zeros_like(data[0][0]) for class_idx in range(3) for band_idx in range(5)}
    sample_counts = {(class_idx, band_idx): 0 for class_idx in range(3) for band_idx in range(5)}
    if isinstance(original_labels, torch.Tensor):
        labels = original_labels.numpy()
    else:
        labels = original_labels
    model.load_state_dict(torch.load( f'./LOC_slice_1_model_fold_{fold + 1}.pth'))
    fold_test_data, fold_test_labels = data[test_indices], labels[test_indices]
    fold_patient_ids = [patient_ids[i] for i in test_indices]
    fold_test_dataset = EEGDataset(fold_test_data, fold_test_labels, fold_patient_ids)
    fold_test_data_loader = DataLoader(fold_test_dataset, batch_size=test_batch_size, shuffle=test_shuffle)
    model.eval()
    model.requires_grad_()
    device = next(model.parameters()).device
    for inputs, labels, _ in fold_test_data_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        for i in range(inputs.size(0)):
            class_idx = labels[i].item()
            for band_idx, band_name in enumerate(frequency_bands):
                input_tensor = inputs[i].unsqueeze(0).float()
                mask = torch.zeros_like(input_tensor)
                mask[0, band_idx] = 1
                input_tensor *= mask
                input_tensor.requires_grad_()
                outputs = model(input_tensor)
                score = outputs[0, class_idx]
                score.backward()
                saliency = input_tensor.grad.data.abs().squeeze().cpu().numpy()
                saliency_band = saliency[band_idx]
                saliency_sums[(class_idx, band_idx)] += saliency_band
                saliency_values = [saliency_band[i][j] for i in range(19) for j in range(i+1, 19)]
                avg_saliency = np.mean(saliency_values)
                saliency_values_all_folds[(class_idx, band_idx)].append(avg_saliency)
                sample_counts[(class_idx, band_idx)] += 1
    for class_idx in range(3):
        for band_idx in range(5):
            if sample_counts[(class_idx, band_idx)] != 0:
                average_saliency_band = saliency_sums[(class_idx, band_idx)] / sample_counts[(class_idx, band_idx)]
                saliency_sums_all_folds[(class_idx, band_idx)].append(average_saliency_band)
            else:
                print(f"No samples for class {class_idx} and band {band_idx}.")
average_saliencies = {(class_idx, band_idx): np.mean(saliency_sums_all_folds[(class_idx, band_idx)], axis=0) for class_idx in range(3) for band_idx in range(5)}
all_saliencies_flat = np.concatenate([saliency.ravel() for saliency in average_saliencies.values()])
min_val, max_val = all_saliencies_flat.min(), all_saliencies_flat.max()
fig = plt.figure(figsize=(17, 15))
gs = gridspec.GridSpec(3, len(frequency_bands) + 1, width_ratios=[1] * len(frequency_bands) + [0.05])
gs.update(wspace=0.5)
axes = []
for class_idx in range(3):
    for band_idx, band_name in enumerate(frequency_bands):
        for i in range(19):
            for j in range(i, 19):
                average_saliencies[(class_idx, band_idx)][i][j] = 0
        ax = plt.subplot(gs[class_idx, band_idx])
        im = ax.imshow(average_saliencies[(class_idx, band_idx)], cmap='jet', vmin=0, vmax=1.5)
        ax.set_title(f"{class_names[class_idx]}, {band_name}")
        axes.append(ax)
cbar_ax = plt.subplot(gs[:, -1])
fig.colorbar(im, cax=cbar_ax, orientation='vertical', ticks=[0, 0.5, 1, 1.5])
plt.tight_layout()
plt.savefig("dl_all_saliency_plots_with_colorbar.eps", format='eps')
plt.show()
