# Tutorial 4: The temporal model
In tutorials 1-3 we have used simple fully-connected neural networks to predict sources from single time instances of EEG data. To harness the full information within the EEG, however, we can also incorporate multiple time instances at once. 

For time-series data we can use recurrent neural networks (RNNs). A prominent RNN is the long-short-term memory (LSTM) network, which makes use of temporal information in a quite useful manner.

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import mne
import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
import sys; sys.path.insert(0, '../')
from esinet import util
from esinet import Simulation
from esinet import Net
from esinet.forward import create_forward_model, get_info
plot_params = dict(surface='white', hemi='both', verbose=0)


## Forward model
To get started we just create some generic forward model using the esinet.forward module.

In [None]:
info = get_info()
info['sfreq'] = 100
fwd = create_forward_model(info=info)

## Simulation
Next, we simulate our training data. In order to invoke the LSTM architecture we need to simulate data that have a temporal dimension. This is controlled via the *duration_of_trial* setting as shown below. We set the duration to 0.2 seconds, which together with our sampling rate of 100 Hz yields 20 time points:
```
100 Hz * 0.2 s = 20
```

Note, that for publication-ready inverse solutions you should increase the number of training samples to 100,000.

In [None]:
n_samples = 10000
settings = dict(duration_of_trial=0.2, target_snr=(0.5, 10))

sim_lstm = Simulation(fwd, info, verbose=True, settings=settings).simulate(n_samples=n_samples)
sim_dense = util.convert_simulation_temporal_to_single(sim_lstm)

sim_lstm_test = Simulation(fwd, info, verbose=True, settings=settings).simulate(n_samples=2000)
sim_dense_test = util.convert_simulation_temporal_to_single(sim_lstm_test)

## Build & train LSTM network
The neural network class *Net()* is intelligent and recognizes the temporal structure in the simulations.
It will automatically build the LSTM network architecture without further specification.

In [8]:
train_params = dict(epochs=200, patience=20, tensorboard=True, dropout=0, loss='mse', optimizer='adam')
model_params = dict(n_lstm_layers=1, n_lstm_units=100, activation_function='relu')
# Train 
net_dense = Net(fwd, **model_params)
net_dense.fit(sim_dense, **train_params)

# Train LSTM for single time points
# net_lstm = Net(fwd, **model_params)
# net_lstm.fit(sim_lstm, **train_params)

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_4 (Dense)              (None, 100)               6200      
_________________________________________________________________
dense_5 (Dense)              (None, 1284)              129684    
Total params: 135,884
Trainable params: 135,884
Non-trainable params: 0
_________________________________________________________________
[<function weighted_huber_loss.<locals>.loss at 0x000002C0CC5A2700>]
Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200
Epoch 34/200
Epoch 35

<esinet.net.Net at 0x2c071512460>

In [12]:
from esinet.losses import weighted_mse_loss
import tensorflow as tf

ground_truths = np.stack([s.data for s in sim_eval.source_data], axis=0)

p_lstm=net_lstm.predict(sim_eval)
predictions_lstm = np.stack([s.data for s in p_lstm], axis=0)

p_dense=net_dense.predict(sim_eval)
predictions_dense = np.stack([s.data for s in p_dense], axis=0)


predictions_lstm = np.swapaxes(predictions_lstm, 1,2)
predictions_lstm = predictions_lstm.reshape(2000, 1284)

predictions_dense = np.swapaxes(predictions_dense, 1,2)
predictions_dense = predictions_dense.reshape(2000, 1284)

ground_truths = np.swapaxes(ground_truths, 1,2)
ground_truths = ground_truths.reshape(2000, 1284)

predictions_lstm = tf.cast(predictions_lstm, dtype=tf.float32)
predictions_dense = tf.cast(predictions_dense, dtype=tf.float32)

ground_truths = tf.cast(ground_truths, dtype=tf.float32)

fun = weighted_mse_loss(weight=1.0, scale=False)
print('LSTM: ', fun(ground_truths, predictions_lstm).numpy())
print('Dense: ', fun(ground_truths, predictions_dense).numpy())

LSTM:  5.159307e-19
Dense:  4.963919e-19


In [15]:
from tensorflow.keras import backend as K
print(K.mean(K.square(ground_truths-predictions_lstm)))
print(K.mean(K.square(ground_truths-predictions_dense)))

tf.Tensor(5.159307e-19, shape=(), dtype=float32)
tf.Tensor(4.963919e-19, shape=(), dtype=float32)


# Numeric Evaluation

In [55]:
settings_eval = dict(duration_of_trial=0.2, target_snr=(0.5, 10))
n_samples = 100
sim_eval = Simulation(fwd, info, verbose=True, settings=settings_eval).simulate(n_samples=n_samples)
sim_eval_dense = deepcopy(sim_eval).to_nontemporal()
# Evaluate the new simulations
mse_dense = net_dense.evaluate_nmse(sim_eval_dense)
mse_lstm = net_lstm.evaluate_nmse(sim_eval)

print(np.median(mse_lstm), np.median(mse_dense))

Simulate Source


  0%|          | 0/100 [00:00<?, ?it/s]

Converting Source Data to mne.SourceEstimate object


  0%|          | 0/100 [00:00<?, ?it/s]


Project sources to EEG...

Create EEG trials with noise...


  0%|          | 0/100 [00:00<?, ?it/s]


Convert EEG matrices to a single instance of mne.Epochs...
(1, 2000)
(100, 20)
0.019384857206812513 0.02826184160447744


In [82]:
# settings_eval = dict(duration_of_trial=0.2, target_snr=(0.5, 10))
# n_samples = 1000
# sim_eval = Simulation(fwd, info, verbose=True, settings=settings_eval).simulate(n_samples=n_samples)
# sim_eval_dense = deepcopy(sim_eval).to_nontemporal()
# # Evaluate the new simulations
# mse_dense = net_dense.evaluate_nmse(sim_eval_dense)
# mse_lstm = net_lstm.evaluate_nmse(sim_eval)
# perc_median_diff = 100*(1-(np.median(mse_lstm) / np.median(mse_dense)))
# # Plot
# %matplotlib qt
# diff = mse_dense.flatten() - mse_lstm.flatten()
# relative_better_predictions = np.sum(diff>0)/ len(diff)
# title = f'{relative_better_predictions*100:.1f} % of samples were better with lstm. \n ({perc_median_diff:.1f}% better)'
# decim = 1
import seaborn as sns; sns.set(style='whitegrid')
fig, ax = plt.subplots(1,1, figsize=(10,10))
ax.scatter(mse_dense[::decim], mse_lstm[::decim], s=0.5)
ax.set_xlim([-np.percentile(mse_dense, 5), np.percentile(mse_dense, 99)])
ax.set_ylim([-np.percentile(mse_dense, 5), np.percentile(mse_dense, 99)])
ax.plot([0, 1], [0, 1], linewidth=2, color='black')
ax.set_xlabel('Dense Errors')
ax.set_ylabel('LSTM Errors')
# ax.axis('equal')
ax.set_title(title)
plt.savefig(r'C:\Users\lukas\Desktop\lstm.png', dpi=600)

# Visual Evaluation
We can now evaluate the performance of our LSTM network using a newly simulated, thus unseen simulated sample.

In [6]:
%matplotlib qt
settings_eval = dict(duration_of_trial=0.2, target_snr=(0.5, 10))

# Simulate new data
sim_test = Simulation(fwd, info, settings=settings_eval).simulate(1)
idx = 0
# Predict sources
prediction_dense = net_dense.predict(sim_test)
prediction_lstm = net_lstm.predict(sim_test)


# Plot True Source
brain = sim_test.source_data[idx].plot(**plot_params)
brain.add_text(0.1, 0.9, 'Ground Truth', 'title')

# Plot True EEG
evoked = sim_test.eeg_data[idx].average()
evoked.plot()
evoked.plot_topomap(title='Ground Truth')
evoked = util.get_eeg_from_source(sim_test.source_data[idx], fwd, info, tmin=0.)
evoked.plot_topomap(title='Ground Truth Noiseless')


# Plot predicted source Dense
brain = prediction_dense.plot(**plot_params)
brain.add_text(0.1, 0.9, 'Dense', 'title')
# Plot predicted EEG
evoked_esi = util.get_eeg_from_source(prediction_dense, fwd, info, tmin=0.)
evoked_esi.plot()
evoked_esi.plot_topomap(title='Dense')

# Plot predicted source LSTM
brain = prediction_lstm.plot(**plot_params)
brain.add_text(0.1, 0.9, 'LSTM', 'title')
# Plot predicted EEG
evoked_esi = util.get_eeg_from_source(prediction_lstm, fwd, info, tmin=0.)
evoked_esi.plot()
evoked_esi.plot_topomap(title='LSTM')

error_dense = ((prediction_dense.data - sim_test.source_data[idx].data)**2).flatten()
error_lstm = ((prediction_lstm.data - sim_test.source_data[idx].data)**2).flatten()

diff = error_dense - error_lstm
relative_better_predictions = np.sum(diff>0)/ len(diff)
title = f'{relative_better_predictions*100:.1f} % of samples were better with lstm'
print(title)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

56.5 % of samples were better with lstm


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'
  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [1.19318743e-09 1.67361899e-09 3.05720313e-09]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [1.16309050e-09 2.13878146e-09 4.25039559e-09]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [3.31846915e-09 4.01882137e-09 5.65349509e-09]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'
  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [1.16309050e-09 2.13878146e-09 4.25039559e-09]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [3.31846915e-09 4.01882137e-09 5.65349509e-09]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [4.09237262e-10 7.75631868e-10 4.63485824e-09]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [6.72407097e-10 7.65050635e-10 1.26336163e-09]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'
  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [2.91266412e-10 4.04218161e-10 7.58409964e-10]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [6.72407097e-10 7.65050635e-10 1.26336163e-09]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [4.09237262e-10 7.75631868e-10 4.63485824e-09]
Using control points [2.91266412e-10 4.04218161e-10 7.58409964e-10]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [9.04268205e-10 1.15940588e-09 1.90292830e-09]
Using control points [1.97750172e-09 2.25603606e-09 4.72387520e-09]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [8.97302466e-10 1.47243492e-09 3.41262587e-09]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'
  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [1.97750172e-09 2.25603606e-09 4.72387520e-09]
Reading labels from parcellation...
   read 181 labels from C:\Users\lukas\mne_data\MNE-sample-data\subjects\fsaverage\label\lh.HCPMMP1.annot
Reading labels from parcellation...
   read 181 labels from C:\Users\lukas\mne_data\MNE-sample-data\subjects\fsaverage\label\rh.HCPMMP1.annot
Reading labels from parcellation...
   read 9 labels from C:\Users\lukas\mne_data\MNE-sample-data\subjects\fsaverage\label\lh.oasis.chubs.annot
Reading labels from parcellation...
   read 9 labels from C:\Users\lukas\mne_data\MNE-sample-data\subjects\fsaverage\label\rh.oasis.chubs.annot
Reading labels from parcellation...
   read 9 labels from C:\Users\lukas\mne_data\MNE-sample-data\subjects\fsaverage\label\lh.oasis.chubs.annot
Reading labels from parcellation...
   read 9 labels from C:\Users\lukas\mne_data\MNE-sample-data\subjects\fsaverage\label\rh.oasis.chubs.annot
Reading labels from parcellation...
   read 82 labels from C:\Users\luka

  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


In [None]:
prediction_dense.data /= np.max(np.abs(prediction_dense.data), axis=0)
prediction_lstm.data /= np.max(np.abs(prediction_lstm.data), axis=0)

error_dense = ((prediction_dense.data - sim_test.source_data[idx].data)**2).flatten()
error_lstm = ((prediction_lstm.data - sim_test.source_data[idx].data)**2).flatten()

diff = error_dense - error_lstm
relative_better_predictions = np.sum(diff>0)/ len(diff)
title = f'{relative_better_predictions*100:.1f} % of samples were better with lstm'
print(title)

In [None]:
x = np.array(
    [[1,2,3],
    [4,5,6]]
)
print(x)

print(x.T[::-1].T)

In [None]:
np.max(np.abs(sigs), axis=1)

In [None]:
sigs = np.random.randn(10, 20)
sigs /= np.max(np.abs(sigs), axis=0)
plt.figure()
[plt.plot(sig) for sig in sigs.T]