In [1]:
!pip install mne

Collecting mne
  Downloading mne-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading mne-1.8.0-py3-none-any.whl (7.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m48.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mne
Successfully installed mne-1.8.0


In [2]:
import tensorflow as tf
import glob
import os
import mne
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.io import loadmat

In [3]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)
data_folder = '/content/drive/My Drive/Data'

Mounted at /content/drive


In [4]:
all_files_path = glob.glob(os.path.join(data_folder, '*.edf'))
len(all_files_path)

28

In [5]:
all_files_path[0]

'/content/drive/My Drive/Data/h01.edf'

In [6]:
healthy_file_path = [i for i in all_files_path if 'h' in i.split('/')[-1]]
patient_file_path = [i for i in all_files_path if 's' in i.split('/')[-1]]

In [7]:
healthy_file_path

['/content/drive/My Drive/Data/h01.edf',
 '/content/drive/My Drive/Data/h02.edf',
 '/content/drive/My Drive/Data/h03.edf',
 '/content/drive/My Drive/Data/h04.edf',
 '/content/drive/My Drive/Data/h05.edf',
 '/content/drive/My Drive/Data/h06.edf',
 '/content/drive/My Drive/Data/h07.edf',
 '/content/drive/My Drive/Data/h08.edf',
 '/content/drive/My Drive/Data/h09.edf',
 '/content/drive/My Drive/Data/h10.edf',
 '/content/drive/My Drive/Data/h11.edf',
 '/content/drive/My Drive/Data/h12.edf',
 '/content/drive/My Drive/Data/h13.edf',
 '/content/drive/My Drive/Data/h14.edf']

In [8]:
patient_file_path

['/content/drive/My Drive/Data/s01.edf',
 '/content/drive/My Drive/Data/s02.edf',
 '/content/drive/My Drive/Data/s03.edf',
 '/content/drive/My Drive/Data/s04.edf',
 '/content/drive/My Drive/Data/s05.edf',
 '/content/drive/My Drive/Data/s06.edf',
 '/content/drive/My Drive/Data/s07.edf',
 '/content/drive/My Drive/Data/s08.edf',
 '/content/drive/My Drive/Data/s09.edf',
 '/content/drive/My Drive/Data/s10.edf',
 '/content/drive/My Drive/Data/s11.edf',
 '/content/drive/My Drive/Data/s12.edf',
 '/content/drive/My Drive/Data/s13.edf',
 '/content/drive/My Drive/Data/s14.edf']

In [9]:
def read_data(file_path):
    datax=mne.io.read_raw_edf(file_path,preload=True)
    datax.set_eeg_reference()
    datax.filter(l_freq=1,h_freq=45)
    epochs=mne.make_fixed_length_epochs(datax,duration=25,overlap=0)
    epochs=epochs.get_data()
    return epochs #trials,channel,length

In [10]:
data=read_data(healthy_file_path[0])

Extracting EDF parameters from /content/drive/My Drive/Data/h01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 231249  =      0.000 ...   924.996 secs...
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 825 samples (3.300 s)

Not setting metadata
37 matching events found
No baseline correction applied
0 projection items activated
Using data 

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s


In [11]:
data.shape

(37, 19, 6250)

In [12]:
%%capture
control_epochs_array=[read_data(subject) for subject in healthy_file_path]
patients_epochs_array=[read_data(subject) for subject in patient_file_path]

In [16]:
control_epochs_labels=[len(i)*[0] for i in control_epochs_array]
patients_epochs_labels=[len(i)*[1] for i in patients_epochs_array]

In [17]:
data_list=control_epochs_array+patients_epochs_array
label_list=control_epochs_labels+patients_epochs_labels

In [18]:
groups_list=[[i]*len(j) for i, j in enumerate(data_list)]

In [19]:
data_array=np.vstack(data_list)
label_array=np.hstack(label_list)
group_array=np.hstack(groups_list)
data_array=np.moveaxis(data_array,1,2)

print(data_array.shape,label_array.shape,group_array.shape)

(1142, 6250, 19) (1142,) (1142,)


In [22]:
from tensorflow.keras.layers import Conv1D, BatchNormalization, LeakyReLU, MaxPool1D, GlobalAveragePooling1D, Dense, Dropout, Add, Input
from tensorflow.keras.models import Model
from tensorflow.keras.backend import clear_session
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import GroupKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt


def cnnmodel():
    clear_session()

    input_layer = Input(shape=(6250, 19))


    x = Conv1D(filters=32, kernel_size=3, strides=1, padding='same')(input_layer)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = MaxPool1D(pool_size=2, strides=2)(x)


    x = Conv1D(filters=64, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = MaxPool1D(pool_size=2, strides=2)(x)
    x = Dropout(0.3)(x)

    x = Conv1D(filters=128, kernel_size=5, strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = MaxPool1D(pool_size=2, strides=2)(x)
    x = Dropout(0.4)(x)


    shortcut = x
    x = Conv1D(filters=128, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = Conv1D(filters=128, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = Add()([x, shortcut])
    x = LeakyReLU()(x)
    x = MaxPool1D(pool_size=2, strides=2)(x)
    x = Dropout(0.4)(x)


    x = GlobalAveragePooling1D()(x)
    x = Dense(64, activation='relu')(x)
    x = Dropout(0.5)(x)
    output_layer = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=input_layer, outputs=output_layer)

    model.compile(optimizer=Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy'])

    return model

model = cnnmodel()
model.summary()


In [24]:
gkf = GroupKFold(n_splits=10)
all_fold_metrics = {'accuracy': [], 'f1_score': [], 'precision': [], 'recall': []}

for fold, (train_index, val_index) in enumerate(gkf.split(data_array, label_array, groups=group_array)):
    print(f"Training fold {fold+1}")
    train_features, train_labels = data_array[train_index], label_array[train_index]
    val_features, val_labels = data_array[val_index], label_array[val_index]


    scaler = StandardScaler()
    train_features = scaler.fit_transform(train_features.reshape(-1, train_features.shape[-1])).reshape(train_features.shape)
    val_features = scaler.transform(val_features.reshape(-1, val_features.shape[-1])).reshape(val_features.shape)

    model = cnnmodel()
    early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
    history = model.fit(train_features, train_labels, epochs=30, batch_size=120,
                        validation_data=(val_features, val_labels), callbacks=[early_stopping], verbose=1)

    y_pred = (model.predict(val_features) > 0.5).astype("int32")
    all_fold_metrics['accuracy'].append(accuracy_score(val_labels, y_pred))
    all_fold_metrics['f1_score'].append(f1_score(val_labels, y_pred))
    all_fold_metrics['precision'].append(precision_score(val_labels, y_pred))
    all_fold_metrics['recall'].append(recall_score(val_labels, y_pred))

    print(f"Fold {fold+1} - Accuracy: {all_fold_metrics['accuracy'][-1]:.4f}, F1 Score: {all_fold_metrics['f1_score'][-1]:.4f}")

mean_metrics = {metric: np.mean(values) for metric, values in all_fold_metrics.items()}
print("\nMean metrics across folds:")
print(mean_metrics)

Training fold 1
Epoch 1/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 1s/step - accuracy: 0.5412 - loss: 0.8718 - val_accuracy: 1.0000 - val_loss: 0.5257
Epoch 2/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 113ms/step - accuracy: 0.5809 - loss: 0.7512 - val_accuracy: 1.0000 - val_loss: 0.4691
Epoch 3/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 111ms/step - accuracy: 0.5401 - loss: 0.8018 - val_accuracy: 1.0000 - val_loss: 0.4656
Epoch 4/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 109ms/step - accuracy: 0.5866 - loss: 0.6906 - val_accuracy: 1.0000 - val_loss: 0.4791
Epoch 5/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 93ms/step - accuracy: 0.6080 - loss: 0.6873 - val_accuracy: 0.9916 - val_loss: 0.5049
Epoch 6/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 93ms/step - accuracy: 0.6823 - loss: 0.5967 - val_accuracy: 0.9328 - val_loss: 0.4998
Epoch 7/30
[1m9/9[0m [32m

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 8 - Accuracy: 0.2821, F1 Score: 0.0000
Training fold 9
Epoch 1/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 908ms/step - accuracy: 0.5371 - loss: 0.8010 - val_accuracy: 0.3707 - val_loss: 0.7169
Epoch 2/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 120ms/step - accuracy: 0.5429 - loss: 0.7633 - val_accuracy: 0.3707 - val_loss: 0.6870
Epoch 3/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 114ms/step - accuracy: 0.5738 - loss: 0.7520 - val_accuracy: 0.3707 - val_loss: 0.6548
Epoch 4/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 106ms/step - accuracy: 0.5802 - loss: 0.7041 - val_accuracy: 0.5948 - val_loss: 0.6242
Epoch 5/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 93ms/step - accuracy: 0.6144 - loss: 0.6510 - val_accuracy: 0.8707 - val_loss: 0.6060
Epoch 6/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 105ms/step - accuracy: 0.6527 - loss: 0.6257 - val_accuracy: 0.879