<a href="https://colab.research.google.com/github/HopeRetina/EEG-related/blob/main/EEG_data_preprocessing_practice_for_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Here I used the mne library to preprocess EEG data from [this paper.](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0188629)

In [1]:
!pip install mne

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting mne
  Downloading mne-1.3.1-py3-none-any.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m33.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: mne
Successfully installed mne-1.3.1


In [2]:
from glob import glob
import os
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
cd /content/drive/MyDrive/data_collection/

/content/drive/MyDrive/data_collection


In [5]:
all_file_path = glob('dataverse_files/*.edf')

print(len(all_file_path))

28


In [6]:
all_file_path

['dataverse_files/h01.edf',
 'dataverse_files/h02.edf',
 'dataverse_files/h03.edf',
 'dataverse_files/h04.edf',
 'dataverse_files/h05.edf',
 'dataverse_files/h06.edf',
 'dataverse_files/h07.edf',
 'dataverse_files/h08.edf',
 'dataverse_files/h09.edf',
 'dataverse_files/h10.edf',
 'dataverse_files/h11.edf',
 'dataverse_files/h12.edf',
 'dataverse_files/h14.edf',
 'dataverse_files/h13.edf',
 'dataverse_files/s01.edf',
 'dataverse_files/s02.edf',
 'dataverse_files/s03.edf',
 'dataverse_files/s04.edf',
 'dataverse_files/s05.edf',
 'dataverse_files/s06.edf',
 'dataverse_files/s07.edf',
 'dataverse_files/s08.edf',
 'dataverse_files/s09.edf',
 'dataverse_files/s10.edf',
 'dataverse_files/s11.edf',
 'dataverse_files/s12.edf',
 'dataverse_files/s13.edf',
 'dataverse_files/s14.edf']

In [7]:
# Separate patient data from healthy (control) data by the starting letter

healthy_path = [i for i in all_file_path if 'h' in i.split('/')[-1]]
patient_path = [i for i in all_file_path if 's' in i.split('/')[-1]]

print(len(healthy_path))
print(len(patient_path))

14
14


In [12]:
# Write a function to read data

def read_data(file_path):
  data = mne.io.read_raw_edf(file_path, preload=True)
  data.set_eeg_reference()
  data.filter(l_freq=0.5, h_freq=45)
  epochs = mne.make_fixed_length_epochs(data, duration=5, overlap=1)
  array = epochs.get_data()
  return array


In [13]:
sample = read_data(healthy_path[0])

Extracting EDF parameters from /content/drive/MyDrive/data_collection/dataverse_files/h01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 231249  =      0.000 ...   924.996 secs...
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 1651 samples (6.604 sec)

Not setting metadata
231 matching events found


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  19 out of  19 | elapsed:    0.2s finished


No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 231 events and 1250 original time points ...
0 bad epochs dropped


In [10]:
# check the shape: num of epochs, num of channel, signal length
sample.shape

(231, 19, 1250)

In [11]:
%%capture
control_arr = [read_data(i) for i in healthy_path]
patient_arr = [read_data(i) for i in patient_path]

In [14]:
control_arr[0].shape

(231, 19, 1250)

In [15]:
# Create labels
control_labels = [len(i)*[0] for i in control_arr]
patient_labels = [len(i)*[1] for i in patient_arr]

print(len(control_labels))
print(len(patient_labels))

14
14


In [16]:
# Combine data and labels
data = control_arr + patient_arr
labels = control_labels + patient_labels

In [17]:
group_list = [[i]*len(j) for i,j in enumerate(data)]

len(group_list)

28

In [None]:
group_list[0]

In [24]:
data_arr = np.vstack(data)
label_arr = np.hstack(labels)
group_arr = np.hstack(group_list)

print(data_arr.shape)
print(label_arr.shape)
print(group_arr.shape)

(7201, 19, 1250)
(7201,)
(7201,)


In [25]:
from scipy import stats

# Write udfs to make feature matrix

def mean(x):
  return np.mean(x, axis=-1)

def std(x):
  return np.std(x, axis=-1)

def ptp(x):
  return np.ptp(x, axis=-1)

def var(x):
  return np.var(x, axis=-1)

def minimum(x):
  return np.min(x, axis=-1)
def maximum(x):
  return np.max(x, axis=-1)

def argmin(x):
  return np.argmin(x, axis=-1)

def argmax(x):
  return np.argmax(x, axis=-1)

def rms(x):
  return np.sqrt(np.mean(x**2, axis=-1))

def abs_diff_signal(x):
  return np.sum(np.abs(np.diff(x, axis=-1)), axis=-1)

def skewness(x):
  return stats.skew(x, axis=-1)

def kurtosis(x):
  return stats.kurtosis(x, axis=-1)

def concatenate_features(x):
  return np.concatenate((mean(x),
                         std(x),
                         ptp(x),
                         var(x),
                         minimum(x),
                         maximum(x),
                         argmin(x),
                         argmax(x),
                         rms(x),
                         abs_diff_signal(x),
                         skewness(x),
                         kurtosis(x)),
                        axis=-1)

In [26]:
# Create features

features = []

for d in data_arr:
  features.append(concatenate_features(d))

In [28]:
features_arr = np.array(features) # Convert into an np array
features_arr.shape

(7201, 228)

#Now try a few classification models
##logistic regresion

In [29]:
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GroupKFold, GridSearchCV

In [30]:
clf = LogisticRegression()

gkf=GroupKFold(5)

pipeline = Pipeline( [('scaler', StandardScaler()), 
                      ('clf', clf)] )

param_grid = {'clf__C': [0.1, 0.5, 0.7, 1, 3, 5, 7]}

gscv = GridSearchCV(pipeline, param_grid, cv=gkf, n_jobs=12)

gscv.fit(features_arr, label_arr, groups=group_arr)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


GridSearchCV(cv=GroupKFold(n_splits=5),
             estimator=Pipeline(steps=[('scaler', StandardScaler()),
                                       ('clf', LogisticRegression())]),
             n_jobs=12, param_grid={'clf__C': [0.1, 0.5, 0.7, 1, 3, 5, 7]})

In [32]:
gscv.best_score_

0.641518291987081