In [1]:
import logging
import os.path
import time
from collections import OrderedDict
import sys

import numpy as np
import torch.nn.functional as F
from torch import optim

from braindecode.models.deep4 import Deep4Net
from braindecode.datasets.bcic_iv_2a import BCICompetition4Set2A
from braindecode.experiments.experiment import Experiment
from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, \
    RuntimeMonitor
from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or
from braindecode.datautil.iterators import BalancedBatchSizeIterator
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from braindecode.datautil.splitters import split_into_two_sets
from braindecode.torch_ext.constraints import MaxNormDefaultConstraint
from braindecode.torch_ext.util import set_random_seeds, np_to_var
from braindecode.mne_ext.signalproc import mne_apply
from braindecode.datautil.signalproc import (bandpass_cnt,
                                             exponential_running_standardize)
from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne

log = logging.getLogger(__name__)

<p>Here's the description from the paper</p>
<img src="DeepConvNet.png" style="width: 700px; float:left;">

<p>Here's the description from the paper</p>
<img src="DeepConvNetDetail.png" style="width: 700px; float:left;">

In [2]:
def run_exp(data_folder, subject_id, low_cut_hz, model, cuda):
    ival = [-500, 4000]
    max_epochs = 100
    max_increase_epochs = 16
    batch_size = 60
    high_cut_hz = 38
    factor_new = 1e-3
    init_block_size = 1000
    valid_set_fraction = 0.2

    #Loading data 
    
    train_filename = 'A{:02d}T.gdf'.format(subject_id)
    test_filename = 'A{:02d}E.gdf'.format(subject_id)
    train_filepath = os.path.join(data_folder, train_filename)
    test_filepath = os.path.join(data_folder, test_filename)
    train_label_filepath = train_filepath.replace('.gdf', '.mat')
    test_label_filepath = test_filepath.replace('.gdf', '.mat')

    train_loader = BCICompetition4Set2A(
        train_filepath, labels_filename=train_label_filepath)
    test_loader = BCICompetition4Set2A(
        test_filepath, labels_filename=test_label_filepath)
    train_cnt = train_loader.load()
    test_cnt = test_loader.load()

    # Preprocessing

    train_cnt = train_cnt.drop_channels(['STI 014', 'EOG-left',
                                         'EOG-central', 'EOG-right'])
    assert len(train_cnt.ch_names) == 22
    # lets convert to millvolt for numerical stability of next operations
    train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
    train_cnt = mne_apply(
        lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, train_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), train_cnt)
    train_cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T, factor_new=factor_new,
                                                  init_block_size=init_block_size,
                                                  eps=1e-4).T,
        train_cnt)

    test_cnt = test_cnt.drop_channels(['STI 014', 'EOG-left',
                                       'EOG-central', 'EOG-right'])
    assert len(test_cnt.ch_names) == 22
    test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
    test_cnt = mne_apply(
        lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, test_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), test_cnt)
    test_cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T, factor_new=factor_new,
                                                  init_block_size=init_block_size,
                                                  eps=1e-4).T,
        test_cnt)

    marker_def = OrderedDict([('Left Hand', [1]), ('Right Hand', [2],),
                              ('Foot', [3]), ('Tongue', [4])])

    train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
    test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)

    train_set, valid_set = split_into_two_sets(
        train_set, first_set_fraction=1-valid_set_fraction)

    set_random_seeds(seed=20190706, cuda=cuda)
    
    #Run experiment
    n_classes = 4
    n_chans = int(train_set.X.shape[1])
    input_time_length = train_set.X.shape[2]
    if model == 'shallow':
        model = ShallowFBCSPNet(n_chans, n_classes, input_time_length=input_time_length,
                            final_conv_length='auto').create_network()
    elif model == 'deep':
        model = Deep4Net(n_chans, n_classes, input_time_length=input_time_length,
                            final_conv_length='auto').create_network()
    if cuda:
        model.cuda()
    #log.info("Model: \n{:s}".format(str(model)))

    optimizer = optim.Adam(model.parameters())

    iterator = BalancedBatchSizeIterator(batch_size=batch_size)

    stop_criterion = Or([MaxEpochs(max_epochs)
                         ,NoDecrease('valid_misclass', max_increase_epochs)])

    monitors = [LossMonitor(), MisclassMonitor(), RuntimeMonitor()]

    model_constraint = MaxNormDefaultConstraint()

    exp = Experiment(model, train_set, valid_set, test_set, iterator=iterator,
                     loss_function=F.nll_loss, optimizer=optimizer,
                     model_constraint=model_constraint,
                     monitors=monitors,
                     stop_criterion=stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True, cuda=cuda)
    exp.run()
    return exp

In [3]:
if __name__ == '__main__':
    logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                        level=logging.DEBUG, stream=sys.stdout)
    # Should contain both .gdf files and .mat-labelfiles from competition
    data_folder = './data/BCICIV_2a_gdf/'
    subject_id = 1 # 1-9
    low_cut_hz = 4 # 0 or 4
    model = 'deep' #'shallow' or 'deep'
    cuda = False
    exp = run_exp(data_folder, subject_id, low_cut_hz, model, cuda)
    log.info("Last 10 epochs")
    log.info("\n" + str(exp.epochs_df.iloc[-10:]))

Extracting EDF parameters from ./data/BCICIV_2a_gdf/A01T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')


Reading 0 ... 672527  =      0.000 ...  2690.108 secs...
Extracting EDF parameters from ./data/BCICIV_2a_gdf/A01E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 686999  =      0.000 ...  2747.996 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')


2019-02-10 11:54:09,309 INFO : Trial per class:
Counter({'Tongue': 72, 'Foot': 72, 'Right Hand': 72, 'Left Hand': 72})
2019-02-10 11:54:09,637 INFO : Trial per class:
Counter({'Left Hand': 72, 'Right Hand': 72, 'Foot': 72, 'Tongue': 72})
2019-02-10 11:54:10,047 INFO : Run until first stop...
2019-02-10 11:54:14,021 INFO : Epoch 0
2019-02-10 11:54:14,022 INFO : train_loss                1.44188
2019-02-10 11:54:14,023 INFO : valid_loss                1.40538
2019-02-10 11:54:14,023 INFO : test_loss                 1.42932
2019-02-10 11:54:14,024 INFO : train_misclass            0.74783
2019-02-10 11:54:14,024 INFO : valid_misclass            0.77586
2019-02-10 11:54:14,025 INFO : test_misclass             0.75347
2019-02-10 11:54:14,025 INFO : runtime                   0.00000
2019-02-10 11:54:14,026 INFO : 
2019-02-10 11:54:14,029 INFO : New best valid_misclass: 0.775862
2019-02-10 11:54:14,030 INFO : 
2019-02-10 11:54:18,168 INFO : Time only for training updates: 4.14s
2019-02-10 11:5

2019-02-10 11:55:43,747 INFO : runtime                   8.00628
2019-02-10 11:55:43,748 INFO : 
2019-02-10 11:55:43,755 INFO : New best valid_misclass: 0.775862
2019-02-10 11:55:43,756 INFO : 
2019-02-10 11:55:47,878 INFO : Time only for training updates: 4.12s
2019-02-10 11:55:51,722 INFO : Epoch 12
2019-02-10 11:55:51,722 INFO : train_loss                2.43393
2019-02-10 11:55:51,723 INFO : valid_loss                3.06059
2019-02-10 11:55:51,724 INFO : test_loss                 2.58014
2019-02-10 11:55:51,724 INFO : train_misclass            0.70000
2019-02-10 11:55:51,725 INFO : valid_misclass            0.77586
2019-02-10 11:55:51,725 INFO : test_misclass             0.71181
2019-02-10 11:55:51,726 INFO : runtime                   7.92919
2019-02-10 11:55:51,727 INFO : 
2019-02-10 11:55:51,735 INFO : New best valid_misclass: 0.775862
2019-02-10 11:55:51,736 INFO : 
2019-02-10 11:55:55,951 INFO : Time only for training updates: 4.21s
2019-02-10 11:55:59,763 INFO : Epoch 13
2019

2019-02-10 11:57:11,845 INFO : Time only for training updates: 3.39s
2019-02-10 11:57:14,760 INFO : Epoch 24
2019-02-10 11:57:14,761 INFO : train_loss                1.89408
2019-02-10 11:57:14,761 INFO : valid_loss                2.62887
2019-02-10 11:57:14,761 INFO : test_loss                 2.15680
2019-02-10 11:57:14,762 INFO : train_misclass            0.63913
2019-02-10 11:57:14,763 INFO : valid_misclass            0.74138
2019-02-10 11:57:14,763 INFO : test_misclass             0.65625
2019-02-10 11:57:14,764 INFO : runtime                   6.28800
2019-02-10 11:57:14,764 INFO : 
2019-02-10 11:57:18,132 INFO : Time only for training updates: 3.37s
2019-02-10 11:57:21,105 INFO : Epoch 25
2019-02-10 11:57:21,106 INFO : train_loss                1.86945
2019-02-10 11:57:21,107 INFO : valid_loss                2.62878
2019-02-10 11:57:21,107 INFO : test_loss                 2.13694
2019-02-10 11:57:21,108 INFO : train_misclass            0.63043
2019-02-10 11:57:21,108 INFO : vali

2019-02-10 11:58:30,359 INFO : 
2019-02-10 11:58:33,751 INFO : Time only for training updates: 3.39s
2019-02-10 11:58:36,647 INFO : Epoch 37
2019-02-10 11:58:36,648 INFO : train_loss                1.56343
2019-02-10 11:58:36,648 INFO : valid_loss                2.66270
2019-02-10 11:58:36,649 INFO : test_loss                 2.04908
2019-02-10 11:58:36,649 INFO : train_misclass            0.53043
2019-02-10 11:58:36,650 INFO : valid_misclass            0.72414
2019-02-10 11:58:36,650 INFO : test_misclass             0.57986
2019-02-10 11:58:36,651 INFO : runtime                   6.28776
2019-02-10 11:58:36,651 INFO : 
2019-02-10 11:58:40,026 INFO : Time only for training updates: 3.37s
2019-02-10 11:58:42,964 INFO : Epoch 38
2019-02-10 11:58:42,965 INFO : train_loss                1.47872
2019-02-10 11:58:42,965 INFO : valid_loss                2.61857
2019-02-10 11:58:42,966 INFO : test_loss                 1.98409
2019-02-10 11:58:42,966 INFO : train_misclass            0.51739
201

2019-02-10 12:00:01,248 INFO : valid_loss                2.28112
2019-02-10 12:00:01,248 INFO : test_loss                 1.69636
2019-02-10 12:00:01,249 INFO : train_misclass            0.39130
2019-02-10 12:00:01,249 INFO : valid_misclass            0.62069
2019-02-10 12:00:01,249 INFO : test_misclass             0.51389
2019-02-10 12:00:01,250 INFO : runtime                   6.99704
2019-02-10 12:00:01,250 INFO : 
2019-02-10 12:00:04,778 INFO : Time only for training updates: 3.53s
2019-02-10 12:00:07,904 INFO : Epoch 51
2019-02-10 12:00:07,905 INFO : train_loss                1.04676
2019-02-10 12:00:07,905 INFO : valid_loss                2.20684
2019-02-10 12:00:07,906 INFO : test_loss                 1.63942
2019-02-10 12:00:07,907 INFO : train_misclass            0.35652
2019-02-10 12:00:07,907 INFO : valid_misclass            0.60345
2019-02-10 12:00:07,908 INFO : test_misclass             0.53125
2019-02-10 12:00:07,908 INFO : runtime                   6.69318
2019-02-10 12:

2019-02-10 12:01:35,051 INFO : valid_loss                2.30893
2019-02-10 12:01:35,052 INFO : test_loss                 1.66307
2019-02-10 12:01:35,052 INFO : train_misclass            0.35217
2019-02-10 12:01:35,053 INFO : valid_misclass            0.62069
2019-02-10 12:01:35,053 INFO : test_misclass             0.51736
2019-02-10 12:01:35,054 INFO : runtime                   8.61481
2019-02-10 12:01:35,055 INFO : 
2019-02-10 12:01:39,464 INFO : Time only for training updates: 4.41s
2019-02-10 12:01:43,649 INFO : Epoch 64
2019-02-10 12:01:43,649 INFO : train_loss                0.83794
2019-02-10 12:01:43,650 INFO : valid_loss                2.14882
2019-02-10 12:01:43,651 INFO : test_loss                 1.56477
2019-02-10 12:01:43,651 INFO : train_misclass            0.31739
2019-02-10 12:01:43,652 INFO : valid_misclass            0.62069
2019-02-10 12:01:43,652 INFO : test_misclass             0.51042
2019-02-10 12:01:43,653 INFO : runtime                   8.57420
2019-02-10 12:

2019-02-10 12:03:27,195 INFO : test_misclass             0.45833
2019-02-10 12:03:27,196 INFO : runtime                   8.65366
2019-02-10 12:03:27,197 INFO : 
2019-02-10 12:03:31,625 INFO : Time only for training updates: 4.43s
2019-02-10 12:03:35,772 INFO : Epoch 77
2019-02-10 12:03:35,773 INFO : train_loss                0.51855
2019-02-10 12:03:35,773 INFO : valid_loss                1.87214
2019-02-10 12:03:35,774 INFO : test_loss                 1.44317
2019-02-10 12:03:35,775 INFO : train_misclass            0.21739
2019-02-10 12:03:35,775 INFO : valid_misclass            0.58621
2019-02-10 12:03:35,776 INFO : test_misclass             0.47569
2019-02-10 12:03:35,776 INFO : runtime                   8.52422
2019-02-10 12:03:35,777 INFO : 
2019-02-10 12:03:40,296 INFO : Time only for training updates: 4.52s
2019-02-10 12:03:44,443 INFO : Epoch 78
2019-02-10 12:03:44,444 INFO : train_loss                0.50474
2019-02-10 12:03:44,444 INFO : valid_loss                1.90540
201

2019-02-10 12:05:19,776 INFO : New best valid_misclass: 0.517241
2019-02-10 12:05:19,776 INFO : 
2019-02-10 12:05:24,197 INFO : Time only for training updates: 4.42s
2019-02-10 12:05:28,480 INFO : Epoch 90
2019-02-10 12:05:28,481 INFO : train_loss                0.26356
2019-02-10 12:05:28,481 INFO : valid_loss                1.59395
2019-02-10 12:05:28,482 INFO : test_loss                 1.19538
2019-02-10 12:05:28,482 INFO : train_misclass            0.12609
2019-02-10 12:05:28,483 INFO : valid_misclass            0.48276
2019-02-10 12:05:28,483 INFO : test_misclass             0.43056
2019-02-10 12:05:28,484 INFO : runtime                   8.59076
2019-02-10 12:05:28,485 INFO : 
2019-02-10 12:05:28,494 INFO : New best valid_misclass: 0.482759
2019-02-10 12:05:28,494 INFO : 
2019-02-10 12:05:32,948 INFO : Time only for training updates: 4.45s
2019-02-10 12:05:37,176 INFO : Epoch 91
2019-02-10 12:05:37,177 INFO : train_loss                0.31811
2019-02-10 12:05:37,178 INFO : valid

2019-02-10 12:07:20,339 INFO : Epoch 93
2019-02-10 12:07:20,339 INFO : train_loss                0.79550
2019-02-10 12:07:20,340 INFO : valid_loss                1.74296
2019-02-10 12:07:20,341 INFO : test_loss                 1.48954
2019-02-10 12:07:20,341 INFO : train_misclass            0.31250
2019-02-10 12:07:20,342 INFO : valid_misclass            0.51724
2019-02-10 12:07:20,342 INFO : test_misclass             0.48958
2019-02-10 12:07:20,343 INFO : runtime                   10.24871
2019-02-10 12:07:20,343 INFO : 
2019-02-10 12:07:25,885 INFO : Time only for training updates: 5.54s
2019-02-10 12:07:30,487 INFO : Epoch 94
2019-02-10 12:07:30,488 INFO : train_loss                0.72007
2019-02-10 12:07:30,489 INFO : valid_loss                1.49777
2019-02-10 12:07:30,489 INFO : test_loss                 1.40184
2019-02-10 12:07:30,490 INFO : train_misclass            0.28472
2019-02-10 12:07:30,490 INFO : valid_misclass            0.46552
2019-02-10 12:07:30,491 INFO : test_mi

2019-02-10 12:09:34,078 INFO : Epoch 106
2019-02-10 12:09:34,079 INFO : train_loss                0.34880
2019-02-10 12:09:34,079 INFO : valid_loss                0.64383
2019-02-10 12:09:34,080 INFO : test_loss                 1.01905
2019-02-10 12:09:34,081 INFO : train_misclass            0.15972
2019-02-10 12:09:34,081 INFO : valid_misclass            0.27586
2019-02-10 12:09:34,082 INFO : test_misclass             0.38542
2019-02-10 12:09:34,082 INFO : runtime                   10.29368
2019-02-10 12:09:34,083 INFO : 
2019-02-10 12:09:34,092 INFO : New best valid_misclass: 0.275862
2019-02-10 12:09:34,093 INFO : 
2019-02-10 12:09:39,741 INFO : Time only for training updates: 5.65s
2019-02-10 12:09:44,355 INFO : Epoch 107
2019-02-10 12:09:44,356 INFO : train_loss                0.40467
2019-02-10 12:09:44,356 INFO : valid_loss                0.73012
2019-02-10 12:09:44,357 INFO : test_loss                 1.08364
2019-02-10 12:09:44,357 INFO : train_misclass            0.16667
2019

2019-02-10 12:11:47,735 INFO : Epoch 119
2019-02-10 12:11:47,736 INFO : train_loss                0.25870
2019-02-10 12:11:47,737 INFO : valid_loss                0.40620
2019-02-10 12:11:47,737 INFO : test_loss                 0.96800
2019-02-10 12:11:47,738 INFO : train_misclass            0.11458
2019-02-10 12:11:47,739 INFO : valid_misclass            0.20690
2019-02-10 12:11:47,739 INFO : test_misclass             0.39931
2019-02-10 12:11:47,740 INFO : runtime                   10.23401
2019-02-10 12:11:47,740 INFO : 
2019-02-10 12:11:47,750 INFO : New best valid_misclass: 0.206897
2019-02-10 12:11:47,750 INFO : 
2019-02-10 12:11:53,322 INFO : Time only for training updates: 5.57s
2019-02-10 12:11:58,025 INFO : Epoch 120
2019-02-10 12:11:58,025 INFO : train_loss                0.25331
2019-02-10 12:11:58,026 INFO : valid_loss                0.39786
2019-02-10 12:11:58,027 INFO : test_loss                 0.97656
2019-02-10 12:11:58,027 INFO : train_misclass            0.12847
2019

In [4]:
data_folder = './data/BCICIV_2a_gdf/'
train_filename = 'A{:02d}T.gdf'.format(subject_id)
test_filename = 'A{:02d}E.gdf'.format(subject_id)
train_filepath = os.path.join(data_folder, train_filename)
test_filepath = os.path.join(data_folder, test_filename)
train_label_filepath = train_filepath.replace('.gdf', '.mat')
test_label_filepath = test_filepath.replace('.gdf', '.mat')
train_loader = BCICompetition4Set2A(
        train_filepath, labels_filename=train_label_filepath)

# cnt = self.extract_data()
# events, artifact_trial_mask = self.extract_events(cnt)
# cnt.info['events'] = events
# cnt.info['artifact_trial_mask'] = artifact_trial_mask



In [5]:
import mne
raw_edf = mne.io.read_raw_edf(train_filepath, stim_channel='auto')
raw_edf.load_data()
data = raw_edf.get_data()

Extracting EDF parameters from ./data/BCICIV_2a_gdf/A01T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 672527  =      0.000 ...  2690.108 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]
  raw_edf = mne.io.read_raw_edf(train_filepath, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(train_filepath, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(train_filepath, stim_channel='auto')


In [6]:
raw_edf.info

<Info | 16 non-empty fields
    bads : list | 0 items
    ch_names : list | EEG-Fz, EEG-0, EEG-1, EEG-2, EEG-3, EEG-4, EEG-5, ...
    chs : list | 26 items (EEG: 25, STIM: 1)
    comps : list | 0 items
    custom_ref_applied : bool | False
    dev_head_t : Transform | 3 items
    events : list | 0 items
    highpass : float | 0.5 Hz
    hpi_meas : list | 0 items
    hpi_results : list | 0 items
    lowpass : float | 100.0 Hz
    meas_date : tuple | 2005-01-17 12:00:00 GMT
    nchan : int | 26
    proc_history : list | 0 items
    projs : list | 0 items
    sfreq : float | 250.0 Hz
    acq_pars : NoneType
    acq_stim : NoneType
    ctf_head_t : NoneType
    description : NoneType
    dev_ctf_t : NoneType
    dig : NoneType
    experimenter : NoneType
    file_id : NoneType
    gantry_angle : NoneType
    hpi_subsystem : NoneType
    kit_system_id : NoneType
    line_freq : NoneType
    meas_id : NoneType
    proj_id : NoneType
    proj_name : NoneType
    subject_info : NoneType
    xp

In [7]:
data.shape

(26, 672528)