# EEG Classification

#Import packages

In [1]:
from glob import glob
import os
import mne
import numpy as np
import pandas
import matplotlib.pyplot as plt

#Import the files and create a variable

In [2]:
all_file_path = glob('dataverse_files/*.edf')
print(len(all_file_path))

28


In [3]:
all_file_path[0]

'dataverse_files\\h01.edf'

#Separate healthy and shizophrenia patients

In [4]:
healthy_file_path=[i for i in all_file_path if 'h' in i.split('\\')[1]]
patient_file_path=[i for i in all_file_path if 's' in i.split('\\')[1]]
print(len(healthy_file_path),len(patient_file_path))

14 14


#Filter and segment data

In [5]:
def read_data(file_path):
    data=mne.io.read_raw_edf(file_path,preload=True)
    data.set_eeg_reference()
    data.filter(l_freq=0.5,h_freq=45)
    epochs=mne.make_fixed_length_epochs(data,duration=5,overlap=1)
    array=epochs.get_data()
    return array

In [6]:
sample_data=read_data(healthy_file_path[0])

Extracting EDF parameters from C:\Users\DIPANKAR KARMAKAR\dataverse_files\h01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 231249  =      0.000 ...   924.996 secs...
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 1651 samples (6.604 s)



[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s


Not setting metadata
231 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 231 events and 1250 original time points ...


[Parallel(n_jobs=1)]: Done  19 out of  19 | elapsed:    0.5s finished


0 bad epochs dropped


In [7]:
sample_data.shape # number of epochs, channels, length of signal

(231, 19, 1250)

In [8]:
%%capture
control_epochs_array=[read_data(i) for i in healthy_file_path]
patient_epochs_array=[read_data(i) for i in patient_file_path]

In [9]:
control_epochs_array[0].shape,control_epochs_array[1].shape

((231, 19, 1250), (227, 19, 1250))

In [10]:
control_epoch_labels=[len(i)*[0] for i in control_epochs_array]
patient_epoch_labels=[len(i)*[0] for i in patient_epochs_array]
len(control_epoch_labels),len(patient_epoch_labels)

(14, 14)

In [11]:
data_list=control_epochs_array+patient_epochs_array
label_list=control_epoch_labels+patient_epoch_labels

In [12]:
group_list=[[i]*len(j) for i,j in enumerate(data_list)]
len(group_list)

28

In [13]:
data_array=np.vstack(data_list)
label_array=np.hstack(label_list)
group_array=np.hstack(group_list)
print(data_array.shape,label_array.shape,group_array.shape)

(7201, 19, 1250) (7201,) (7201,)


# Feature Extraction


In [14]:
from scipy import stats
def mean(x):
    return np.mean(x,axis=-1)
def std(x):
    return np.std(x,axis=-1)
def ptp(x):
    return np.ptp(x,axis=-1)
def var(x):
    return np.var(x,axis=-1)
def minim(x):
    return np.min(x,axis=-1)
def maxim(x):
    return np.max(x,axis=-1)

def argminim(x):
    return np.argmin(x,axis=-1)

def argmaxim(x):
    return np.argmax(x,axis=-1)

def rms(x):
    return np.sqrt(np.mean(x**2,axis=-1))

def abs_diff_signal(x):
    return np.sum(np.abs(np.diff(x,axis=-1)),axis=-1)
def skewness(x):
    return stats.skew(x,axis=-1)
def kurtosis(x):
    return stats.kurtosis(x,axis=-1)

def concatenate_features(x):
    return np.concatenate((mean(x),std(x),ptp(x),var(x),minim(x),maxim(x),argminim(x),argmaxim(x),rms(x),abs_diff_signal(x),skewness(x),kurtosis(x)),axis=-1)

In [15]:
features=[]
for d in data_array:
    features.append(concatenate_features(d))

In [16]:
features_array=np.array(features)
features_array.shape

(7201, 228)