In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib qt
import sys; sys.path.insert(0, '../')
import numpy as np
from matplotlib import pyplot as plt
from scipy.stats import pearsonr
import mne
from esinet import Simulation
from esinet.forward import get_info, create_forward_model
from esinet.util import unpack_fwd
pp = dict(surface='white', hemi='both')

# Get Forward Model

In [2]:
info = get_info(kind='biosemi64')
fwd = create_forward_model(info=info, sampling='ico3')

leadfield, pos = unpack_fwd(fwd)[1:3]
n_chans, n_dipoles = leadfield.shape

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    2.3s remaining:    3.9s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    2.6s remaining:    1.5s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    2.7s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.1s remaining:    0.2s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    0.1s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.1s remaining:    0.2s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    0.1s finished


# Get sample data

In [93]:
settings = dict(number_of_sources=3, extents=(25, 40), duration_of_trial=1, target_snr=25)

sim = Simulation(fwd, info, settings).simulate(2)
stc = sim.source_data[0]
evoked = sim.eeg_data[0].average()

brain = stc.plot(**pp)
brain.add_text(0.1, 0.9, 'Ground Truth', 'title',
               font_size=14)

-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]
100%|██████████| 2/2 [00:00<00:00, 284.76it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00, 10.31it/s]

Using control points [1.94443010e-11 2.02768995e-10 6.27168357e-08]





For automatic theme detection, "darkdetect" has to be installed! You can install it with `pip install darkdetect`
To use light mode, "qdarkstyle" has to be installed! You can install it with `pip install qdarkstyle`


# Regularization Optimizations

In [94]:
from invert.evaluate import corr, nmse
from invert.solvers.lucas import SolverLUCAS

solver = SolverLUCAS()
solver.make_inverse_operator(fwd, verbose=0)
solver.optimize_weights(fwd, info)
solver.plot_weights()
stc_hat = solver.apply_inverse_operator(evoked)
stc_hat.plot(**pp, brain_kwargs=dict(title=solver.name))

print("Mean correlation: ", np.mean(corr(stc_hat.data, stc.data)))
print("Mean NMSE: ", np.mean(nmse(stc_hat.data, stc.data)))

Preparing MNE
Preparing wMNE
Preparing dSPM
alpha must be set to a float when using Dynamic Statistical Parametric Mapping, auto does not work yet.
Preparing LORETA
Preparing sLORETA
Preparing eLORETA
Preparing LAURA
Preparing Backus-Gilbert
Preparing S-MAP
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 256/256 [00:01<00:00, 207.79it/s]
100%|██████████| 256/256 [00:00<00:00, 11133.32it/s]


source data shape:  (1284, 1) (1284, 1)


100%|██████████| 256/256 [00:00<00:00, 1108.13it/s]


sample  0
sample  1
sample  2
sample  3
sample  4
sample  5
sample  6
sample  7
sample  8
sample  9
sample  10
sample  11
sample  12
sample  13
sample  14
sample  15
sample  16
sample  17
sample  18
sample  19
sample  20
sample  21
sample  22
sample  23
sample  24
sample  25
sample  26
sample  27
sample  28
sample  29
sample  30
sample  31
sample  32
sample  33
sample  34
sample  35
sample  36
sample  37
sample  38
sample  39
sample  40
sample  41
sample  42
sample  43
sample  44
sample  45
sample  46
sample  47
sample  48
sample  49
sample  50
sample  51
sample  52
sample  53
sample  54
sample  55
sample  56
sample  57
sample  58
sample  59
sample  60
sample  61
sample  62
sample  63
sample  64
sample  65
sample  66
sample  67
sample  68
sample  69
sample  70
sample  71
sample  72
sample  73
sample  74
sample  75
sample  76
sample  77
sample  78
sample  79
sample  80
sample  81
sample  82
sample  83
sample  84
sample  85
sample  86
sample  87
sample  88
sample  89
sample  90
sample  9

In [95]:
from invert import Solver
from invert.evaluate import corr, nmse
solver = Solver(solver="elor")

solver.make_inverse_operator(fwd, alpha="auto")
stc_hat = solver.apply_inverse_operator(evoked)
stc_hat.plot(**pp, brain_kwargs=dict(title=solver.name))
print("Mean correlation: ", np.mean(corr(stc_hat.data, stc.data)))
print("Mean NMSE: ", np.mean(nmse(stc_hat.data, stc.data)))

Using control points [2.35909273e-09 2.98296382e-09 1.14005580e-08]
For automatic theme detection, "darkdetect" has to be installed! You can install it with `pip install darkdetect`
To use light mode, "qdarkstyle" has to be installed! You can install it with `pip install qdarkstyle`
Mean correlation:  0.21165194027897546
Mean NMSE:  0.020785161545233133


# Adapt

In [69]:
from invert.adapters import contextualize_bd, contextualize

stc_hat_cbd = contextualize_bd(stc_hat, fwd, lstm_look_back=10, num_units=128, num_epochs=50, verbose=0)
stc_hat_cbd.plot(**pp, brain_kwargs=dict(title="c"+solver.name))

print("Mean correlation: ", np.mean(corr(stc_hat_cbd.data, stc.data)))
print("Mean NMSE: ", np.mean(nmse(stc_hat_cbd.data, stc.data)))

stc_hat_c = contextualize(stc_hat, fwd, lstm_look_back=10, num_units=128, num_epochs=50, verbose=0)
stc_hat_c.plot(**pp, brain_kwargs=dict(title="c"+solver.name))

print("Mean correlation: ", np.mean(corr(stc_hat_c.data, stc.data)))
print("Mean NMSE: ", np.mean(nmse(stc_hat_c.data, stc.data)))

Using control points [2.36416085e-09 3.78566290e-09 2.79187920e-08]
For automatic theme detection, "darkdetect" has to be installed! You can install it with `pip install darkdetect`
To use light mode, "qdarkstyle" has to be installed! You can install it with `pip install qdarkstyle`
Mean correlation:  0.4776289649789299
Mean NMSE:  0.0074936934532251346
Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 lstm_28 (LSTM)              (None, 128)               723456    
                                                                 
 dense_28 (Dense)            (None, 1284)              165636    
                                                                 
Total params: 889,092
Trainable params: 889,092
Non-trainable params: 0
_________________________________________________________________
Using control points [2.14741503e-09 3.54469978e-09 2.77268462e-08]
For automatic theme 

# Evaluation

In [37]:
from invert import all_solvers, Solver
settings = dict(number_of_sources=(1,10), extents=(1, 40), duration_of_trial=1, target_snr=(1,25))


errors = {sname: [] for sname in all_solvers}
solvers = dict()
for i in range(25):
    print(i)
    sim = Simulation(fwd, info, settings).simulate(2)
    stc = sim.source_data[0]
    evoked = sim.eeg_data[0].average()

    for solver_name in all_solvers:
        solver = Solver(solver=solver_name)

        if solver.name.lower() == "multiple sparse priors" or "bayesian" in solver.name.lower():
            solvers[solver_name] = solver.make_inverse_operator(fwd, evoked, alpha="auto")
        else:
            if not solver_name in solvers:
                solvers[solver_name] = solver.make_inverse_operator(fwd, alpha="auto")
        stc_hat = solvers[solver_name].apply_inverse_operator(evoked)
        # stc_hat.plot(**pp, brain_kwargs=dict(title=solver.name))
        error = np.mean([pearsonr(a, b)[0] for a, b in zip(stc.data.T, stc_hat.data.T)])
        errors[solver_name].append( error )

0
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  3.95it/s]
100%|██████████| 2/2 [00:00<00:00, 365.44it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  5.87it/s]


alpha must be set to a float when using Dynamic Statistical Parametric Mapping, auto does not work yet.
1
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  4.63it/s]
100%|██████████| 2/2 [00:00<00:00, 358.86it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  7.17it/s]


2
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  4.98it/s]
100%|██████████| 2/2 [00:00<00:00, 571.12it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.35it/s]


3
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  4.99it/s]
100%|██████████| 2/2 [00:00<00:00, 517.50it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  7.11it/s]


4
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  5.25it/s]
100%|██████████| 2/2 [00:00<00:00, 569.22it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.54it/s]


5
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  5.20it/s]
100%|██████████| 2/2 [00:00<00:00, 680.12it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.52it/s]


6
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  5.00it/s]
100%|██████████| 2/2 [00:00<00:00, 633.58it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.66it/s]


7
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  5.02it/s]
100%|██████████| 2/2 [00:00<00:00, 210.17it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.48it/s]


8
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  4.84it/s]
100%|██████████| 2/2 [00:00<00:00, 157.44it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.93it/s]


9
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  5.09it/s]
100%|██████████| 2/2 [00:00<00:00, 321.64it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.88it/s]


10
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  4.78it/s]
100%|██████████| 2/2 [00:00<00:00, 372.05it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.67it/s]


11
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  5.22it/s]
100%|██████████| 2/2 [00:00<00:00, 499.35it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.80it/s]


12
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  5.08it/s]
100%|██████████| 2/2 [00:00<00:00, 693.56it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  7.02it/s]


13
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  5.12it/s]
100%|██████████| 2/2 [00:00<00:00, 355.84it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.38it/s]


14
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  5.00it/s]
100%|██████████| 2/2 [00:00<00:00, 399.23it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.59it/s]


15
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  5.12it/s]
100%|██████████| 2/2 [00:00<00:00, 444.92it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.72it/s]


16
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00, 214.79it/s]
100%|██████████| 2/2 [00:00<00:00, 127.96it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.65it/s]


17
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  4.88it/s]
100%|██████████| 2/2 [00:00<00:00, 156.59it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.96it/s]


18
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  4.80it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.52it/s]


19
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  4.99it/s]
100%|██████████| 2/2 [00:00<00:00, 127.17it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.64it/s]


20
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  5.00it/s]
100%|██████████| 2/2 [00:00<00:00, 296.86it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.88it/s]


21
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  4.92it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.95it/s]


22
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  5.01it/s]
100%|██████████| 2/2 [00:00<00:00, 432.83it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.89it/s]


23
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  4.70it/s]
100%|██████████| 2/2 [00:00<00:00, 492.98it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.45it/s]


24
-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  3.96it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  6.84it/s]


In [52]:
import pandas as pd
import seaborn as sns
sns.set(font_scale=0.8)
df = pd.DataFrame(errors)
sorted_index = df.median().sort_values().index
df = df[sorted_index]
sns.boxplot(data=df)

<AxesSubplot:>

In [None]:
from invert.solvers.multiple_sparse_priors import SolverMultipleSparsePriors
from invert.solvers.loreta import SolverLORETA, SolverSLORETA, SolverELORETA
from invert.solvers.wrop import SolverBackusGilbert, SolverLAURA
from invert.solvers.smap import SolverSMAP
from invert.solvers.minimum_norm_estimates import SolverDynamicStatisticalParametricMapping, SolverWeightedMinimumNorm, SolverMinimumNorm
solvers = [SolverMultipleSparsePriors, SolverLORETA, SolverSLORETA, SolverELORETA, SolverBackusGilbert, SolverLAURA, SolverSMAP, SolverDynamicStatisticalParametricMapping, SolverWeightedMinimumNorm, SolverMinimumNorm]

for solver in solvers:
    solver_ = solver()
    if solver_.name == "Multiple Sparse Priors":
        solver_.make_inverse_operator(fwd, evoked, alpha='auto')
    else:
        solver_.make_inverse_operator(fwd, alpha='auto')
    stc_hat = solver_.apply_inverse_operator(evoked)
    stc_hat.plot(**pp, brain_kwargs=dict(title=solver_.name))



In [None]:
from mne.minimum_norm import make_inverse_operator as mne_inverse
from mne.minimum_norm import apply_inverse as mne_apply
from mne import make_ad_hoc_cov
noise_cov = make_ad_hoc_cov(evoked.info, verbose=0)
mne_io = mne_inverse(evoked.info, fwd, noise_cov=noise_cov, fixed=True, loose=0, depth=0, verbose=0)
stc_hat = mne_apply(evoked, mne_io, method="MNE", verbose=0)
stc_hat.plot(**pp)