In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys; sys.path.insert(0, '../')
import pickle as pkl
import numpy as np
import pandas as pd
from copy import deepcopy
import mne
import seaborn as sns
import matplotlib.pyplot as plt
from esinet import util
from esinet import Simulation
from esinet import Net
from esinet import forward

plot_params = dict(surface='white', hemi='both', verbose=0)

# Forward Model

In [2]:
info = forward.get_info()
info['sfreq'] = 100
fwd = forward.create_forward_model(info=info)
fwd_free = forward.create_forward_model(info=info, fixed_ori=False)

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    2.1s remaining:    3.6s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    2.3s remaining:    1.3s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    2.4s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.0s remaining:    0.1s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    0.1s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.0s remaining:    0.1s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(

# Load Models

In [3]:
lstm_standard = util.load_net('models/LSTM Medium_1-1000points_standard-cosine_0')
dense_standard = util.load_net('models/Dense Medium_1-1000points_standard-cosine_0')
convdip_standard = util.load_net('models/ConvDip Medium_1-1000points_standard-cosine_0')

models = [lstm_standard, dense_standard, convdip_standard]
model_names = ['LSTM', 'Fully-Connected', 'ConvDip']

# Plot single ground truth and predictions

In [4]:
import seaborn as sns
%matplotlib qt
sns.reset_orig()

%load_ext autoreload
%autoreload 2
model_names_tmp = deepcopy(model_names)
plot_params = dict(surface='white', hemi='both', verbose=0, 
    clim=dict(kind='percent', pos_lims=[20, 30, 100]))

settings_eval = dict(method='standard')
# settings_eval = dict( method='standard')

# Simulate new data
sim = Simulation(fwd, info, settings=settings_eval).simulate(2)
# snr = sim.simulation_info['target_snr'].values[0]
snr = None
# print(sim.simulation_info)
idx = 0
# Predict sources using the esinet models
predictions = [model.predict(sim) for model in models]

# Predict sources with classical methods
# eLORETA
prediction_elor_data = util.wrap_mne_inverse(fwd, sim, method='eLORETA', 
    add_baseline=True, n_baseline=400)[idx].data.astype(np.float32)
prediction_elor = deepcopy(predictions[0][0])
prediction_elor.data = prediction_elor_data / np.abs(np.max(prediction_elor_data))
# MNE
prediction_mne_data = util.wrap_mne_inverse(fwd, sim, method='MNE', 
    add_baseline=True, n_baseline=400)[idx].data.astype(np.float32)
prediction_mne = deepcopy(predictions[0][0])
prediction_mne.data = prediction_mne_data / np.abs(np.max(prediction_mne_data))
# Beamformer
prediction_lcmv_data = util.wrap_mne_inverse(fwd, sim, method='lcmv', 
    add_baseline=True, n_baseline=400, parallel=False)[idx].data.astype(np.float32)
prediction_lcmv = deepcopy(predictions[0][0])
prediction_lcmv.data = prediction_lcmv_data / np.max(np.abs(prediction_lcmv_data))

# Get predictions and names in order
predictions.append([prediction_elor])
predictions.append([prediction_mne])
predictions.append([prediction_lcmv])

model_names_tmp.append('eLORETA')
model_names_tmp.append('MNE')
model_names_tmp.append('LCMV')

# Plot True Source
brain = sim.source_data[idx].plot(**plot_params)
brain.add_text(0.1, 0.9, f'Ground Truth {sim.simulation_info.number_of_sources.values[0]} sources', 'title')
# Plot True EEG
evoked = sim.eeg_data[idx].average()
# evoked.plot()
evoked.plot_topomap(title='Ground Truth')
# evoked = util.get_eeg_from_source(sim.source_data[idx], fwd, info, tmin=0.)
# evoked.plot_topomap(title='Ground Truth Noiseless')

model_selection = model_names_tmp#['LCMV',]
# Plot predicted sources
for model_name, prediction in zip(model_names_tmp, predictions):
    
    if not any([model_name.lower() in model_select.lower() for model_select in model_selection]):
        continue
    error = util.batch_nmse(sim.source_data[idx].data, prediction[idx].data)
    r = util.batch_corr(sim.source_data[idx].data, prediction[idx].data)
    
    brain = prediction[idx].plot(**plot_params)

    title = f'{model_name}, error: {error:.4}, r: {r}'
    print(title)
    brain.add_text(0.1, 0.9, title, 'title')
    # Plot predicted EEG
    # evoked_esi = util.get_eeg_from_source(prediction[idx], fwd, info, tmin=0.)
    # evoked_esi.plot_topomap(title=model_name)


Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00, 142.66it/s]
100%|██████████| 2/2 [00:00<00:00, 4466.78it/s]

source data shape:  (1284, 100) (1284, 100)



100%|██████████| 2/2 [00:00<00:00, 105.42it/s]


interpolating for convdip...


2it [00:00, 13.03it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

  epochs.set_eeg_reference(projection=True, verbose=verbose)#.apply_baseline(baseline=baseline)
  epochs.set_eeg_reference(projection=True, verbose=verbose)#.apply_baseline(baseline=baseline)


  0%|          | 0/2 [00:00<?, ?it/s]

  epochs.set_eeg_reference(projection=True, verbose=verbose)#.apply_baseline(baseline=baseline)
  epochs.set_eeg_reference(projection=True, verbose=verbose)#.apply_baseline(baseline=baseline)


  0%|          | 0/2 [00:00<?, ?it/s]

  epochs.set_eeg_reference(projection=True, verbose=verbose)#.apply_baseline(baseline=baseline)
  data_cov = mne.compute_raw_covariance(raw, tmin=tmin,
  epochs.set_eeg_reference(projection=True, verbose=verbose)#.apply_baseline(baseline=baseline)
  data_cov = mne.compute_raw_covariance(raw, tmin=tmin,


<RawArray | 61 x 500 (5.0 s), ~325 kB, data loaded> 4.0 info empirical
<RawArray | 61 x 500 (5.0 s), ~325 kB, data loaded> 4.0 info empirical
LSTM, error: 0.009551, r: 0.7136614161906574
Fully-Connected, error: 0.01353, r: 0.5887072193723496
ConvDip, error: 0.01734, r: 0.5246880901425557
eLORETA, error: 0.1012, r: 0.1614547572011845
MNE, error: 0.05877, r: 0.12983144421458173
LCMV, error: 0.03435, r: 0.08681544504562365
Using control points [2.01647748e-10 2.71510836e-10 1.23118085e-09]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [2.01647748e-10 2.71510836e-10 1.23118085e-09]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [7.47266520e-10 1.72226714e-09 8.31365150e-09]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [1.07800510e-09 2.88319091e-09 9.75376176e-09]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [7.47266520e-10 1.72226714e-09 8.31365150e-09]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'
  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [7.19298552e-10 9.88622884e-10 2.39252744e-09]


  File "C:\Users\lukas\virtualenvs\esienv\lib\site-packages\mne\viz\_brain\_brain.py", line 1348, in _on_button_release
    self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y)
AttributeError: 'RenderWindowInteractor' object has no attribute 'FindPokedRenderer'


Using control points [6.89443689e-10 9.57919761e-10 5.66304863e-09]


## Create or Load Evaluation Set

In [4]:
# n_samples = 1000
# duration_of_trial = (0.01, 2)
# method = 'standard'
# settings = dict(duration_of_trial=duration_of_trial, method=method)
# sim_test = Simulation(fwd, info, verbose=False, settings=settings).simulate(n_samples=n_samples)
# if type(duration_of_trial) == tuple:
#     sim_test.save(f'simulations\\sim_test_{n_samples}_{int(duration_of_trial[0]*100)}-{int(duration_of_trial[1]*100)}points_{method}.pkl')
# else:
#     sim_test.save(f'simulations\\sim_test_{n_samples}_{int(duration_of_trial*100)}points_{method}.pkl')

# or Load
with open(f'simulations\\sim_test_1000_1-200points_standard.pkl', 'rb') as f:
    sim_test = pkl.load(f)

# Calculate Performance Metrics

In [None]:
%load_ext autoreload
%autoreload 2

from esinet.evaluate import eval_mean_localization_error, eval_nmse, eval_auc, eval_mse
from esinet.util import wrap_mne_inverse
from scipy.spatial.distance import cdist
from tqdm.notebook import tqdm
from joblib import Parallel, delayed

model_names_tmp = deepcopy(model_names)
# Predict
print('predict esinets...')
predictions = [model.predict(sim_test) for model in models]

print('predict elor')
pred_elor = wrap_mne_inverse(fwd, sim_test, method='eLORETA', add_baseline=True, 
    n_baseline=400)
model_names_tmp.append('eLORETA')
predictions.append(pred_elor)

print('predict MNE')
pred_mne = wrap_mne_inverse(fwd, sim_test, method='MNE', add_baseline=True, 
    n_baseline=400)
model_names_tmp.append('MNE')
predictions.append(pred_mne)

print('predict LCMV')
pred_lcmv = wrap_mne_inverse(fwd, sim_test, method='lcmv', 
    parallel=False, add_baseline=True, n_baseline=400)
model_names_tmp.append('LCMV')
predictions.append(pred_lcmv)

pos = util.unpack_fwd(fwd)[2]
argsorted_distance_matrix = np.argsort(cdist(pos, pos), axis=-1)

metrics = dict()
true_sources = np.concatenate([src.data for src in sim_test.source_data], axis=1).T

for prediction, model_name in tqdm(zip(predictions, model_names_tmp)):
    print('\n', model_name, ':\n')
     
    predicted_sources = np.concatenate([src.data for src in prediction], axis=1).T

    print('mle calculation....')
    mean_localization_errors = [eval_mean_localization_error(true_source, predicted_source, pos, argsorted_distance_matrix=argsorted_distance_matrix) for true_source, predicted_source in tqdm(zip(true_sources, predicted_sources))]
    print(len(mean_localization_errors), len(true_sources), len(predicted_sources))
    print('auc calculation....')
    # aucs_combined = [eval_auc(true_source, predicted_source, pos, epsilon=0.25, n_redraw=5) for true_source, predicted_source in tqdm(zip(true_sources, predicted_sources))]
    aucs_combined = Parallel(n_jobs=-1, backend='loky') \
        (delayed(eval_auc)(true_source, predicted_source, pos, epsilon=0.25, n_redraw=5)
        for true_source, predicted_source in tqdm(zip(true_sources, predicted_sources)))
    print('nmse calculation....')
    nmses = [eval_nmse(true_source, predicted_source) for true_source, predicted_source in tqdm(zip(true_sources, predicted_sources))]
    print('mse calculation....')
    mses = [eval_mse(true_source, predicted_source) for true_source, predicted_source in tqdm(zip(true_sources, predicted_sources))]

    aucs_far = [auc[1] for auc in np.array(aucs_combined)]
    aucs_close = [auc[0] for auc in np.array(aucs_combined)]
    aucs_combined = [np.nanmean([auc[0], auc[1]]) for auc in np.array(aucs_combined)]

    metric = pd.DataFrame(dict(
        mean_localization_errors=mean_localization_errors,
        aucs_combined=aucs_combined,
        aucs_far=aucs_far,
        aucs_close=aucs_close,
        nmses=nmses,
        mses=mses
        )
    )
    metric.name = model_name
    metrics[model_name] = metric
    

with open(f'results\\metrics_{len(true_sources)}_1-200points_standard.pkl', 'wb') as f:
    pkl.dump([metrics, sim_test.simulation_info], f)

# Load Metrics

In [5]:
with open(f'results\\metrics_101947_1-200points_standard.pkl', 'rb') as f:
    [metrics, simulation_info] = pkl.load(f)

# Prepare Metrics

In [8]:
durs_in_samples = [round(dur*100) for dur in sim_test.simulation_info.duration_of_trials.values]

idx = 0
indices = []
for dur in durs_in_samples:
    idc = [idx, dur+idx]
    indices.append(idc)
    idx += dur
metrics_short = dict()


for method in metrics.keys():
    metrics_short[method] = pd.DataFrame(columns=metrics[method].columns)
    for id, idc in enumerate(indices):
        sample_summary = metrics[method].iloc[idc[0]:idc[1]].apply(np.nanmean, axis=0)
        sample_summary.sample_id = id
        metrics_short[method] = metrics_short[method].append(sample_summary, ignore_index=True)
    metrics_short[method].method = method
    # metrics_short[method].iloc[:, 0:6].values = np.real(metrics_short[method].iloc[:, 0:6])

dfs = [df[1] for df in list(metrics_short.items())]

for i, (key, df) in enumerate(metrics_short.items()):
    dfs[i]['method'] = [key]*df.shape[0]
    dfs[i]['sample_id'] = np.arange(df.shape[0])
df_aio = pd.concat(dfs)
df_aio.head()

  results[i] = self.f(v)


Unnamed: 0,mean_localization_errors,aucs_combined,aucs_far,aucs_close,nmses,mses,method,sample_id
0,23.219618,0.868293,0.942479,0.794107,0.012646,4.666444e-19,LSTM,0
1,11.719351,0.822743,0.8756,0.769886,0.008011,1.222845e-19,LSTM,1
2,14.562478,0.843674,0.921228,0.766119,0.0135,2.5045299999999996e-19,LSTM,2
3,19.684041,0.887201,0.966092,0.808311,0.012914,1.005835e-19,LSTM,3
4,19.647893,0.828699,0.875775,0.781622,0.010828,2.464672e-19,LSTM,4


# Plot Evaluation Metrics

## Boxplot Overview 
### All Sources

In [23]:
import seaborn as sns; sns.set(style='whitegrid', font_scale=1.6, font='helvetica')
%matplotlib qt

cols = ['mean_localization_errors', 'aucs_combined', 'nmses']  # df_aio.iloc[:, 0:6].columns
metric_names = ["Mean Localization Error [mm]", "Area Under the Curve", "Normalized Mean Squared Error"]
for y, metric_name in zip(cols, metric_names):
    plt.figure(figsize=(13, 7))
    sns.boxplot(data=df_aio, x='method', y=y)
    plt.title(metric_name)
    plt.xlabel("Inverse Solution")
    plt.ylabel(metric_name)
    plt.ylim(0, None)
# util.multipage(r'C:\Users\lukas\Sync\lstm_inverse_problem\figures\results\boxplot_overview_allsims.pdf', png=True)
print(df_aio.describe())

       mean_localization_errors  aucs_combined     aucs_far   aucs_close  \
count               5991.000000    5997.000000  5997.000000  5997.000000   
mean                  18.291031       0.774498     0.831871     0.717124   
std                    5.242296       0.124975     0.125013     0.133415   
min                    0.000000       0.147200     0.141287     0.153114   
25%                   15.181946       0.694053     0.768795     0.615655   
50%                   18.794063       0.785411     0.858439     0.712460   
75%                   21.760584       0.864222     0.924865     0.811050   
max                   38.030185       1.000000     1.000000     1.000000   

             nmses          mses   sample_id  
count  5997.000000  6.000000e+03  6000.00000  
mean      0.035624  2.433638e-15   499.50000  
std       0.044970  7.350359e-14   288.69905  
min       0.000658  1.265231e-23     0.00000  
25%       0.010361  3.070457e-20   249.75000  
50%       0.017352  8.683084e-20 

### Single Sources

In [22]:
import seaborn as sns; sns.set(style='whitegrid', font_scale=1.6, font='helvetica')
%matplotlib qt

idc_single_source = np.argwhere(simulation_info.number_of_sources.values==1)[:, 0]
df_single = df_aio[df_aio.sample_id.isin(idc_single_source)]

cols = ['mean_localization_errors', 'aucs_combined', 'nmses']  # df_single.iloc[:, 0:6].columns
metric_names = ["Mean Localization Error [mm]", "Area Under the Curve", "Normalized Mean Squared Error"]
for y, metric_name in zip(cols, metric_names):
    plt.figure(figsize=(13, 7))
    sns.boxplot(data=df_single, x='method', y=y)
    # sns.swarmplot(data=df_single, x='method', y=y)
    plt.title(metric_name)
    plt.xlabel("Inverse Solution")
    plt.ylabel(metric_name)
    plt.ylim(0, None)
# util.multipage(r'C:\Users\lukas\Sync\lstm_inverse_problem\figures\results\boxplot_overview_singlesource.pdf', png=True)

# Tables

## All Source

In [53]:
df_of_interest = df_aio

methods = set(df_of_interest.method.values)
columns = ['method', 'mean_localization_errors', 'aucs_combined', 'nmses']  # df_aio.iloc[:, 0:6].columns
column_names = ["Inverse Solution", "MLE [mm] (SD)", "AUC [%] (SD)", "nMSE (SD)"]
decimals = [None, 2, 2, 4]

table = pd.DataFrame(columns=column_names)
for i, method in enumerate(methods):
    row_dict = {column_names[0]:method}
    for j, (column, column_name, decimal) in enumerate(zip(columns[1:], column_names[1:], decimals[1:])):
        values = df_of_interest[df_of_interest.method==method][column].values
        row_dict[column_name] = str(round(np.nanmedian(values), decimal)) + ' (' + str(round(np.nanstd(values), decimal)) + ')'
    table = table.append(row_dict, ignore_index=True)
table = table.sort_values("AUC [%] (SD)", ascending=False)
table

Unnamed: 0,Inverse Solution,MLE [mm] (SD),AUC [%] (SD),nMSE (SD)
5,LSTM,15.85 (4.27),0.86 (0.08),0.0089 (0.0032)
3,Fully-Connected,16.55 (3.87),0.85 (0.07),0.0109 (0.0036)
1,ConvDip,17.24 (3.96),0.81 (0.08),0.0124 (0.0038)
2,eLORETA,20.32 (5.45),0.74 (0.1),0.0832 (0.0752)
0,MNE,22.48 (4.84),0.72 (0.11),0.0447 (0.0206)
4,LCMV,20.67 (5.7),0.61 (0.13),0.0348 (0.0193)


## Single Source

In [45]:
df_single = df_aio[df_aio.sample_id.isin(idc_single_source)]

df_of_interest = df_single

methods = set(df_of_interest.method.values)
columns = ['method', 'mean_localization_errors', 'aucs_combined', 'nmses']  # df_aio.iloc[:, 0:6].columns
column_names = ["Inverse Solution", "MLE [mm] (SD)", "AUC [%] (SD)", "nMSE (SD)"]
decimals = [None, 2, 2, 4]

table = pd.DataFrame(columns=column_names)
for i, method in enumerate(methods):
    row_dict = {column_names[0]:method}
    for j, (column, column_name, decimal) in enumerate(zip(columns[1:], column_names[1:], decimals[1:])):
        values = df_of_interest[df_of_interest.method==method][column].values
        row_dict[column_name] = str(round(np.nanmedian(values), decimal)) + ' (' + str(round(np.nanstd(values), decimal)) + ')'
    table = table.append(row_dict, ignore_index=True)
table

Unnamed: 0,Inverse Solution,MLE [mm] (SD),AUC [%] (SD),nMSE (SD)
0,MNE,20.33 (8.43),0.89 (0.16),0.0372 (0.0248)
1,ConvDip,14.13 (5.63),0.96 (0.06),0.0108 (0.0042)
2,eLORETA,10.87 (8.01),0.91 (0.15),0.0677 (0.0696)
3,Fully-Connected,13.32 (5.51),0.96 (0.08),0.0109 (0.0044)
4,LCMV,17.73 (9.54),0.66 (0.21),0.0292 (0.0242)
5,LSTM,12.08 (5.53),0.99 (0.07),0.0073 (0.0031)


## Quadratic Scatter

In [None]:
import seaborn as sns; sns.set(style='whitegrid', font_scale=1.2, font='helvetica')
%matplotlib qt
methods_of_interest = ['LSTM Standard', 'Dense Standard']
# df_select = df_aio[df_aio['method'].str.contains('|'.join(methods_of_interest))]

cols = df_aio.iloc[:, 0:6].columns
for method_name in cols:

    vals_A = df_aio[df_aio['method'].str.contains(methods_of_interest[0])][method_name].values
    vals_B = df_aio[df_aio['method'].str.contains(methods_of_interest[1])][method_name].values
    d = {methods_of_interest[0]: vals_A, methods_of_interest[1]: vals_B,} 
    df_tmp = pd.DataFrame(d)
    plt.figure(figsize=(10, 10))
    ax = sns.scatterplot(data=df_tmp, x=methods_of_interest[0], y=methods_of_interest[1])


    xlim, ylim = (plt.xlim(), plt.ylim())
    plt.plot(xlim, ylim, '--k')
    xlim = (xlim[0]*0.95, xlim[1]*1.05)
    ylim = (ylim[0]*0.95, ylim[1]*1.05)
    plt.ylim(ylim)
    plt.xlim(xlim)
    
    # Title
    prop_higher = np.sum(vals_B > vals_A) / len(vals_A)
    cohens_d = (np.nanmean(vals_A) - np.nanmean(vals_B)) / np.mean([np.nanstd(vals_A), np.nanstd(vals_B)])
    median_diff = np.abs(np.nanmedian(vals_A-vals_B))
    method_name_title = method_name.replace('_', ' ').title()
    title = f'{method_name_title} ({methods_of_interest[1]} higher in {100*prop_higher:.1f} %)\nmedian_difference: {median_diff}\ncohens d: {abs(cohens_d):.2f}'
    plt.title(title)

    plt.gca().set_aspect('equal', adjustable='box')
    del d, df_tmp

## Dependence on anything

In [11]:
import pandas as pd
%matplotlib qt
sns.set(font_scale=1.2, font='helvetica')
# pd.DataFrame( metrics , index=model_names)
target_column = 'target_snr'
binning = True
n_bins = 6


df = sim_test.simulation_info
if binning:
    minimum = np.min([np.min(arr) for arr in df[target_column].values])
    maximum = np.max([np.max(arr) for arr in df[target_column].values])
    
    bins = np.linspace(minimum, maximum*1.01, num=n_bins)
    bin_labels = [str(round(bins[i], 1)) + ' - ' + str(round(bins[i+1], 1)) for i in range(len(bins)-1)]
    new_target_column = 'bins ' + target_column
    df[new_target_column] = np.digitize(df[target_column].values, bins=bins)
    target_column = new_target_column
    
else:
    bins = list(set(df[target_column].values))
    bins[-1] *= 1.01
    bin_labels = [str(bins[i]) for i in range(len(bins))]


for i, model_name in enumerate(list(set(df_aio.method.values))):
    cols = df_aio[df_aio.method==model_name].iloc[:, 0:6].columns
    values = df_aio[df_aio.method==model_name].iloc[:, 0:6].values

    for metric_name, metric in zip(cols, values.T):
        col_name = model_name.replace(' ', '_') + '_' + metric_name.replace(' ', '_')
        df[col_name] = metric

dep_var_regex = target_column
dep_var_label = target_column.replace('_', ' ').title()
metric_names_nice = [col.replace('_', ' ').title() for col in df_aio.iloc[:, :6].columns]
for metric_name, metric_name_nice in zip(df_aio.iloc[:, :6].columns,  metric_names_nice):
    df_temp = pd.concat((df.filter(regex=dep_var_regex), df.filter(regex='_'+metric_name)), axis=1).melt(dep_var_regex, var_name='cols', value_name='vals')
    g = sns.catplot(x=dep_var_regex, y='vals', hue='cols', capsize=.2, kind='point', data=df_temp)
    g.set(xticklabels=bin_labels, ylabel=metric_name_nice, xlabel=dep_var_label)
    g._legend.remove()
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.tight_layout()

## Dependence on Duration

In [6]:
from tqdm.notebook import tqdm
from scipy.spatial.distance import cdist
from esinet import evaluate

n_simulations = 1000
settings = dict(duration_of_trial=5.0,)
sim = Simulation(fwd, info, settings=settings).simulate(n_simulations)

times = sim.source_data[0].times

pos = util.unpack_fwd(fwd)[2]
argsorted_distance_matrix = np.argsort(cdist(pos, pos), axis=-1)

mles = dict()
nmses = dict()
mses = dict()
aucs_combined = dict()
aucs_close = dict()
aucs_far = dict()

points = dict()
for i, time in tqdm(enumerate(times[:-1][::1])):
    for model, model_name in zip(models, model_names):
        if not model_name in mles.keys():
            mles[model_name] = []    
            nmses[model_name] = []
            mses[model_name] = []    
            aucs_combined[model_name] = []
            aucs_close[model_name] = []
            aucs_far[model_name] = []
            points[model_name] = []

        mle = []
        nmse = []
        mse = []
        auc = []
        new_sim = deepcopy(sim).crop(tmin=time)
        y_pred = model.predict(new_sim)
        
        for idx in range(n_simulations):
            y_true = sim.source_data[idx].data[:, -1][:, np.newaxis]
            y_est = y_pred[idx].data[:, -1][:, np.newaxis]
            mle.append( evaluate.eval_mean_localization_error(y_true, y_est, pos, 
                argsorted_distance_matrix=argsorted_distance_matrix) )
            nmse.append( evaluate.eval_nmse(y_true, y_est) )
            mse.append( evaluate.eval_mse(y_true, y_est) )
            auc.append( evaluate.eval_auc(y_true, y_est, pos, n_redraw=5, epsilon=0.25) )

        auc_combined = [(a[0]+a[1])/2 for a in auc]
        auc_close = [a[0] for a in auc]
        auc_far = [a[1] for a in auc]
            

        
        mles[model_name].append(np.nanmean(mle))
        nmses[model_name].append(np.nanmean(nmse))
        mses[model_name].append(np.nanmean(mse))
        aucs_combined[model_name].append(np.nanmean(auc_combined))
        aucs_close[model_name].append(np.nanmean(auc_close))
        aucs_far[model_name].append(np.nanmean(auc_far))
        points[model_name].append(new_sim.source_data[idx].shape[1])

params = [mles, nmses, mses, aucs_combined, aucs_close, aucs_far]
param_names = ['mles', 'nmses', 'mses', 'aucs_combined', 'aucs_close', 'aucs_far']

with open(f'results\\metrics_duration.pkl', 'wb') as f:
    pkl.dump([params, param_names], f)

Simulating data based on sparse patches.


100%|██████████| 1000/1000 [00:11<00:00, 88.67it/s]
100%|██████████| 1000/1000 [00:01<00:00, 643.56it/s]


source data shape:  (1284, 500) (1284, 500)


100%|██████████| 1000/1000 [00:52<00:00, 18.88it/s]


0it [00:00, ?it/s]

interpolating for convdip...


1000it [05:53,  2.83it/s]


interpolating for convdip...


1000it [05:52,  2.83it/s]


interpolating for convdip...


1000it [05:49,  2.86it/s]


interpolating for convdip...


1000it [05:49,  2.86it/s]


interpolating for convdip...


1000it [05:49,  2.86it/s]


interpolating for convdip...


1000it [05:48,  2.87it/s]


interpolating for convdip...


1000it [05:51,  2.85it/s]


interpolating for convdip...


1000it [05:50,  2.86it/s]


interpolating for convdip...


1000it [05:46,  2.88it/s]


interpolating for convdip...


1000it [05:46,  2.88it/s]


interpolating for convdip...


1000it [05:47,  2.88it/s]


interpolating for convdip...


1000it [05:52,  2.84it/s]


interpolating for convdip...


1000it [05:46,  2.89it/s]


interpolating for convdip...


1000it [05:48,  2.87it/s]


interpolating for convdip...


1000it [05:55,  2.81it/s]


interpolating for convdip...


1000it [05:46,  2.89it/s]


interpolating for convdip...


1000it [05:48,  2.87it/s]


interpolating for convdip...


1000it [05:48,  2.87it/s]


interpolating for convdip...


1000it [05:47,  2.88it/s]


interpolating for convdip...


1000it [05:49,  2.86it/s]


interpolating for convdip...


1000it [05:47,  2.88it/s]


interpolating for convdip...


1000it [05:44,  2.90it/s]


interpolating for convdip...


1000it [05:44,  2.91it/s]


interpolating for convdip...


1000it [05:44,  2.91it/s]


interpolating for convdip...


1000it [05:41,  2.93it/s]


interpolating for convdip...


1000it [05:39,  2.94it/s]


interpolating for convdip...


1000it [05:40,  2.93it/s]


interpolating for convdip...


1000it [05:40,  2.94it/s]


interpolating for convdip...


1000it [05:40,  2.94it/s]


interpolating for convdip...


1000it [05:41,  2.93it/s]


interpolating for convdip...


1000it [05:37,  2.96it/s]


interpolating for convdip...


1000it [05:36,  2.97it/s]


interpolating for convdip...


1000it [05:33,  3.00it/s]


interpolating for convdip...


1000it [05:37,  2.96it/s]


interpolating for convdip...


1000it [05:31,  3.02it/s]


interpolating for convdip...


1000it [05:35,  2.98it/s]


interpolating for convdip...


1000it [05:30,  3.02it/s]


interpolating for convdip...


1000it [05:30,  3.03it/s]


interpolating for convdip...


1000it [05:32,  3.00it/s]


interpolating for convdip...


1000it [05:31,  3.01it/s]


interpolating for convdip...


1000it [05:33,  3.00it/s]


interpolating for convdip...


1000it [05:30,  3.02it/s]


interpolating for convdip...


1000it [05:27,  3.06it/s]


interpolating for convdip...


1000it [05:26,  3.06it/s]


interpolating for convdip...


1000it [05:27,  3.05it/s]


interpolating for convdip...


1000it [05:24,  3.08it/s]


interpolating for convdip...


1000it [05:28,  3.04it/s]


interpolating for convdip...


1000it [05:22,  3.10it/s]


interpolating for convdip...


1000it [05:22,  3.10it/s]


interpolating for convdip...


1000it [05:23,  3.09it/s]


interpolating for convdip...


1000it [05:24,  3.08it/s]


interpolating for convdip...


1000it [05:21,  3.11it/s]


interpolating for convdip...


1000it [05:25,  3.07it/s]


interpolating for convdip...


1000it [05:21,  3.11it/s]


interpolating for convdip...


1000it [05:19,  3.13it/s]


interpolating for convdip...


1000it [05:22,  3.10it/s]


interpolating for convdip...


1000it [05:19,  3.13it/s]


interpolating for convdip...


1000it [05:18,  3.14it/s]


interpolating for convdip...


1000it [05:18,  3.14it/s]


interpolating for convdip...


1000it [05:15,  3.17it/s]


interpolating for convdip...


1000it [05:14,  3.18it/s]


interpolating for convdip...


1000it [05:14,  3.18it/s]


interpolating for convdip...


1000it [05:12,  3.20it/s]


interpolating for convdip...


1000it [05:18,  3.14it/s]


interpolating for convdip...


1000it [05:15,  3.17it/s]


interpolating for convdip...


1000it [05:15,  3.17it/s]


interpolating for convdip...


1000it [05:17,  3.15it/s]


In [27]:
params = [mles, nmses, mses, aucs_combined, aucs_close, aucs_far]
param_names = ['mles', 'nmses', 'mses', 'aucs_combined', 'aucs_close', 'aucs_far']
for param, name in zip(params, param_names):
    plt.figure(figsize=(8,6))

    for model_name in model_names:
        plt.plot(points[model_name], param[model_name], label=model_name)

    plt.xlabel('Number of available data points')
    plt.title(f'{name} ground truth vs prediction')
    plt.legend()

# Get interpolated Topomap for publication

In [60]:
from mne.viz.topomap import (_setup_interp, _make_head_outlines, _check_sphere, 
    _check_extrapolate)
from mne.channels.layout import _find_topomap_coords

model = models[0]
eeg, src = model._handle_data_input((sim,))
eeg_prep = np.swapaxes(eeg[0].get_data(), 1,2)
elec_pos = _find_topomap_coords(model.info, model.info.ch_names)
interpolator = model.make_interpolator(elec_pos, res=model.interp_channel_shape[0])
eeg_prep_interp = deepcopy(eeg_prep)
for i, sample in tqdm(enumerate(eeg_prep)):
    list_of_time_slices = []
    for time_slice in sample:
        time_slice_interp = interpolator.set_values(time_slice)()[::-1]
        list_of_time_slices.append(time_slice_interp)

sns.set(style='white')
tick_font_size = 50
plt.figure()
plt.imshow(np.mean(list_of_time_slices, axis=0)*8e6, cmap='RdBu_r')
plt.tick_params(left = False, right = False , labelleft = False ,
                labelbottom = False, bottom = False)
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=tick_font_size)

0it [00:00, ?it/s]

## Speed test

In [None]:
import time

def new_sim_params(sr=100, packages_per_second=20):
    package_size = int( round( sr / packages_per_second  ) )
    package_interval = package_size/sr

    n_chan = len(sim_test.eeg_data.ch_names)
    data_package = np.random.randn(n_chan, package_size)

    sim_data_package = Simulation(fwd, info, settings=dict(duration_of_trial=0.01*package_size)).simulate(1)
    print(f'performing predictions {packages_per_second} times per second')

    return sim_data_package, package_interval

packages_per_second = 50
sim_data_package, package_interval = new_sim_params(packages_per_second=packages_per_second)

while True:
    start = time.time()
    # stc = net_dense.predict(sim_data_package)
    stc = models[0].predict(sim_data_package)

    stop = time.time()
    diff = stop-start
    if stop-start > package_interval:
        print(f"took longer than expected: {diff} (instead of {package_interval})")
        print(f'decreasing package interval by one')
        packages_per_second -= 1
        sim_data_package, package_interval = new_sim_params(packages_per_second=packages_per_second)
        print(f'packages_per_second={packages_per_second}\n')
        continue
    time.sleep(package_interval-diff)

