In [None]:
from common import Trial, safe_log, nll, channel_map, load_df
import numpy as np
import os
import pandas as pd
import sys
from sklearn.model_selection import train_test_split
import keras
from keras.layers import Dense, Flatten, BatchNormalization, Dropout, Lambda
from keras.layers import Conv2D, AveragePooling2D
from keras.models import Sequential
import matplotlib.pylab as plt
import tensorflow as tf
from keras.callbacks import TensorBoard
import logging
import sys
import IPython
import pickle


logging.basicConfig(format='%(asctime)s | %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
log = logging.getLogger()
log.setLevel(logging.INFO)


config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.2
config.gpu_options.allow_growth = True
keras.backend.tensorflow_backend.set_session(tf.Session(config=config))

In [None]:
df = load_df()

In [None]:
sys.getsizeof(df)

In [None]:
df.head()

In [None]:
df.describe()

In [None]:
log.info((len(df.loc[df['subject_class'] == 1])))
log.info((len(df.loc[df['subject_class'] == 0])))


In [None]:
X = df['eeg'].values
y = df['subject_class'].values

# keras required format
X = np.rollaxis(np.dstack(X), -1)
X = X.reshape(X.shape[0], 64, 256, 1)
y = keras.utils.to_categorical(y, 2)

In [None]:
X.shape

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)

In [None]:
"""
Sequential(
  (dimshuffle): Expression(expression=_transpose_time_to_spat)
  (conv_time): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1))
  (conv_spat): Conv2d(40, 40, kernel_size=(1, 64), stride=(1, 1), bias=False)
  (bnorm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_nonlin): Expression(expression=square)
  (pool): AvgPool2d(kernel_size=(75, 1), stride=(1, 1), padding=0)
  (pool_nonlin): Expression(expression=safe_log)
  (drop): Dropout(p=0.5)
  (conv_classifier): Conv2d(40, 2, kernel_size=(30, 1), stride=(1, 1), dilation=(15, 1))
  (softmax): LogSoftmax()
  (squeeze): Expression(expression=_squeeze_final_output)
)
"""

input_shape = (64, 256, 1)
num_classes = 2
batch_size=128
epochs=150

model = Sequential()
model.add(Conv2D(40, kernel_size=(1, 25),
                 input_shape=input_shape))
model.add(BatchNormalization(momentum=0.1))
model.add(Conv2D(40, kernel_size=(64, 1)))
model.add(BatchNormalization(momentum=0.1))
model.add(Lambda(lambda x: x ** 2))
model.add(AveragePooling2D(pool_size=(1, 75), strides=(1, 1)))
model.add(Lambda(lambda x: safe_log(x)))
model.add(BatchNormalization(momentum=0.1))
#model.add(Dropout(0.5))
model.add(Conv2D(2, kernel_size=(1, 30), dilation_rate=(15, 1)))
model.add(BatchNormalization(momentum=0.1))
model.add(Flatten())
model.add(Dense(num_classes, activation='softmax'))


In [None]:
model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])

model.summary()

In [None]:
class AccuracyHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.acc = []

    def on_epoch_end(self, batch, logs={}):
        self.acc.append(logs.get('val_acc'))
        
history = AccuracyHistory()



In [None]:
tensor_board = TensorBoard('./logs/baseline_shallow_batch_normalization')

model.fit(X_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(X_test, y_test),
          callbacks=[history, tensor_board])


score = model.evaluate(X_test, y_test, verbose=0)
log.info('Test loss:', score[0])
log.info('Test accuracy:', score[1])

plt.plot(range(1,len(history.acc) + 1), history.acc)
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.show()

In [None]:
IPython.display.Audio("F:\\Tresorit\\01 - Startup Screen.mp3", autoplay=True)


In [None]:
weights = model.get_weights()

%store weights

In [None]:

with open('json_weights_bn.pkl', 'wb') as f:
    pickle.dump((model.to_json(), model.get_weights()), f)

In [None]:
model.evaluate(X_train, y_train)

In [None]:
model.get