In [1]:
import logging
import os.path
import time
from collections import OrderedDict
import sys

import numpy as np
import torch.nn.functional as F
from torch import optim

from braindecode.models.deep4 import Deep4Net
from braindecode.datasets.bcic_iv_2a import BCICompetition4Set2A
from braindecode.experiments.experiment import Experiment
from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, \
    RuntimeMonitor
from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or
from braindecode.datautil.iterators import BalancedBatchSizeIterator
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from braindecode.datautil.splitters import split_into_two_sets
from braindecode.torch_ext.constraints import MaxNormDefaultConstraint
from braindecode.torch_ext.util import set_random_seeds, np_to_var
from braindecode.mne_ext.signalproc import mne_apply
from braindecode.datautil.signalproc import (bandpass_cnt,
                                             exponential_running_standardize)
from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne
from braindecode.datautil.signal_target import SignalAndTarget

import warnings
warnings.filterwarnings("ignore")

log = logging.getLogger(__name__)
# logging.basicConfig(filename="./outputlog.csv")

In [2]:
data_folder = '/Users/debaojian/OneDrive/OneDrive - UNSW/UNSW/Research/EEG/BCI_Competition/'
subject_id = 1 
low_cut_hz = 4 
ival = [-500, 4000]
max_epochs = 1600
max_increase_epochs = 160
batch_size = 60
high_cut_hz = 38
factor_new = 1e-3
init_block_size = 1000
valid_set_fraction = 0.2
marker_def = OrderedDict([('Left Hand', [1]), ('Right Hand', [2],),
                        ('Foot', [3]), ('Tongue', [4])])   

In [3]:
def load_subject_data(subject_id):
    train_filename = 'A{:02d}T.gdf'.format(subject_id)
    test_filename = 'A{:02d}E.gdf'.format(subject_id)
    train_filepath = os.path.join(data_folder, train_filename)
    test_filepath = os.path.join(data_folder, test_filename)
    train_label_filepath = train_filepath.replace('.gdf', '.mat')
    test_label_filepath = test_filepath.replace('.gdf', '.mat')

    train_loader = BCICompetition4Set2A(
        train_filepath, labels_filename=train_label_filepath)
    test_loader = BCICompetition4Set2A(
        test_filepath, labels_filename=test_label_filepath)
    train_cnt = train_loader.load()
    test_cnt = test_loader.load()
    
    train_cnt = data_preprocessing(train_cnt)
    test_cnt = data_preprocessing(test_cnt)
    
    train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
    test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)
    
    return SignalAndTarget(np.concatenate((train_set.X, test_set.X), axis=0),
                           np.concatenate((train_set.y, test_set.y), axis=0))

In [4]:
def data_preprocessing(train_cnt):
    train_cnt = train_cnt.drop_channels(['STI 014', 'EOG-left',
                                         'EOG-central', 'EOG-right'])
    assert len(train_cnt.ch_names) == 22
    # lets convert to millvolt for numerical stability of next operations
    train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
    train_cnt = mne_apply(
        lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, train_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), train_cnt)
    train_cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T, factor_new=factor_new,
                                                  init_block_size=init_block_size,
                                                  eps=1e-4).T,
        train_cnt)
    return train_cnt

In [5]:
def load_train_data(subjects):
    data = load_subject_data(subjects.pop(0))
    x = data.X
    y = data.y
    for s in subjects:
        data = load_subject_data(s)
        x = np.concatenate((x, data.X), axis=0)
        y = np.concatenate((y, data.y), axis=0)
    return SignalAndTarget(x, y)      

In [12]:
def load_single_data(subject_id):
    train_filename = 'A{:02d}T.gdf'.format(subject_id)
    test_filename = 'A{:02d}E.gdf'.format(subject_id)
    train_filepath = os.path.join(data_folder, train_filename)
    test_filepath = os.path.join(data_folder, test_filename)
    train_label_filepath = train_filepath.replace('.gdf', '.mat')
    test_label_filepath = test_filepath.replace('.gdf', '.mat')

    train_loader = BCICompetition4Set2A(
        train_filepath, labels_filename=train_label_filepath)
    test_loader = BCICompetition4Set2A(
        test_filepath, labels_filename=test_label_filepath)
    train_cnt = train_loader.load()
    test_cnt = test_loader.load()
    
    train_cnt = data_preprocessing(train_cnt)
    test_cnt = data_preprocessing(test_cnt)
    
    train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
    test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)
    
    return train_set, test_set

In [7]:
def run_exp(train_data, test_data, low_cut_hz, model, cuda):
    batch_size = 60
    train_set, valid_set = split_into_two_sets(
        train_data, first_set_fraction=1-valid_set_fraction)

    set_random_seeds(seed=20181221, cuda=cuda)

    n_classes = 4
    n_chans = int(train_set.X.shape[1])
    input_time_length = train_set.X.shape[2]
    if model == 'shallow':
        model = ShallowFBCSPNet(n_chans, n_classes, input_time_length=input_time_length,
                            final_conv_length='auto').create_network()
    elif model == 'deep':
        model = Deep4Net(n_chans, n_classes, input_time_length=input_time_length,
                            final_conv_length='auto').create_network()
    if cuda:
        model.cuda()
    log.info("Model: \n{:s}".format(str(model)))

    optimizer = optim.Adam(model.parameters())

    iterator = BalancedBatchSizeIterator(batch_size=batch_size)

    stop_criterion = Or([MaxEpochs(max_epochs),
                         NoDecrease('valid_misclass', max_increase_epochs)])

    monitors = [LossMonitor(), MisclassMonitor(), RuntimeMonitor()]

    model_constraint = MaxNormDefaultConstraint()

    exp = Experiment(model, train_set, valid_set, test_data, iterator=iterator,
                     loss_function=F.nll_loss, optimizer=optimizer,
                     model_constraint=model_constraint,
                     monitors=monitors,
                     stop_criterion=stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True, cuda=cuda)
    exp.run()
    return exp

In [42]:
out = 1
subjects_list = [i for i in range(1, 10)]
subjects_list.remove(out)
train_data = load_train_data(subjects_list)
test_data = load_subject_data(out)

Extracting EDF parameters from /Users/debaojian/OneDrive/OneDrive - UNSW/UNSW/Research/EEG/BCI_Competition/A02T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 677168  =      0.000 ...  2708.672 secs...
Extracting EDF parameters from /Users/debaojian/OneDrive/OneDrive - UNSW/UNSW/Research/EEG/BCI_Competition/A02E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 662665  =      0.000 ...  2650.660 secs...
2018-12-22 19:41:37,214 INFO : Trial per class:
Counter({'Left Hand': 72, 'Right Hand': 72, 'Foot': 72, 'Tongue': 72})
2018-12-22 19:41:37,470 INFO : Trial per class:
Counter({'Left Hand': 72, 'Right Hand': 72, 'Foot': 72, 'Tongue': 72})
Extracting EDF parameters from /Users/debaojian/OneDrive/OneDrive - UNSW/UNSW/Research/EEG/BCI_Competition/A03T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 660529  =      0.000 ...  26

In [23]:
train_data.X.shape

(4608, 22, 1125)

In [None]:
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                        level=logging.DEBUG, stream=sys.stdout)
cuda = True
exp = run_exp(train_data, test_data, low_cut_hz, "deep", cuda)
log.info("Last 10 epochs")
log.info("\n" + str(exp.epochs_df.iloc[-10:]))

2018-12-22 20:30:54,692 INFO : Model: 
Sequential(
  (dimshuffle): Expression(expression=_transpose_time_to_spat)
  (conv_time): Conv2d(1, 25, kernel_size=(10, 1), stride=(1, 1))
  (conv_spat): Conv2d(25, 25, kernel_size=(1, 22), stride=(1, 1), bias=False)
  (bnorm): BatchNorm2d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_nonlin): Expression(expression=elu)
  (pool): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)
  (pool_nonlin): Expression(expression=identity)
  (drop_2): Dropout(p=0.5)
  (conv_2): Conv2d(25, 50, kernel_size=(10, 1), stride=(1, 1), bias=False)
  (bnorm_2): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (nonlin_2): Expression(expression=elu)
  (pool_2): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)
  (pool_nonlin_2): Expression(expression=identity)
  (drop_3): Dropout(p=0.5)
  (conv_3): Conv2d(50, 100, kernel_size=(10, 1), stri

2018-12-22 20:32:15,786 INFO : test_loss                 1.03192
2018-12-22 20:32:15,786 INFO : train_misclass            0.59848
2018-12-22 20:32:15,787 INFO : valid_misclass            0.50108
2018-12-22 20:32:15,787 INFO : test_misclass             0.47917
2018-12-22 20:32:15,788 INFO : runtime                   8.56870
2018-12-22 20:32:15,789 INFO : 
2018-12-22 20:32:15,794 INFO : New best valid_misclass: 0.501085
2018-12-22 20:32:15,794 INFO : 
2018-12-22 20:32:20,774 INFO : Time only for training updates: 4.98s
2018-12-22 20:32:24,355 INFO : Epoch 10
2018-12-22 20:32:24,356 INFO : train_loss                1.25207
2018-12-22 20:32:24,356 INFO : valid_loss                1.05866
2018-12-22 20:32:24,356 INFO : test_loss                 0.94292
2018-12-22 20:32:24,357 INFO : train_misclass            0.57922
2018-12-22 20:32:24,358 INFO : valid_misclass            0.44360
2018-12-22 20:32:24,359 INFO : test_misclass             0.45486
2018-12-22 20:32:24,360 INFO : runtime         

2018-12-22 20:34:07,251 INFO : train_misclass            0.38877
2018-12-22 20:34:07,251 INFO : valid_misclass            0.43492
2018-12-22 20:34:07,252 INFO : test_misclass             0.34028
2018-12-22 20:34:07,252 INFO : runtime                   8.57064
2018-12-22 20:34:07,252 INFO : 
2018-12-22 20:34:12,234 INFO : Time only for training updates: 4.98s
2018-12-22 20:34:15,815 INFO : Epoch 23
2018-12-22 20:34:15,816 INFO : train_loss                0.97936
2018-12-22 20:34:15,817 INFO : valid_loss                1.11822
2018-12-22 20:34:15,818 INFO : test_loss                 0.89163
2018-12-22 20:34:15,818 INFO : train_misclass            0.42404
2018-12-22 20:34:15,818 INFO : valid_misclass            0.46312
2018-12-22 20:34:15,819 INFO : test_misclass             0.37500
2018-12-22 20:34:15,821 INFO : runtime                   8.56802
2018-12-22 20:34:15,821 INFO : 
2018-12-22 20:34:20,804 INFO : Time only for training updates: 4.98s
2018-12-22 20:34:24,386 INFO : Epoch 24
201

2018-12-22 20:36:07,231 INFO : valid_loss                0.97684
2018-12-22 20:36:07,232 INFO : test_loss                 0.73883
2018-12-22 20:36:07,232 INFO : train_misclass            0.27564
2018-12-22 20:36:07,232 INFO : valid_misclass            0.43384
2018-12-22 20:36:07,234 INFO : test_misclass             0.32292
2018-12-22 20:36:07,234 INFO : runtime                   8.57172
2018-12-22 20:36:07,234 INFO : 
2018-12-22 20:36:12,216 INFO : Time only for training updates: 4.98s
2018-12-22 20:36:15,796 INFO : Epoch 37
2018-12-22 20:36:15,797 INFO : train_loss                0.70559
2018-12-22 20:36:15,798 INFO : valid_loss                1.00907
2018-12-22 20:36:15,799 INFO : test_loss                 0.86616
2018-12-22 20:36:15,799 INFO : train_misclass            0.24444
2018-12-22 20:36:15,801 INFO : valid_misclass            0.42733
2018-12-22 20:36:15,801 INFO : test_misclass             0.35938
2018-12-22 20:36:15,801 INFO : runtime                   8.57316
2018-12-22 20:

2018-12-22 20:37:59,696 INFO : 
2018-12-22 20:38:04,859 INFO : Time only for training updates: 5.16s
2018-12-22 20:38:08,555 INFO : Epoch 50
2018-12-22 20:38:08,556 INFO : train_loss                0.51399
2018-12-22 20:38:08,556 INFO : valid_loss                1.01523
2018-12-22 20:38:08,557 INFO : test_loss                 0.74988
2018-12-22 20:38:08,558 INFO : train_misclass            0.13375
2018-12-22 20:38:08,558 INFO : valid_misclass            0.41649
2018-12-22 20:38:08,559 INFO : test_misclass             0.30208
2018-12-22 20:38:08,559 INFO : runtime                   8.86166
2018-12-22 20:38:08,560 INFO : 
2018-12-22 20:38:13,632 INFO : Time only for training updates: 5.07s
2018-12-22 20:38:17,280 INFO : Epoch 51
2018-12-22 20:38:17,281 INFO : train_loss                0.49959
2018-12-22 20:38:17,281 INFO : valid_loss                1.06035
2018-12-22 20:38:17,282 INFO : test_loss                 0.71286
2018-12-22 20:38:17,282 INFO : train_misclass            0.12534
201

2018-12-22 20:40:00,483 INFO : valid_misclass            0.43275
2018-12-22 20:40:00,484 INFO : test_misclass             0.29167
2018-12-22 20:40:00,484 INFO : runtime                   8.57019
2018-12-22 20:40:00,485 INFO : 
2018-12-22 20:40:05,470 INFO : Time only for training updates: 4.98s
2018-12-22 20:40:09,052 INFO : Epoch 64
2018-12-22 20:40:09,052 INFO : train_loss                0.35960
2018-12-22 20:40:09,053 INFO : valid_loss                1.06241
2018-12-22 20:40:09,054 INFO : test_loss                 0.71514
2018-12-22 20:40:09,055 INFO : train_misclass            0.06077
2018-12-22 20:40:09,055 INFO : valid_misclass            0.41106
2018-12-22 20:40:09,056 INFO : test_misclass             0.29514
2018-12-22 20:40:09,056 INFO : runtime                   8.57119
2018-12-22 20:40:09,056 INFO : 
2018-12-22 20:40:14,041 INFO : Time only for training updates: 4.98s
2018-12-22 20:40:17,620 INFO : Epoch 65
2018-12-22 20:40:17,620 INFO : train_loss                0.34733
201