In [2]:
!pip install mne



In [3]:
import tensorflow as tf
import glob
import os
import mne
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.io import loadmat

In [4]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)
data_folder = '/content/drive/My Drive/Data'

Mounted at /content/drive


In [5]:
all_files_path = glob.glob(os.path.join(data_folder, '*.edf'))
len(all_files_path)

28

In [6]:
all_files_path[0]

'/content/drive/My Drive/Data/h01.edf'

In [7]:
healthy_file_path = [i for i in all_files_path if 'h' in i.split('/')[-1]]
patient_file_path = [i for i in all_files_path if 's' in i.split('/')[-1]]

In [8]:
healthy_file_path

['/content/drive/My Drive/Data/h01.edf',
 '/content/drive/My Drive/Data/h02.edf',
 '/content/drive/My Drive/Data/h03.edf',
 '/content/drive/My Drive/Data/h04.edf',
 '/content/drive/My Drive/Data/h05.edf',
 '/content/drive/My Drive/Data/h06.edf',
 '/content/drive/My Drive/Data/h07.edf',
 '/content/drive/My Drive/Data/h08.edf',
 '/content/drive/My Drive/Data/h09.edf',
 '/content/drive/My Drive/Data/h10.edf',
 '/content/drive/My Drive/Data/h11.edf',
 '/content/drive/My Drive/Data/h12.edf',
 '/content/drive/My Drive/Data/h13.edf',
 '/content/drive/My Drive/Data/h14.edf']

In [9]:
patient_file_path

['/content/drive/My Drive/Data/s01.edf',
 '/content/drive/My Drive/Data/s02.edf',
 '/content/drive/My Drive/Data/s03.edf',
 '/content/drive/My Drive/Data/s04.edf',
 '/content/drive/My Drive/Data/s05.edf',
 '/content/drive/My Drive/Data/s06.edf',
 '/content/drive/My Drive/Data/s07.edf',
 '/content/drive/My Drive/Data/s08.edf',
 '/content/drive/My Drive/Data/s09.edf',
 '/content/drive/My Drive/Data/s10.edf',
 '/content/drive/My Drive/Data/s11.edf',
 '/content/drive/My Drive/Data/s12.edf',
 '/content/drive/My Drive/Data/s13.edf',
 '/content/drive/My Drive/Data/s14.edf']

In [10]:
def read_data(file_path):
    datax=mne.io.read_raw_edf(file_path,preload=True)
    datax.set_eeg_reference()
    datax.filter(l_freq=1,h_freq=45)
    epochs=mne.make_fixed_length_epochs(datax,duration=25,overlap=0)
    epochs=epochs.get_data()
    return epochs #trials,channel,length

In [11]:
data=read_data(healthy_file_path[0])

Extracting EDF parameters from /content/drive/My Drive/Data/h01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 231249  =      0.000 ...   924.996 secs...
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 825 samples (3.300 s)

Not setting metadata
37 matching events found
No baseline correction applied
0 projection items activated
Using data 

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s


In [12]:
data.shape

(37, 19, 6250)

In [13]:
%%capture
control_epochs_array=[read_data(subject) for subject in healthy_file_path]
patients_epochs_array=[read_data(subject) for subject in patient_file_path]

In [14]:
control_epochs_labels=[len(i)*[0] for i in control_epochs_array]
patients_epochs_labels=[len(i)*[1] for i in patients_epochs_array]
print(len(control_epochs_labels),len(patients_epochs_labels))

14 14


In [15]:
data_list=control_epochs_array+patients_epochs_array
label_list=control_epochs_labels+patients_epochs_labels
print(len(data_list),len(label_list))

28 28


In [16]:
groups_list=[[i]*len(j) for i, j in enumerate(data_list)]

In [17]:
data_array=np.vstack(data_list)
label_array=np.hstack(label_list)
group_array=np.hstack(groups_list)
data_array=np.moveaxis(data_array,1,2)

print(data_array.shape,label_array.shape,group_array.shape)

(1142, 6250, 19) (1142,) (1142,)


In [19]:
from tensorflow.keras.layers import Conv1D, BatchNormalization, LeakyReLU, MaxPool1D, GlobalAveragePooling1D, Dense, Dropout
from tensorflow.keras.models import Sequential
from tensorflow.keras.backend import clear_session
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import GroupKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt

# Define CNN model with more filters and layers
def cnnmodel():
    clear_session()
    model = Sequential()
    model.add(Conv1D(filters=16, kernel_size=3, strides=1, input_shape=(6250, 19)))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(MaxPool1D(pool_size=2, strides=2))

    model.add(Conv1D(filters=32, kernel_size=3, strides=1))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(MaxPool1D(pool_size=2, strides=2))
    model.add(Dropout(0.5))

    model.add(Conv1D(filters=64, kernel_size=3, strides=1))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(GlobalAveragePooling1D())

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

model=cnnmodel()
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv1d (Conv1D)             (None, 6248, 16)          928       
                                                                 
 batch_normalization (Batch  (None, 6248, 16)          64        
 Normalization)                                                  
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 6248, 16)          0         
                                                                 
 max_pooling1d (MaxPooling1  (None, 3124, 16)          0         
 D)                                                              
                                                                 
 conv1d_1 (Conv1D)           (None, 3122, 32)          1568      
                                                                 
 batch_normalization_1 (Bat  (None, 3122, 32)          1

In [20]:
gkf = GroupKFold(n_splits=5)
all_fold_metrics = {'accuracy': [], 'f1_score': [], 'precision': [], 'recall': []}

# Group K-Fold cross-validation
for fold, (train_index, val_index) in enumerate(gkf.split(data_array, label_array, groups=group_array)):
    print(f"Training fold {fold+1}")
    train_features, train_labels = data_array[train_index], label_array[train_index]
    val_features, val_labels = data_array[val_index], label_array[val_index]

    # Scale features
    scaler = StandardScaler()
    train_features = scaler.fit_transform(train_features.reshape(-1, train_features.shape[-1])).reshape(train_features.shape)
    val_features = scaler.transform(val_features.reshape(-1, val_features.shape[-1])).reshape(val_features.shape)

    # Initialize and train model
    model = cnnmodel()
    early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
    history = model.fit(train_features, train_labels, epochs=30, batch_size=120,
                        validation_data=(val_features, val_labels), callbacks=[early_stopping], verbose=1)

    # Get predictions and calculate metrics
    y_pred = (model.predict(val_features) > 0.5).astype("int32")
    all_fold_metrics['accuracy'].append(accuracy_score(val_labels, y_pred))
    all_fold_metrics['f1_score'].append(f1_score(val_labels, y_pred))
    all_fold_metrics['precision'].append(precision_score(val_labels, y_pred))
    all_fold_metrics['recall'].append(recall_score(val_labels, y_pred))

    print(f"Fold {fold+1} - Accuracy: {all_fold_metrics['accuracy'][-1]:.4f}, F1 Score: {all_fold_metrics['f1_score'][-1]:.4f}")

# Compute mean metrics across folds
mean_metrics = {metric: np.mean(values) for metric, values in all_fold_metrics.items()}
print("\nMean metrics across folds:")
print(mean_metrics)

Training fold 1
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Fold 1 - Accuracy: 0.3723, F1 Score: 0.5426
Training fold 2
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Fold 2 - Accuracy: 0.4709, F1 Score: 0.6403
Training fold 3
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Fold 3 - Accuracy: 0.5000, F1 Score: 0.6667
Training fold 4
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Fold 4 - Accuracy: 0.6822, F1 Score: 0.7851
Training fold 5
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Fold 5 - Accuracy: 0.6936, F1 Score: 0.7989

Mean metrics across folds:
{'accuracy': 0.5437977139618486, 'f1_score': 0.6867000823910295, 'precision': 0.5626121263640839, 'recall': 0.943558282208589}
