In [1]:
%load_ext autoreload
%autoreload 2

import sys; 
sys.path.insert(0, '../../esinet')
sys.path.insert(0, '../')

import numpy as np
from matplotlib import pyplot as plt
from scipy.stats import pearsonr
import mne
from esinet import Simulation
from esinet.forward import get_info, create_forward_model
from esinet.util import unpack_fwd
from invert.cmaps import parula
pp = dict(surface='white', hemi='both')

# Get Forward Model

In [3]:
info = get_info(kind='biosemi32')
fwd = create_forward_model(info=info, sampling='ico3')

leadfield, pos = unpack_fwd(fwd)[1:3]
n_chans, n_dipoles = leadfield.shape

[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    0.1s remaining:    0.1s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.2s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    0.2s remaining:    0.2s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.3s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.3s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    0.2s remaining:    0.2s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.2s finished


# Get sample data

In [14]:
# settings = dict(number_of_sources=1, extents=40, duration_of_trial=0.01, target_snr=99999999999)
settings = dict(number_of_sources=2, extents=(1, 2), duration_of_trial=1, target_snr=4.5)

sim = Simulation(fwd, info, settings).simulate(2)
stc = sim.source_data[0]
evoked = sim.eeg_data[0].average()
M = evoked.data

brain = stc.plot(**pp)
brain.add_text(0.1, 0.9, 'Ground Truth', 'title',
               font_size=14)

-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]
100%|██████████| 2/2 [00:00<00:00, 250.65it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00, 15.60it/s]


Using control points [0.000000e+00 0.000000e+00 2.614745e-08]


# Prepare ESINET

In [5]:
from esinet import Simulation
from esinet import Net
sim = Simulation(fwd, evoked.info, settings=dict(duration_of_trial=0, number_of_sources=(1,5), extents=(25, 40), target_snr=4.5)).simulate(10000)
net = Net(fwd, n_lstm_layers=0, n_dense_units=128, activation_function="tanh").fit(sim)

-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 10000/10000 [00:45<00:00, 221.94it/s]
100%|██████████| 10000/10000 [00:00<00:00, 26629.84it/s]


source data shape:  (1284, 1) (1284, 1)


100%|██████████| 10000/10000 [00:11<00:00, 841.11it/s]
  epochs = [epoch.set_eeg_reference('average', projection=True, verbose=0) for epoch in epochs]
  epochs = [epoch.set_eeg_reference('average', projection=True, verbose=0) for epoch in epochs]
  epochs = [epoch.set_eeg_reference('average', projection=True, verbose=0) for epoch in epochs]
  epochs = [epoch.set_eeg_reference('average', projection=True, verbose=0) for epoch in epochs]
  epochs = [epoch.set_eeg_reference('average', projection=True, verbose=0) for epoch in epochs]
  epochs = [epoch.set_eeg_reference('average', projection=True, verbose=0) for epoch in epochs]
  epochs = [epoch.set_eeg_reference('average', projection=True, verbose=0) for epoch in epochs]
  epochs = [epoch.set_eeg_reference('average', projection=True, verbose=0) for epoch in epochs]
  epochs = [epoch.set_eeg_reference('average', projection=True, verbose=0) for epoch in epochs]
  epochs = [epoch.set_eeg_reference('average', projection=True, verbose=0) for ep

preprocess data
Model: "Dense-model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 FC_0 (TimeDistributed)      (None, None, 128)         4224      
                                                                 
 Drop_0 (Dropout)            (None, None, 128)         0         
                                                                 
 FC_Out (TimeDistributed)    (None, None, 1284)        165636    
                                                                 
Total params: 169,860
Trainable params: 169,860
Non-trainable params: 0
_________________________________________________________________
fit model
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25

In [None]:
from invert import make_inverse_operator, apply_inverse_operator
from invert.adapters import contextualize_bd
from invert.evaluate import nmse

solver = "lor"
inverse_operator = make_inverse_operator(fwd, solver=solver, evoked=evoked)

stc_hat = apply_inverse_operator(evoked, inverse_operator, fwd)
error = np.median(nmse(stc.data, stc_hat.data))
print(f"error={error:.4f}")
stc_hat.plot(**pp, brain_kwargs=dict(title=solver))

stc_hat.data = contextualize_bd(stc_hat.data, leadfield, num_epochs=20)
error = np.median(nmse(stc.data, stc_hat.data))
print(f"error={error:.4f}")
stc_hat.plot(**pp, brain_kwargs=dict(title=solver + " contextualized"))


In [15]:
from invert import all_solvers, make_inverse_operator, apply_inverse_operator
from invert.evaluate import nmse  # , corr
stc.plot(**pp, brain_kwargs=dict(title="Ground Truth"))
errors = dict()
for solver in all_solvers:
    print(solver)
    inverse_operator = make_inverse_operator(fwd, solver=solver, evoked=evoked, alpha=1/4.5)
    stc_hat = apply_inverse_operator(evoked, inverse_operator, fwd)
    # stc_hat.plot(**pp, brain_kwargs=dict(title=solver))
    errors[solver] = corr(stc.data, stc_hat.data)
    
solver = "esinet"
stc_hat = net.predict(evoked)[0]
# stc_hat.plot(**pp, brain_kwargs=dict(title=solver))
errors[solver] = corr(stc.data, stc_hat.data)

Using control points [0.000000e+00 0.000000e+00 2.614745e-08]
MNE
wMNE
dSPM
LORETA
sLORETA
eLORETA
LAURA


  A = -d**-drop_off


Backus-Gilbert
Multiple Sparse Priors
Using 16 temporal mode(s)
Iteration 1. Free Energy Improvement: 10.61
Iteration 2. Free Energy Improvement: 8.17
Iteration 3. Free Energy Improvement: 145.50
Iteration 4. Free Energy Improvement: 108.53
Iteration 5. Free Energy Improvement: 361.29
Iteration 6. Free Energy Improvement: 13.43
Iteration 7. Free Energy Improvement: 4.58
Iteration 8. Free Energy Improvement: 4.93
Iteration 9. Free Energy Improvement: 4.31
Iteration 10. Free Energy Improvement: 3.83
Iteration 11. Free Energy Improvement: 25.82
Iteration 12. Free Energy Improvement: 3.68
Iteration 13. Free Energy Improvement: 3.30
Iteration 14. Free Energy Improvement: 5.82
Iteration 15. Free Energy Improvement: 2.82
Iteration 16. Free Energy Improvement: 2.87
Iteration 17. Free Energy Improvement: 3.05
Iteration 18. Free Energy Improvement: 3.30
Iteration 19. Free Energy Improvement: 3.92
Iteration 20. Free Energy Improvement: 3.24
Iteration 21. Free Energy Improvement: 2.93
Iteration 22

  Fc = Ft/2 + np.e*hP*np.e/2 + np.log(np.linalg.det( Ph/hP )) / 2
  r = _umath_linalg.det(a, signature=signature)



Using 16 temporal mode(s)
ReML Iteration 0: 144.53473655715158
ReML Iteration 1: 43.8073681211909
ReML Iteration 2: -28.54360072335367
Free-energy:  [[nan nan]
 [nan nan]]
final h:  [ 5.99200605 -5.56264071]
Bayesian Beamformer


  Fc = Ft/2 + np.e*hP*np.e/2 + np.log(np.linalg.det( Ph/hP )) / 2


Using 16 temporal mode(s)
ReML Iteration 0: 11.677461075275954
ReML Iteration 1: 8.615339169215279
ReML Iteration 2: 7.908630837588918
ReML Iteration 3: 8.671320759472252
ReML Iteration 4: 9.389705744709946
ReML Iteration 5: 9.960523315341852
ReML Iteration 6: 10.393965754117922
ReML Iteration 7: 10.709937906717858
ReML Iteration 8: 10.926732525962318
ReML Iteration 9: 11.06029681266668
ReML Iteration 10: 11.124608063424116
ReML Iteration 11: 11.131960074604187
ReML Iteration 12: 11.093109648351433
ReML Iteration 13: 11.01739056389676
ReML Iteration 14: 10.912797713867256
ReML Iteration 15: 10.786102007902766
ReML Iteration 16: 10.64295899001238
ReML Iteration 17: 10.488040812551269
ReML Iteration 18: 10.325164504158522
ReML Iteration 19: 10.157411170011834
ReML Iteration 20: 9.987241659387157
ReML Iteration 21: 9.816600304492603
ReML Iteration 22: 9.647000737885968
ReML Iteration 23: 9.479606680783919
ReML Iteration 24: 9.315294546153343
ReML Iteration 25: 9.1547109640652
ReML Iterati

  Fa = Ft/2 - np.trace(C*P*YY*P)/2 - N*n*np.log(2*np.pi)/2  - N * np.log(np.linalg.det(C))/2
  warn("Method 'bounded' does not support relative tolerance in x; "


In [16]:
import seaborn as sns
import pandas as pd
%matplotlib qt
f, ax = plt.subplots(figsize=(7, 6))
# ax.set_yscale("log")
sns.boxplot(data=pd.DataFrame(errors))

<AxesSubplot:>