In [19]:
import os
try:
    import PyQt5.QtCore
    % matplotlib qt
except ImportError:
    % matplotlib inline
import keras
import mne
import numpy as np
import pandas as pd
import scipy.io
import tensorflow as tf
from mne.channels import make_standard_montage
from tensorflow.keras import layers, losses
from tensorflow.keras.models import Model


In [20]:
data_dir = os.path.dirname("./data/")
data_files = os.listdir(data_dir)

In [21]:
def annotations_from_eGUI(raw, egui):
    codes = []
    starts = []

    current_state = None

    for i in range(len(egui)):
        if egui[i][0] != current_state:
            starts.append(i)
            current_state = egui[i][0]
            codes.append(str(egui[i][0]))

    starts.append(len(egui))
    codes = np.array(codes)
    sf = raw.info.get('sfreq')
    starts = np.array(starts) / sf
    durations = starts[1:] - starts[:-1]
    starts = starts[:-1]

    raw.set_annotations(mne.Annotations(onset=starts, duration=durations, description=codes))


def raw_from_mat(file):
    mat = scipy.io.loadmat(os.path.join(data_dir, file))

    sampling_freq = mat["o"][0][0][2][0][0]
    n_samples = mat["o"][0][0][3][0][0]
    ch_names = [element[0][0] for element in mat["o"][0][0][6]]

    df = pd.DataFrame(mat["o"][0][0][5], columns=ch_names)
    df = df.drop(columns=["X5"])
    df = df.T
    ch_names.remove("X5")

    ch_types = ['eeg'] * 21
    info = mne.create_info(ch_names, ch_types=ch_types, sfreq=sampling_freq)
    raw = mne.io.RawArray(df.to_numpy(), info)

    montage = make_standard_montage("standard_prefixed")
    raw.set_montage(montage)

    raw.load_data().set_eeg_reference(ref_channels='average')
    annotations_from_eGUI(raw, mat["o"][0][0][4])
    return raw


def filter_raw(raw):
    return raw.load_data().filter(0.1, 30, method="fir", phase="zero-double")

In [22]:
raw_NoMT = [raw_from_mat(file) for file in data_files if "NoMT" in file]
raw_FREEFORM = [raw_from_mat(file) for file in data_files if "FREEFORM" in file]

Creating RawArray with float64 data, n_channels=21, n_times=664400
    Range : 0 ... 664399 =      0.000 ...  3321.995 secs
Ready.
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Creating RawArray with float64 data, n_channels=21, n_times=664600
    Range : 0 ... 664599 =      0.000 ...  3322.995 secs
Ready.
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Creating RawArray with float64 data, n_channels=21, n_times=662400
    Range : 0 ... 662399 =      0.000 ...  3311.995 secs
Ready.
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Creating RawArray with float64 data, n_channels=21, n_times=667600
    Range : 0 ... 667599 =      0.000 ...  3337.995 secs
Ready.
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Creating RawArray with float64 d

In [23]:
def get_epochs(raw):
    metadata_tmin, metadata_tmax = -1, 1

    all_events, all_event_id = mne.events_from_annotations(raw)
    metadata, events, event_id = mne.epochs.make_metadata(
        events=all_events,
        event_id=all_event_id,
        tmin=metadata_tmin,
        tmax=metadata_tmax,
        sfreq=raw.info["sfreq"],
    )
    return mne.Epochs(raw, events, event_id)


In [24]:
epochs_NoMT = [get_epochs(file) for file in raw_NoMT]
epochs_FREEFORM = [get_epochs(file) for file in raw_FREEFORM]


Used Annotations descriptions: ['0', '1', '2', '3', '4', '5', '6', '91', '92', '99']
Not setting metadata
1931 matching events found
Setting baseline interval to [-0.2, 0.0] sec
Applying baseline correction (mode: mean)
0 projection items activated
Used Annotations descriptions: ['0', '1', '2', '3', '4', '5', '6', '91', '92', '99']
Not setting metadata
1919 matching events found
Setting baseline interval to [-0.2, 0.0] sec
Applying baseline correction (mode: mean)
0 projection items activated
Used Annotations descriptions: ['0', '1', '2', '3', '4', '5', '6', '91', '92', '99']
Not setting metadata
1925 matching events found
Setting baseline interval to [-0.2, 0.0] sec
Applying baseline correction (mode: mean)
0 projection items activated
Used Annotations descriptions: ['0', '1', '2', '3', '4', '5', '6', '91', '92', '99']
Not setting metadata
1935 matching events found
Setting baseline interval to [-0.2, 0.0] sec
Applying baseline correction (mode: mean)
0 projection items activated
Used

In [25]:
epochs_NoMT[0].get_data().max()

Using data from preloaded Raw for 1931 events and 141 original time points ...
1 bad epochs dropped


865.3172938443672

In [26]:
epochs_FREEFORM[0].get_data().max()

Using data from preloaded Raw for 1481 events and 141 original time points ...
1 bad epochs dropped


158.92792102206735

In [27]:
epochs_data_NOMT = [file.get_data() for file in epochs_NoMT]
epochs_data_FREEFORM = [file.get_data() for file in epochs_FREEFORM]

Using data from preloaded Raw for 1930 events and 141 original time points ...
Using data from preloaded Raw for 1919 events and 141 original time points ...
1 bad epochs dropped
Using data from preloaded Raw for 1925 events and 141 original time points ...
1 bad epochs dropped
Using data from preloaded Raw for 1935 events and 141 original time points ...
1 bad epochs dropped
Using data from preloaded Raw for 1935 events and 141 original time points ...
1 bad epochs dropped
Using data from preloaded Raw for 1935 events and 141 original time points ...
1 bad epochs dropped
Using data from preloaded Raw for 1933 events and 141 original time points ...
1 bad epochs dropped
Using data from preloaded Raw for 1480 events and 141 original time points ...
Using data from preloaded Raw for 1383 events and 141 original time points ...
1 bad epochs dropped
Using data from preloaded Raw for 1409 events and 141 original time points ...
1 bad epochs dropped


In [28]:
stacked_NOMT = np.vstack(epochs_data_NOMT)
stacked_FREEFORM = np.vstack(epochs_data_FREEFORM)

In [29]:
np.random.shuffle(stacked_NOMT)
np.random.shuffle(stacked_FREEFORM)

In [30]:
X_nomt_train = stacked_NOMT[:12000]
X_nomt_test = stacked_NOMT[12000:]

In [31]:
X_free = stacked_FREEFORM

In [32]:
stacked_NOMT.shape

(13506, 21, 141)

In [33]:
stacked_FREEFORM.shape

(4270, 21, 141)

In [34]:
import random

idy = random.sample(range(0, len(X_free)), X_nomt_test.shape[0])
X_free_test = X_free[idy]

In [35]:
X_free_test.shape


(1506, 21, 141)

In [36]:
layer = layers.Normalization()
layer1 = layers.Normalization()
layer.adapt(X_nomt_train.astype(float))
layer1.adapt(X_free_test.astype(float))

print(X_nomt_train)
print(np.max(X_nomt_train))
print(np.max(X_free_test))
print(np.max(layer(X_nomt_train)))
print(np.max(layer1(X_free_test)))

[[[  9.79012776  12.05393728  -4.71511034 ... -14.55891986 -18.40034843
   -18.51034843]
  [ -2.49157956  -1.24777003   2.89318235 ... -22.78062718 -21.38205575
   -18.89205575]
  [ -0.2762137   -6.84240418   6.4785482  ...  -3.73526132  -4.5966899
    -0.6666899 ]
  ...
  [  2.44939605   4.43320557   4.43415796 ...  -3.25965157  -2.52108014
    -4.31108014]
  [ -0.78060395  -1.58679443  -3.19584204 ...   7.83034843   9.08891986
     9.10891986]
  [ -2.90157956  -2.76777003  -6.26681765 ...   3.15937282   4.38794425
     3.62794425]]

 [[ -1.17514518  -4.00276423  -0.99038328 ...  -6.85324042  -6.01276423
    -2.40847851]
  [ -1.55929152   1.28308943   0.92547038 ...  -1.46738676  -5.15691057
    -0.28262485]
  [ -3.50855981  -3.79617886  -5.65379791 ...  -4.32665505  -2.51617886
    -4.60189315]
  ...
  [ -5.0575842   -2.00520325  -1.2728223  ...  -4.14567944  -3.79520325
    -4.51091754]
  [ -7.37831591  -4.87593496  -4.82355401 ...  -0.60641115  -0.62593496
     0.44835075]
  [ -1.3

In [37]:
latent_dim = 512
keras.backend.clear_session()


class Autoencoder(Model):
    def __init__(self, latent_dim):
        super(Autoencoder, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = tf.keras.Sequential([
            layers.Flatten(),
            layers.Dense(1024, activation='gelu'),
            layers.Dense(512, activation='gelu'),
            layers.Dense(64, activation='gelu'),
        ])
        self.decoder = tf.keras.Sequential([
            layers.Dense(512, activation='gelu'),
            layers.Dense(1024, activation='gelu'),
            layers.Dense(21 * 141, activation='linear'),
            layers.Reshape((21, 141))
        ])

    def call(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


autoencoder = Autoencoder(latent_dim)

In [38]:
opt = keras.optimizers.Adam(learning_rate=0.0010)
autoencoder.compile(optimizer=opt, loss=losses.MeanSquaredError())
autoencoder.fit(layer(X_nomt_train), layer(X_nomt_train),
                epochs=50,
                batch_size=64,
                shuffle=True,
                validation_data=(layer(X_nomt_test[:1000]), layer(X_nomt_test[:1000])))

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50

KeyboardInterrupt: 

In [39]:
err = []
err2 = []
for i in X_nomt_train:
    # need to expand here because the flatten layer assumes that the first dimension is the number of samples
    i = np.expand_dims(i, axis=0)
    encoded = autoencoder.encoder(layer(i)).numpy()
    decoded = autoencoder.decoder(encoded).numpy()
    #print("Error:",(np.square(i-decoded)).mean())
    err.append((np.square(layer(i) - decoded)).mean())
print("###################")

for j in X_free_test:
    j = np.expand_dims(j, axis=0)
    encoded = autoencoder.encoder(layer1(j)).numpy()
    decoded = autoencoder.decoder(encoded).numpy()
    #print("Error:",(np.square(layer(j)-decoded)).mean())
    err2.append((np.square(layer1(j) - decoded)).mean())
print("##############")
print(np.array(err).mean())
print(np.array(err2).mean())


###################
##############
0.31427377
0.61125827


descicion 0.65

In [48]:
def calc_accuracy(a, b, th):
    first = [1 if i < th else 0 for i in a]
    last = [1 if i > th else 0 for i in b]
    return sum(first + last) / len(first + last)

In [53]:
calc_accuracy(err, err2, 3)

0.8868650969939286

In [None]:
print(err[:10])

1,1,1,1,1,1,,0,1,1,1 9/10

In [None]:
print(err2[:10])

0,0,1,0,1,1,1,1,0,0 5/10

Acc:

In [None]:
14 / 20