## 1. Data Downloads

### Warning: Executing these blocks will automatically create directories and download datasets.

In [4]:
import requests
import re
import os
import pathlib
import urllib

In [2]:
CONTEXT = 'pn4/'
MATERIAL = 'eegmmidb/'
URL = 'https://www.physionet.org/' + CONTEXT + MATERIAL

USERDIR = '/Users/Jimmy/data/PhysioNet/'

page = requests.get(URL).text
FOLDERS = sorted(list(set(re.findall(r'S[0-9]+', page))))

URLS = [URL+x+'/' for x in FOLDERS]

In [3]:
#for folder in FOLDERS:
#    pathlib.Path(USERDIR + folder).mkdir(parents=True, exist_ok=True)

In [7]:
for i, folder in enumerate(FOLDERS):
    page = requests.get(URLS[i]).text
    subs = list(set(re.findall(r'S[0-9]+R[0-9]+', page)))
    
    print('Working on {}, {:.1%} completed'.format(folder, (i+1)/len(FOLDERS)))
    for sub in subs:
        urllib.request.urlretrieve(URLS[i]+sub+'.edf', os.path.join(USERDIR, folder, sub+'.edf'))
        urllib.request.urlretrieve(URLS[i]+sub+'.edf.event', os.path.join(USERDIR, folder, sub+'.edf.event'))

Working on S077, 0.9% completed
Working on S078, 1.8% completed
Working on S079, 2.8% completed
Working on S080, 3.7% completed
Working on S081, 4.6% completed
Working on S082, 5.5% completed
Working on S083, 6.4% completed
Working on S084, 7.3% completed
Working on S085, 8.3% completed
Working on S086, 9.2% completed
Working on S087, 10.1% completed
Working on S088, 11.0% completed
Working on S089, 11.9% completed
Working on S090, 12.8% completed
Working on S091, 13.8% completed
Working on S092, 14.7% completed
Working on S093, 15.6% completed
Working on S094, 16.5% completed
Working on S095, 17.4% completed
Working on S096, 18.3% completed
Working on S097, 19.3% completed
Working on S098, 20.2% completed
Working on S099, 21.1% completed
Working on S100, 22.0% completed
Working on S101, 22.9% completed
Working on S102, 23.9% completed
Working on S103, 24.8% completed
Working on S104, 25.7% completed
Working on S105, 26.6% completed
Working on S106, 27.5% completed
Working on S107, 28.

## 2. Raw Data Import

In [42]:
import warnings
warnings.filterwarnings('ignore')
    
import numpy as np
import pandas as pd

from glob import glob

from mne import find_events, Epochs, concatenate_raws, pick_types
from mne.channels import read_montage
from mne.io import read_raw_edf

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# Get file paths
PATH = '/Users/jimmy/data/PhysioNet/'
SUBS = glob(PATH+ 'S[0-9]*')
FNAMES = sorted([x[-4:] for x in SUBS])

In [47]:
# Import sample data
raw = read_raw_edf(os.path.join('/Users/jimmy/data/PhysioNet/S100', 'S100R04.edf'), preload=True)
montage = read_montage('standard_1020')
raw.info['montage'] = montage

Extracting EDF parameters from /Users/jimmy/data/PhysioNet/S100/S100R04.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 15743  =      0.000 ...   122.992 secs...
EDF+ with overlapping events are not fully supported
EDF+ with overlapping events are not fully supported


In [44]:
raw.get_data(picks=picks).shape

(64, 15744)

In [48]:
raw.find_edf_events()

[[0.0, 5.125, 'T0'],
 [5.125, 5.125, 'T1'],
 [10.25, 5.125, 'T0'],
 [15.38, 5.125, 'T2'],
 [20.5, 5.125, 'T0'],
 [25.62, 5.125, 'T2'],
 [30.75, 5.125, 'T0'],
 [35.88, 5.125, 'T1'],
 [41.0, 5.125, 'T0'],
 [46.12, 5.125, 'T1'],
 [51.25, 5.125, 'T0'],
 [56.38, 5.125, 'T2'],
 [61.5, 5.125, 'T0'],
 [66.62, 5.125, 'T1'],
 [71.75, 5.125, 'T0'],
 [76.88, 5.125, 'T2'],
 [82.0, 5.125, 'T0'],
 [87.12, 5.125, 'T2'],
 [92.25, 5.125, 'T0'],
 [97.38, 5.125, 'T1'],
 [102.5, 5.125, 'T0'],
 [107.6, 5.125, 'T2'],
 [112.8, 5.125, 'T0'],
 [117.9, 5.125, 'T1']]

In [33]:
picks = pick_types(raw.info, eeg=True)

In [6]:
# Event codes mean different actions for two groups of runs
event_0 = '01,02'.split(',')
event_1 = '03,04,07,08,11,12'.split(',')
event_2 = '05,06,09,10,13,14'.split(',')

In [114]:
def get_data(subj_num=FNAMES, resampling=True):
    """Import each subject`s trials and make a 3D array
        The output shape: (Trial*Channel*Frames)
        
        Set 'resampling=False' to exclude some edf+ files recored
        at low sampling rate, 128Hz. Majority was sampled at 160Hz."""
    # To calculated the completion rate
    count=0
    
    # Initiate X, y
    X = []
    y = []
    
    for fname in subj_num:
        count+=1
        print('working on {}, {:.1%} completed'.format(fname, count/len(subj_num)))
        
        fnames = glob(os.path.join(PATH, fname, fname+'R*'+'.edf'))
    

        for i, fname in enumerate(fnames):
            
            # Import data into MNE raw object
            raw = read_raw_edf(fname, preload=True, verbose=False)
            picks = pick_types(raw.info, eeg=True)
            
            # Resampling
            if resampling:
                events = find_events(raw, initial_event=True, verbose=False)
                raw, events = raw.copy().resample(128, npad='auto', events=events, verbose=False)
                sfreq = 120
                
            else:
                sfreq = 160
                if raw.info['sfreq'] != 160:
                    print(f'{fname} is sampled at 128Hz so will be excluded.')
                    pass
            
            # High-pass filtering
            raw.filter(l_freq=1, h_freq=None, picks=picks)
            
            # Get annotation
            events = raw.find_edf_events()
            
            # Get data
            data = raw.get_data(picks=picks)
            
            # Epoch period
            epoch_sec = 4
        
            # Experiment number 0,1
            if fname[-6:-4] in event_0:
                for n in range(15):
                    
                    X.append(data[:, int(sfreq*epoch_sec*n):int(sfreq*epoch_sec*(n+1))])
                    y.append(0)
                    
                    if X[-1].shape != (64, sfreq*epoch_sec): print(F'shape error!: {fname}, {X[-1].shape}') 
                    
            # Experiment number 3,4,7,8,11,12        
            elif fname[-6:-4] in event_1:
                for n in range(len(events)):
                
                    if events[n][2] == 'T0':
                        y.append(0)
                    elif events[n][2] == 'T1':
                        y.append(1)
                    else:
                        y.append(2)
                    
                    X.append(data[:, int(events[n][0]*sfreq):int(events[n][0]*sfreq)+sfreq*epoch_sec])
                    if X[-1].shape != (64, sfreq*epoch_sec): print(F'shape error!: {fname}, {X[-1].shape}')
                        
            # Experiment number 5,6,9,10,13,14
            else:
                for n in range(len(events)):
            
                    if events[n][2] == 'T0':
                        y.append(0)
                    elif events[n][2] == 'T1':
                        y.append(3)
                    else:
                        y.append(4)
                
                    X.append(data[:, int(events[n][0]*sfreq):int(events[n][0]*sfreq)+sfreq*epoch_sec])
                    if X[-1].shape != (64, sfreq*epoch_sec): print(F'shape error!: {fname}, {X[-1].shape}')
                        
    X = np.stack(X)
    y = np.array(y).reshape((-1,1))
    return X, y

In [115]:
X,y = get_data(FNAMES, resampling=True)

working on S001, 0.9% completed
working on S002, 1.8% completed
working on S003, 2.8% completed
working on S004, 3.7% completed
working on S005, 4.6% completed
working on S006, 5.5% completed
working on S007, 6.4% completed
working on S008, 7.3% completed
working on S009, 8.3% completed
working on S010, 9.2% completed
working on S011, 10.1% completed
working on S012, 11.0% completed
working on S013, 11.9% completed
working on S014, 12.8% completed
working on S015, 13.8% completed
working on S016, 14.7% completed
working on S017, 15.6% completed
working on S018, 16.5% completed
working on S019, 17.4% completed
working on S020, 18.3% completed
working on S021, 19.3% completed
working on S022, 20.2% completed
working on S023, 21.1% completed
working on S024, 22.0% completed
working on S025, 22.9% completed
working on S026, 23.9% completed
working on S027, 24.8% completed
working on S028, 25.7% completed
working on S029, 26.6% completed
working on S030, 27.5% completed
working on S031, 28.

EDF+ with overlapping events are not fully supported
EDF+ with overlapping events are not fully supported
EDF+ with overlapping events are not fully supported
EDF+ with overlapping events are not fully supported
EDF+ with overlapping events are not fully supported
EDF+ with overlapping events are not fully supported
EDF+ with overlapping events are not fully supported
EDF+ with overlapping events are not fully supported
EDF+ with overlapping events are not fully supported
EDF+ with overlapping events are not fully supported
working on S101, 92.7% completed
working on S102, 93.6% completed
working on S103, 94.5% completed
working on S104, 95.4% completed
working on S105, 96.3% completed
working on S106, 97.2% completed
working on S107, 98.2% completed
working on S108, 99.1% completed
working on S109, 100.0% completed


In [117]:
X.shape

(42619, 64, 480)

In [118]:
y.shape

(42619, 1)

## 3. Modeling

In [157]:
from keras.layers import Conv1D, Dense, Flatten, MaxPool1D, Activation
from keras.models import Sequential
from sklearn.preprocessing import OneHotEncoder 

In [139]:
oh = OneHotEncoder()
y = oh.fit_transform(y).toarray()

In [159]:
model = Sequential()
model.add(Conv1D(filters=10, input_shape=(64,480), kernel_size=5, padding='valid', strides=2, activation='elu'))
model.add(MaxPool1D(pool_size=(4), strides=2))
model.add(Flatten())
model.add(Dense(64, activation='elu'))
model.add(Dense(5, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

print(model.summary())
model.fit(X,y, batch_size=64, epochs=1)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv1d_3 (Conv1D)            (None, 30, 10)            24010     
_________________________________________________________________
max_pooling1d_2 (MaxPooling1 (None, 14, 10)            0         
_________________________________________________________________
flatten_8 (Flatten)          (None, 140)               0         
_________________________________________________________________
dense_15 (Dense)             (None, 64)                9024      
_________________________________________________________________
dense_16 (Dense)             (None, 5)                 325       
Total params: 33,359
Trainable params: 33,359
Non-trainable params: 0
_________________________________________________________________
None
Epoch 1/1


<keras.callbacks.History at 0x1c1ac85128>

In [154]:
model = Sequential()
model.add(Dense(128, input_shape=(64,480), activation='elu'))
model.add(Flatten())
model.add(Dense(64, activation='elu'))
model.add(Dense(5, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_12 (Dense)             (None, 64, 128)           61568     
_________________________________________________________________
flatten_7 (Flatten)          (None, 8192)              0         
_________________________________________________________________
dense_13 (Dense)             (None, 64)                524352    
_________________________________________________________________
dense_14 (Dense)             (None, 5)                 325       
Total params: 586,245
Trainable params: 586,245
Non-trainable params: 0
_________________________________________________________________


In [155]:
model.fit(X, y, batch_size=64, epochs=1)

Epoch 1/1


<keras.callbacks.History at 0x1c1b7bc240>

In [None]:
model.predict()