In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib qt
import sys; sys.path.insert(0, '../')
import numpy as np
from matplotlib import pyplot as plt
from scipy.stats import pearsonr
import mne
from esinet import Simulation
from esinet.forward import get_info, create_forward_model
from esinet.util import unpack_fwd
pp = dict(surface='white', hemi='both')

# Get Forward Model

In [9]:
info = get_info(kind='biosemi32')
fwd = create_forward_model(info=info, sampling='ico4')

leadfield, pos = unpack_fwd(fwd)[1:3]
n_chans, n_dipoles = leadfield.shape

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.2s remaining:    0.4s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.2s remaining:    0.1s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    0.2s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.2s remaining:    0.5s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.3s remaining:    0.1s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    0.3s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.4s remaining:    0.7s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.6s remaining:    0.3s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    0.6s finished


# Get sample data

In [46]:
# settings = dict(number_of_sources=1, extents=40, duration_of_trial=0.01, target_snr=99999999999)
settings = dict(number_of_sources=3, extents=(1, 40), duration_of_trial=1, target_snr=2)

sim = Simulation(fwd, info, settings).simulate(2)
stc = sim.source_data[0]
evoked = sim.eeg_data[0].average()
M = evoked.data

brain = stc.plot(**pp)
brain.add_text(0.1, 0.9, 'Ground Truth', 'title',
               font_size=14)

-- number of adjacent vertices : 5124
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:03<00:00,  1.97s/it]
100%|██████████| 2/2 [00:00<00:00, 62.47it/s]


source data shape:  (5124, 1000) (5124, 1000)


100%|██████████| 2/2 [00:00<00:00, 19.42it/s]

Using control points [5.19182089e-10 1.74571140e-09 5.22952957e-08]





For automatic theme detection, "darkdetect" has to be installed! You can install it with `pip install darkdetect`
To use light mode, "qdarkstyle" has to be installed! You can install it with `pip install qdarkstyle`


# Regularization Optimizations

In [60]:
from invert import make_inverse_operator, apply_inverse_operator
from invert.evaluate import nmse, corr
from tqdm.notebook import tqdm
solver = "lor"
alpha_heuristic = 1/settings["target_snr"]**2
inverse_operator = make_inverse_operator(fwd, solver=solver, evoked=evoked, alpha=alpha_heuristic, verbose=0)
stc_hat = apply_inverse_operator(evoked, inverse_operator, fwd)
error_heuristic = np.median(corr(stc.data, stc_hat.data))

alphas = np.linspace(1.5e3, 2.5e3, 20)
errors = []
for alpha in tqdm(alphas):
    inverse_operator = make_inverse_operator(fwd, solver=solver, evoked=evoked, alpha=alpha, verbose=0)
    stc_hat = apply_inverse_operator(evoked, inverse_operator, fwd)
    error = np.median(corr(stc.data, stc_hat.data))
    errors.append(error)

plt.figure()
plt.plot(alphas, errors)
plt.xlabel("alphas")
plt.ylabel("corr")
plt.ylim(0,1)
plt.plot([alpha_heuristic, alpha_heuristic], plt.ylim(), 'r')




  0%|          | 0/20 [00:00<?, ?it/s]

In [54]:
solver = "dSPM"
# alpha_heuristic = 1/settings["target_snr"]**2
alpha = 2000
inverse_operator = make_inverse_operator(fwd, solver=solver, evoked=evoked, alpha=alpha, verbose=0)
stc_hat = apply_inverse_operator(evoked, inverse_operator, fwd)
stc_hat.plot(**pp)

Using control points [6.52956037e-05 7.85430437e-05 1.65699895e-04]
For automatic theme detection, "darkdetect" has to be installed! You can install it with `pip install darkdetect`
To use light mode, "qdarkstyle" has to be installed! You can install it with `pip install qdarkstyle`


<mne.viz._brain._brain.Brain at 0x2480de520d0>

Using control points [4.54311450e-05 5.30130690e-05 7.36848788e-05]
Using control points [4.54311450e-05 5.30130690e-05 7.36848788e-05]


# Tests

In [None]:
from invert import make_inverse_operator, apply_inverse_operator
from invert.adapters import contextualize_bd
from invert.evaluate import nmse

solver = "lstm"
# inverse_operator = make_inverse_operator(fwd, solver=solver, evoked=evoked, verbose=1)

stc_hat = apply_inverse_operator(evoked, inverse_operator, fwd)
error = np.median(nmse(stc.data, stc_hat.data))
print(f"error={error:.4f}")
stc_hat.plot(**pp, brain_kwargs=dict(title=solver))

# stc_hat.data = contextualize_bd(stc_hat.data, leadfield, num_epochs=20)
# error = np.median(nmse(stc.data, stc_hat.data))
# print(f"error={error:.4f}")
# stc_hat.plot(**pp, brain_kwargs=dict(title=solver + " contextualized"))




  warn("Method 'bounded' does not support relative tolerance in x; "


error=0.0260
Using control points [2.69993792e-09 3.79865658e-09 1.90327871e-08]
For automatic theme detection, "darkdetect" has to be installed! You can install it with `pip install darkdetect`
To use light mode, "qdarkstyle" has to be installed! You can install it with `pip install qdarkstyle`


<mne.viz._brain._brain.Brain at 0x1975446e2b0>

Using control points [1.33156213e-09 2.41506867e-09 3.25942281e-08]
Using control points [2.45756571e-09 3.58205804e-09 7.57540420e-09]


In [None]:
from invert import all_solvers, make_inverse_operator, apply_inverse_operator
from invert.adapters import contextualize_bd
from invert.evaluate import nmse, corr
stc.plot(**pp, brain_kwargs=dict(title="Ground Truth"))
errors = dict()
for solver in all_solvers:
    print(solver)
    inverse_operator = make_inverse_operator(fwd, solver=solver, evoked=evoked, alpha=1/4.5)
    stc_hat = apply_inverse_operator(evoked, inverse_operator, fwd)
    # stc_hat.plot(**pp, brain_kwargs=dict(title=solver))
    errors[solver] = corr(stc.data, stc_hat.data)
    
solver = "esinet"
stc_hat = net.predict(evoked)[0]
# stc_hat.plot(**pp, brain_kwargs=dict(title=solver))
errors[solver] = corr(stc.data, stc_hat.data)

solver += " contextualized"
stc_hat.data = contextualize_bd(stc_hat.data, leadfield)
errors[solver] = corr(stc.data, stc_hat.data)
# stc_hat.plot(**pp, brain_kwargs=dict(title=solver))

In [None]:
solver = "LORETA"
inverse_operator = make_inverse_operator(fwd, solver=solver, evoked=evoked, alpha=1/4.5)
stc_hat = apply_inverse_operator(evoked, inverse_operator, fwd)
stc_hat.plot(**pp, brain_kwargs=dict(title=solver))

stc_hat.data = contextualize_bd(stc_hat.data, leadfield, steps_per_ep=25, verbose=1)
errors[solver] = corr(stc.data, stc_hat.data)
stc_hat.plot(**pp, brain_kwargs=dict(title=solver))

In [None]:
import seaborn as sns
import pandas as pd
%matplotlib qt
f, ax = plt.subplots(figsize=(7, 6))
# ax.set_yscale("log")
sns.boxplot(data=pd.DataFrame(errors))