In [1]:
# %load import_setup.py
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf

import os
os.chdir("../")

%load_ext autoreload
%autoreload 2

RANDOM_STATE = 42  # for reproducibility
np.random.seed(RANDOM_STATE)

NUM_ELECTRODES = 64
FS = 240          # Hz
NUM_TRAIN_LETTERS = 85
NUM_TEST_LETTERS = 100
NUM_ROWCOLS = 12    # col:1-6 row:7-12
NUM_REPEAT = 15
SECONDS_TO_SLICE = 0.65    # after the simulation 0.65s data is treated as a sample
DATA_DIR = "./data/raw/BCI_Comp_III_Wads_2004/"

In [2]:
# %load data_preparation.py
from src import pipelines

A_train = pipelines.signal_mat_sub_band_norm(DATA_DIR+"Subject_A_Train")
A_test = pipelines.signal_mat_sub_band_norm(DATA_DIR + "Subject_A_Test")

A_train['signal'] = A_train['signal'].reshape([-1,NUM_ELECTRODES,A_train['signal'].shape[-1],1])
A_test['signal'] = A_test['signal'].reshape([-1,NUM_ELECTRODES,A_test['signal'].shape[-1],1]) 

from keras import layers, models
from keras.utils import to_categorical

A_train['label'] = to_categorical(A_train['label']) 

from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(A_train['signal'], A_train['label'], test_size=0.05,
                                                  random_state=RANDOM_STATE, stratify=A_train['label'])



Using TensorFlow backend.


In [3]:
from src.models.PAMI import CNN_1_P300_PAMI_BCIIII
model_PAMI = CNN_1_P300_PAMI_BCIIII(Ns = 10,seconds_to_slice=SECONDS_TO_SLICE)
model_PAMI.compile(optimizer = 'sgd',loss = 'binary_crossentropy',metrics = ['acc', 'mse'])

In [4]:
from keras.callbacks import EarlyStopping
earlystopping = EarlyStopping(monitor = "val_mean_squared_error",patience = 10)

In [5]:
model_PAMI.fit(x = X_train, y = y_train, batch_size=32, epochs = 1000, callbacks = [earlystopping], 
              validation_data = [X_val,y_val])

Train on 14535 samples, validate on 765 samples
Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000


<keras.callbacks.History at 0x16a7212cf98>

In [7]:
from src.pipelines import test_pipeline,PARADIGM,accuracy
predictions = test_pipeline(A_test['signal'],A_test['code'],model_PAMI,15,PARADIGM)

from src.pipelines import accuracy
accuracy(predictions,'WQXPLZCOMRKO97YFZDEZ1DPI9NNVGRQDJCUVRMEUOOOJD2UFYPOO6J7LDGYEGOA5VHNEHBTXOO1TDOILUEE5BFAEEXAW_K4R3MRU')

0.92

In [8]:
model_PAMI_2 = CNN_1_P300_PAMI_BCIIII(Ns = 10,seconds_to_slice=SECONDS_TO_SLICE)
model_PAMI_2.compile(optimizer = 'sgd',loss = 'binary_crossentropy',metrics = ['acc', 'mse'])

earlystopping_2 = EarlyStopping(monitor = "val_mean_squared_error",patience = 3)

model_PAMI_2.fit(x = X_train, y = y_train, batch_size=32, epochs = 1000, callbacks = [earlystopping_2], 
              validation_data = [X_val,y_val])

predictions = test_pipeline(A_test['signal'],A_test['code'],model_PAMI_2,15,PARADIGM)
accuracy(predictions,'WQXPLZCOMRKO97YFZDEZ1DPI9NNVGRQDJCUVRMEUOOOJD2UFYPOO6J7LDGYEGOA5VHNEHBTXOO1TDOILUEE5BFAEEXAW_K4R3MRU')

Train on 14535 samples, validate on 765 samples
Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000


0.93

In [12]:
model_PAMI_2.history.history

{'acc': [0.82583419335925956,
  0.82903336774536041,
  0.82903336773305814,
  0.82903336774125969,
  0.82903336774125969,
  0.82906776749066646,
  0.83020295837633296,
  0.83130374957820496,
  0.83278293776930423,
  0.83477812180783195,
  0.83543171657907378,
  0.83708290333677327,
  0.83814929483844769,
  0.83952528383053571,
  0.84031647748458327,
  0.84103887168902647,
  0.84375644998120658,
  0.84413484692122465,
  0.84434124527003784,
  0.8458548331613347,
  0.84554523565861872,
  0.84764361888385531,
  0.84733402136063551,
  0.84919160646714831,
  0.84884760921912628,
  0.85018919848641206,
  0.85053319575493791,
  0.85005159959540488],
 'loss': [0.46286677810839866,
  0.4482021279174272,
  0.44180694344448068,
  0.43489483228695935,
  0.426855905699,
  0.41820377556983246,
  0.41035331958064847,
  0.40366161060284045,
  0.39829230537765575,
  0.39400341278464268,
  0.39045997343401256,
  0.38742242677333966,
  0.38433829045271112,
  0.38173253696154269,
  0.3786403203363225,
  0