In [13]:
%load_ext autoreload
%autoreload 2

import sys; 
sys.path.insert(0, '../../esinet')
sys.path.insert(0, '../')

import numpy as np
from copy import deepcopy
from scipy.sparse.csgraph import laplacian
from matplotlib import pyplot as plt
from scipy.spatial.distance import cdist
from scipy.stats import pearsonr
import mne
from esinet import Simulation
from esinet.forward import get_info, create_forward_model
from esinet.util import unpack_fwd
from invert import focuss, inverse_msp
from invert.cmaps import parula
pp = dict(surface='white', hemi='both')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Get Forward Model

In [2]:
info = get_info(kind='biosemi64')
fwd = create_forward_model(info=info, sampling='ico3')

leadfield, pos = unpack_fwd(fwd)[1:3]
n_chans, n_dipoles = leadfield.shape
dist = cdist(pos, pos)

[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    1.4s remaining:    1.4s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    1.6s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    1.6s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    0.2s remaining:    0.2s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.2s finished
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:    0.3s remaining:    0.3s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.3s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:    0.3s finished


# Get sample data

In [21]:
# settings = dict(number_of_sources=1, extents=40, duration_of_trial=0.01, target_snr=99999999999)
settings = dict(number_of_sources=4, extents=(1, 40), duration_of_trial=1, target_snr=999)

# sim = Simulation(fwd, info, settings).simulate(2)
# stc = sim.source_data[0]
# evoked = sim.eeg_data[0].average()
# M = evoked.data

brain = stc.plot(**pp)
brain.add_text(0.1, 0.9, 'Ground Truth', 'title',
               font_size=14)

-- number of adjacent vertices : 1284
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00,  4.01it/s]
100%|██████████| 2/2 [00:00<00:00, 334.13it/s]


source data shape:  (1284, 1000) (1284, 1000)


100%|██████████| 2/2 [00:00<00:00,  8.48it/s]

Using control points [9.21744162e-10 5.56706075e-09 6.78126172e-08]





# Minimum Norm Estimate

In [None]:
alpha = 0.001
# if n_chans>n_dipoles:
# D_MNE = np.linalg.inv(leadfield.T @ leadfield + alpha * np.identity(n_dipoles)) @ leadfield.T @ M
# else:
D_MNE = leadfield.T @ np.linalg.inv(leadfield @ leadfield.T + alpha * np.identity(n_chans)) @ M


stc_hat = stc.copy()
stc_hat.data = D_MNE
brain = stc_hat.plot(**pp)
brain.add_text(0.1, 0.9, 'MNE', 'title',
               font_size=14)

# Weighted Minimum Norm Estimate

In [23]:
alpha = 0.001
omega = np.diag(np.linalg.norm(leadfield, axis=0))
I_3 = np.identity(3)
W = omega # np.kron(omega, I_3)

D_WMNE = np.linalg.inv(W.T @ W) @ leadfield.T  @ np.linalg.inv(leadfield @ np.linalg.inv(W.T @ W) @ leadfield.T + alpha * np.identity(n_chans)) @ M

stc_hat = stc.copy()
stc_hat.data = D_WMNE
r = np.median([pearsonr(abs(s_pred), abs(s_true))[0] for s_pred, s_true in zip(stc_hat.data.T, stc.data.T)])
brain = stc_hat.plot(**pp)
brain.add_text(0.1, 0.9, f'wMNE (r={r:.2f})', 'title',
               font_size=14)

Using control points [5.32389372e-09 7.85593124e-09 3.48273411e-08]


# MNE with FOCUSS (Focal underdetermined system solution)

In [17]:
alpha = 0.001

D_MNE = leadfield.T @ np.linalg.inv(leadfield @ leadfield.T + alpha * np.identity(n_chans)) @ M
D_FOCUSS = focuss(D_MNE, M, leadfield, alpha)



stc_hat = stc.copy()
stc_hat.data = D_FOCUSS
brain = stc_hat.plot(**pp)
brain.add_text(0.1, 0.9, 'MNE FOCUSS', 'title',
            font_size=14)

KeyboardInterrupt: 

# Contextual Minimum Norm

In [24]:
from invert import inverse_loreta, contextualize, contextualize_bd
import tensorflow as tf
stc_instant = inverse_loreta(M, leadfield, fwd)
stc_cmne = contextualize_bd(stc_instant, fwd, num_epochs=20)
from scipy.stats import pearsonr

stc = sim.source_data[0]

# stc_hat = stc.copy()
# brain = stc_hat.plot(**pp)
# brain.add_text(0.1, 0.9, 'Ground Truth', 'title',
#                font_size=14)
               
stc_hat = stc.copy()
stc_hat.data = stc_cmne
r = np.median([pearsonr(abs(s_pred), abs(s_true))[0] for s_pred, s_true in zip(stc_hat.data.T, stc.data.T)])
brain = stc_hat.plot(**pp)
brain.add_text(0.1, 0.9, f'stc_cmne (r={r:.2f})', 'title',
               font_size=14)

-- number of adjacent vertices : 1284
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
 lstm (LSTM)                 (None, 128)               723456    
                                                                 
 dense (Dense)               (None, 1284)              165636    
                                                                 
Total params: 889,092
Trainable params: 889,092
Non-trainable params: 0
_________________________________________________________________
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Forward Steps:
Time Step 0/920
Time Step 1/920
Time Step 2/920
Time Step 3/920
Time Step 4/920
Time Step 5/920
Time Step 6/920
Time Step 7/920
Time Step 8/920
Time Step 9/920
Ti

# dSPM

In [None]:
alpha = 0.001

noise_cov = np.identity(n_chans) + np.random.rand(n_chans, n_chans)*0.1
source_cov = np.identity(n_dipoles)

M_norm = (1/np.sqrt(noise_cov)) @ M
G_norm = (1/np.sqrt(noise_cov)) @ leadfield

K = source_cov @ G_norm.T @ np.linalg.inv(G_norm @ source_cov @ G_norm.T + alpha**2 * np.identity(n_chans))
W_dSPM = np.diag(np.sqrt(1/np.diagonal(K @ noise_cov @ K.T)))
K_dSPM = W_dSPM @ K
D_dSPM = K_dSPM @ M_norm

# rectify & normalize
D_dSPM_norm = np.stack([(x-abs(x).mean()) / abs(x).std() for x in D_dSPM.T], axis=1)


stc_hat = stc.copy()
stc_hat.data = D_dSPM_norm
brain = stc_hat.plot(**pp)
brain.add_text(0.1, 0.9, 'dSPM', 'title',
               font_size=14)

# LORETA

In [25]:
alpha = 0.001
adjacency = mne.spatial_src_adjacency(fwd['src']).toarray()
B = np.diag(np.linalg.norm(leadfield, axis=0))
laplace_operator = laplacian(adjacency)
D_LOR = np.linalg.inv(leadfield.T @ leadfield + alpha * B @ laplace_operator.T @ laplace_operator @ B) @ leadfield.T @ M
# D_LOR = np.linalg.inv(B @ laplace_operator.T @ laplace_operator @ B) @ leadfield.T @ (leadfield @ np.linalg.inv(B @ laplace_operator.T @ laplace_operator @ B) @ leadfield.T + alpha * np.identity(n_chans) ) @ M

stc_hat = stc.copy()
stc_hat.data = D_LOR
r = np.median([pearsonr(a, b)[0] for a, b in zip(stc.data.T, stc_hat.data.T)])
brain = stc_hat.plot(**pp)
brain.add_text(0.1, 0.9, f'LORETA (r={r:.2f})', 'title',
            font_size=14)

-- number of adjacent vertices : 1284
Using control points [1.98194068e-08 2.37039756e-08 5.56828713e-08]


# sLORETA

In [None]:
alpha = 0.001
K_MNE = leadfield.T @ np.linalg.inv(leadfield @ leadfield.T + alpha * np.identity(n_chans))
D_MNE = K_MNE @ M
W_diag = 1 / np.diag(K_MNE @ leadfield)

W_slor = np.diag(W_diag)

W_slor = np.sqrt(W_slor)

K_slor = W_slor @ K_MNE
D_SLOR = K_slor @ M
D_SLOR[np.isnan(D_SLOR)] = 0

stc_hat = stc.copy()
stc_hat.data = D_SLOR
brain = stc_hat.plot(**pp)
brain.add_text(0.1, 0.9, 'sLORETA', 'title',
            font_size=14)

# eLORETA

In [None]:
from invert import calc_eloreta_D
stop_crit = 0.005
alpha = 0.001

D, C = calc_eloreta_D(leadfield, alpha, stop_crit=stop_crit)
K_elor = np.linalg.inv(D) @ leadfield.T @ np.linalg.inv( leadfield @ np.linalg.inv(D) @ leadfield.T + alpha * np.identity(n_chans) )
D_ELOR = K_elor @ M

stc_hat = stc.copy()
stc_hat.data = D_ELOR
brain = stc_hat.plot(**pp)
brain.add_text(0.1, 0.9, 'eLORETA', 'title',
            font_size=14)


# LAURA

In [None]:
alpha = 200
drop_off = 2
d = cdist(pos, pos)
# Get the adjacency matrix of the source spaces
adj = mne.spatial_src_adjacency(fwd["src"], verbose=0).toarray()
for i in range(d.shape[0]):
    # find dipoles that are no neighbor to dipole i
    non_neighbors = np.where(~adj.astype(bool)[i, :])[0]
    # append dipole itself
    non_neighbors = np.append(non_neighbors, i)
    # set non-neighbors to zero
    d[i, non_neighbors] = 0
A = -d**-drop_off
A[np.isinf(A)] = 0
W = np.identity(A.shape[0])
M_j = W @ A

# Source Space metric
W_j = np.linalg.inv(M_j.T @ M_j)
W_j_inv = np.linalg.inv(W_j)

W_d = np.linalg.inv(np.identity(n_chans))
noise_term = (alpha**2) * np.linalg.inv(W_d)
G = W_j_inv @ leadfield.T @ np.linalg.inv(leadfield @ W_j_inv @ leadfield.T + noise_term)
D_LAURA = G @ M

stc_hat.data = D_LAURA
brain = stc_hat.plot(**pp)
brain.add_text(0.1, 0.9, 'LAURA', 'title',
            font_size=14)

# VARETA

In [None]:
# 1) J & JVareta given; find A
# 2) A given, find J
# 3) repeat until convergence

In [None]:
 
# alpha = 0.001
# adjacency = mne.spatial_src_adjacency(fwd['src'], verbose=0).toarray()
# B = np.diag(np.linalg.norm(leadfield, axis=0))
# L = laplacian(adjacency)  # non-singular univariate discrete laplacian -> is that correct?
# D_LOR = np.linalg.inv(leadfield.T @ leadfield + alpha * B @ L.T @ L @ B) @ leadfield.T @ M
# D_Last = deepcopy(D_LOR)[:, 0][:, np.newaxis]
# W = np.diag(np.linalg.norm(leadfield, axis=0))  # WMNE weight matrix (depth weighting)
# tau = 1  # controls smoothness
# alpha_2 = 1  # controls importance of grid point
# # for t in range(M.shape[1]):
# t = 0
# A = np.identity(n_dipoles)  # in paper: large lambda
    
# # for _ in range(10):
# term_1 = np.linalg.norm( M[:, t] - leadfield @ D_Last )
# term_2 = np.linalg.norm( A@L * W * D_Last )
# term_3 = tau**2 * np.linalg.norm( L * np.diag(np.log(np.diagonal(A))) - alpha_2 )


# # D_VAR = term_1 + term_2 + term_3

# from scipy.optimize import minimize
# def find_lambda(A, m, leadfield, D_Last, L, W, tau, alpha_2):
#     A = A.reshape(leadfield.shape[1], leadfield.shape[1])
#     term_1 = np.linalg.norm( M[:, t] - leadfield @ D_Last )
#     term_2 = np.linalg.norm( A@L * W * D_Last )
#     term_3 = tau**2 * np.linalg.norm( L * np.diag(np.log(np.diagonal(A))) - alpha_2 )
#     return term_1 + term_2 + term_3
# # find_lambda(A, M[:, t], leadfield, D_Last, L, W, tau, alpha_2)
# minimize(find_lambda, A.flatten(), args=(M[:, t], leadfield, D_Last, L, W, tau, alpha_2), method='L-BFGS-B')

# S-MAP

In [None]:
alpha = 0.001
adjacency = mne.spatial_src_adjacency(fwd['src']).toarray()
B = np.diag(np.linalg.norm(leadfield, axis=0))
laplace_operator = laplacian(adjacency)
D_LOR = np.linalg.inv(leadfield.T @ leadfield + alpha * B @ laplace_operator.T @ laplace_operator @ B) @ leadfield.T @ M

gradient = np.gradient(B)[0] #np.gradient(B)[0]
D_SMAP = np.linalg.inv(leadfield.T @ leadfield + alpha * gradient.T @ gradient) @ leadfield.T @ M
# D_SMAP = np.linalg.inv(gradient.T @ gradient) @ leadfield.T @ np.linalg.inv(leadfield @ np.linalg.inv( gradient.T @ gradient ) @ leadfield.T + alpha * np.identity(n_chans)) @ M

stc_hat = stc.copy()
stc_hat.data = D_SMAP
brain = stc_hat.plot(**pp)
brain.add_text(0.1, 0.9, 'S-MAP', 'title',
            font_size=14)

# ESINET

In [None]:
# from esinet import Net
# import tensorflow as tf
# sim_train = Simulation(fwd, info, settings=dict(duration_of_trial=0)).simulate(5000)
# net = Net(fwd, n_lstm_layers=0, activation_function='tanh').fit(sim_train)
# D_ESINET = net.predict(sim.eeg_data[0])[0].data
from scipy.stats import pearsonr
stc_hat = stc.copy()
stc_hat.data = D_ESINET
r = np.mean([pearsonr(a, b)[0] for a, b in zip(stc_hat.data.T, stc.data.T)])
brain = stc_hat.plot(**pp)
brain.add_text(0.1, 0.9, f'ESINET, r={r:.3f}', 'title',
            font_size=14)


D_cESINET = contextualize(D_ESINET, fwd, num_epochs=15)
stc_hat.data = D_cESINET
brain = stc_hat.plot(**pp)
r = np.mean([pearsonr(a, b)[0] for a, b in zip(stc_hat.data.T, stc.data.T)])
brain.add_text(0.1, 0.9, f'cESINET, r={r:.3f}', 'title',
            font_size=14)

# Backus-Gilbert

In [None]:
# Calculate distance matrix
dist = cdist(pos, pos)

print("Get W_BG")
W_BG = []
for i in range(n_dipoles):
    W_gamma_BG = np.diag(dist[i, :])
    W_BG.append(W_gamma_BG)

print("Get C")
C = []
for i in range(n_dipoles):
    C_gamma = leadfield @ W_BG[i] @ leadfield.T
    C.append(C_gamma)

print("Get F")
F = leadfield @ leadfield.T

print("Get E")
E = []
for i in range(n_dipoles):
    E_gamma = C[i] + F
    E.append(E_gamma)

print("Get L")
L = leadfield @ np.ones((n_dipoles, 1))

print("Get T")
T = []
for i in range(n_dipoles):
    E_gamma_pinv = np.linalg.pinv(E[i])
    T_gamma = (E_gamma_pinv @ L) / (L.T @ E_gamma_pinv @ L)
    T.append(T_gamma)

T_final = np.stack(T, axis=0)[:, :, 0]
D_BG = T_final @ M

stc_hat = stc.copy()
stc_hat.data = D_BG
brain = stc_hat.plot(**pp)
brain.add_text(0.1, 0.9, 'Backus-Gilbert', 'title',
            font_size=14)

# Multiple Sparse Priors

In [26]:
D_MSP = inverse_msp(evoked, fwd)

stc_hat = stc.copy()
stc_hat.data = D_MSP
r = np.median([pearsonr(a, b)[0] for a, b in zip(stc.data.T, stc_hat.data.T)])
brain = stc_hat.plot(**pp)
brain.add_text(0.1, 0.9, f'MSP (r={r:.2f})', 'title',
            font_size=14)

Using 3 temporal mode(s)
dedh.shape:  (3181, 193) v[0].shape:  (60,) v[1].shape:  (14,)
Iteration 1. Free Energy Improvement: 9.29
Iteration 2. Free Energy Improvement: 6.55
Iteration 3. Free Energy Improvement: 6.70
Iteration 4. Free Energy Improvement: 8.64
Iteration 5. Free Energy Improvement: 5.40
Iteration 6. Free Energy Improvement: 4.56
Iteration 7. Free Energy Improvement: 6.72
Iteration 8. Free Energy Improvement: 6.47
Iteration 9. Free Energy Improvement: 3.85
Iteration 10. Free Energy Improvement: 6.35
Iteration 11. Free Energy Improvement: 3.85
Iteration 12. Free Energy Improvement: 3.17
Iteration 13. Free Energy Improvement: 7.54
Iteration 14. Free Energy Improvement: 2.92
Iteration 15. Free Energy Improvement: 2.71
Iteration 16. Free Energy Improvement: 2.51
Iteration 17. Free Energy Improvement: 2.46
Iteration 18. Free Energy Improvement: 2.75
Iteration 19. Free Energy Improvement: 2.59
Iteration 20. Free Energy Improvement: 3.02
Iteration 21. Free Energy Improvement: 68

Using control points [2.97799097e-09 4.61287339e-09 1.37087461e-08]
Using control points [2.74908186e-09 2.74908186e-09 2.18906066e-08]
Using control points [0.16697388 0.26854478 0.86360814]
Using control points [9.37477380e-09 1.19848137e-08 2.15566455e-08]
Using control points [1.37857776e-09 1.80507212e-09 7.01081296e-09]
Using control points [9.37477380e-09 1.19848137e-08 2.15566455e-08]
Using control points [9.37477380e-09 1.19848137e-08 2.15566455e-08]
Using control points [3.53918266e-09 3.53918266e-09 4.86446529e-08]
Using control points [6.78224105e-09 1.02564527e-08 3.04848123e-08]
Using control points [0.17097981 0.25580586 0.86678645]


  File "c:\Users\Lukas\Envs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [2.0347094e-08 2.6747435e-08 4.7864829e-08]
Using control points [3.80365394e-09 5.44439345e-09 1.38161964e-08]
Using control points [0.17097981 0.25580586 0.86678645]
Using control points [0.17097981 0.25580586 0.86678645]
Using control points [2.0347094e-08 2.6747435e-08 4.7864829e-08]


  File "c:\Users\Lukas\Envs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [3.80365394e-09 5.44439345e-09 1.38161964e-08]
Using control points [0.17097981 0.25580586 0.86678645]
Using control points [6.78224105e-09 1.02564527e-08 3.04848123e-08]
Using control points [3.80365394e-09 5.44439345e-09 1.38161964e-08]


  File "c:\Users\Lukas\Envs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'
  File "c:\Users\Lukas\Envs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'
  File "c:\Users\Lukas\Envs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'
  File "c:\Users\Lukas\Envs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: '

Using control points [3.04296414e-10 3.04296414e-10 7.01263586e-09]
Using control points [1.49563095e-09 6.75969277e-09 6.75969277e-09]
Using control points [1.63099024e-09 1.94888443e-09 4.11116009e-09]
Using control points [0.08044504 0.09850412 0.26764728]
Using control points [1.63099024e-09 1.94888443e-09 4.11116009e-09]
Using control points [0.08044504 0.09850412 0.26764728]
Using control points [3.52954138e-09 4.94428397e-09 1.40347741e-08]
Using control points [0.08044504 0.09850412 0.26764728]
Using control points [3.60413316e-09 4.73230536e-09 1.11437420e-08]
Using control points [3.60413316e-09 4.73230536e-09 1.11437420e-08]
Using control points [3.52954138e-09 4.94428397e-09 1.40347741e-08]
Using control points [0.08044504 0.09850412 0.26764728]



# LUCAS

In [None]:
stc_hat.data = np.mean([D_MNE , D_WMNE, D_LOR, D_SLOR, D_ELOR, D_LAURA, D_SMAP, D_ESINET, D_BG, D_MSP], axis=0)
brain = stc_hat.plot(**pp)
brain.add_text(0.1, 0.9, 'LUCAS', 'title',
            font_size=14)

stc_hat.data = contextualize(stc_hat.data, fwd)
brain = stc_hat.plot(**pp)
brain.add_text(0.1, 0.9, 'cLUCAS', 'title',
            font_size=14)

# Saved for later

In [None]:
from esinet.util import get_eeg_from_source
evoked_hat = get_eeg_from_source(stc_hat, fwd, info, tmin=stc.tmin)
evoked_hat.plot_topomap()

In [None]:
evoked.plot_topomap()