# Tutorial 4: The temporal model
In tutorials 1-3 we have used simple fully-connected neural networks to predict sources from single time instances of EEG data. To harness the full information within the EEG, however, we can also incorporate multiple time instances at once. 

For time-series data we can use recurrent neural networks (RNNs). A prominent RNN is the long-short-term memory (LSTM) network, which makes use of temporal information in a quite useful manner.

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import mne
import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
import sys; sys.path.insert(0, '../')
from esinet import util
from esinet import Simulation
from esinet import Net
from esinet.forward import create_forward_model, get_info
plot_params = dict(surface='white', hemi='both', verbose=0)


## Forward model
To get started we just create some generic forward model using the esinet.forward module.

In [None]:
info = get_info()
info['sfreq'] = 100
fwd = create_forward_model(info=info)

## Simulation
Next, we simulate our training data. In order to invoke the LSTM architecture we need to simulate data that have a temporal dimension. This is controlled via the *duration_of_trial* setting as shown below. We set the duration to 0.2 seconds, which together with our sampling rate of 100 Hz yields 20 time points:
```
100 Hz * 0.2 s = 20
```

Note, that for publication-ready inverse solutions you should increase the number of training samples to 100,000.

In [None]:
n_samples = 10000
settings = dict(duration_of_trial=0.5, target_snr=(0.5, 10))

sim_lstm = Simulation(fwd, info, verbose=True, settings=settings).simulate(n_samples=n_samples)
sim_dense = util.convert_simulation_temporal_to_single(sim_lstm)

sim_lstm_test = Simulation(fwd, info, verbose=True, settings=settings).simulate(n_samples=2000)
sim_dense_test = util.convert_simulation_temporal_to_single(sim_lstm_test)

## Build & train LSTM network
The neural network class *Net()* is intelligent and recognizes the temporal structure in the simulations.
It will automatically build the LSTM network architecture without further specification.

In [None]:
# Train
train_params = dict(epochs=200, patience=20, tensorboard=True, dropout=0.1, 
    loss='mse', optimizer='adam', return_history=True)
model_params = dict(activation_function='relu', n_dense_layers=3, 
    n_dense_units=300)
net_dense = Net(fwd, **model_params)
_, history_dense = net_dense.fit(sim_dense, **train_params)


# LSTM v2
model_params = dict(activation_function='relu', n_lstm_layers=2, 
    n_lstm_units=75, model_type='v2')
net_lstm = Net(fwd, **model_params)

train_params = dict(epochs=200, patience=20, tensorboard=True, dropout=0.1, 
    loss='mse', optimizer=None, return_history=True, batch_size=8)
_, history_lstm = net_lstm.fit(sim_lstm, **train_params)



# Numeric Evaluation

In [1]:
settings_eval = dict(duration_of_trial=0.5, target_snr=(0.5, 10))
n_samples = 100
sim_eval = Simulation(fwd, info, verbose=True, settings=settings_eval).simulate(n_samples=n_samples)
sim_eval_dense = deepcopy(sim_eval).to_nontemporal()
# Evaluate the new simulations
mse_dense = net_dense.evaluate_nmse(sim_eval_dense)
mse_lstm = net_lstm.evaluate_nmse(sim_eval)
perc_median_diff = 100*(1-(np.median(mse_lstm) / np.median(mse_dense)))
# Plot
%matplotlib qt
diff = mse_dense.flatten() - mse_lstm.flatten()
relative_better_predictions = np.sum(diff>0)/ len(diff)
title = f'{relative_better_predictions*100:.1f} % of samples were better with lstm. \n ({perc_median_diff:.1f}% better)'
decim = 1
import seaborn as sns; sns.set(style='whitegrid')
fig, ax = plt.subplots(1,1, figsize=(10,10))
ax.scatter(mse_dense[::decim], mse_lstm[::decim], s=0.5)
ax.set_xlim([-np.percentile(mse_dense, 5), np.percentile(mse_dense, 99)])
ax.set_ylim([-np.percentile(mse_dense, 5), np.percentile(mse_dense, 99)])
ax.plot([0, 1], [0, 1], linewidth=2, color='black')
ax.set_xlabel('Dense Errors')
ax.set_ylabel('LSTM Errors')
# ax.axis('equal')
ax.set_title(title)
# plt.savefig(r'C:\Users\lukas\Desktop\lstm.png', dpi=600)

NameError: name 'Simulation' is not defined

# Visual Evaluation
We can now evaluate the performance of our LSTM network using a newly simulated, thus unseen simulated sample.

In [None]:
%matplotlib qt
settings_eval = dict(duration_of_trial=0.2, target_snr=(0.5, 10))

# Simulate new data
sim_test = Simulation(fwd, info, settings=settings_eval).simulate(1)
idx = 0
# Predict sources
prediction_dense = net_dense.predict(sim_test)
prediction_lstm = net_lstm.predict(sim_test)


# Plot True Source
brain = sim_test.source_data[idx].plot(**plot_params)
brain.add_text(0.1, 0.9, 'Ground Truth', 'title')

# Plot True EEG
evoked = sim_test.eeg_data[idx].average()
evoked.plot()
evoked.plot_topomap(title='Ground Truth')
evoked = util.get_eeg_from_source(sim_test.source_data[idx], fwd, info, tmin=0.)
evoked.plot_topomap(title='Ground Truth Noiseless')


# Plot predicted source Dense
brain = prediction_dense.plot(**plot_params)
brain.add_text(0.1, 0.9, 'Dense', 'title')
# Plot predicted EEG
evoked_esi = util.get_eeg_from_source(prediction_dense, fwd, info, tmin=0.)
evoked_esi.plot()
evoked_esi.plot_topomap(title='Dense')

# Plot predicted source LSTM
brain = prediction_lstm.plot(**plot_params)
brain.add_text(0.1, 0.9, 'LSTM', 'title')
# Plot predicted EEG
evoked_esi = util.get_eeg_from_source(prediction_lstm, fwd, info, tmin=0.)
evoked_esi.plot()
evoked_esi.plot_topomap(title='LSTM')

error_dense = ((prediction_dense.data - sim_test.source_data[idx].data)**2).flatten()
error_lstm = ((prediction_lstm.data - sim_test.source_data[idx].data)**2).flatten()

diff = error_dense - error_lstm
relative_better_predictions = np.sum(diff>0)/ len(diff)
title = f'{relative_better_predictions*100:.1f} % of samples were better with lstm'
print(title)

# Noise-loss Plot

In [None]:
durations_of_trials = [0, 0.2, 0.5, 1]
target_snrs = [0.125, 0.25, 0.5, 1, 2, 4, 8, 16]

result = [np.zeros((len(durations_of_trials), len(target_snrs))), 
    np.zeros((len(durations_of_trials), len(target_snrs)))]
net_names = ['LSTM network', 'FC network']
xlabel = 'SNR'
ylabel = 'Mean Squared Error'
n_samples_list = [1000, 50, 20, 10]
for i, duration_of_trial in enumerate(durations_of_trials):
    for j, target_snr in enumerate(target_snrs):
        settings_eval = dict(duration_of_trial=duration_of_trial, target_snr=target_snr, )
        print(settings_eval)
        # Simulate new data
        sim_test = Simulation(fwd, info, settings=settings_eval).simulate(n_samples_list[i])
        # Predict
        # Evaluate
        error_lstm = np.median(net_lstm.evaluate_nmse(sim_test))
        error_dense = np.median(net_dense.evaluate_nmse(sim_test))
        
        result[0][i,j] = error_lstm
        result[1][i,j] = error_dense

In [None]:
%matplotlib qt
t = target_snrs
fig, ax = plt.subplots()
for i, name in enumerate(net_names):
    for j, duration_of_trial in enumerate(durations_of_trials):
        ax.semilogx(t, result[i][j, :], label=f'{name} {duration_of_trial}')
plt.legend()    
plt.xticks(target_snrs, labels=target_snrs)
plt.xlabel(xlabel)
plt.ylabel(ylabel)

In [None]:
from scipy.stats import wilcoxon, ttest_rel
ttest_rel(result[1][0], result[1][-1])