In [2]:
import pandas as pd
import numpy as np
from typing import List
import importlib

import matplotlib.pyplot as plt
import plotly.graph_objects as go

import torch

from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import f1_score, precision_score, recall_score

import EpilepsyNet_model 
from preprocessing_utils import *

epoch_duration = 5.0

# Step 1 : Load Data

### Read raw data

In [None]:
df = pd.read_excel('eeg_metadata.xlsx')
df['epilepsy'] = df['edf_path'].apply(lambda x: 0 if 'no_epilepsy' in x else 1)
try:
    df.drop(['Unnamed: 0', 'ethnicity'], axis=1, inplace=True)
except:
    pass
df

### Make balanced Train/test Datasets

In [None]:
'''
    Each label contains 100 patients (with different number & times of acquisition.
    The epilepic patients have on average more recordings (closer attention on them).
    However the most important is to avoid data leakage and split the datasets on patient ids.
'''

df_pos = df[df['epilepsy'] == 1]
df_neg = df[df['epilepsy'] == 0]

print(df_pos['subject_id'].nunique(), df_neg['subject_id'].nunique())
split_pos = int(0.8 * len(df_pos['subject_id'].unique()))
print(split_pos)

split_train_pos, split_test_pos = df_pos['subject_id'].unique()[:split_pos], df_pos['subject_id'].unique()[split_pos:]

print(len(split_train_pos ), len(split_test_pos ))
##############################
split_neg = int(0.8 * len(df_neg['subject_id'].unique()))
print(split_neg)

split_train_neg, split_test_neg = df_neg['subject_id'].unique()[:split_neg], df_neg['subject_id'].unique()[split_neg:]

print(len(split_train_neg), len(split_test_neg))


In [None]:
df_train = pd.concat([df_pos[df_pos['subject_id'].isin(split_train_pos)], df_neg[df_neg['subject_id'].isin(split_train_neg)]])
print(df_train.shape[0], df_train['epilepsy'].value_counts())
display(df_train)

df_test = pd.concat([df_pos[df_pos['subject_id'].isin(split_test_pos)], df_neg[df_neg['subject_id'].isin(split_test_neg)]])
print(df_test.shape[0], df_test['epilepsy'].value_counts())
display(df_test)

In [None]:
X_train_list, y_train = Load_raw_labeled(
                                        df_train.groupby('subject_id').nth(range(25))
                                    )
X_test_list , y_test  = Load_raw_labeled(
                                        df_test.groupby('subject_id').nth(range(25))
                                    )

In [None]:
len(X_train_list), sum(y_train), len(y_train), len(X_test_list), sum(y_test), len(y_test)

# Step 2 : Preprocess Data

## Segmentation : 1 minute split into 12 segments of 5 secs

In [11]:
# EEG channels used for prediction
eeg_cols = ['EEG FP1', 'EEG FP2', 'EEG F3', 'EEG F4', 
            'EEG C3', 'EEG C4', 'EEG P3', 'EEG P4', 
            'EEG O1', 'EEG O2', 'EEG F7', 'EEG F8', 
            'EEG T3', 'EEG T4', 'EEG T5', 'EEG T6', 
            'EEG T1', 'EEG T2', 'EEG FZ', 'EEG CZ',
            'EEG PZ']

In [None]:
# Process raw files to get tensor of shape (len(raw_files), 12, len(eeg_cols), 1250)
X_train = process_raw_files(
    X_train_list,
    eeg_cols=eeg_cols,
    segment_duration=60.0,        # 60 second segments
    n_segments_per_file=12,       # Split into 12 epochs (5 sec each)
    samples_per_segment=1250,     # 1250 samples per segment (250 Hz sampling rate)
    random_state=42               # For reproducibility
)

X_test = process_raw_files(
    X_test_list,
    eeg_cols=eeg_cols,
    segment_duration=60.0,        
    n_segments_per_file=12,       
    samples_per_segment=1250,     
    random_state=42               
)

print(f"Output shape: train->{X_train.shape}")
# Expected shape: (len(raw_files), 12, 21, 1250)

In [None]:
print(f"Output shape: train->{X_train.shape}, test->{X_test.shape}")

In [None]:
# Check element of X_train that are all 0 :
zero_elements_train = np.where(np.all(X_train == 0, axis=(1, 2, 3)))
print(f"Number of all-zero elements in train: {len(zero_elements_train[0])}")
# Check element of X_test that are all 0 :
zero_elements_test = np.where(np.all(X_test == 0, axis=(1, 2, 3)))
print(f"Number of all-zero elements in test: {len(zero_elements_test[0])}")

# Drop all-zero elements from X_train and y_train
X_train = np.delete(X_train, zero_elements_train[0], axis=0)
y_train = np.delete(y_train, zero_elements_train[0], axis=0)
# Drop all-zero elements from X_test and y_test
X_test = np.delete(X_test, zero_elements_test[0], axis=0)
y_test = np.delete(y_test, zero_elements_test[0], axis=0)
print(f"Output shape: train->{X_train.shape}, test->{X_test.shape}")

## Preprocessing 

##### Standardization -> Correlation Matrices -> flattening upper triangle

In [None]:
# Standardize the data
X_train_standardized = standardize_data(X_train)
print(f"Standardized output shape: {X_train_standardized.shape}")
# Expected shape: (len(raw_files), 12, 21, 1250)
# Compute the correlation matrix per sample, segment, and channel

corr_matrix_train = compute_correlation_matrix(X_train_standardized)
print(f"Correlation matrix shape: {corr_matrix_train.shape}")
# Expected shape: (21, 21)
# Display the correlation matrix
plt.figure(figsize=(10, 8)) 
plt.imshow(corr_matrix_train[0][11], cmap='coolwarm', aspect='auto')
plt.colorbar()
plt.title('Correlation Matrix')
plt.xlabel('Channels')
plt.ylabel('Channels')
plt.xticks(range(len(eeg_cols)), eeg_cols, rotation=90)
plt.yticks(range(len(eeg_cols)), eeg_cols)
plt.tight_layout()
plt.show()


upper_triangle_matrix_train = extract_upper_triangle(corr_matrix_train)

In [None]:
# Standardize the data
X_test_standardized = standardize_data(X_test)
print(f"Standardized output shape: {X_test_standardized.shape}")
print(f"# of nan values : {np.isnan(X_test_standardized).sum()}")
# Expected shape: (len(raw_files), 12, 21, 1250)
# Compute the correlation matrix per sample, segment, and channel

corr_matrix_test = compute_correlation_matrix(X_test_standardized)
print(f"Correlation matrix shape: {corr_matrix_test.shape}")
print(f"# of nan values : {np.isnan(corr_matrix_test).sum()}")
# Expected shape: (21, 21)

# Example usage
upper_triangle_matrix_test = extract_upper_triangle(corr_matrix_test)

print(upper_triangle_matrix_test.shape)

In [21]:
# Convert to PyTorch tensors
X_train_tensor = torch.tensor(upper_triangle_matrix_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
# Create TensorDataset
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True )

X_test_tensor = torch.tensor(upper_triangle_matrix_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)
# Create TensorDataset
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
# Create DataLoader
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [None]:
print(X_train_tensor.shape)
print(X_test_tensor.shape)
print(y_train_tensor.shape)
print(y_test_tensor.shape)

# Step 3 : Model Training

### Model Definition

In [6]:
importlib.reload(EpilepsyNet_model)

# Model parameters
input_dim = 210  # Size of flattened upper triangle (21*20/2)
embed_dim = 56  # Embedding dimension
num_heads = 7    # Number of attention heads

model = EpilepsyNet_model.TimeSeriesAttentionClassifier(input_dim, embed_dim, num_heads, dropout=0.2)

### Training

In [None]:
train_losses, val_losses, val_accuracies = EpilepsyNet_model.train_model(
    model, 
    train_loader, 
    test_loader,
    num_epochs=1000,
    learning_rate=1e-5,
    weight_decay=1e-5,
    patience=20,
    scheduler_factor=0.5,
    min_lr=1e-6
)

# Plot training curves
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(val_accuracies)
plt.xlabel('Epoch')
plt.ylabel('Validation Accuracy (%)')
plt.tight_layout()
plt.show()

# Visualize attention for a sample
# visualize_attention(model, X_standardized)

In [None]:
# Save model locally:
torch.save(model.state_dict(), 'Weights_model/EpilepsyNet_7Heads.pth')

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(
    x=np.arange(1, 150),
    y=train_losses,
    mode='lines+markers',
    name='Train Loss'
))
fig.add_trace(go.Scatter(
    x=np.arange(1, 150),
    y=val_losses,
    mode='lines+markers',
    name='Validation Loss'
))
fig.update_layout(
    title='Training and Validation Loss',
    xaxis_title='Epoch',
    yaxis_title='Loss',
    legend_title='Legend'
)
fig.show()

In [None]:
# Example usage
y_pred, _ = model(torch.tensor(next(iter(test_loader))))#, dtype=torch.float32)).argmax(dim=1).numpy()
metrics = calculate_metrics(y_test, y_pred)
print(metrics)

In [None]:
# Add to inference : matrix list attention score (210x210) giving correlation p

In [None]:
# Load th model:
model = EpilepsyNet_model.TimeSeriesAttentionClassifier(input_dim, embed_dim, num_heads)
model.load_state_dict(torch.load('Weights_model/EpilepsyNet.pth'))
model.eval()
