In [1]:
%matplotlib qt
%load_ext autoreload
%autoreload 2

import sys; sys.path.insert(0, '../invert')
import mne
import pickle as pkl
from time import time

from invert.forward import get_info, create_forward_model
from invert.solvers.esinet import generator
from invert import Solver

# Load Files

In [2]:
fwd = mne.read_forward_solution("forward_model/64ch_ico3-fwd.fif", verbose=0)
fwd = mne.convert_forward_solution(fwd, force_fixed=True)
fn = "forward_model/64ch_info.pkl"
with open(fn, 'rb') as f:
    info = pkl.load(f)

    No patch info available. The standard source space normals will be employed in the rotation to the local surface coordinates....
    Changing to fixed-orientation forward solution with surface-based source orientations...
    [done]


# Simulate Data

## Settings

In [17]:
generator_args_single = dict(
    use_cov = False, 
    return_mask = False, 
    n_sources = 10, 
    n_orders = (0,0),
    snr_range = (0.1, 100),
    amplitude_range = (0.001, 1),
    batch_repetitions = 1,
    batch_size = 1000,
    scale_data = False,
    n_timepoints = 20,
    beta_range = (0, 3),
    return_info = True)

generator_args_extended = dict(
    use_cov = False, 
    return_mask = False, 
    n_sources = 10, 
    n_orders = (1, 4),
    snr_range = (0.1, 20),
    amplitude_range = (0.001, 1),
    batch_repetitions = 1,
    batch_size = 1000,
    scale_data = False,
    n_timepoints = 20,
    beta_range = (0, 3),
    return_info = True)

generator_args = dict(extended=generator_args_extended, single=generator_args_single)
solver_names = ["FLEX-MUSIC", "TRAP-MUSIC", "eLORETA", "Convexity Champagne", "MCMV"]
recompute_list = ["FLEX-MUSIC", "TRAP-MUSIC", "Convexity Champagne", "MCMV"]

In [19]:
# Generate Samples
solver_mne = Solver("MNE").make_inverse_operator(fwd)

solvers = {solver_name: Solver(solver_name) for solver_name in solver_names}

for sim_type, generator_args_ in generator_args.items():
    proc_time_make = dict()
    proc_time_apply = dict()
    gen_test = generator(fwd, **generator_args_)
    x_test, y_test, sim_info = gen_test.__next__()

    stc_dict = dict()

    for i_sample, (x_sample, y_sample) in enumerate(zip(x_test, y_test)):
        print("Sample ", i_sample)
        evoked = mne.EvokedArray(x_sample.T, info, tmin=0)
        for solver_name, solver in solvers.items():
            if i_sample == 0:
                stc_dict[solver_name] = []
                proc_time_make[solver_name] = []
                proc_time_apply[solver_name] = []
                
            if solver_name in recompute_list or i_sample == 0:
                start_make = time()
                solver.make_inverse_operator(fwd, evoked, alpha="auto")
                end_make = time()
            
            start_apply = time()
            stc = solver.apply_inverse_operator(evoked)
            end_apply = time()
            
            stc_dict[solver_name].append(stc)
            proc_time_make[solver_name].append(end_make - start_make)
            proc_time_apply[solver_name].append(end_apply - start_apply)
            
            
    fn = f"evaluation/sim_and_preds_{sim_type}.pkl"
    with open(fn, 'wb') as f:
        pkl.dump([stc_dict, x_test, y_test, sim_info, proc_time_make, proc_time_apply], f)

Sample  0
Sample  1
Sample  2
Sample  3
Sample  4
Sample  5
Sample  6


KeyboardInterrupt: 

# Plot samples

In [22]:
pp = dict(surface='white', hemi='both', verbose=0)
sample = 2

stc_ = stc.copy()
stc_.data = y_test[sample].T
stc_.plot(**pp, brain_kwargs=dict(title="True"))

for solver, stc_list in stc_dict.items():
    stc_list[sample].plot(**pp, brain_kwargs=dict(title=solver))
display(sim_info.iloc[sample])

n_sources                                            2
amplitudes    [0.907917529887389, 0.26380035363660265]
snr                                          16.147458
Name: 2, dtype: object

Using control points [5.26725532e-05 1.53177692e-04 2.36555659e-02]
Using control points [5.26725532e-05 1.53177692e-04 2.36555659e-02]
Using control points [9.48860114e-05 2.84658034e-04 2.11078202e-02]
Using control points [9.48860114e-05 2.84658034e-04 2.11078202e-02]
Using control points [0.00000000e+00 7.78693932e-06 1.87411353e-02]
Using control points [5.26725532e-05 1.53177692e-04 2.36555659e-02]
Using control points [0.00012208 0.00099767 0.0134913 ]
Using control points [8.04834742e-05 2.33942621e-04 1.30086290e-02]
Using control points [0.00000000e+00 1.82359132e-05 1.06895868e-02]
Using control points [0.00041055 0.00047471 0.00077506]
Using control points [0.00041055 0.00047471 0.00077506]
Using control points [0.00041055 0.00047471 0.00077506]
Using control points [0.00012208 0.00099767 0.0134913 ]
Using control points [0.00054711 0.0006143  0.00121755]
Using control points [0.00054711 0.0006143  0.00121755]
Using control points [0.00019641 0.00021089 0.00029645]
Using co