# Data

## Download dataset

In [1]:
from datautils import downloader

path  = 'data'
dpath = downloader(path, subjects=list(range(1,10)))

Subjects [1 2 3 4 5 6 7 8 9] are now in data


## Extracting .mat files

In [2]:
from datautils import mat_extractor

tr_name = 'A01T.mat'
te_name = 'A01E.mat'
x_train, y_train = mat_extractor(path=dpath/tr_name)
x_test , y_test  = mat_extractor(path=dpath/te_name)

print (f'*** Shapes ***\nx_train:\t{x_train.shape}\ny_train:\t{y_train.shape}')
print (f'x_test:\t\t{x_test.shape}\ny_test:\t\t{y_test.shape}')

*** Shapes ***
x_train:	(288, 22, 1000)
y_train:	(288,)
x_test:		(288, 22, 1000)
y_test:		(288,)


## cropping for augmentation

In [3]:
from datautils import cropper

x_train, y_train = cropper(x_train, y_train, window=500, step=500)
x_test , y_test  = cropper(x_test , y_test, window=500, step=500)

print (f'*** Shapes ***\nx_train:\t{x_train.shape}\ny_train:\t{y_train.shape}')
print (f'x_test:\t\t{x_test.shape}\ny_test:\t\t{y_test.shape}')

*** Shapes ***
x_train:	(576, 22, 500)
y_train:	(576,)
x_test:		(576, 22, 500)
y_test:		(576,)


## Split for validation

In [14]:
from sklearn.model_selection import train_test_split

x_tr, x_va, y_tr, y_va = train_test_split(x_train, y_train-1, test_size=0.25, random_state=216, shuffle=True)

print (x_tr.shape, x_va.shape)

(432, 22, 500) (144, 22, 500)


# EEGNet

In [15]:
import numpy as np
import torch

# adding a second dimension, becasue we have Conv2d in our structure
x_tr = np.expand_dims(x_tr, axis=1)
x_va = np.expand_dims(x_va, axis=1)

x_tr, x_va, y_tr, y_va = map(torch.tensor, [x_tr, x_va, y_tr, y_va])
print (x_tr.shape, x_va.shape)

torch.Size([432, 1, 22, 500]) torch.Size([144, 1, 22, 500])


In [27]:
from models import EEGNet
from fitting import train

if torch.cuda.is_available():
    target_device = 'cuda'
else:
    target_device = 'cpu'
    
model = EEGNet().to(target_device)
hist = train(model, x_tr, y_tr, x_va, y_va, batch_size=144, epochs=100, learning_rate=0.001, period=10)

*** Epoch: 1 ***
Train Loss: 1.3886 --- Train Acc 25.23
Valid Loss: 1.3861 --- Valid Acc: 29.17
*** Epoch: 10 ***
Train Loss: 1.2209 --- Train Acc 53.01
Valid Loss: 1.2801 --- Valid Acc: 42.36
*** Epoch: 20 ***
Train Loss: 1.0246 --- Train Acc 54.17
Valid Loss: 1.1013 --- Valid Acc: 45.14
*** Epoch: 30 ***
Train Loss: 0.9436 --- Train Acc 58.80
Valid Loss: 1.0585 --- Valid Acc: 46.53
*** Epoch: 40 ***
Train Loss: 0.8861 --- Train Acc 65.05
Valid Loss: 1.0323 --- Valid Acc: 49.31
*** Epoch: 50 ***
Train Loss: 0.8432 --- Train Acc 66.90
Valid Loss: 0.9991 --- Valid Acc: 52.08
*** Epoch: 60 ***
Train Loss: 0.8035 --- Train Acc 69.68
Valid Loss: 0.9640 --- Valid Acc: 54.17
*** Epoch: 70 ***
Train Loss: 0.7710 --- Train Acc 72.92
Valid Loss: 0.9310 --- Valid Acc: 56.25
*** Epoch: 80 ***
Train Loss: 0.7238 --- Train Acc 73.15
Valid Loss: 0.9075 --- Valid Acc: 59.72
*** Epoch: 90 ***
Train Loss: 0.7151 --- Train Acc 71.53
Valid Loss: 0.9020 --- Valid Acc: 60.42
*** Epoch: 100 ***
Train Loss: 