In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import dlc_bci as bci
from dlc_practical_prologue import *

import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable

Load dataset

In [2]:
tr_input, tr_target = bci.load("bci", train=True, one_khz=True)
te_input, te_target = bci.load("bci", train=False, one_khz=True)

# convert to one hot labels
tr_target = convert_to_one_hot_labels(tr_input, tr_target)
te_target = convert_to_one_hot_labels(te_input, te_target)

#b = np.zeros((np.shape(tr_target.numpy())[0], np.amax(tr_target.numpy()) + 1))
#b[np.arange(np.shape(tr_target.numpy())[0]), tr_target] = 1
#tr_target = torch.from_numpy(b)
#b = np.zeros((np.shape(te_target.numpy())[0], np.amax(te_target.numpy()) + 1))
#b[np.arange(np.shape(te_target.numpy())[0]), te_target] = 1
#te_target = torch.from_numpy(b)

# normalize
#tr_input = torch.nn.functional.normalize(tr_input, p=2, dim=0) 
#te_input = torch.nn.functional.normalize(te_input, p=2, dim=0) 

# cnn expects 4D tensor
tr_input = tr_input[:, np.newaxis, :, :]
te_input = te_input[:, np.newaxis, :, :]

# convert to pytorch variable
tr_input, tr_target = Variable(tr_input), Variable(tr_target)
te_input, te_target = Variable(te_input), Variable(te_target)

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1a = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=(28, 3), stride=3)
        self.conv1b = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=(28, 5), stride=3)
        self.conv1c = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=(28, 7), stride=3)
        self.conv1d = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=(28, 9), stride=3)
        
        self.fc1 = nn.Linear(400, 2)
        
    def forward(self, x, mode=False):
        
        # convolutions
        x1a = self.conv1a(x)
        x1b = self.conv1b(x)
        x1c = self.conv1c(x)
        x1d = self.conv1d(x)
        
        # max poolings
        x2a = nn.functional.max_pool2d(x1a, kernel_size=(1, 166))
        x2b = nn.functional.max_pool2d(x1b, kernel_size=(1, 166))
        x2c = nn.functional.max_pool2d(x1c, kernel_size=(1, 165))
        x2d = nn.functional.max_pool2d(x1d, kernel_size=(1, 164))
               
        # concatenations
        x = torch.cat((x2a, x2b, x2c, x2d), dim = 1)
        x = x.view(4,-1)
    
        # dropout to fully connected layers
        x = nn.functional.dropout(x, p=0.5, training=mode)
        x = self.fc1(x)
        x = nn.functional.softmax(x, dim=1)
        return x

In [4]:
# construct and train model
model = Net()
tr_loss, te_loss = bci.train_model(model, tr_input, tr_target, 4, te_input, te_target, 4, 2)

# save model and losses
torch.save(model.state_dict(), os.getcwd() + "v2.pth")

# compute train and test errors
nb_tr_errors = bci.compute_nb_errors(model, tr_input, tr_target, 4)
nb_te_errors = bci.compute_nb_errors(model, te_input, te_target, 4)

print('tr error {:0.2f}% {:d}/{:d}'.format((100 * nb_tr_errors) / tr_input.size(0), nb_tr_errors, tr_input.size(0)))
print('te error {:0.2f}% {:d}/{:d}'.format((100 * nb_te_errors) / te_input.size(0), nb_te_errors, te_input.size(0)))

epoch 0 tr loss 117.56 te loss 38.00
epoch 1 tr loss 117.50 te loss 38.00
epoch 2 tr loss 117.50 te loss 38.00
epoch 3 tr loss 118.50 te loss 38.00
epoch 4 tr loss 118.00 te loss 38.00
epoch 5 tr loss 118.50 te loss 38.00
epoch 6 tr loss 118.00 te loss 38.00
epoch 7 tr loss 119.00 te loss 38.00
epoch 8 tr loss 118.50 te loss 38.00
epoch 9 tr loss 118.00 te loss 38.00
epoch 10 tr loss 117.50 te loss 38.00
epoch 11 tr loss 118.50 te loss 38.00
epoch 12 tr loss 118.50 te loss 38.00
epoch 13 tr loss 118.50 te loss 38.00
epoch 14 tr loss 118.00 te loss 38.00
epoch 15 tr loss 118.50 te loss 36.00
epoch 16 tr loss 117.50 te loss 36.00
epoch 17 tr loss 117.50 te loss 36.00
epoch 18 tr loss 117.00 te loss 36.00
epoch 19 tr loss 117.50 te loss 36.00
epoch 20 tr loss 116.50 te loss 36.00
epoch 21 tr loss 117.00 te loss 36.00
epoch 22 tr loss 116.00 te loss 36.00
epoch 23 tr loss 117.00 te loss 36.00
epoch 24 tr loss 117.00 te loss 36.00
tr error 48.42% 153/316
te error 47.00% 47/100


In [None]:
import matplotlib as mpl
mpl.rcParams['figure.dpi']= 150

plt.figure()
plt.plot(tr_loss, label='training loss')
plt.plot(te_loss, label='validation loss')
plt.legend(loc='upper left')