In [1]:
%matplotlib qt
%load_ext autoreload
%autoreload 2

import sys; sys.path.insert(0, '../invert')
import mne
import pickle as pkl

from time import time

from invert.forward import get_info, create_forward_model
from invert.solvers.esinet import generator
from invert import Solver

# Load Files

In [4]:
fwd = mne.read_forward_solution("forward_model/64ch_ico3-fwd.fif", verbose=0)
fwd = mne.convert_forward_solution(fwd, force_fixed=True)
fn = "forward_model/64ch_info.pkl"
with open(fn, 'rb') as f:
    info = pkl.load(f)

    No patch info available. The standard source space normals will be employed in the rotation to the local surface coordinates....
    Changing to fixed-orientation forward solution with surface-based source orientations...
    [done]


# Simulate Data

## Settings

In [5]:
from copy import deepcopy
generator_args_single = dict(
    use_cov = False, 
    return_mask = False, 
    n_sources = 10, 
    n_orders = (0,0),
    snr_range = (0.1, 100),
    amplitude_range = (0.001, 1),
    batch_repetitions = 1,
    batch_size = 500,
    scale_data = False,
    n_timepoints = 20,
    beta_range = (0, 3),
    return_info = True)

generator_args_extended = deepcopy(generator_args_single)
generator_args_extended["n_orders"] = (1, 3)

generator_args = dict(single=generator_args_single, extended=generator_args_extended)
solver_names = ["FLEX-MUSIC", "TRAP-MUSIC", "eLORETA", "Convexity Champagne", "MCMV"]
recompute_list = ["FLEX-MUSIC", "TRAP-MUSIC", "Convexity Champagne", "MCMV"]

In [6]:
# Generate Samples
solvers = {solver_name: Solver(solver_name) for solver_name in solver_names}
for sim_type, generator_args_ in generator_args.items():
    proc_time_make = dict()
    proc_time_apply = dict()
    gen_test = generator(fwd, **generator_args_)
    x_test, y_test, sim_info = gen_test.__next__()

    stc_dict = dict()

    for i_sample, (x_sample, y_sample) in enumerate(zip(x_test, y_test)):
        print("Sample ", i_sample)
        evoked = mne.EvokedArray(x_sample.T, info, tmin=0)
        for solver_name, solver in solvers.items():
            if i_sample == 0:
                stc_dict[solver_name] = []
                proc_time_make[solver_name] = []
                proc_time_apply[solver_name] = []
                
            if solver_name in recompute_list or i_sample == 0:
                start_make = time()
                solver.make_inverse_operator(fwd, evoked, alpha="auto")
                end_make = time()
            
            start_apply = time()
            stc = solver.apply_inverse_operator(evoked)
            end_apply = time()
            
            stc_dict[solver_name].append(stc)
            proc_time_make[solver_name].append(end_make - start_make)
            proc_time_apply[solver_name].append(end_apply - start_apply)
            
    fn = f"evaluation/sim_and_preds_{sim_type}.pkl"
    with open(fn, 'wb') as f:
        pkl.dump([stc_dict, x_test, y_test, sim_info, proc_time_make, proc_time_apply], f)

Sample  0
Sample  1
Sample  2
Sample  3
Sample  4
Sample  5
Sample  6
Sample  7
Sample  8
Sample  9
Sample  10
Sample  11
Sample  12
Sample  13
Sample  14
Sample  15
Sample  16
Sample  17
Sample  18
Sample  19
Sample  20
Sample  21
Sample  22
Sample  23
Sample  24
Sample  25
Sample  26
Sample  27
Sample  28
Sample  29
Sample  30
Sample  31
Sample  32
Sample  33
Sample  34
Sample  35
Sample  36
Sample  37
Sample  38
Sample  39
Sample  40
Sample  41
Sample  42
Sample  43
Sample  44
Sample  45
Sample  46
Sample  47
Sample  48
Sample  49
Sample  50
Sample  51
Sample  52
Sample  53
Sample  54
Sample  55
Sample  56
Sample  57
Sample  58
Sample  59
Sample  60
Sample  61
Sample  62
Sample  63
Sample  64
Sample  65
Sample  66
Sample  67
Sample  68
Sample  69
Sample  70
Sample  71
Sample  72
Sample  73
Sample  74
Sample  75
Sample  76
Sample  77
Sample  78
Sample  79
Sample  80
Sample  81
Sample  82
Sample  83
Sample  84
Sample  85
Sample  86
Sample  87
Sample  88
Sample  89
Sample  90
Sample  9

In [10]:
solvers.keys()

dict_keys(['FLEX-MUSIC', 'TRAP-MUSIC', 'eLORETA', 'Convexity Champagne', 'MCMV', 'LSTM'])

In [14]:
[solver.save(f"saved_solvers/{solver.name}") for solver in solvers.values()]



INFO:tensorflow:Assets written to: saved_solvers/LSTM\LSTM_0\assets


INFO:tensorflow:Assets written to: saved_solvers/LSTM\LSTM_0\assets


[<invert.solvers.music.SolverFLEXMUSIC at 0x188f12eb580>,
 <invert.solvers.music.SolverTRAPMUSIC at 0x188f12eb5b0>,
 <invert.solvers.loreta.SolverELORETA at 0x188f12eb5e0>,
 <invert.solvers.empirical_bayes.SolverConvexityChampagne at 0x188f12eb610>,
 <invert.solvers.beamformer.SolverMCMV at 0x188f12eb640>,
 <invert.solvers.esinet.SolverLSTM at 0x188f12eb670>]

# Plot samples

In [22]:
pp = dict(surface='white', hemi='both', verbose=0)
sample = 2

stc_ = stc.copy()
stc_.data = y_test[sample].T
stc_.plot(**pp, brain_kwargs=dict(title="True"))

for solver, stc_list in stc_dict.items():
    stc_list[sample].plot(**pp, brain_kwargs=dict(title=solver))
display(sim_info.iloc[sample])

n_sources                                            2
amplitudes    [0.907917529887389, 0.26380035363660265]
snr                                          16.147458
Name: 2, dtype: object

Using control points [5.26725532e-05 1.53177692e-04 2.36555659e-02]
Using control points [5.26725532e-05 1.53177692e-04 2.36555659e-02]
Using control points [9.48860114e-05 2.84658034e-04 2.11078202e-02]
Using control points [9.48860114e-05 2.84658034e-04 2.11078202e-02]
Using control points [0.00000000e+00 7.78693932e-06 1.87411353e-02]
Using control points [5.26725532e-05 1.53177692e-04 2.36555659e-02]
Using control points [0.00012208 0.00099767 0.0134913 ]
Using control points [8.04834742e-05 2.33942621e-04 1.30086290e-02]
Using control points [0.00000000e+00 1.82359132e-05 1.06895868e-02]
Using control points [0.00041055 0.00047471 0.00077506]
Using control points [0.00041055 0.00047471 0.00077506]
Using control points [0.00041055 0.00047471 0.00077506]
Using control points [0.00012208 0.00099767 0.0134913 ]
Using control points [0.00054711 0.0006143  0.00121755]
Using control points [0.00054711 0.0006143  0.00121755]
Using control points [0.00019641 0.00021089 0.00029645]
Using co