In [86]:
#!/usr/bin/env python
# coding: utf-8
'''Subject-adaptative classification with KU Data,
using Deep ConvNet model from [1].

References
----------
.. [1] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J.,
   Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F. & Ball, T. (2017).
   Deep learning with convolutional neural networks for EEG decoding and
   visualization.
   Human Brain Mapping , Aug. 2017. Online: http://dx.doi.org/10.1002/hbm.23730
'''
import argparse
import json
import logging
import sys
from os.path import join as pjoin

import h5py
import torch
import torch.nn.functional as F
from braindecode.models.deep4 import Deep4Net
from braindecode.torch_ext.optimizers import AdamW
from braindecode.torch_ext.util import set_random_seeds
from torch import nn
import numpy as np

datapath = "dataset-h5-2/KU_mi_smt.h5"
outpath = "model-predictions"
modelpath = "pretrained_models"
scheme = 4
rate = 100
lr = 0.0005
dfile = h5py.File(datapath, 'r')
torch.cuda.set_device(0)
set_random_seeds(seed=20200205, cuda=True)
BATCH_SIZE = 16
TRAIN_EPOCH = 200

subjs = [1]

# Get data from single subject.
def get_data(subj):
    dpath = '/s' + str(subj)
    X = dfile["s1/X"]
    Y = dfile["s1/Y"]
    return X[:], Y[:]


X, Y = get_data(subjs[0])
n_classes = 2
in_chans = X.shape[1]
# final_conv_length = auto ensures we only get a single output in the time dimension
model = Deep4Net(in_chans=in_chans, n_classes=n_classes,
                 input_time_length=X.shape[2],
                 final_conv_length='auto').cuda()

# Deprecated.


def reset_conv_pool_block(network, block_nr):
    suffix = "_{:d}".format(block_nr)
    conv = getattr(network, 'conv' + suffix)
    kernel_size = conv.kernel_size
    n_filters_before = conv.in_channels
    n_filters = conv.out_channels
    setattr(network, 'conv' + suffix,
            nn.Conv2d(
                n_filters_before,
                n_filters,
                kernel_size,
                stride=(1, 1),
                bias=False,
            ))
    setattr(network, 'bnorm' + suffix,
            nn.BatchNorm2d(
                n_filters,
                momentum=0.1,
                affine=True,
                eps=1e-5,
            ))
    # Initialize the layers.
    conv = getattr(network, 'conv' + suffix)
    bnorm = getattr(network, 'bnorm' + suffix)
    nn.init.xavier_uniform_(conv.weight, gain=1)
    nn.init.constant_(bnorm.weight, 1)
    nn.init.constant_(bnorm.bias, 0)


def reset_model(checkpoint):
    # Load the state dict of the model.
    model.network.load_state_dict(checkpoint['model_state_dict'])

    # # Resets the last conv block
    # reset_conv_pool_block(model.network, block_nr=4)
    # reset_conv_pool_block(model.network, block_nr=3)
    # reset_conv_pool_block(model.network, block_nr=2)
    # # Resets the fully-connected layer.
    # # Parameters of newly constructed modules have requires_grad=True by default.
    # n_final_conv_length = model.network.conv_classifier.kernel_size[0]
    # n_prev_filter = model.network.conv_classifier.in_channels
    # n_classes = model.network.conv_classifier.out_channels
    # model.network.conv_classifier = nn.Conv2d(
    #     n_prev_filter, n_classes, (n_final_conv_length, 1), bias=True)
    # nn.init.xavier_uniform_(model.network.conv_classifier.weight, gain=1)
    # nn.init.constant_(model.network.conv_classifier.bias, 0)

    if scheme != 5:
        # Freeze all layers.
        for param in model.network.parameters():
            param.requires_grad = False

        if scheme in {1, 2, 3, 4}:
            # Unfreeze the FC layer.
            for param in model.network.conv_classifier.parameters():
                param.requires_grad = True

        if scheme in {2, 3, 4}:
            # Unfreeze the conv4 layer.
            for param in model.network.conv_4.parameters():
                param.requires_grad = True
            for param in model.network.bnorm_4.parameters():
                param.requires_grad = True

        if scheme in {3, 4}:
            # Unfreeze the conv3 layer.
            for param in model.network.conv_3.parameters():
                param.requires_grad = True
            for param in model.network.bnorm_3.parameters():
                param.requires_grad = True

        if scheme == 4:
            # Unfreeze the conv2 layer.
            for param in model.network.conv_2.parameters():
                param.requires_grad = True
            for param in model.network.bnorm_2.parameters():
                param.requires_grad = True

    # Only optimize parameters that requires gradient.
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.network.parameters()),
                      lr=lr, weight_decay=0.5*0.001)
    model.compile(loss=F.nll_loss, optimizer=optimizer,
                  iterator_seed=20200205, )

cutoff = int(rate * 200 / 100)
# Use only session 1 data for training
assert(cutoff <= 200)

In [87]:
fold = 0
subj = 1
suffix = '_s' + str(subj) + '_f' + str(fold)
checkpoint = torch.load(pjoin(modelpath, 'model_f' + str(fold) + '.pt'),
                        map_location='cuda:' + '0')
print("checkpoint loaded")
reset_model(checkpoint)
print("model reset")
X, Y = get_data(subj)
X_train, Y_train = X[:cutoff], Y[:cutoff]
X_val, Y_val = X[200:300], Y[200:300]
X_test, Y_test = X[300:], Y[300:]
print("data loaded")
model.fit(X_train, Y_train, epochs=TRAIN_EPOCH,
          batch_size=BATCH_SIZE, scheduler='cosine',
          validation_data=(X_val, Y_val), remember_best_column='valid_loss')
print("model fitted")
model.epochs_df.to_csv(pjoin(outpath, 'epochs' + suffix + '.csv')) # crashes here
print("did something")


checkpoint loaded
model reset
data loaded
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.


FileNotFoundError: [Errno 2] No such file or directory: 'model-predictions\\epochs_s1_f0.csv'

In [88]:
test_loss = model.evaluate(X_test, Y_test)
print(test_loss)


.
.
.
.
.
.
{'loss': 0.2673458755016327, 'misclass': 0.07999999999999996, 'runtime': 0.0009851455688476562}


In [89]:
X, Y = get_data(subjs[0])
X_channeled = X[:,(0, 8, 16, 24, 32, 40, 48, 56),:]
print(X_channeled.shape)

X_downscaled = []
for i in range(0, 1000, 4):
    X_downscaled.append(X_channeled[:,:,i])
print(len(X_downscaled))
X_downscaled = np.asarray(X_downscaled)
X_downscaled = np.moveaxis(X_downscaled[:,:,:], [0],[-1])
print(X_downscaled.shape)

X_test, Y_test = X_downscaled[300:], Y[300:]
print(X_test.shape)


(400, 8, 1000)
250
(400, 8, 250)
(100, 8, 250)


In [90]:
test_loss = model.evaluate(X_test, Y_test)
print(test_loss)

.


RuntimeError: Calculated padded input size per channel: (241 x 8). Kernel size: (1 x 62). Kernel size can't be greater than actual input size