In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import dlc_bci as bci
from dlc_practical_prologue import *

import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable

Load dataset (Dont convert to one hot labels if using cross entropy loss)

In [2]:
tr_input, tr_target = bci.load("bci", train=True, one_khz=False)
te_input, te_target = bci.load("bci", train=False, one_khz=False)

tr_target_onehot = convert_to_one_hot_labels(tr_input, tr_target)
te_target_onehot = convert_to_one_hot_labels(te_input, te_target)

Standardize

In [3]:
tr_input = torch.nn.functional.normalize(tr_input, p=2, dim=0) 
te_input = torch.nn.functional.normalize(te_input, p=2, dim=0) 

tr_input, tr_target, tr_target_onehot = Variable(tr_input.view(-1,28*50)), Variable(tr_target), Variable(tr_target_onehot)
te_input, te_target, te_target_onehot = Variable(te_input.view(-1,28*50)), Variable(te_target), Variable(te_target_onehot)

Output of NN is still a vector (cross entropy loss handles log-softmax)

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 50, 200)
        self.fc2 = nn.Linear(200, 2)
        
    def forward(self, x, mode=False):
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=mode)
        x = self.fc2(x)
        #x = F.softmax(x)
        return x

In [5]:
# construct and train model
model = Net()
tr_loss, te_loss = bci.train_model(model, tr_input, tr_target, 4, te_input, te_target, 4, 50)
#torch.save(model.state_dict(), os.getcwd() + "v1.pth")

# compute train and test errors
nb_tr_errors = bci.compute_nb_errors(model, tr_input, tr_target_onehot, 4)
nb_te_errors = bci.compute_nb_errors(model, te_input, te_target_onehot, 4)

print('tr error {:0.2f}% {:d}/{:d}'.format((100 * nb_tr_errors) / tr_input.size(0), nb_tr_errors, tr_input.size(0)))
print('te error {:0.2f}% {:d}/{:d}'.format((100 * nb_te_errors) / te_input.size(0), nb_te_errors, te_input.size(0)))

epoch 0 tr loss 55.19 te loss 17.36
epoch 1 tr loss 54.94 te loss 17.37
epoch 2 tr loss 54.74 te loss 17.47
epoch 3 tr loss 54.62 te loss 17.52
epoch 4 tr loss 54.65 te loss 17.52
epoch 5 tr loss 54.23 te loss 17.65
epoch 6 tr loss 54.55 te loss 17.79
epoch 7 tr loss 53.84 te loss 18.08
epoch 8 tr loss 53.55 te loss 18.16
epoch 9 tr loss 53.33 te loss 18.83
epoch 10 tr loss 53.43 te loss 19.14
epoch 11 tr loss 53.11 te loss 19.94
epoch 12 tr loss 53.16 te loss 20.03
epoch 13 tr loss 52.46 te loss 20.60
epoch 14 tr loss 53.26 te loss 20.16
epoch 15 tr loss 51.62 te loss 21.80
epoch 16 tr loss 52.16 te loss 21.69
epoch 17 tr loss 52.26 te loss 22.01
epoch 18 tr loss 51.41 te loss 23.27
epoch 19 tr loss 50.78 te loss 23.17
epoch 20 tr loss 51.78 te loss 24.47
epoch 21 tr loss 50.84 te loss 23.84
epoch 22 tr loss 50.09 te loss 26.10
epoch 23 tr loss 50.80 te loss 24.70
epoch 24 tr loss 49.99 te loss 25.42
epoch 25 tr loss 49.39 te loss 26.01
epoch 26 tr loss 48.90 te loss 28.50
epoch 27 tr

In [None]:
import matplotlib as mpl
mpl.rcParams['figure.dpi']= 150

plt.figure()
plt.plot(tr_loss, label='training loss')
plt.plot(te_loss, label='validation loss')
plt.legend(loc='upper left')