# Tutorial 4: The temporal model
In tutorials 1-3 we have used simple fully-connected neural networks to predict sources from single time instances of EEG data. To harness the full information within the EEG, however, we can also incorporate multiple time instances at once. 

For time-series data we can use recurrent neural networks (RNNs). A prominent RNN is the long-short-term memory (LSTM) network, which makes use of temporal information in a quite useful manner.

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import mne
import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
import sys; sys.path.insert(0, '../')
from esinet import util
from esinet import Simulation
from esinet import Net
from esinet.forward import create_forward_model, get_info
plot_params = dict(surface='white', hemi='both', verbose=0)


## Forward model
To get started we just create some generic forward model using the esinet.forward module.

In [None]:
info = get_info()
info['sfreq'] = 100
fwd = create_forward_model(info=info)

## Simulation
Next, we simulate our training data. In order to invoke the LSTM architecture we need to simulate data that have a temporal dimension. This is controlled via the *duration_of_trial* setting as shown below. We set the duration to 0.2 seconds, which together with our sampling rate of 100 Hz yields 20 time points:
```
100 Hz * 0.2 s = 20
```

Note, that for publication-ready inverse solutions you should increase the number of training samples to 100,000.

In [None]:
n_samples = 10000
settings = dict(duration_of_trial=0.2, target_snr=(0.5, 10))

sim_lstm = Simulation(fwd, info, verbose=True, settings=settings).simulate(n_samples=n_samples)
sim_dense = util.convert_simulation_temporal_to_single(sim_lstm)

sim_lstm_test = Simulation(fwd, info, verbose=True, settings=settings).simulate(n_samples=2000)
sim_dense_test = util.convert_simulation_temporal_to_single(sim_lstm_test)

## Build & train LSTM network
The neural network class *Net()* is intelligent and recognizes the temporal structure in the simulations.
It will automatically build the LSTM network architecture without further specification.

In [29]:
train_params = dict(epochs=200, patience=20, tensorboard=True, dropout=0, loss='mse', optimizer='adam')
model_params = dict(n_lstm_layers=1, n_lstm_units=100, activation_function='relu')
# Train 
# net_dense = Net(fwd, **model_params)
# net_dense.fit(sim_dense, **train_params)

# Train LSTM for single time points
net_lstm = Net(fwd, **model_params)
net_lstm.fit(sim_lstm, **train_params)


Model: "LSTM_v3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Input (InputLayer)              [(None, None, 61)]   0                                            
__________________________________________________________________________________________________
FC1 (TimeDistributed)           (None, None, 100)    6200        Input[0][0]                      
__________________________________________________________________________________________________
LSTM1 (Bidirectional)           (None, None, 200)    129600      Input[0][0]                      
__________________________________________________________________________________________________
Dropout1 (Dropout)              (None, None, 100)    0           FC1[0][0]                        
____________________________________________________________________________________________

KeyboardInterrupt: 

In [19]:
predictions.shape

(100, 1284, 20)

In [27]:
from esinet.losses import weighted_mse_loss
import tensorflow as tf
fun = weighted_mse_loss(weight=2.0)
fun(ground_truths, predictions)

<tf.Tensor: shape=(), dtype=float32, numpy=0.00017742967>

In [26]:
p=net_lstm.predict(sim_eval)
predictions = np.stack([s.data for s in p], axis=0)
ground_truths = np.stack([s.data for s in sim_eval.source_data], axis=0)

predicitons = np.swapaxes(predictions, 1,2)
predictions = predictions.reshape(2000, 1284)

ground_truths = np.swapaxes(ground_truths, 1,2)
ground_truths = ground_truths.reshape(2000, 1284)

predictions = tf.cast(predictions, dtype=tf.float32)
ground_truths = tf.cast(ground_truths, dtype=tf.float32)

<tf.Tensor: shape=(2000, 1284), dtype=float32, numpy=
array([[ 9.2233549e-10, -2.1166515e-09,  2.9879305e-10, ...,
        -3.4170871e-11, -4.9712912e-10, -1.2248715e-09],
       [-8.8395430e-10, -8.8119956e-11,  3.8926989e-10, ...,
        -1.2773120e-09, -2.9352482e-10, -3.1089797e-10],
       [ 1.6790618e-10,  4.0304060e-10,  6.0788219e-10, ...,
         1.3933509e-10,  2.5180380e-10,  2.5817642e-10],
       ...,
       [ 1.7008616e-09,  1.6684041e-09,  1.4448674e-10, ...,
        -3.5912759e-10, -2.4945682e-10, -4.1895312e-10],
       [-2.8113362e-10, -4.4815146e-10, -4.4510864e-10, ...,
        -6.0656014e-10, -9.1317853e-10, -5.7077520e-10],
       [-5.6258942e-10, -6.3030475e-10, -7.5334294e-10, ...,
         3.3599089e-11, -1.2047902e-10,  1.6030796e-09]], dtype=float32)>

In [None]:
kwargs = dict(epochs=15, patience=10, return_history=True, tensorboard=True, learning_rate=0.001)
# Train 
net_dense, hist_dense = Net(fwd, model_type='auto').fit(sim_dense, **kwargs)
net_lstm, hist_lstm = Net(fwd, model_type='temporal').fit(sim_lstm, **kwargs)

# Numeric Evaluation

In [None]:
settings_eval = dict(duration_of_trial=0.2, target_snr=(0.5, 10))
n_samples = 100
sim_eval = Simulation(fwd, info, verbose=True, settings=settings_eval).simulate(n_samples=n_samples)

# Evaluate the new simulations
mse_dense = net_dense.evaluate_nmse(sim_eval)
mse_lstm = net_lstm.evaluate_nmse(sim_eval)
perc_median_diff = 100*(1-(np.median(mse_lstm) / np.median(mse_dense)))
# Plot
%matplotlib qt
diff = mse_dense.flatten() - mse_lstm.flatten()
relative_better_predictions = np.sum(diff>0)/ len(diff)
title = f'{relative_better_predictions*100:.1f} % of samples were better with lstm. \n ({perc_median_diff:.1f}% better)'
decim = 1
fig, ax = plt.subplots(1,1, figsize=(10,10))
ax.scatter(mse_dense[::decim], mse_lstm[::decim], s=0.5)
ax.set_xlim([-np.percentile(mse_dense, 5), np.percentile(mse_dense, 99)])
ax.set_ylim([-np.percentile(mse_dense, 5), np.percentile(mse_dense, 99)])
ax.plot([0, 1], [0, 1], linewidth=2, color='black')
ax.set_xlabel('Dense Errors')
ax.set_ylabel('LSTM Errors')
# ax.axis('equal')
ax.set_title(title)

## Evaluate performance
We can now evaluate the performance of our LSTM network using a newly simulated, thus unseen simulated sample.

In [None]:
%matplotlib qt
settings_eval = dict(duration_of_trial=0.2, target_snr=(0.5, 10))

# Simulate new data
sim_test = Simulation(fwd, info, settings=settings_eval).simulate(1)
idx = 0
# Predict sources
prediction_dense = net_dense.predict(sim_test)
prediction_lstm = net_lstm.predict(sim_test)


# Plot True Source
brain = sim_test.source_data[idx].plot(**plot_params)
brain.add_text(0.1, 0.9, 'Ground Truth', 'title')

# Plot True EEG
evoked = sim_test.eeg_data[idx].average()
evoked.plot()
evoked.plot_topomap(title='Ground Truth')
evoked = util.get_eeg_from_source(sim_test.source_data[idx], fwd, info, tmin=0.)
evoked.plot_topomap(title='Ground Truth Noiseless')


# Plot predicted source Dense
brain = prediction_dense.plot(**plot_params)
brain.add_text(0.1, 0.9, 'Dense', 'title')
# Plot predicted EEG
evoked_esi = util.get_eeg_from_source(prediction_dense, fwd, info, tmin=0.)
evoked_esi.plot()
evoked_esi.plot_topomap(title='Dense')

# Plot predicted source LSTM
brain = prediction_lstm.plot(**plot_params)
brain.add_text(0.1, 0.9, 'LSTM', 'title')
# Plot predicted EEG
evoked_esi = util.get_eeg_from_source(prediction_lstm, fwd, info, tmin=0.)
evoked_esi.plot()
evoked_esi.plot_topomap(title='LSTM')

error_dense = ((prediction_dense.data - sim_test.source_data[idx].data)**2).flatten()
error_lstm = ((prediction_lstm.data - sim_test.source_data[idx].data)**2).flatten()
dif = error_dense - error_lstm
relative_better_predictions = np.sum(diff>0)/ len(diff)
title = f'{relative_better_predictions*100:.1f} % of samples were better with lstm'
print(title)
