In [None]:
%load_ext autoreload
%autoreload 2

import sys; 
sys.path.insert(0, '../../esinet')
sys.path.insert(0, '../')

import numpy as np
from copy import deepcopy
from scipy.sparse.csgraph import laplacian
from matplotlib import pyplot as plt
from scipy.spatial.distance import cdist
from scipy.stats import pearsonr
import mne
from esinet import Simulation
from esinet.forward import get_info, create_forward_model
from esinet.util import unpack_fwd
from invert.cmaps import parula
pp = dict(surface='white', hemi='both')

# Forward Model

In [None]:
info = get_info(kind='biosemi32')
fwd = create_forward_model(info=info, sampling='ico3')
adjacency = mne.spatial_src_adjacency(fwd['src'], verbose=0).toarray()
leadfield, pos = unpack_fwd(fwd)[1:3]
n_chans, n_dipoles = leadfield.shape
dist = cdist(pos, pos)

# Simulation

In [193]:
settings = dict(number_of_sources=5, extents=(1, 40), duration_of_trial=0.01, target_snr=99999999999)
# settings = dict(number_of_sources=1, extents=25, duration_of_trial=0.01, target_snr=99999)

sim = Simulation(fwd, info, settings).simulate(2)
stc = sim.source_data[0]
evoked = sim.eeg_data[0].average()
y = evoked.data[:, 0]
x = stc.data[:, 0]

brain = stc.plot(**pp)
brain.add_text(0.1, 0.9, 'Ground Truth', 'title',
               font_size=14)

Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00, 19.61it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00, 285.68it/s]

Using control points [2.34699485e-09 5.67297007e-09 5.62983532e-08]





For automatic theme detection, "darkdetect" has to be installed! You can install it with `pip install darkdetect`
To use light mode, "qdarkstyle" has to be installed! You can install it with `pip install qdarkstyle`


Using control points [5.61701841e-09 5.61701841e-09 2.22843899e-08]


# Defs

In [194]:
def soft_thresholding(r, lam):
    r = np.squeeze(np.array(r))
    C = np.sign(r) * np.clip(abs(r) - lam, a_min=0, a_max=None)
    return C

def zero_norm(x):
    return (x!=0).sum()

adjacency_non_diag = deepcopy(adjacency)
np.fill_diagonal(adjacency_non_diag, 0)
neigh_idc = [np.where(ad)[0] for ad in adjacency_non_diag]

def smoothness_l1(x, adjacency, lam_smooth=0.1):
    amplitudes_of_neighbors = get_amplitudes_of_neighbors(x, adjacency)
    term = np.sum( abs(x) - (amplitudes_of_neighbors * lam_smooth) )
    # print(abs(x).sum(), (amplitudes_of_neighbors * lam_smooth).sum())
    return term
def get_l1(x):
    return np.sum( abs(x) )

def get_amplitudes_of_neighbors(x, neigh_idc):
    # x_adj = (x[:, np.newaxis] @ x[np.newaxis, :] )* adjacency
    # np.fill_diagonal(x_adj, 0)
    # amplitudes_of_neighbors = np.sum(abs(x_adj), axis=0)
    amplitudes_of_neighbors = [np.mean(abs(x[idc])) for idc in neigh_idc]
    return np.array(amplitudes_of_neighbors)

def get_gradient_of_neighbors(x, neigh_idc):
    # x_adj = (x[:, np.newaxis] @ x[np.newaxis, :] )* adjacency
    # np.fill_diagonal(x_adj, 0)
    # amplitudes_of_neighbors = np.sum(abs(x_adj), axis=0)
    amplitudes_of_neighbors = [np.mean(abs(x[idc])) for idc in neigh_idc]
    return abs(x) - np.array(amplitudes_of_neighbors)**2



# ISTA

In [None]:
t_total = int(1e5)
A = deepcopy(leadfield)

n_chans, n_dipoles = A.shape
beta = 1 / np.sum(A**2)
lam = 1e-14
x_t = np.zeros(n_dipoles)
errors = []
A_H = np.matrix(A).getH()
# for t in range(t_total):
t = 0
while True:
    v_t = y - A @ x_t
    errors.append( np.linalg.norm(v_t) )
    if np.mod(t, 1000) == 0:
        print(f"iter {t} error {errors[-1]} maxval {abs(x_t).max()}")
    r = x_t + beta * A_H @ v_t
    x_tp = soft_thresholding(r, lam)
    if np.any(np.isnan(x_tp)):# or (t > 1 and errors [-1] > errors[-2]):
        break
    x_t = x_tp
    t += 1

## Evaluate

In [None]:
stc_ = stc.copy()
stc_.data[:, 0] = x_t / abs(x_t).max()
stc_.plot(**pp)

# FISTA

In [203]:
t_total = int(1e5)
A = deepcopy(leadfield)
n_chans, n_dipoles = A.shape
beta = 1 / np.sum(A**2)
lam = 1e-14
patience = 10000
x_t = np.zeros(n_dipoles)
x_t_prev = np.zeros(n_dipoles)
x_best = np.zeros(n_dipoles)
error_best = np.inf
errors = []
A_H = np.matrix(A).getH()
for t in range(t_total):
# t = 0
# while True:
    v_t = y - A @ x_t
    
    r = x_t + beta * A_H @ v_t + ((t-2)/(t+1)) * (x_t - x_t_prev)
    x_tplus = soft_thresholding(r, lam)
    
    x_t_prev = deepcopy(x_t)
    x_t = x_tplus
    error = np.sum((y - A @ x_t)**2) * 0.5 + lam * abs(x_t).sum()
    errors.append( error )

    if error < error_best:
        x_best = deepcopy(x_t)
        error_best = error
    if t>patience and  (np.any(np.isnan(x_tplus))  or np.all(np.array(errors[-patience:-1]) < errors[-1] )):
        break
    if np.mod(t, 1000) == 0:
        print(f"iter {t} error {errors[-1]} maxval {abs(x_t).max()}")

print(f"Finished after {t} iterations, error = {error_best}")
print("Neighbor Term: ", np.sum(get_gradient_of_neighbors(x_t, adjacency)*np.sign(x_t))*1e-5, "L1 Term: ", np.sum(x_best))

iter 0 error 2.7312522523848175e-10 maxval 1.0015689852552044e-10
iter 1000 error 7.041236428244176e-16 maxval 2.5818149650497366e-09
iter 2000 error 4.97396098571199e-16 maxval 3.92990996816069e-09
iter 3000 error 4.0816077341013986e-16 maxval 5.5387359214237704e-09
iter 4000 error 3.762934594791658e-16 maxval 6.761021383041815e-09
iter 5000 error 3.161366937276759e-16 maxval 8.44329553239271e-09
iter 6000 error 1.6653006954950299e-16 maxval 9.099291216199116e-09
iter 7000 error 4.742679402625187e-16 maxval 9.118952777984684e-09
iter 8000 error 2.643690319697674e-16 maxval 1.0341329437709912e-08
iter 9000 error 2.6758298782002387e-16 maxval 1.0870015870632981e-08
iter 10000 error 2.798162603386793e-16 maxval 1.199448229317194e-08
iter 11000 error 2.0519171315656232e-16 maxval 1.302577414538684e-08
iter 12000 error 2.988221839267419e-16 maxval 1.4056519565592392e-08
iter 13000 error 2.0539717333160852e-16 maxval 1.4347264621618977e-08
iter 14000 error 2.122112653697268e-16 maxval 1.464

## Evaluate

In [204]:
stc_ = stc.copy()
stc_.data[:, 0] = x_best / abs(x_best).max()
stc_.plot(**pp, brain_kwargs=dict(title="FISTA"))

Using control points [5.67297007e-09 1.86717457e-08 6.59507534e-01]
For automatic theme detection, "darkdetect" has to be installed! You can install it with `pip install darkdetect`
To use light mode, "qdarkstyle" has to be installed! You can install it with `pip install qdarkstyle`


<mne.viz._brain._brain.Brain at 0x2840dd17550>

Using control points [0.01421222 0.12895882 0.95355272]


# S-Fista

In [201]:
t_total = int(1e5)
A = deepcopy(leadfield)
n_chans, n_dipoles = A.shape
beta = 1 / np.sum(A**2)
lam = 1e-14
patience = 10000
x_t = np.zeros(n_dipoles)
x_t_prev = np.zeros(n_dipoles)
x_best = np.zeros(n_dipoles)
lam_gradient = 1e-5
error_best = np.inf
errors = []
A_H = np.matrix(A).getH()
for t in range(t_total):
# t = 0
# while True:
    v_t = y - A @ x_t
    r = x_t + beta * A_H @ v_t + ((t-2)/(t+1)) * (x_t - x_t_prev) + get_gradient_of_neighbors(x_t, adjacency) * lam_gradient # * np.sign(x_t)
    
    x_tplus = soft_thresholding(r, lam)
    
    x_t_prev = deepcopy(x_t)
    x_t = x_tplus
     
    error = np.sum((y - A @ x_t)**2) * 0.5 + lam * abs(x_t).sum() + np.sum(get_gradient_of_neighbors(x_t, adjacency) * lam_gradient )  # * np.sign(x_t)
    errors.append( error )

    if errors[-1] < error_best:
        x_best = deepcopy(x_t)
        error_best = errors[-1]
    if t>patience and  (np.any(np.isnan(x_tplus))  or np.all(np.array(errors[-patience:-1]) < errors[-1] )):
        break
    if np.mod(t, 100) == 0:
        print(f"iter {t} error {errors[-1]} maxval {abs(x_t).max()}")

print(f"Finished after {t} iterations, error = {error_best}")
print("Neighbor Term: ", np.sum(get_gradient_of_neighbors(x_t, adjacency)*np.sign(x_t))*lam_gradient, "L1 Term: ", np.sum(x_best))

iter 0 error 2.7359715425398443e-10 maxval 1.0015689852552044e-10
iter 100 error 2.6099725186117462e-12 maxval 1.165446975323222e-09
iter 200 error 2.4767326912724186e-12 maxval 1.4973652289947094e-09
iter 300 error 2.330893687803134e-12 maxval 1.71461236223066e-09
iter 400 error 2.225570947298678e-12 maxval 1.9070616029349935e-09
iter 500 error 2.1307568375809016e-12 maxval 2.087351690679838e-09
iter 600 error 2.0688192522689617e-12 maxval 2.3427175759392686e-09
iter 700 error 2.0304930966354808e-12 maxval 2.6104566417851464e-09
iter 800 error 2.0126248441725004e-12 maxval 2.899000675890662e-09
iter 900 error 2.005232806930119e-12 maxval 3.1877998581756076e-09
iter 1000 error 2.024993310893875e-12 maxval 3.4875256948271964e-09
iter 1100 error 2.033892838861355e-12 maxval 3.810262232880209e-09
iter 1200 error 2.061128752358467e-12 maxval 4.191389703510768e-09
iter 1300 error 2.1184281798814412e-12 maxval 4.668857709394873e-09
iter 1400 error 2.2099116193579173e-12 maxval 5.229221149644

# Evaluate

In [202]:
stc_ = stc.copy()
stc_.data[:, 0] = x_best / abs(x_best).max()
stc_.plot(**pp, brain_kwargs=dict(title="S-FISTA"))

Using control points [0.0048464  0.0318313  0.61397084]
For automatic theme detection, "darkdetect" has to be installed! You can install it with `pip install darkdetect`
To use light mode, "qdarkstyle" has to be installed! You can install it with `pip install qdarkstyle`


<mne.viz._brain._brain.Brain at 0x28409998640>

Using control points [0.31341884 0.40073647 0.86956696]


# MXNE

In [None]:
noise_cov = mne.make_ad_hoc_cov(evoked.info)
evoked.set_eeg_reference("average", projection=True)
stc_mxne = mne.inverse_sparse.mixed_norm(evoked, fwd, noise_cov, alpha='sure', 
        loose=0, depth=0.8, maxit=3000, tol=0.0001, active_set_size=10, 
        debias=True, time_pca=True, weights=None, weights_min=0.0, solver='auto', 
        n_mxne_iter=1, return_residual=False, return_as_dipoles=False, 
        dgap_freq=10, rank=None, pick_ori=None, sure_alpha_grid='auto', 
        random_state=None, verbose=None)
stc_mxne.data /= abs(stc_mxne.data).max()
stc_mxne.plot(**pp, brain_kwargs=dict(title="MXNE"))

In [None]:
stc_mxne.data.shape

In [None]:
from scipy.stats import pearsonr
corr = lambda x, y: pearsonr(x, y)[0]
print(corr(stc.data.flatten(), stc_mxne.data.flatten()))
print(corr(stc.data.flatten(), stc_.data.flatten()))

In [None]:
noise_cov = mne.make_ad_hoc_cov(evoked.info)
evoked.set_eeg_reference("average", projection=True)
stc_rap = mne.beamformer.rap_music(evoked, fwd, noise_cov, n_dipoles=5, return_residual=False, verbose=None)


In [None]:
mne.inverse_sparse.make_stc_from_dipoles(stc_rap, fwd["src"]).plot(**pp)

In [None]:
np.any(np.diff(fwd["src"][1]["vertno"])<0)

In [None]:
fwd