In [1]:
import logging
import os.path
import time
from collections import OrderedDict
import sys

import numpy as np
import torch.nn.functional as F
from torch import optim

from braindecode.models.deep4 import Deep4Net
from braindecode.datasets.bcic_iv_2a import BCICompetition4Set2A
from braindecode.experiments.experiment import Experiment
from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, \
    RuntimeMonitor
from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or
from braindecode.datautil.iterators import BalancedBatchSizeIterator
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from braindecode.datautil.splitters import split_into_two_sets
from braindecode.torch_ext.constraints import MaxNormDefaultConstraint
from braindecode.torch_ext.util import set_random_seeds, np_to_var
from braindecode.mne_ext.signalproc import mne_apply
from braindecode.datautil.signalproc import (bandpass_cnt,
                                             exponential_running_standardize)
from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne
from braindecode.datautil.signal_target import SignalAndTarget

import warnings
warnings.filterwarnings("ignore")

log = logging.getLogger(__name__)
# logging.basicConfig(filename="./outputlog.csv")

In [2]:
data_folder = '/Users/debaojian/OneDrive/OneDrive - UNSW/UNSW/Research/EEG/BCI_Competition/'
subject_id = 1 
low_cut_hz = 4 
ival = [-500, 4000]
max_epochs = 1600
max_increase_epochs = 160
batch_size = 60
high_cut_hz = 38
factor_new = 1e-3
init_block_size = 1000
valid_set_fraction = 0.2
marker_def = OrderedDict([('Left Hand', [1]), ('Right Hand', [2],),
                        ('Foot', [3]), ('Tongue', [4])])   

In [3]:
def load_subject_data(subject_id):
    train_filename = 'A{:02d}T.gdf'.format(subject_id)
    test_filename = 'A{:02d}E.gdf'.format(subject_id)
    train_filepath = os.path.join(data_folder, train_filename)
    test_filepath = os.path.join(data_folder, test_filename)
    train_label_filepath = train_filepath.replace('.gdf', '.mat')
    test_label_filepath = test_filepath.replace('.gdf', '.mat')

    train_loader = BCICompetition4Set2A(
        train_filepath, labels_filename=train_label_filepath)
    test_loader = BCICompetition4Set2A(
        test_filepath, labels_filename=test_label_filepath)
    train_cnt = train_loader.load()
    test_cnt = test_loader.load()
    
    train_set = data_preprocessing(train_cnt)
    test_set = data_preprocessing(test_cnt)
    
    train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
    test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)
    
    return SignalAndTarget(np.concatenate((train_set.X, test_set.X), axis=0),
                           np.concatenate((train_set.y, test_set.y), axis=0))

In [4]:
def data_preprocessing(train_cnt):
    train_cnt = train_cnt.drop_channels(['STI 014', 'EOG-left',
                                         'EOG-central', 'EOG-right'])
    assert len(train_cnt.ch_names) == 22
    # lets convert to millvolt for numerical stability of next operations
    train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
    train_cnt = mne_apply(
        lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, train_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), train_cnt)
    train_cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T, factor_new=factor_new,
                                                  init_block_size=init_block_size,
                                                  eps=1e-4).T,
        train_cnt)
    return train_cnt

In [5]:
def load_train_data(subjects):
    data = load_subject_data(subjects.pop(0))
    x = data.X
    y = data.y
    for s in subjects:
        data = load_subject_data(s)
        x = np.concatenate((x, data.X), axis=0)
        y = np.concatenate((y, data.y), axis=0)
    return SignalAndTarget(x, y)      

In [25]:
def load_single_data(subject_id):
    train_filename = 'A{:02d}T.gdf'.format(subject_id)
    test_filename = 'A{:02d}E.gdf'.format(subject_id)
    train_filepath = os.path.join(data_folder, train_filename)
    test_filepath = os.path.join(data_folder, test_filename)
    train_label_filepath = train_filepath.replace('.gdf', '.mat')
    test_label_filepath = test_filepath.replace('.gdf', '.mat')

    train_loader = BCICompetition4Set2A(
        train_filepath, labels_filename=train_label_filepath)
    test_loader = BCICompetition4Set2A(
        test_filepath, labels_filename=test_label_filepath)
    train_cnt = train_loader.load()
    test_cnt = test_loader.load()
    
    train_set = data_preprocessing(train_cnt)
    test_set = data_preprocessing(test_cnt)
    
    train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
    test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)
    
    return train_set, test_set

In [20]:
def run_exp(train_data, test_data, low_cut_hz, model, cuda):
    batch_size = 60
    train_set, valid_set = split_into_two_sets(
        train_data, first_set_fraction=1-valid_set_fraction)

    set_random_seeds(seed=20181221, cuda=cuda)

    n_classes = 4
    n_chans = int(train_set.X.shape[1])
    input_time_length = train_set.X.shape[2]
    if model == 'shallow':
        model = ShallowFBCSPNet(n_chans, n_classes, input_time_length=input_time_length,
                            final_conv_length='auto').create_network()
    elif model == 'deep':
        model = Deep4Net(n_chans, n_classes, input_time_length=input_time_length,
                            final_conv_length='auto').create_network()
    if cuda:
        model.cuda()
    log.info("Model: \n{:s}".format(str(model)))

    optimizer = optim.Adam(model.parameters())

    iterator = BalancedBatchSizeIterator(batch_size=batch_size)

    stop_criterion = Or([MaxEpochs(max_epochs),
                         NoDecrease('valid_misclass', max_increase_epochs)])

    monitors = [LossMonitor(), MisclassMonitor(), RuntimeMonitor()]

    model_constraint = MaxNormDefaultConstraint()

    exp = Experiment(model, train_set, valid_set, test_data, iterator=iterator,
                     loss_function=F.nll_loss, optimizer=optimizer,
                     model_constraint=model_constraint,
                     monitors=monitors,
                     stop_criterion=stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True, cuda=cuda)
    exp.run()
    return exp

In [21]:
subjects_list = [i for i in range(1, 10)]
subjects_list.remove(9)
train_data = load_train_data(subjects_list)
test_data = load_subject_data(9)

Extracting EDF parameters from /Users/debaojian/OneDrive/OneDrive - UNSW/UNSW/Research/EEG/BCI_Competition/A01T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 672527  =      0.000 ...  2690.108 secs...
Extracting EDF parameters from /Users/debaojian/OneDrive/OneDrive - UNSW/UNSW/Research/EEG/BCI_Competition/A01E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 686999  =      0.000 ...  2747.996 secs...
2018-12-21 22:09:21,853 INFO : Trial per class:
Counter({'Tongue': 72, 'Foot': 72, 'Right Hand': 72, 'Left Hand': 72})
2018-12-21 22:09:22,143 INFO : Trial per class:
Counter({'Left Hand': 72, 'Right Hand': 72, 'Foot': 72, 'Tongue': 72})
Extracting EDF parameters from /Users/debaojian/OneDrive/OneDrive - UNSW/UNSW/Research/EEG/BCI_Competition/A02T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 677168  =      0.000 ...  27

In [None]:
subjects_list = [i for i in range(1, 10)]
subjects_list.remove(9)
train_data, test_data = load_single_data(1)

Extracting EDF parameters from /Users/debaojian/OneDrive/OneDrive - UNSW/UNSW/Research/EEG/BCI_Competition/A01T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 672527  =      0.000 ...  2690.108 secs...
Extracting EDF parameters from /Users/debaojian/OneDrive/OneDrive - UNSW/UNSW/Research/EEG/BCI_Competition/A01E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 686999  =      0.000 ...  2747.996 secs...


In [22]:
train_data.X.shape

(4608, 22, 1125)

In [24]:
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                        level=logging.DEBUG, stream=sys.stdout)
cuda = True
exp = run_exp(train_data, test_data, low_cut_hz, "deep", cuda)
log.info("Last 10 epochs")
log.info("\n" + str(exp.epochs_df.iloc[-10:]))

2018-12-21 22:29:10,587 INFO : Model: 
Sequential(
  (dimshuffle): Expression(expression=_transpose_time_to_spat)
  (conv_time): Conv2d(1, 25, kernel_size=(10, 1), stride=(1, 1))
  (conv_spat): Conv2d(25, 25, kernel_size=(1, 22), stride=(1, 1), bias=False)
  (bnorm): BatchNorm2d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_nonlin): Expression(expression=elu)
  (pool): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)
  (pool_nonlin): Expression(expression=identity)
  (drop_2): Dropout(p=0.5)
  (conv_2): Conv2d(25, 50, kernel_size=(10, 1), stride=(1, 1), bias=False)
  (bnorm_2): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (nonlin_2): Expression(expression=elu)
  (pool_2): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)
  (pool_nonlin_2): Expression(expression=identity)
  (drop_3): Dropout(p=0.5)
  (conv_3): Conv2d(50, 100, kernel_size=(10, 1), stri

2018-12-21 22:30:31,881 INFO : runtime                   8.55348
2018-12-21 22:30:31,882 INFO : 
2018-12-21 22:30:36,829 INFO : Time only for training updates: 4.95s
2018-12-21 22:30:40,430 INFO : Epoch 10
2018-12-21 22:30:40,431 INFO : train_loss                19.66461
2018-12-21 22:30:40,431 INFO : valid_loss                19.67782
2018-12-21 22:30:40,432 INFO : test_loss                 19.64951
2018-12-21 22:30:40,432 INFO : train_misclass            0.75014
2018-12-21 22:30:40,433 INFO : valid_misclass            0.74946
2018-12-21 22:30:40,434 INFO : test_misclass             0.75000
2018-12-21 22:30:40,434 INFO : runtime                   8.55192
2018-12-21 22:30:40,434 INFO : 
2018-12-21 22:30:40,442 INFO : New best valid_misclass: 0.749458
2018-12-21 22:30:40,442 INFO : 
2018-12-21 22:30:45,392 INFO : Time only for training updates: 4.95s
2018-12-21 22:30:48,989 INFO : Epoch 11
2018-12-21 22:30:48,990 INFO : train_loss                36.02463
2018-12-21 22:30:48,991 INFO : v

2018-12-21 22:32:14,586 INFO : 
2018-12-21 22:32:19,532 INFO : Time only for training updates: 4.94s
2018-12-21 22:32:23,131 INFO : Epoch 22
2018-12-21 22:32:23,132 INFO : train_loss                41.14233
2018-12-21 22:32:23,132 INFO : valid_loss                41.31526
2018-12-21 22:32:23,133 INFO : test_loss                 41.42434
2018-12-21 22:32:23,134 INFO : train_misclass            0.74986
2018-12-21 22:32:23,134 INFO : valid_misclass            0.75054
2018-12-21 22:32:23,135 INFO : test_misclass             0.75000
2018-12-21 22:32:23,135 INFO : runtime                   8.55745
2018-12-21 22:32:23,136 INFO : 
2018-12-21 22:32:28,086 INFO : Time only for training updates: 4.95s
2018-12-21 22:32:31,687 INFO : Epoch 23
2018-12-21 22:32:31,687 INFO : train_loss                44.76258
2018-12-21 22:32:31,688 INFO : valid_loss                44.80281
2018-12-21 22:32:31,689 INFO : test_loss                 44.94884
2018-12-21 22:32:31,689 INFO : train_misclass            0.750

2018-12-21 22:34:02,252 INFO : Time only for training updates: 4.95s
2018-12-21 22:34:05,851 INFO : Epoch 34
2018-12-21 22:34:05,852 INFO : train_loss                41.37487
2018-12-21 22:34:05,853 INFO : valid_loss                41.34791
2018-12-21 22:34:05,853 INFO : test_loss                 41.49054
2018-12-21 22:34:05,854 INFO : train_misclass            0.75014
2018-12-21 22:34:05,854 INFO : valid_misclass            0.74946
2018-12-21 22:34:05,855 INFO : test_misclass             0.75000
2018-12-21 22:34:05,855 INFO : runtime                   8.56056
2018-12-21 22:34:05,855 INFO : 
2018-12-21 22:34:05,862 INFO : New best valid_misclass: 0.749458
2018-12-21 22:34:05,863 INFO : 
2018-12-21 22:34:10,812 INFO : Time only for training updates: 4.95s
2018-12-21 22:34:14,410 INFO : Epoch 35
2018-12-21 22:34:14,411 INFO : train_loss                97.28780
2018-12-21 22:34:14,412 INFO : valid_loss                97.18096
2018-12-21 22:34:14,412 INFO : test_loss                 97.547

2018-12-21 22:35:57,093 INFO : Epoch 47
2018-12-21 22:35:57,093 INFO : train_loss                38.28332
2018-12-21 22:35:57,094 INFO : valid_loss                38.24726
2018-12-21 22:35:57,095 INFO : test_loss                 38.34924
2018-12-21 22:35:57,095 INFO : train_misclass            0.75014
2018-12-21 22:35:57,096 INFO : valid_misclass            0.74946
2018-12-21 22:35:57,096 INFO : test_misclass             0.75000
2018-12-21 22:35:57,097 INFO : runtime                   8.55482
2018-12-21 22:35:57,097 INFO : 
2018-12-21 22:36:02,050 INFO : Time only for training updates: 4.95s
2018-12-21 22:36:05,650 INFO : Epoch 48
2018-12-21 22:36:05,651 INFO : train_loss                84.02932
2018-12-21 22:36:05,652 INFO : valid_loss                84.04325
2018-12-21 22:36:05,652 INFO : test_loss                 84.30143
2018-12-21 22:36:05,653 INFO : train_misclass            0.75014
2018-12-21 22:36:05,654 INFO : valid_misclass            0.74946
2018-12-21 22:36:05,654 INFO : te

2018-12-21 22:37:48,311 INFO : runtime                   8.55240
2018-12-21 22:37:48,311 INFO : 
2018-12-21 22:37:53,260 INFO : Time only for training updates: 4.95s
2018-12-21 22:37:56,863 INFO : Epoch 61
2018-12-21 22:37:56,863 INFO : train_loss                26.39727
2018-12-21 22:37:56,864 INFO : valid_loss                26.37831
2018-12-21 22:37:56,865 INFO : test_loss                 26.40571
2018-12-21 22:37:56,865 INFO : train_misclass            0.75014
2018-12-21 22:37:56,866 INFO : valid_misclass            0.74946
2018-12-21 22:37:56,867 INFO : test_misclass             0.75000
2018-12-21 22:37:56,867 INFO : runtime                   8.54846
2018-12-21 22:37:56,868 INFO : 
2018-12-21 22:38:01,820 INFO : Time only for training updates: 4.95s
2018-12-21 22:38:05,417 INFO : Epoch 62
2018-12-21 22:38:05,418 INFO : train_loss                38.15761
2018-12-21 22:38:05,418 INFO : valid_loss                38.21518
2018-12-21 22:38:05,419 INFO : test_loss                 38.320

2018-12-21 22:39:48,065 INFO : valid_misclass            0.75054
2018-12-21 22:39:48,066 INFO : test_misclass             0.75000
2018-12-21 22:39:48,067 INFO : runtime                   8.55337
2018-12-21 22:39:48,067 INFO : 
2018-12-21 22:39:53,018 INFO : Time only for training updates: 4.95s
2018-12-21 22:39:56,619 INFO : Epoch 75
2018-12-21 22:39:56,619 INFO : train_loss                34.17495
2018-12-21 22:39:56,620 INFO : valid_loss                34.27625
2018-12-21 22:39:56,621 INFO : test_loss                 34.47404
2018-12-21 22:39:56,621 INFO : train_misclass            0.75014
2018-12-21 22:39:56,622 INFO : valid_misclass            0.74946
2018-12-21 22:39:56,623 INFO : test_misclass             0.75000
2018-12-21 22:39:56,623 INFO : runtime                   8.55226
2018-12-21 22:39:56,623 INFO : 
2018-12-21 22:40:01,573 INFO : Time only for training updates: 4.95s
2018-12-21 22:40:05,169 INFO : Epoch 76
2018-12-21 22:40:05,170 INFO : train_loss                5.08379


2018-12-21 22:41:47,824 INFO : test_loss                 69.63081
2018-12-21 22:41:47,825 INFO : train_misclass            0.74986
2018-12-21 22:41:47,825 INFO : valid_misclass            0.75054
2018-12-21 22:41:47,826 INFO : test_misclass             0.75000
2018-12-21 22:41:47,827 INFO : runtime                   8.55306
2018-12-21 22:41:47,827 INFO : 
2018-12-21 22:41:52,778 INFO : Time only for training updates: 4.95s
2018-12-21 22:41:56,378 INFO : Epoch 89
2018-12-21 22:41:56,379 INFO : train_loss                8.62701
2018-12-21 22:41:56,379 INFO : valid_loss                8.66933
2018-12-21 22:41:56,380 INFO : test_loss                 8.77647
2018-12-21 22:41:56,382 INFO : train_misclass            0.75014
2018-12-21 22:41:56,382 INFO : valid_misclass            0.74946
2018-12-21 22:41:56,383 INFO : test_misclass             0.75000
2018-12-21 22:41:56,384 INFO : runtime                   8.55368
2018-12-21 22:41:56,385 INFO : 
2018-12-21 22:42:01,335 INFO : Time only for t

2018-12-21 22:43:47,574 INFO : Epoch 102
2018-12-21 22:43:47,576 INFO : train_loss                9.92787
2018-12-21 22:43:47,576 INFO : valid_loss                9.94389
2018-12-21 22:43:47,577 INFO : test_loss                 10.20961
2018-12-21 22:43:47,577 INFO : train_misclass            0.75014
2018-12-21 22:43:47,578 INFO : valid_misclass            0.74946
2018-12-21 22:43:47,578 INFO : test_misclass             0.75000
2018-12-21 22:43:47,579 INFO : runtime                   8.55011
2018-12-21 22:43:47,580 INFO : 
2018-12-21 22:43:52,533 INFO : Time only for training updates: 4.95s
2018-12-21 22:43:56,133 INFO : Epoch 103
2018-12-21 22:43:56,134 INFO : train_loss                1.79963
2018-12-21 22:43:56,134 INFO : valid_loss                1.82542
2018-12-21 22:43:56,135 INFO : test_loss                 1.80268
2018-12-21 22:43:56,135 INFO : train_misclass            0.73114
2018-12-21 22:43:56,137 INFO : valid_misclass            0.72668
2018-12-21 22:43:56,137 INFO : test_

KeyboardInterrupt: 