# Imports

In [1]:
import os
import logging
import importlib
importlib.reload(logging)
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s |: %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import mne
from mne.io import concatenate_raws
import matplotlib.pyplot as plt
from scipy import signal
from sklearn import preprocessing
import numpy as np
from torchsummary import summary
import torch

from braindecode.datautil.signal_target import SignalAndTarget
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from torch import nn
from braindecode.torch_ext.util import set_random_seeds
from braindecode.torch_ext.optimizers import AdamW
import torch.nn.functional as F
from braindecode.models.deep4 import Deep4Net
import pickle



# Data processing

In [9]:
path = './processed_data/'
data_type = 'XDAWN'
X_pseudo_test = np.load(path + f'X_pseudo_test_{data_type}.npy', )
X_word_test = np.load(path + f'X_word_test_{data_type}.npy', )
X_pseudo_train = np.load(path + f'X_pseudo_train_{data_type}.npy', )
X_word_train = np.load(path + f'X_word_train_{data_type}.npy', )
try:
    X_pseudo_valid = np.load(path + f'X_pseudo_valid_{data_type}.npy', )
    X_word_valid = np.load(path + f'X_word_valid_{data_type}.npy', )
except:
    print('No validation data provided')

No validation data provided


In [10]:
X_test = np.concatenate((X_word_test,X_pseudo_test))
X_test = (X_test * 1e8).astype(np.float32)
y_test = np.zeros(X_test.shape[0])
y_test[:X_word_test.shape[0]] = 1
ind = np.arange(X_test.shape[0])
np.random.shuffle(ind)
X_test = X_test[ind,:,:]
y_test = y_test[ind]

X_train = np.concatenate((X_word_train,X_pseudo_train))
X_train = (X_train * 1e8).astype(np.float32)
y_train = np.zeros(X_train.shape[0])
y_train[:X_word_train.shape[0]] = 1
ind = np.arange(X_train.shape[0])
np.random.shuffle(ind)
X_train = X_train[ind,:,:]
y_train = y_train[ind]

y_test = y_test.astype(np.int64)
y_train = y_train.astype(np.int64)

try:
    X_valid = np.concatenate((X_word_valid,X_pseudo_valid))
    X_valid = (X_valid * 1e8).astype(np.float32)
    y_valid = np.zeros(X_valid.shape[0])
    y_valid[:X_word_valid.shape[0]] = 0
    y_valid[X_word_valid.shape[0]:] = 1
    ind = np.arange(X_valid.shape[0])
    np.random.shuffle(ind)
    X_valid = X_valid[ind,:,:]
    y_valid = y_valid[ind]

    y_valid = y_valid.astype(np.int64)
except:
    print('No validation data provided')

In [11]:
print(f'X_train; {X_train.shape}')
print(f'y_train; {y_train.shape}')
print(f'X_test; {X_test.shape}')
print(f'y_test; {y_test.shape}')
try:
    print(f'X_valid; {X_valid.shape}')
    print(f'y_valid; {y_valid.shape}')
    print(X_valid.shape[0]+X_test.shape[0]+X_train.shape[0])
except:
    print('No validation data provided')

X_train; (12000, 19, 306)
y_train; (12000,)
X_test; (6063, 19, 306)
y_test; (6063,)
X_valid; (2773, 19, 306)
y_valid; (2773,)
20836


In [12]:
train_set = SignalAndTarget(X_train, y=y_train)
valid_set = SignalAndTarget(X_test, y=y_test)

In [13]:
X = np.concatenate((X_train,X_test))
y = np.concatenate((y_train,y_test))
train_set = SignalAndTarget(X=X, y=y)
valid_set = SignalAndTarget(X_valid, y=y_valid)

# ShallowFBCSPNet model

In [14]:
# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = True
set_random_seeds(seed=20170629, cuda=cuda)
n_classes = 2
in_chans = train_set.X.shape[1]
# final_conv_length = auto ensures we only get a single output in the time dimension
model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,
                        input_time_length=train_set.X.shape[2],
                        final_conv_length='auto')
if cuda:
    model.cuda()
    

optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model
# optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0)
model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1,)

In [15]:
model.fit(train_set.X, train_set.y, epochs=10, batch_size=64, scheduler='cosine',
         validation_data=(valid_set.X, valid_set.y),)

2020-04-28 12:20:56,732 INFO |: Run until first stop...
2020-04-28 12:21:01,932 INFO |: Epoch 0
2020-04-28 12:21:01,933 INFO |: train_loss                4.47374
2020-04-28 12:21:01,934 INFO |: valid_loss                5.69890
2020-04-28 12:21:01,935 INFO |: train_misclass            0.49289
2020-04-28 12:21:01,935 INFO |: valid_misclass            0.65489
2020-04-28 12:21:01,936 INFO |: runtime                   0.00000
2020-04-28 12:21:01,936 INFO |: 
2020-04-28 12:21:11,297 INFO |: Time only for training updates: 9.36s
2020-04-28 12:21:16,465 INFO |: Epoch 1
2020-04-28 12:21:16,465 INFO |: train_loss                0.70020
2020-04-28 12:21:16,466 INFO |: valid_loss                2.38905
2020-04-28 12:21:16,467 INFO |: train_misclass            0.30582
2020-04-28 12:21:16,467 INFO |: valid_misclass            0.71114
2020-04-28 12:21:16,468 INFO |: runtime                   14.56537
2020-04-28 12:21:16,468 INFO |: 
2020-04-28 12:21:25,728 INFO |: Time only for training updates: 9.2

KeyboardInterrupt: 

In [46]:
model_path = './models/'
torch.save(model.network, model_path + f'ShallowFBCSPNet_{data_type}')

In [27]:
model_path = './models/'
pickle_out = open(model_path + f'ShallowFBCSPNet_{data_type}.pickle',"wb")
pickle.dump(model, pickle_out)

In [83]:
try:
    validation_set = SignalAndTarget(X=X_valid, y=y_valid)
    print(model.evaluate(validation_set.X, validation_set.y))
except:
    print('No validation data provided')

{'loss': 2.2557923793792725, 'misclass': 0.3764875586007934, 'runtime': 0.0009815692901611328}


In [74]:
print(model.predict_classes(validation_set.X[:20]))

[0 1 1 1 1 1 0 1 0 1 1 1 0 0 0 1 1 1 1 0]


In [75]:
print(validation_set.y[:20])

[1 1 1 0 1 0 1 1 0 1 1 0 1 1 0 1 1 0 1 0]


# Deep4Net model

In [18]:
# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = True
set_random_seeds(seed=20170629, cuda=cuda)

# This will determine how many crops are processed in parallel
input_time_length = 306
# final_conv_length determines the size of the receptive field of the ConvNet
model = Deep4Net(in_chans=19, n_classes=2, input_time_length=input_time_length,
                 filter_length_3=5, filter_length_4=5,
                 pool_time_stride=2,
                 stride_before_pool=True,
                        final_conv_length=1)
if cuda:
    model.cuda()
    
optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model
model.compile(loss=F.nll_loss, optimizer=optimizer,  iterator_seed=1, cropped=True)

In [20]:
input_time_length = 306
model.fit(train_set.X, train_set.y, epochs=10, batch_size=64, scheduler='cosine',
          input_time_length=input_time_length,
         validation_data=(valid_set.X, valid_set.y),)

2020-04-23 23:51:09,148 INFO |: Run until first stop...
2020-04-23 23:51:17,392 INFO |: Epoch 0
2020-04-23 23:51:17,393 INFO |: train_loss                25.65163
2020-04-23 23:51:17,393 INFO |: valid_loss                26.84682
2020-04-23 23:51:17,394 INFO |: train_misclass            0.50067
2020-04-23 23:51:17,395 INFO |: valid_misclass            0.52086
2020-04-23 23:51:17,395 INFO |: runtime                   0.00000
2020-04-23 23:51:17,396 INFO |: 
2020-04-23 23:52:35,254 INFO |: Time only for training updates: 77.10s
2020-04-23 23:52:43,688 INFO |: Epoch 1
2020-04-23 23:52:43,689 INFO |: train_loss                0.42233
2020-04-23 23:52:43,689 INFO |: valid_loss                0.63215
2020-04-23 23:52:43,690 INFO |: train_misclass            0.14608
2020-04-23 23:52:43,691 INFO |: valid_misclass            0.31667
2020-04-23 23:52:43,692 INFO |: runtime                   86.10577
2020-04-23 23:52:43,692 INFO |: 
2020-04-23 23:53:55,842 INFO |: Time only for training updates: 

KeyboardInterrupt: 

In [32]:
sum(validation_set.y)

671