In [1]:
import mne
import pickle

In [2]:
with open('bcci_data_preprocessed.pkl','rb') as f:
    dataset=pickle.load(f)

print("Preprocessed data has been loaded from 'bcci_data_preprocessed.pkl'")

  "class": algorithms.Blowfish,


Preprocessed data has been loaded from 'bcci_data_preprocessed.pkl'


In [3]:
input_window_samples = 1000

In [4]:
import torch

from braindecode.models import Deep4Net
from braindecode.util import set_random_seeds

cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
device = 'cuda' if cuda else 'cpu'
if cuda:
    torch.backends.cudnn.benchmark = True
# Set random seed to be able to roughly reproduce results
# Note that with cudnn benchmark set to True, GPU indeterminism
# may still make results substantially different between runs.
# To obtain more consistent results at the cost of increased computation time,
# you can set `cudnn_benchmark=False` in `set_random_seeds`
# or remove `torch.backends.cudnn.benchmark = True`
seed = 20200220
set_random_seeds(seed=seed, cuda=cuda)

n_classes = 4
classes = list(range(n_classes))
# Extract number of chans from dataset
n_chans = dataset[0][0].shape[0]



In [7]:
model = Deep4Net(
    n_chans,
    n_classes,
    input_window_samples=input_window_samples,
    final_conv_length='auto',
    batch_norm_alpha=0
)

# Display torchinfo table describing the model
print(model)

# Send model to GPU
if cuda:
    _ = model.cuda()

Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
Deep4Net (Deep4Net)                      [1, 23, 1000]             [1, 4]                    --                        --
├─Ensure4d (ensuredims): 1-1             [1, 23, 1000]             [1, 23, 1000, 1]          --                        --
├─Rearrange (dimshuffle): 1-2            [1, 23, 1000, 1]          [1, 1, 1000, 23]          --                        --
├─CombinedConv (conv_time_spat): 1-3     [1, 1, 1000, 23]          [1, 25, 991, 1]           14,650                    --
├─BatchNorm2d (bnorm): 1-4               [1, 25, 991, 1]           [1, 25, 991, 1]           50                        --
├─Expression (conv_nonlin): 1-5          [1, 25, 991, 1]           [1, 25, 991, 1]           --                        --
├─MaxPool2d (pool): 1-6                  [1, 25, 991, 1]           [1, 25, 330, 1]           --                        [3, 1]
├─Expressi

In [8]:
model.to_dense_prediction_model()

In [9]:
n_preds_per_input = model.get_output_shape()[2]

In [10]:
from braindecode.preprocessing import create_windows_from_events

trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info['sfreq']
assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])

# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
    dataset,
    trial_start_offset_samples=trial_start_offset_samples,
    trial_stop_offset_samples=0,
    window_size_samples=input_window_samples,
    window_stride_samples=n_preds_per_input,
    drop_last_window=False,
    preload=True
)

Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']

Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']


In [11]:
splitted = windows_dataset.split('session')
train_set = splitted['0train']  # Session train
valid_set = splitted['1test'] 

In [12]:
from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split

from braindecode import EEGClassifier
from braindecode.training import CroppedLoss

# For deep4 they should be:
lr = 1 * 0.01
weight_decay = 0.5 * 0.001

batch_size=64
n_epochs=30

In [13]:
clf = EEGClassifier(
    model,
    cropped=True,
    criterion=CroppedLoss,
    criterion__loss_function=torch.nn.functional.nll_loss,
    optimizer=torch.optim.AdamW,
    train_split=predefined_split(valid_set),
    optimizer__lr=lr,
    optimizer__weight_decay=weight_decay,
    iterator_train__shuffle=True,
    batch_size=batch_size,
    callbacks=[
        "accuracy",
        ("lr_schedu6ler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
    ],
    device=device,
    classes=classes,
)
# Model training for a specified number of epochs. `y` is None as it is already supplied
# in the dataset.
_ = clf.fit(train_set, y=None, epochs=n_epochs)

  epoch    train_accuracy    train_loss    valid_accuracy    valid_loss      lr       dur
-------  ----------------  ------------  ----------------  ------------  ------  --------
      1            [36m0.2909[0m        [32m1.8675[0m            [35m0.2990[0m      [31m270.3975[0m  0.0100  973.5391
      2            0.2500        [32m1.4941[0m            0.2500     1353.4364  0.0100  641.8380
      3            0.2535        [32m1.3362[0m            0.2519      554.2991  0.0099  630.1099
      4            0.2581        [32m1.2333[0m            0.2531      621.9261  0.0097  2016.2003
      5            0.2654        [32m1.1677[0m            0.2600      521.8598  0.0095  2009.0304
      6            0.2500        [32m1.1106[0m            0.2496      725.6219  0.0093  1588.7855
      7            0.2500        [32m1.0926[0m            0.2515     1903.9499  0.0090  591.6234
      8            0.2500        [32m1.0529[0m            0.2504     3242.8428  0.0086  590.859

In [14]:
import joblib

In [15]:
file1='Deep_clf_no_batch_norm.sav'

In [16]:
joblib.dump(clf,file1)

['Deep_clf_no_batch_norm.sav']