In [1]:
import numpy as np
import matplotlib.pyplot as plt
import mne
import os
import sys
from mne.datasets import eegbci
import glob
from IPython.display import clear_output
import numpy as np

from mne.datasets import eegbci

In [2]:
class EEG:
    def __init__(self, path, base_url, subjects, runs):
        self.subpath = 'MNE-eegbci-data/files/eegmmidb/1.0.0'
        self.path = path
        self.base_url = base_url
        self.subjects = subjects
        self.runs = runs
        
        # download data if does not exist in path.
        self.load_data()
        self.data_to_raw()
    
    def load_data(self):
        print(f">>> Start download from: {self.base_url}.")
        print(f"Downloading files to: {self.path}.")
        for subject in self.subjects:
            eegbci.load_data(subject,self.runs,path=self.path,base_url=self.base_url)
        print("Done.")
    
    def data_to_raw(self):
        fullpath = os.path.join(self.path, *self.subpath.split(sep='/'))
        print(f">>> Extract all subjects from: {fullpath}.")
        extension = "edf"
        raws = []
        count = 1
        for i, subject in enumerate(self.subjects):
            sname = f"S{str(subject).zfill(3)}".upper()
            for j, run in enumerate(self.runs):
                rname = f"{sname}R{str(run).zfill(2)}".upper()
                path_file = os.path.join(fullpath, sname, f'{rname}.{extension}')
                print(f"Loading file #{count}/{len(self.subjects)*len(self.runs)}: {f'{rname}.{extension}'}")
                raw = mne.io.read_raw_edf( path_file , preload=True, verbose='WARNING' )
                raws.append(raw)
                count += 1

        raw = mne.io.concatenate_raws(raws)         # Concatenate all EEGs together
        eegbci.standardize(raw)                     # Standardize channel names
        montage = mne.channels.make_standard_montage('standard_1005')
        raw.set_montage(montage)
        self.raw = raw
        print("Done.")
    
    def filter(self, freq):
        low, high = freq
        print(f">>> Apply filter.")
        self.raw.filter(low, high, fir_design='firwin', verbose=20)
        
    def get_events(self):
        event_id = dict(T1=0, T2=1) # the events we want to extract
        events, event_id = mne.events_from_annotations(self.raw, event_id=event_id)
        return events, event_id
    
    def get_epochs(self, events, event_id):
        picks = mne.pick_types(self.raw.info, eeg=True, exclude='bads')
        tmin = -1
        tmax = 4 
        epochs = mne.Epochs(self.raw, events, event_id, tmin, tmax, proj=True, 
                            picks=picks, baseline=None, preload=True)
        return epochs
    
    def create_epochs(self):
        print(">>> Create Epochs.")
        events, event_id = self.get_events()
        self.epochs = self.get_epochs(events, event_id)
        print("Done.")
    
    def get_X_y(self):
        if self.epochs is None:
            self.create_epochs()
        self.X = self.epochs.get_data()
        self.y = self.epochs.events[:, -1]
        return self.X, self.y

In [3]:
# home directory + datasets folder
path = os.path.join(os.path.expanduser("~"), 'datasets')
base_url = 'https://physionet.org/files/eegmmidb/1.0.0/'
# subjects = [1]
runs = [3, 4, 7, 8]
subjects = [i for i in range(1, 11)]
# runs = [6,10,14]

eeg = EEG(path, base_url, subjects, runs)

# apply filter
freq = (8., 30.)
eeg.filter(freq=freq)

eeg.create_epochs()

>>> Start download from: https://physionet.org/files/eegmmidb/1.0.0/.
Downloading files to: C:\Users\supha\datasets.
Done.
>>> Extract all subjects from: C:\Users\supha\datasets\MNE-eegbci-data\files\eegmmidb\1.0.0.
Loading file #1/40: S001R03.edf
Loading file #2/40: S001R04.edf
Loading file #3/40: S001R07.edf
Loading file #4/40: S001R08.edf
Loading file #5/40: S002R03.edf
Loading file #6/40: S002R04.edf
Loading file #7/40: S002R07.edf
Loading file #8/40: S002R08.edf
Loading file #9/40: S003R03.edf
Loading file #10/40: S003R04.edf
Loading file #11/40: S003R07.edf
Loading file #12/40: S003R08.edf
Loading file #13/40: S004R03.edf
Loading file #14/40: S004R04.edf
Loading file #15/40: S004R07.edf
Loading file #16/40: S004R08.edf
Loading file #17/40: S005R03.edf
Loading file #18/40: S005R04.edf
Loading file #19/40: S005R07.edf
Loading file #20/40: S005R08.edf
Loading file #21/40: S006R03.edf
Loading file #22/40: S006R04.edf
Loading file #23/40: S006R07.edf
Loading file #24/40: S006R08.edf
L

In [4]:
X, y = eeg.get_X_y()

print(X.shape, y.shape)

(600, 64, 801) (600,)
