In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import mne
import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
import sys; sys.path.insert(0, '../')
from esinet import util
from esinet import Simulation
from esinet import Net
from esinet.forward import get_info, create_forward_model
import os

import tensorflow as tf
from tensorflow.keras.layers import (Dense, MultiHeadAttention, Bidirectional, TimeDistributed, 
    LSTM, GRU, InputLayer, Attention, BatchNormalization, RepeatVector, Input, Activation, dot, 
    concatenate)

plot_params = dict(surface='white', hemi='both', verbose=0)

# Get Training Data

In [2]:
info = get_info(sfreq=100)
fwd = create_forward_model(info=info, sampling='ico2')
sim = Simulation(fwd, info).simulate(n_samples=5000)

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    1.4s remaining:    2.4s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    1.4s remaining:    0.8s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    1.5s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    0.0s finished


Simulating data based on sparse patches.


100%|██████████| 5000/5000 [00:23<00:00, 210.30it/s]
100%|██████████| 5000/5000 [00:00<00:00, 26253.25it/s]


source data shape:  (324, 100) (324, 100)


100%|██████████| 5000/5000 [00:43<00:00, 114.54it/s]


In [3]:
X = np.stack([eeg.get_data()[0] for eeg in sim.eeg_data], axis=0)
y = np.stack([src.data for src in sim.source_data], axis=0)

## Scale

In [4]:
X = np.swapaxes(X, 1,2)
y = np.swapaxes(y, 1,2)

for n in range(X.shape[0]):
    for t in range(X.shape[1]):
        X[n,t,:] -= X[n,t,:].mean()
        X[n,t,:] /= X[n,t,:].std()
        
        # y[n,t,:] -= y[n,t,:].mean()
        y[n,t,:] /= np.abs(y[n,t,:]).max()

## LSTM

In [10]:
input_shape= (None, X.shape[2])
output_shape = (None, y.shape[2])

model = tf.keras.models.Sequential(name='Attention')

model.add(InputLayer(input_shape=input_shape))

model.add(Bidirectional(LSTM(75, return_sequences=True)))
model.add(Bidirectional(LSTM(75, return_sequences=True)))

model.add(TimeDistributed(Dense(y.shape[2])))

model.build()
model.summary()
model.compile(optimizer='adam', loss=tf.keras.losses.CosineSimilarity())
model.fit(X, y, validation_split=0.1, epochs=100)

Model: "Attention"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
bidirectional_4 (Bidirection (None, None, 150)         82200     
_________________________________________________________________
bidirectional_5 (Bidirection (None, None, 150)         135600    
_________________________________________________________________
time_distributed_4 (TimeDist (None, None, 324)         48924     
Total params: 266,724
Trainable params: 266,724
Non-trainable params: 0
_________________________________________________________________
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100


<tensorflow.python.keras.callbacks.History at 0x265d7640100>

## Transformer

In [12]:
from keras.models import Model
input_shape= (X.shape[1], X.shape[2])
output_shape = (y.shape[1], y.shape[2])



n_hidden = 150
n_blocks = 1
input_train = Input(shape=input_shape)
output_train = Input(shape=output_shape)

for i in range(n_blocks):
    if i == 0:
        first_layer = input_train
    else:
        first_layer = decoder_combined_context

    encoder_stack_h, encoder_last_h, encoder_last_c = LSTM(
        n_hidden, return_state=True, return_sequences=True)(first_layer)
    # print(encoder_stack_h)
    # print(encoder_last_h)
    # print(encoder_last_c)

    encoder_last_h = BatchNormalization(momentum=0.6)(encoder_last_h)
    encoder_last_c = BatchNormalization(momentum=0.6)(encoder_last_c)
    decoder_input = RepeatVector(output_train.shape[1])(encoder_last_h)
    # print(decoder_input)

    decoder_stack_h = LSTM(n_hidden, return_state=False, return_sequences=True)(
    decoder_input, initial_state=[encoder_last_h, encoder_last_c])
    # print(decoder_stack_h)


    attention = dot([decoder_stack_h, encoder_stack_h], axes=[2, 2])
    attention = Activation('sigmoid')(attention)
    # print(attention)

    context = dot([attention, encoder_stack_h], axes=[2,1])
    context = BatchNormalization(momentum=0.6)(context)
    # print(context)


    decoder_combined_context = concatenate([context, decoder_stack_h])
    print(decoder_combined_context)


out = TimeDistributed(Dense(output_train.shape[2]))(decoder_combined_context)
print(out)

model = Model(inputs=input_train, outputs=out)
model.build(input_shape=input_shape)
model.compile(optimizer='adam', loss=tf.keras.losses.CosineSimilarity())
model.summary()
model.fit(X, y, validation_split=0.1, epochs=100)

KerasTensor(type_spec=TensorSpec(shape=(None, 100, 300), dtype=tf.float32, name=None), name='concatenate_25/concat:0', description="created by layer 'concatenate_25'")
KerasTensor(type_spec=TensorSpec(shape=(None, 100, 324), dtype=tf.float32, name=None), name='time_distributed_7/Reshape_1:0', description="created by layer 'time_distributed_7'")
Model: "model_7"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Total params: 407,124
Trainable params: 406,224
Non-trainable params: 900
__________________________________________________________________________________________________
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
E

KeyboardInterrupt: 