In [1]:
import pickle
import numpy as np

participantCount = 32

rawData = [{}]*participantCount

for i in range(participantCount):
    rawData[i] = pickle.load(open(f'../data/s{i+1:02}.dat', 'rb'), encoding="latin1")

labels = np.array(list(map(lambda participant : participant['labels'], rawData)))
data = np.array(list(map(lambda participant : participant['data'], rawData)))

def get_y(emotion):
    return (labels.reshape(-1, 4)[:,emotion] >= 5).astype(int)

def get_eeg_x():
    return data[:, :, :32, :].reshape(-1, 32, 8064)

In [2]:
from keras import backend as K

def recall_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def precision_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

def f1_m(y_true, y_pred):
    precision = precision_m(y_true, y_pred)
    recall = recall_m(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+K.epsilon()))

In [3]:
from matplotlib import pyplot
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from sklearn.model_selection import KFold
from matplotlib import pyplot
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM

import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)


Y = get_y(0)
X = np.moveaxis(get_eeg_x(), 1, 2)

for i in [0, 1, 3]:
    print("EMOTION NUMBER", i)
    Y = get_y(i)
    kfold = KFold(n_splits=5, shuffle=False)
    cvscores = []
    for train, test in kfold.split(X, Y):
        model = Sequential()
        model.add(LSTM(20, input_dim=32))
        model.add(Dense(1, activation='sigmoid'))
        model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['acc',f1_m,precision_m, recall_m])

        model.fit(X[train], Y[train], validation_data=(X[test], Y[test]), epochs=20, batch_size=120, verbose=2, shuffle=False)

        scores = model.evaluate(X[test], Y[test], verbose=0)
        print("%s: %.2f%%, %s %.2f%%" % (model.metrics_names[1], scores[1]*100, model.metrics_names[2], scores[2]*100))
        cvscores.append(scores[1] * 100)

    print("%.2f%% (+/- %.2f%%)" % (np.mean(cvscores), np.std(cvscores)))    
    print("EMOTION", i)

EMOTION NUMBER 0
Epoch 1/20
9/9 - 3s - loss: 0.7416 - acc: 0.4795 - f1_m: 0.5142 - precision_m: 0.5261 - recall_m: 0.5133 - val_loss: 0.7556 - val_acc: 0.4883 - val_f1_m: 0.5830 - val_precision_m: 0.6898 - val_recall_m: 0.5328
Epoch 2/20
9/9 - 3s - loss: 0.7056 - acc: 0.5312 - f1_m: 0.5663 - precision_m: 0.5745 - recall_m: 0.5667 - val_loss: 0.7525 - val_acc: 0.4961 - val_f1_m: 0.5996 - val_precision_m: 0.6930 - val_recall_m: 0.5630
Epoch 3/20
9/9 - 2s - loss: 0.6848 - acc: 0.5615 - f1_m: 0.5957 - precision_m: 0.6006 - recall_m: 0.5999 - val_loss: 0.7522 - val_acc: 0.4922 - val_f1_m: 0.6007 - val_precision_m: 0.6904 - val_recall_m: 0.5677
Epoch 4/20
9/9 - 2s - loss: 0.6688 - acc: 0.5811 - f1_m: 0.6162 - precision_m: 0.6169 - recall_m: 0.6231 - val_loss: 0.7528 - val_acc: 0.4883 - val_f1_m: 0.5987 - val_precision_m: 0.6876 - val_recall_m: 0.5683
Epoch 5/20
9/9 - 2s - loss: 0.6554 - acc: 0.6064 - f1_m: 0.6400 - precision_m: 0.6427 - recall_m: 0.6451 - val_loss: 0.7535 - val_acc: 0.4883 -

Epoch 20/20
9/9 - 2s - loss: 0.5460 - acc: 0.7793 - f1_m: 0.8172 - precision_m: 0.7503 - recall_m: 0.8999 - val_loss: 0.7264 - val_acc: 0.5312 - val_f1_m: 0.5207 - val_precision_m: 0.5066 - val_recall_m: 0.5395
acc: 53.12%, f1_m 56.13%
Epoch 1/20
9/9 - 3s - loss: 0.7368 - acc: 0.4814 - f1_m: 0.4709 - precision_m: 0.5230 - recall_m: 0.4414 - val_loss: 0.7469 - val_acc: 0.4648 - val_f1_m: 0.4392 - val_precision_m: 0.5524 - val_recall_m: 0.3689
Epoch 2/20
9/9 - 2s - loss: 0.7084 - acc: 0.5186 - f1_m: 0.5200 - precision_m: 0.5638 - recall_m: 0.4947 - val_loss: 0.7456 - val_acc: 0.4766 - val_f1_m: 0.4388 - val_precision_m: 0.5182 - val_recall_m: 0.3830
Epoch 3/20
9/9 - 2s - loss: 0.6915 - acc: 0.5439 - f1_m: 0.5449 - precision_m: 0.5878 - recall_m: 0.5208 - val_loss: 0.7449 - val_acc: 0.4805 - val_f1_m: 0.4512 - val_precision_m: 0.5577 - val_recall_m: 0.3860
Epoch 4/20
9/9 - 2s - loss: 0.6785 - acc: 0.5596 - f1_m: 0.5680 - precision_m: 0.6039 - recall_m: 0.5495 - val_loss: 0.7433 - val_acc:

Epoch 19/20
9/9 - 2s - loss: 0.5282 - acc: 0.7793 - f1_m: 0.8129 - precision_m: 0.7845 - recall_m: 0.8499 - val_loss: 0.7498 - val_acc: 0.4531 - val_f1_m: 0.5278 - val_precision_m: 0.5989 - val_recall_m: 0.5330
Epoch 20/20
9/9 - 2s - loss: 0.5215 - acc: 0.7793 - f1_m: 0.8118 - precision_m: 0.7866 - recall_m: 0.8450 - val_loss: 0.7522 - val_acc: 0.4492 - val_f1_m: 0.5245 - val_precision_m: 0.5968 - val_recall_m: 0.5276
acc: 44.92%, f1_m 51.95%
Epoch 1/20
9/9 - 3s - loss: 0.7406 - acc: 0.4951 - f1_m: 0.5025 - precision_m: 0.5390 - recall_m: 0.4862 - val_loss: 0.7403 - val_acc: 0.4922 - val_f1_m: 0.4040 - val_precision_m: 0.4560 - val_recall_m: 0.3883
Epoch 2/20
9/9 - 2s - loss: 0.7027 - acc: 0.5371 - f1_m: 0.5525 - precision_m: 0.5823 - recall_m: 0.5443 - val_loss: 0.7339 - val_acc: 0.5000 - val_f1_m: 0.4136 - val_precision_m: 0.4581 - val_recall_m: 0.4016
Epoch 3/20
9/9 - 2s - loss: 0.6825 - acc: 0.5645 - f1_m: 0.5880 - precision_m: 0.6071 - recall_m: 0.5889 - val_loss: 0.7287 - val_acc

Epoch 18/20
9/9 - 2s - loss: 0.5502 - acc: 0.7510 - f1_m: 0.8103 - precision_m: 0.7549 - recall_m: 0.8789 - val_loss: 0.7523 - val_acc: 0.4844 - val_f1_m: 0.5631 - val_precision_m: 0.5450 - val_recall_m: 0.5905
Epoch 19/20
9/9 - 2s - loss: 0.5439 - acc: 0.7559 - f1_m: 0.8149 - precision_m: 0.7602 - recall_m: 0.8831 - val_loss: 0.7569 - val_acc: 0.4844 - val_f1_m: 0.5630 - val_precision_m: 0.5449 - val_recall_m: 0.5905
Epoch 20/20
9/9 - 2s - loss: 0.5376 - acc: 0.7676 - f1_m: 0.8247 - precision_m: 0.7686 - recall_m: 0.8932 - val_loss: 0.7610 - val_acc: 0.4766 - val_f1_m: 0.5604 - val_precision_m: 0.5408 - val_recall_m: 0.5905
acc: 47.66%, f1_m 47.88%
Epoch 1/20
9/9 - 3s - loss: 0.7472 - acc: 0.5059 - f1_m: 0.5160 - precision_m: 0.5790 - recall_m: 0.4795 - val_loss: 0.7400 - val_acc: 0.5508 - val_f1_m: 0.6871 - val_precision_m: 0.7210 - val_recall_m: 0.6668
Epoch 2/20
9/9 - 2s - loss: 0.7112 - acc: 0.5459 - f1_m: 0.5655 - precision_m: 0.6119 - recall_m: 0.5384 - val_loss: 0.7394 - val_ac

Epoch 17/20
9/9 - 3s - loss: 0.5372 - acc: 0.7461 - f1_m: 0.7818 - precision_m: 0.7453 - recall_m: 0.8231 - val_loss: 0.7318 - val_acc: 0.5078 - val_f1_m: 0.6124 - val_precision_m: 0.6547 - val_recall_m: 0.5843
Epoch 18/20
9/9 - 3s - loss: 0.5308 - acc: 0.7568 - f1_m: 0.7909 - precision_m: 0.7515 - recall_m: 0.8359 - val_loss: 0.7316 - val_acc: 0.5273 - val_f1_m: 0.6285 - val_precision_m: 0.6639 - val_recall_m: 0.6077
Epoch 19/20
9/9 - 2s - loss: 0.5221 - acc: 0.7656 - f1_m: 0.7989 - precision_m: 0.7583 - recall_m: 0.8447 - val_loss: 0.7326 - val_acc: 0.5273 - val_f1_m: 0.6264 - val_precision_m: 0.6643 - val_recall_m: 0.6029
Epoch 20/20
9/9 - 3s - loss: 0.5141 - acc: 0.7715 - f1_m: 0.8049 - precision_m: 0.7642 - recall_m: 0.8511 - val_loss: 0.7337 - val_acc: 0.5273 - val_f1_m: 0.6300 - val_precision_m: 0.6627 - val_recall_m: 0.6123
acc: 52.73%, f1_m 61.23%
Epoch 1/20
9/9 - 3s - loss: 0.7330 - acc: 0.4971 - f1_m: 0.5359 - precision_m: 0.5496 - recall_m: 0.5393 - val_loss: 0.7013 - val_a

Epoch 16/20
9/9 - 3s - loss: 0.5393 - acc: 0.7695 - f1_m: 0.8153 - precision_m: 0.7670 - recall_m: 0.8776 - val_loss: 0.7114 - val_acc: 0.4961 - val_f1_m: 0.6151 - val_precision_m: 0.6050 - val_recall_m: 0.6381
Epoch 17/20
9/9 - 3s - loss: 0.5317 - acc: 0.7744 - f1_m: 0.8179 - precision_m: 0.7701 - recall_m: 0.8792 - val_loss: 0.7134 - val_acc: 0.5039 - val_f1_m: 0.6197 - val_precision_m: 0.6091 - val_recall_m: 0.6434
Epoch 18/20
9/9 - 2s - loss: 0.5239 - acc: 0.7861 - f1_m: 0.8283 - precision_m: 0.7793 - recall_m: 0.8899 - val_loss: 0.7161 - val_acc: 0.5078 - val_f1_m: 0.6229 - val_precision_m: 0.6110 - val_recall_m: 0.6487
Epoch 19/20
9/9 - 3s - loss: 0.5159 - acc: 0.7949 - f1_m: 0.8350 - precision_m: 0.7831 - recall_m: 0.8995 - val_loss: 0.7194 - val_acc: 0.5117 - val_f1_m: 0.6243 - val_precision_m: 0.6135 - val_recall_m: 0.6487
Epoch 20/20
9/9 - 3s - loss: 0.5089 - acc: 0.7930 - f1_m: 0.8322 - precision_m: 0.7864 - recall_m: 0.8911 - val_loss: 0.7223 - val_acc: 0.5039 - val_f1_m: 0

Epoch 15/20
9/9 - 2s - loss: 0.5343 - acc: 0.7461 - f1_m: 0.8125 - precision_m: 0.7428 - recall_m: 0.9018 - val_loss: 0.7711 - val_acc: 0.4727 - val_f1_m: 0.3748 - val_precision_m: 0.3551 - val_recall_m: 0.4039
Epoch 16/20
9/9 - 3s - loss: 0.5268 - acc: 0.7539 - f1_m: 0.8189 - precision_m: 0.7504 - recall_m: 0.9065 - val_loss: 0.7746 - val_acc: 0.4766 - val_f1_m: 0.3761 - val_precision_m: 0.3572 - val_recall_m: 0.4039
Epoch 17/20
9/9 - 3s - loss: 0.5193 - acc: 0.7637 - f1_m: 0.8260 - precision_m: 0.7592 - recall_m: 0.9115 - val_loss: 0.7783 - val_acc: 0.4688 - val_f1_m: 0.3713 - val_precision_m: 0.3521 - val_recall_m: 0.3992
Epoch 18/20
9/9 - 2s - loss: 0.5113 - acc: 0.7725 - f1_m: 0.8304 - precision_m: 0.7672 - recall_m: 0.9107 - val_loss: 0.7823 - val_acc: 0.4688 - val_f1_m: 0.3715 - val_precision_m: 0.3517 - val_recall_m: 0.3992
Epoch 19/20
9/9 - 2s - loss: 0.5037 - acc: 0.7773 - f1_m: 0.8344 - precision_m: 0.7703 - recall_m: 0.9159 - val_loss: 0.7859 - val_acc: 0.4688 - val_f1_m: 0

Epoch 14/20
9/9 - 2s - loss: 0.5316 - acc: 0.7676 - f1_m: 0.8152 - precision_m: 0.7614 - recall_m: 0.8801 - val_loss: 0.7891 - val_acc: 0.4805 - val_f1_m: 0.6163 - val_precision_m: 0.7110 - val_recall_m: 0.5886
Epoch 15/20
9/9 - 2s - loss: 0.5233 - acc: 0.7744 - f1_m: 0.8207 - precision_m: 0.7661 - recall_m: 0.8867 - val_loss: 0.7931 - val_acc: 0.4727 - val_f1_m: 0.6137 - val_precision_m: 0.7048 - val_recall_m: 0.5886
Epoch 16/20
9/9 - 3s - loss: 0.5143 - acc: 0.7900 - f1_m: 0.8333 - precision_m: 0.7806 - recall_m: 0.8960 - val_loss: 0.7963 - val_acc: 0.4688 - val_f1_m: 0.6109 - val_precision_m: 0.7032 - val_recall_m: 0.5836
Epoch 17/20
9/9 - 3s - loss: 0.5059 - acc: 0.7998 - f1_m: 0.8394 - precision_m: 0.7903 - recall_m: 0.8970 - val_loss: 0.7996 - val_acc: 0.4688 - val_f1_m: 0.6272 - val_precision_m: 0.7086 - val_recall_m: 0.6069
Epoch 18/20
9/9 - 2s - loss: 0.4978 - acc: 0.8086 - f1_m: 0.8464 - precision_m: 0.7976 - recall_m: 0.9034 - val_loss: 0.8052 - val_acc: 0.4531 - val_f1_m: 0

Epoch 13/20
9/9 - 3s - loss: 0.5516 - acc: 0.7363 - f1_m: 0.8136 - precision_m: 0.7424 - recall_m: 0.9058 - val_loss: 0.6435 - val_acc: 0.6367 - val_f1_m: 0.7896 - val_precision_m: 0.8043 - val_recall_m: 0.7789
Epoch 14/20
9/9 - 3s - loss: 0.5443 - acc: 0.7354 - f1_m: 0.8149 - precision_m: 0.7406 - recall_m: 0.9119 - val_loss: 0.6427 - val_acc: 0.6406 - val_f1_m: 0.7928 - val_precision_m: 0.8061 - val_recall_m: 0.7835
Epoch 15/20
9/9 - 2s - loss: 0.5371 - acc: 0.7471 - f1_m: 0.8230 - precision_m: 0.7499 - recall_m: 0.9176 - val_loss: 0.6420 - val_acc: 0.6445 - val_f1_m: 0.7959 - val_precision_m: 0.8078 - val_recall_m: 0.7881
Epoch 16/20
9/9 - 3s - loss: 0.5297 - acc: 0.7500 - f1_m: 0.8249 - precision_m: 0.7521 - recall_m: 0.9188 - val_loss: 0.6413 - val_acc: 0.6484 - val_f1_m: 0.7990 - val_precision_m: 0.8095 - val_recall_m: 0.7926
Epoch 17/20
9/9 - 3s - loss: 0.5223 - acc: 0.7627 - f1_m: 0.8347 - precision_m: 0.7600 - recall_m: 0.9313 - val_loss: 0.6411 - val_acc: 0.6484 - val_f1_m: 0

Epoch 12/20
9/9 - 3s - loss: 0.5220 - acc: 0.7617 - f1_m: 0.8320 - precision_m: 0.7696 - recall_m: 0.9106 - val_loss: 0.6932 - val_acc: 0.6016 - val_f1_m: 0.7362 - val_precision_m: 0.6559 - val_recall_m: 0.8463
Epoch 13/20
9/9 - 3s - loss: 0.5130 - acc: 0.7725 - f1_m: 0.8389 - precision_m: 0.7758 - recall_m: 0.9179 - val_loss: 0.6952 - val_acc: 0.6016 - val_f1_m: 0.7362 - val_precision_m: 0.6559 - val_recall_m: 0.8463
Epoch 14/20
9/9 - 3s - loss: 0.5042 - acc: 0.7744 - f1_m: 0.8408 - precision_m: 0.7763 - recall_m: 0.9211 - val_loss: 0.6975 - val_acc: 0.6016 - val_f1_m: 0.7364 - val_precision_m: 0.6562 - val_recall_m: 0.8468
Epoch 15/20
9/9 - 3s - loss: 0.4960 - acc: 0.7842 - f1_m: 0.8481 - precision_m: 0.7806 - recall_m: 0.9325 - val_loss: 0.6995 - val_acc: 0.5938 - val_f1_m: 0.7222 - val_precision_m: 0.6365 - val_recall_m: 0.8468
Epoch 16/20
9/9 - 3s - loss: 0.4873 - acc: 0.7910 - f1_m: 0.8538 - precision_m: 0.7864 - recall_m: 0.9379 - val_loss: 0.7018 - val_acc: 0.6016 - val_f1_m: 0

Epoch 11/20
9/9 - 3s - loss: 0.5929 - acc: 0.7021 - f1_m: 0.7909 - precision_m: 0.7351 - recall_m: 0.8640 - val_loss: 0.6815 - val_acc: 0.5547 - val_f1_m: 0.6539 - val_precision_m: 0.5606 - val_recall_m: 0.8373
Epoch 12/20
9/9 - 3s - loss: 0.5839 - acc: 0.7109 - f1_m: 0.7983 - precision_m: 0.7357 - recall_m: 0.8804 - val_loss: 0.6820 - val_acc: 0.5664 - val_f1_m: 0.6617 - val_precision_m: 0.5645 - val_recall_m: 0.8507
Epoch 13/20
9/9 - 3s - loss: 0.5756 - acc: 0.7246 - f1_m: 0.8131 - precision_m: 0.7391 - recall_m: 0.9094 - val_loss: 0.6831 - val_acc: 0.5820 - val_f1_m: 0.6707 - val_precision_m: 0.5706 - val_recall_m: 0.8641
Epoch 14/20
9/9 - 3s - loss: 0.5679 - acc: 0.7266 - f1_m: 0.8152 - precision_m: 0.7358 - recall_m: 0.9196 - val_loss: 0.6833 - val_acc: 0.5742 - val_f1_m: 0.6668 - val_precision_m: 0.5671 - val_recall_m: 0.8602
Epoch 15/20
9/9 - 3s - loss: 0.5601 - acc: 0.7373 - f1_m: 0.8215 - precision_m: 0.7437 - recall_m: 0.9240 - val_loss: 0.6837 - val_acc: 0.5781 - val_f1_m: 0