# Two-Stages Deep Learning Project

<p>This project has been developped by the Montreal Hacknight community. It derives from the braindecode framework and is aimed at being integrable into the Moabb project.</p>

<p><ul>https://github.com/robintibor/braindecode</ul>
<ul>https://github.com/NeuroTechX/moabb</ul></p>

<p>This project wouldn't have been possible without the contribution of these people: Arna (...), Justin (...), Yannick Roy, Eamon Egan, (...) and it is currently being supported by Fred Simard (fs@re-ak.com)</p>

<p>This project consist in deriving a two-stages training process from the brain decode implementation of a Shallow Convolutional Network. The core motivation is to combine the strenghts of deep learning - that is being one of the most powerful ML algorithms out-there - while grinding down it's fall back - it requires an immense of data.</p>

<p>The goal is to develop a method to train the network in two stages:</p>
<p><ul>1) Train the network over a large dataset, this training forms some kind of prior over the network</ul>
<ul>2) Train the network a second time, this time on the dataset of interest</ul>
</p>

<p>The premise motivating this approach is that the second training is likely to proceed faster than what is normally required to train a deep net, without sacrificing the performance; and, we suspect that the second training will be able to exploit the specificities of the dataset of interest in a way to surpasses the first training, which was trained on a generalistic dataset.</p>

In [1]:
"""
#
# Download the github, which contains the package with the networks
#
"""

# reset to default base path
import os
os.chdir('/content/')

# if dl-eeg-playground already here, erase
!rm -rf dl-eeg-playground

# clone dl-eeg-playground and cd to it
!git clone https://github.com/NeuroTechX/dl-eeg-playground.git
os.chdir('dl-eeg-playground/brainDecode/towardMoabbIntegration')


Cloning into 'dl-eeg-playground'...
remote: Counting objects: 171, done.[K
remote: Compressing objects: 100% (26/26), done.[K
remote: Total 171 (delta 26), reused 25 (delta 14), pack-reused 131[K
Receiving objects: 100% (171/171), 1.25 MiB | 17.06 MiB/s, done.
Resolving deltas: 100% (78/78), done.


In [2]:
#
# Download the whole BNCI 002-2014 dataset
#
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S01T.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S01E.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S02T.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S02E.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S03T.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S03E.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S04T.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S04E.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S05T.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S05E.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S06T.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S06E.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S07T.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S07E.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S08T.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S08E.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S09T.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S09E.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S10T.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S10E.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S11T.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S11E.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S12T.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S12E.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S13T.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S13E.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S14T.mat
!wget http://bnci-horizon-2020.eu/database/data-sets/002-2014/S14E.mat

# move files into a dedicated folder
!mkdir BBCIData
!mv *.mat BBCIData

#install brain decode
!pip install braindecode -q

# install pytorch
# ref: http://pytorch.org/
from os import path
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())

accelerator = 'cu80' if path.exists('/opt/bin/nvidia-smi') else 'cpu'

!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.3.0.post4-{platform}-linux_x86_64.whl torchvision
import torch

import scipy.io as sio
import numpy as np
from os import listdir
from os.path import isfile, join

# prepare data containers
y = []
X = []

"""
Now, let's load data.

We read the file for the desired subject, and parse the data to extract:
- samplingRate
- trialLength
- X, a M x N x K matrix, which stands for trial x chan x samples
    - the actual values are 160 x 15 x 2560
- y, a M vector containing the labels {0,1}

ref: Dataset description: https://lampx.tugraz.at/~bci/database/002-2014/description.pdf
"""

folder = "BBCIData"

for f in listdir(folder):
    # read file
    d1T = sio.loadmat(folder + "/" + f)
    
    samplingRate = d1T['data'][0][0][0][0][3][0][0]
    trialLength = 7*samplingRate

    # run through all training runs
    for run in range(len(d1T['data'][0])):
        y.append(d1T['data'][0][run][0][0][2][0]) # labels
        timestamps = d1T['data'][0][run][0][0][1][0] # timestamps
        rawData = d1T['data'][0][run][0][0][0].transpose() # chan x data

        # parse out data based on timestamps
        for start in timestamps:
            end = start + trialLength
            X.append(rawData[:,start:end]) #15 x 2560

    del rawData
    del d1T

# arrange data into num7py arrays
# also torch expect float32 for samples
# and int64 for labels {0,1}
X = np.array(X).astype(np.float32)
y = (np.array(y).flatten()-1).astype(np.int64)
print()
print(X.shape)
print(y.shape)


# rand permute dataset
idx = np.random.permutation(X.shape[0])

X = X[idx,:,:]
y = y[idx]

--2018-06-16 01:41:01--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S01T.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204.35
Connecting to bnci-horizon-2020.eu (bnci-horizon-2020.eu)|91.227.204.35|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://lampx.tugraz.at/~bci/database/002-2014/S01T.mat [following]
--2018-06-16 01:41:01--  https://lampx.tugraz.at/~bci/database/002-2014/S01T.mat
Resolving lampx.tugraz.at (lampx.tugraz.at)... 129.27.124.207
Connecting to lampx.tugraz.at (lampx.tugraz.at)|129.27.124.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 39794870 (38M)
Saving to: ‘S01T.mat’


2018-06-16 01:41:07 (8.32 MB/s) - ‘S01T.mat’ saved [39794870/39794870]

--2018-06-16 01:41:08--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S01E.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204.35
Connecting to bnci-horizon-2020.eu (bnci-horizon-2020.eu)|91.227.204


2018-06-16 01:41:18 (8.48 MB/s) - ‘S02T.mat’ saved [38364523/38364523]

--2018-06-16 01:41:20--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S02E.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204.35
Connecting to bnci-horizon-2020.eu (bnci-horizon-2020.eu)|91.227.204.35|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://lampx.tugraz.at/~bci/database/002-2014/S02E.mat [following]
--2018-06-16 01:41:20--  https://lampx.tugraz.at/~bci/database/002-2014/S02E.mat
Resolving lampx.tugraz.at (lampx.tugraz.at)... 129.27.124.207
Connecting to lampx.tugraz.at (lampx.tugraz.at)|129.27.124.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 22998442 (22M)
Saving to: ‘S02E.mat’


2018-06-16 01:41:24 (5.83 MB/s) - ‘S02E.mat’ saved [22998442/22998442]

--2018-06-16 01:41:25--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S03T.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204


2018-06-16 01:41:36 (5.85 MB/s) - ‘S03E.mat’ saved [23797275/23797275]

--2018-06-16 01:41:37--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S04T.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204.35
Connecting to bnci-horizon-2020.eu (bnci-horizon-2020.eu)|91.227.204.35|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://lampx.tugraz.at/~bci/database/002-2014/S04T.mat [following]
--2018-06-16 01:41:37--  https://lampx.tugraz.at/~bci/database/002-2014/S04T.mat
Resolving lampx.tugraz.at (lampx.tugraz.at)... 129.27.124.207
Connecting to lampx.tugraz.at (lampx.tugraz.at)|129.27.124.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 39607931 (38M)
Saving to: ‘S04T.mat’


2018-06-16 01:41:42 (8.97 MB/s) - ‘S04T.mat’ saved [39607931/39607931]

--2018-06-16 01:41:43--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S04E.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204


2018-06-16 01:41:54 (8.31 MB/s) - ‘S05T.mat’ saved [39091212/39091212]

--2018-06-16 01:41:55--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S05E.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204.35
Connecting to bnci-horizon-2020.eu (bnci-horizon-2020.eu)|91.227.204.35|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://lampx.tugraz.at/~bci/database/002-2014/S05E.mat [following]
--2018-06-16 01:41:56--  https://lampx.tugraz.at/~bci/database/002-2014/S05E.mat
Resolving lampx.tugraz.at (lampx.tugraz.at)... 129.27.124.207
Connecting to lampx.tugraz.at (lampx.tugraz.at)|129.27.124.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 23628634 (23M)
Saving to: ‘S05E.mat’


2018-06-16 01:42:00 (6.33 MB/s) - ‘S05E.mat’ saved [23628634/23628634]

--2018-06-16 01:42:01--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S06T.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204


2018-06-16 01:42:12 (5.93 MB/s) - ‘S06E.mat’ saved [23850651/23850651]

--2018-06-16 01:42:13--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S07T.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204.35
Connecting to bnci-horizon-2020.eu (bnci-horizon-2020.eu)|91.227.204.35|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://lampx.tugraz.at/~bci/database/002-2014/S07T.mat [following]
--2018-06-16 01:42:13--  https://lampx.tugraz.at/~bci/database/002-2014/S07T.mat
Resolving lampx.tugraz.at (lampx.tugraz.at)... 129.27.124.207
Connecting to lampx.tugraz.at (lampx.tugraz.at)|129.27.124.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 40261257 (38M)
Saving to: ‘S07T.mat’


2018-06-16 01:42:18 (8.70 MB/s) - ‘S07T.mat’ saved [40261257/40261257]

--2018-06-16 01:42:19--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S07E.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204


2018-06-16 01:42:30 (8.48 MB/s) - ‘S08T.mat’ saved [39735612/39735612]

--2018-06-16 01:42:31--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S08E.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204.35
Connecting to bnci-horizon-2020.eu (bnci-horizon-2020.eu)|91.227.204.35|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://lampx.tugraz.at/~bci/database/002-2014/S08E.mat [following]
--2018-06-16 01:42:31--  https://lampx.tugraz.at/~bci/database/002-2014/S08E.mat
Resolving lampx.tugraz.at (lampx.tugraz.at)... 129.27.124.207
Connecting to lampx.tugraz.at (lampx.tugraz.at)|129.27.124.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 23791571 (23M)
Saving to: ‘S08E.mat’


2018-06-16 01:42:35 (6.34 MB/s) - ‘S08E.mat’ saved [23791571/23791571]

--2018-06-16 01:42:36--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S09T.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204


2018-06-16 01:42:48 (5.90 MB/s) - ‘S09E.mat’ saved [23610622/23610622]

--2018-06-16 01:42:49--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S10T.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204.35
Connecting to bnci-horizon-2020.eu (bnci-horizon-2020.eu)|91.227.204.35|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://lampx.tugraz.at/~bci/database/002-2014/S10T.mat [following]
--2018-06-16 01:42:49--  https://lampx.tugraz.at/~bci/database/002-2014/S10T.mat
Resolving lampx.tugraz.at (lampx.tugraz.at)... 129.27.124.207
Connecting to lampx.tugraz.at (lampx.tugraz.at)|129.27.124.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 39871971 (38M)
Saving to: ‘S10T.mat’


2018-06-16 01:42:54 (8.25 MB/s) - ‘S10T.mat’ saved [39871971/39871971]

--2018-06-16 01:42:55--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S10E.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204


2018-06-16 01:43:06 (8.10 MB/s) - ‘S11T.mat’ saved [40140184/40140184]

--2018-06-16 01:43:07--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S11E.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204.35
Connecting to bnci-horizon-2020.eu (bnci-horizon-2020.eu)|91.227.204.35|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://lampx.tugraz.at/~bci/database/002-2014/S11E.mat [following]
--2018-06-16 01:43:08--  https://lampx.tugraz.at/~bci/database/002-2014/S11E.mat
Resolving lampx.tugraz.at (lampx.tugraz.at)... 129.27.124.207
Connecting to lampx.tugraz.at (lampx.tugraz.at)|129.27.124.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 24171700 (23M)
Saving to: ‘S11E.mat’


2018-06-16 01:43:12 (6.23 MB/s) - ‘S11E.mat’ saved [24171700/24171700]

--2018-06-16 01:43:13--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S12T.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204


2018-06-16 01:43:24 (6.11 MB/s) - ‘S12E.mat’ saved [23905075/23905075]

--2018-06-16 01:43:25--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S13T.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204.35
Connecting to bnci-horizon-2020.eu (bnci-horizon-2020.eu)|91.227.204.35|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://lampx.tugraz.at/~bci/database/002-2014/S13T.mat [following]
--2018-06-16 01:43:25--  https://lampx.tugraz.at/~bci/database/002-2014/S13T.mat
Resolving lampx.tugraz.at (lampx.tugraz.at)... 129.27.124.207
Connecting to lampx.tugraz.at (lampx.tugraz.at)|129.27.124.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 39819174 (38M)
Saving to: ‘S13T.mat’


2018-06-16 01:43:30 (8.65 MB/s) - ‘S13T.mat’ saved [39819174/39819174]

--2018-06-16 01:43:31--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S13E.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204


2018-06-16 01:43:42 (8.85 MB/s) - ‘S14T.mat’ saved [39859134/39859134]

--2018-06-16 01:43:43--  http://bnci-horizon-2020.eu/database/data-sets/002-2014/S14E.mat
Resolving bnci-horizon-2020.eu (bnci-horizon-2020.eu)... 91.227.204.35
Connecting to bnci-horizon-2020.eu (bnci-horizon-2020.eu)|91.227.204.35|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://lampx.tugraz.at/~bci/database/002-2014/S14E.mat [following]
--2018-06-16 01:43:43--  https://lampx.tugraz.at/~bci/database/002-2014/S14E.mat
Resolving lampx.tugraz.at (lampx.tugraz.at)... 129.27.124.207
Connecting to lampx.tugraz.at (lampx.tugraz.at)|129.27.124.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 23970887 (23M)
Saving to: ‘S14E.mat’


2018-06-16 01:43:47 (6.09 MB/s) - ‘S14E.mat’ saved [23970887/23970887]


(2240, 15, 3584)
(2240,)


In [3]:

# load the general network Trainer
from brainDecodeSKLearnWrapper.ShallowFBCSPNet_GeneralTrainer import ShallowFBCSPNet_GeneralTrainer

classifier = ShallowFBCSPNet_GeneralTrainer()
classifier.fit(X,y)


  input = module(input)


Epoch 0
Train  Loss: 0.96520
Train  Accuracy: 48.1%
Test   Loss: 0.87153
Test   Accuracy: 55.4%
Epoch 1
Train  Loss: 1.00303
Train  Accuracy: 48.6%
Test   Loss: 0.91248
Test   Accuracy: 56.4%
Epoch 2
Train  Loss: 0.72225
Train  Accuracy: 55.4%
Test   Loss: 0.74053
Test   Accuracy: 50.7%
Epoch 3
Train  Loss: 0.70537
Train  Accuracy: 55.3%
Test   Loss: 0.73199
Test   Accuracy: 52.9%
Epoch 4
Train  Loss: 0.69171
Train  Accuracy: 55.0%
Test   Loss: 0.72662
Test   Accuracy: 48.9%
Epoch 5
Train  Loss: 0.72288
Train  Accuracy: 53.0%
Test   Loss: 0.73340
Test   Accuracy: 53.6%
Epoch 6
Train  Loss: 0.78398
Train  Accuracy: 51.5%
Test   Loss: 0.86019
Test   Accuracy: 47.1%
Epoch 7
Train  Loss: 0.71520
Train  Accuracy: 53.4%
Test   Loss: 0.72639
Test   Accuracy: 52.5%
Epoch 8
Train  Loss: 0.69865
Train  Accuracy: 57.3%
Test   Loss: 0.72801
Test   Accuracy: 53.2%
Epoch 9
Train  Loss: 1.28257
Train  Accuracy: 50.5%
Test   Loss: 1.42833
Test   Accuracy: 45.7%
Epoch 10
Train  Loss: 0.62912
Train  Acc

Epoch 32
Train  Loss: 0.59916
Train  Accuracy: 66.2%
Test   Loss: 0.61190
Test   Accuracy: 66.4%
Epoch 33
Train  Loss: 0.68292
Train  Accuracy: 60.3%
Test   Loss: 0.71491
Test   Accuracy: 63.9%
Epoch 34
Train  Loss: 0.44035
Train  Accuracy: 80.5%
Test   Loss: 0.51628
Test   Accuracy: 71.4%
Epoch 35
Train  Loss: 0.54559
Train  Accuracy: 70.0%
Test   Loss: 0.59993
Test   Accuracy: 68.9%
Epoch 36
Train  Loss: 0.45137
Train  Accuracy: 80.2%
Test   Loss: 0.53663
Test   Accuracy: 71.4%
Epoch 37
Train  Loss: 0.48272
Train  Accuracy: 75.0%
Test   Loss: 0.56258
Test   Accuracy: 72.1%
Epoch 38
Train  Loss: 0.47355
Train  Accuracy: 75.3%
Test   Loss: 0.54563
Test   Accuracy: 70.7%
Epoch 39
Train  Loss: 1.26676
Train  Accuracy: 53.1%
Test   Loss: 1.26652
Test   Accuracy: 56.8%
Epoch 40
Train  Loss: 0.43430
Train  Accuracy: 79.4%
Test   Loss: 0.51697
Test   Accuracy: 72.5%
Epoch 41
Train  Loss: 0.42779
Train  Accuracy: 80.8%
Test   Loss: 0.50886
Test   Accuracy: 73.2%
Epoch 42
Train  Loss: 0.66922


Epoch 64
Train  Loss: 0.51234
Train  Accuracy: 72.4%
Test   Loss: 0.62003
Test   Accuracy: 67.9%
Epoch 65
Train  Loss: 0.73421
Train  Accuracy: 62.1%
Test   Loss: 0.78431
Test   Accuracy: 62.9%
Epoch 66
Train  Loss: 0.40095
Train  Accuracy: 82.5%
Test   Loss: 0.47589
Test   Accuracy: 73.9%
Epoch 67
Train  Loss: 0.52046
Train  Accuracy: 72.1%
Test   Loss: 0.60212
Test   Accuracy: 67.9%
Epoch 68
Train  Loss: 0.38521
Train  Accuracy: 83.2%
Test   Loss: 0.48385
Test   Accuracy: 75.4%
Epoch 69
Train  Loss: 0.37042
Train  Accuracy: 84.7%
Test   Loss: 0.46098
Test   Accuracy: 76.4%
Epoch 70
Train  Loss: 0.43591
Train  Accuracy: 78.2%
Test   Loss: 0.56555
Test   Accuracy: 69.6%
Epoch 71
Train  Loss: 0.37889
Train  Accuracy: 83.7%
Test   Loss: 0.46378
Test   Accuracy: 75.7%
Epoch 72
Train  Loss: 0.36435
Train  Accuracy: 84.7%
Test   Loss: 0.45732
Test   Accuracy: 75.4%
Epoch 73
Train  Loss: 0.40326
Train  Accuracy: 80.4%
Test   Loss: 0.49092
Test   Accuracy: 75.7%
Epoch 74
Train  Loss: 0.37117


Epoch 96
Train  Loss: 0.33377
Train  Accuracy: 86.7%
Test   Loss: 0.47036
Test   Accuracy: 75.4%
Epoch 97
Train  Loss: 0.46763
Train  Accuracy: 77.6%
Test   Loss: 0.53869
Test   Accuracy: 75.0%
Epoch 98
Train  Loss: 0.50159
Train  Accuracy: 71.7%
Test   Loss: 0.62868
Test   Accuracy: 67.1%
Epoch 99
Train  Loss: 0.38434
Train  Accuracy: 82.4%
Test   Loss: 0.49387
Test   Accuracy: 75.4%
Epoch 100
Train  Loss: 0.37618
Train  Accuracy: 82.8%
Test   Loss: 0.46827
Test   Accuracy: 75.4%
Epoch 101
Train  Loss: 0.33765
Train  Accuracy: 86.9%
Test   Loss: 0.44799
Test   Accuracy: 77.9%
Epoch 102
Train  Loss: 0.40315
Train  Accuracy: 79.9%
Test   Loss: 0.52444
Test   Accuracy: 73.6%
Epoch 103
Train  Loss: 0.77058
Train  Accuracy: 61.3%
Test   Loss: 0.84855
Test   Accuracy: 62.5%
Epoch 104
Train  Loss: 0.68528
Train  Accuracy: 63.8%
Test   Loss: 0.77914
Test   Accuracy: 63.6%
Epoch 105
Train  Loss: 0.48181
Train  Accuracy: 74.7%
Test   Loss: 0.59006
Test   Accuracy: 72.5%
Epoch 106
Train  Loss: 0

Train  Loss: 0.31769
Train  Accuracy: 87.1%
Test   Loss: 0.49386
Test   Accuracy: 75.7%
Epoch 128
Train  Loss: 0.30330
Train  Accuracy: 89.1%
Test   Loss: 0.44533
Test   Accuracy: 77.9%
Epoch 129
Train  Loss: 0.31116
Train  Accuracy: 88.2%
Test   Loss: 0.48268
Test   Accuracy: 75.4%
Epoch 130
Train  Loss: 0.30884
Train  Accuracy: 88.0%
Test   Loss: 0.45188
Test   Accuracy: 75.7%
Epoch 131
Train  Loss: 0.31765
Train  Accuracy: 85.9%
Test   Loss: 0.50170
Test   Accuracy: 75.0%
Epoch 132
Train  Loss: 0.32732
Train  Accuracy: 85.8%
Test   Loss: 0.49140
Test   Accuracy: 76.4%
Epoch 133
Train  Loss: 0.29584
Train  Accuracy: 87.7%
Test   Loss: 0.49519
Test   Accuracy: 73.6%
Epoch 134
Train  Loss: 0.36212
Train  Accuracy: 82.4%
Test   Loss: 0.51922
Test   Accuracy: 76.4%
Epoch 135
Train  Loss: 0.36532
Train  Accuracy: 83.5%
Test   Loss: 0.50801
Test   Accuracy: 74.6%
Epoch 136
Train  Loss: 0.29607
Train  Accuracy: 89.3%
Test   Loss: 0.46769
Test   Accuracy: 75.7%
Epoch 137
Train  Loss: 0.48818

Train  Loss: 0.37439
Train  Accuracy: 82.4%
Test   Loss: 0.57646
Test   Accuracy: 73.9%
Epoch 159
Train  Loss: 0.30443
Train  Accuracy: 89.1%
Test   Loss: 0.49829
Test   Accuracy: 75.4%


ShallowFBCSPNet_GeneralTrainer(filter_time_length=75, n_filters_spat=5,
                n_filters_time=10, nb_epoch=160, pool_time_length=60,
                pool_time_stride=30)

In [5]:

from brainDecodeSKLearnWrapper.ShallowFBCSPNet_SpecializedTrainer import ShallowFBCSPNet_SpecializedTrainer





"""
Now, let's load data.

We read the file for the desired subject, and parse the data to extract:
- samplingRate
- trialLength
- X, a M x N x K matrix, which stands for trial x chan x samples
    - the actual values are 160 x 15 x 2560
- y, a M vector containing the labels {0,1}

ref: Dataset description: https://lampx.tugraz.at/~bci/database/002-2014/description.pdf
"""

folder = "BBCIData"
datasetID = 1

# prepare data containers
y = []
X = []

f = listdir(folder)[datasetID]
  
# read file
d1T = sio.loadmat(folder + "/" + f)

samplingRate = d1T['data'][0][0][0][0][3][0][0]
trialLength = 7*samplingRate

# run through all training runs
for run in range(len(d1T['data'][0])):
    y.append(d1T['data'][0][run][0][0][2][0]) # labels
    timestamps = d1T['data'][0][run][0][0][1][0] # timestamps
    rawData = d1T['data'][0][run][0][0][0].transpose() # chan x data

    # parse out data based on timestamps
    for start in timestamps:
        end = start + trialLength
        X.append(rawData[:,start:end]) #15 x 2560

del rawData
del d1T

# arrange data into num7py arrays
# also torch expect float32 for samples
# and int64 for labels {0,1}
X = np.array(X).astype(np.float32)
y = (np.array(y).flatten()-1).astype(np.int64)
print()
print(X.shape)
print(y.shape)


# rand permute dataset
idx = np.random.permutation(X.shape[0])

X = X[idx,:,:]
y = y[idx]







specializedClassifier = ShallowFBCSPNet_SpecializedTrainer(network=classifier.model)
specializedClassifier.lr = 0.01
#specializedClassifier.configure(initial_lr=0.01)
specializedClassifier.fit(X,y)






SyntaxError: ignored

In [0]:



specializedClassifier = ShallowFBCSPNet_SpecializedTrainer(filename="myModel.pth")
specializedClassifier.lr = 0.01
#specializedClassifier.configure(initial_lr=0.01)
specializedClassifier.fit(X,y)


Best scores so far:

Subject 0: 76.9%
Subject 1 50.0%


