In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys; sys.path.insert(0, '../')
import pickle as pkl
import numpy as np
import pandas as pd
from copy import deepcopy
import mne
import seaborn as sns
import matplotlib.pyplot as plt
from esinet import util
from esinet import Simulation
from esinet import Net
from esinet import forward

plot_params = dict(surface='white', hemi='both', verbose=0)

# Forward Model

In [2]:
info = forward.get_info()
info['sfreq'] = 100
fwd = forward.create_forward_model(info=info)
fwd_free = forward.create_forward_model(info=info, fixed_ori=False)

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    1.2s remaining:    2.0s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    1.2s remaining:    0.7s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    1.2s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.0s remaining:    0.1s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:    0.0s remaining:    0.1s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:    0.1s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(

# Load Models

In [19]:
lstm_standard = util.load_net('models/LSTM Medium_1-1000points_standard-cosine_0')
dense_standard = util.load_net('models/Dense Medium_1-1000points_standard-cosine_0')
convdip_standard = util.load_net('models/ConvDip Medium_1-1000points_standard-cosine_0')

models = [lstm_standard, dense_standard, convdip_standard]
model_names = ['LSTM', 'Fully-Connected', 'ConvDip']

# Load Evaluation Set and Metrics

In [21]:
with open(f'simulations\\sim_test_1000_1-200points_standard.pkl', 'rb') as f:
    sim_test = pkl.load(f)

with open(f'results\\metrics_101947_1-200points_standard.pkl', 'rb') as f:
    [metrics, simulation_info] = pkl.load(f)

# Prepare Metrics

In [24]:
durs_in_samples = [round(dur*100) for dur in sim_test.simulation_info.duration_of_trials.values]

idx = 0
indices = []
for dur in durs_in_samples:
    idc = [idx, dur+idx]
    indices.append(idc)
    idx += dur
metrics_short = dict()


for method in metrics.keys():
    metrics_short[method] = pd.DataFrame(columns=metrics[method].columns)
    for id, idc in enumerate(indices):
        sample_summary = metrics[method].iloc[idc[0]:idc[1]].apply(np.nanmean, axis=0)
        sample_summary.sample_id = id
        metrics_short[method] = metrics_short[method].append(sample_summary, ignore_index=True)
    metrics_short[method].method = method
    # metrics_short[method].iloc[:, 0:6].values = np.real(metrics_short[method].iloc[:, 0:6])

dfs = [df[1] for df in list(metrics_short.items())]

for i, (key, df) in enumerate(metrics_short.items()):
    dfs[i]['method'] = [key]*df.shape[0]
    dfs[i]['sample_id'] = np.arange(df.shape[0])
df_aio = pd.concat(dfs)
df_aio.head()

  results[i] = self.f(v)


Unnamed: 0,mean_localization_errors,aucs_combined,aucs_far,aucs_close,nmses,mses,method,sample_id
0,23.219618,0.868293,0.942479,0.794107,0.012646,4.666444e-19,LSTM,0
1,11.719351,0.822743,0.8756,0.769886,0.008011,1.222845e-19,LSTM,1
2,14.562478,0.843674,0.921228,0.766119,0.0135,2.5045299999999996e-19,LSTM,2
3,19.684041,0.887201,0.966092,0.808311,0.012914,1.005835e-19,LSTM,3
4,19.647893,0.828699,0.875775,0.781622,0.010828,2.464672e-19,LSTM,4


# Plot single ground truth and predictions

In [44]:
sim_test.source_data[best_idx]#.data.mean()


<SourceEstimate | 1284 vertices, subject : fsaverage, tmin : 0.0 (ms), tmax : 850.0 (ms), tstep : 10.0 (ms), data shape : (1284, 86), ~873 kB>

In [47]:
plot_params = dict(surface='white', hemi='split', size=(800*2,400*2), verbose=0, time_viewer=False,
    background='w', colorbar=False,
    views=['lat', 'med'], colormap=colormap, initial_time=0.0)

sim.source_data[0].plot(**plot_params)


<mne.viz._brain._brain.Brain at 0x169fcf95e50>

In [56]:
np.min(df_aio[df_aio.method=='LSTM'].aucs_combined.values)

0.48953488372093024

0      False
1      False
2      False
3       True
4      False
       ...  
995    False
996    False
997    False
998    False
999    False
Name: number_of_sources, Length: 1000, dtype: bool

In [72]:
df_aio


Unnamed: 0,mean_localization_errors,aucs_combined,aucs_far,aucs_close,nmses,mses,method,sample_id
0,23.219618,0.868293,0.942479,0.794107,0.012646,4.666444e-19,LSTM,0
1,11.719351,0.822743,0.875600,0.769886,0.008011,1.222845e-19,LSTM,1
2,14.562478,0.843674,0.921228,0.766119,0.013500,2.504530e-19,LSTM,2
3,19.684041,0.887201,0.966092,0.808311,0.012914,1.005835e-19,LSTM,3
4,19.647893,0.828699,0.875775,0.781622,0.010828,2.464672e-19,LSTM,4
...,...,...,...,...,...,...,...,...
995,25.356679,0.464930,0.477803,0.452057,0.026314,4.479224e-19,LCMV,995
996,24.598093,0.557767,0.531220,0.584315,0.050748,7.401235e-21,LCMV,996
997,24.471249,0.602101,0.681807,0.522396,0.041049,1.591045e-18,LCMV,997
998,18.847806,0.617973,0.662383,0.573563,0.038285,2.770887e-18,LCMV,998


In [91]:
import seaborn as sns
%matplotlib qt
sns.reset_orig()

%load_ext autoreload
%autoreload 2
model_names_tmp = deepcopy(model_names)
colormap = 'RdBu_r'
plot_params = dict(surface='white', hemi='split', size=(800*2,400*2), verbose=0, time_viewer=False, 
    background='w', colorbar=False, views=['lat', 'med'], 
    colormap=colormap, initial_time=0.0, transparent=True
)
fractions = [0., 0.2, 0.99]

settings_eval = dict(
    method='standard', 
    number_of_sources=3,
    duration_of_trial=2.0)

# Simulate new data
sim = Simulation(fwd, info, settings=settings_eval).simulate(2)
# best_idx = np.argmin(df_aio[df_aio.method=='LSTM'].nmses.values)
best_idx = np.argsort(df_aio[df_aio.method=='LSTM'].nmses.values)[500]
# best_idx = df_aio[(df_aio.method=='LSTM')][simulation_info.number_of_sources==2].sort_values('nmses', ascending=True).sample_id.values[0]

# sim.source_data[0] = sim_test.source_data[best_idx]
# sim.eeg_data[0] = sim_test.eeg_data[best_idx]

idx = 0
# snr = sim.simulation_info['target_snr'].values[0]
snr = None
# print(sim.simulation_info)
# Predict sources using the esinet models
predictions = [model.predict(sim) for model in models]

# Predict sources with classical methods
# eLORETA
prediction_elor_data = util.wrap_mne_inverse(fwd, sim, method='eLORETA', 
    add_baseline=True, n_baseline=400)[idx].data.astype(np.float32)
prediction_elor = deepcopy(predictions[0][0])
prediction_elor.data = prediction_elor_data / np.abs(np.max(prediction_elor_data))
# MNE
prediction_mne_data = util.wrap_mne_inverse(fwd, sim, method='MNE', 
    add_baseline=True, n_baseline=400)[idx].data.astype(np.float32)
prediction_mne = deepcopy(predictions[0][0])
prediction_mne.data = prediction_mne_data / np.abs(np.max(prediction_mne_data))
# Beamformer
prediction_lcmv_data = util.wrap_mne_inverse(fwd, sim, method='lcmv', 
    add_baseline=True, n_baseline=400, parallel=False)[idx].data.astype(np.float32)
prediction_lcmv = deepcopy(predictions[0][0])
prediction_lcmv.data = prediction_lcmv_data / np.max(np.abs(prediction_lcmv_data))

# Get predictions and names in order
predictions.append([prediction_elor])
predictions.append([prediction_mne])
predictions.append([prediction_lcmv])

model_names_tmp.append('eLORETA')
model_names_tmp.append('MNE')
model_names_tmp.append('LCMV')

# Plot True Source
pos_lims = [np.max(np.abs(sim.source_data[idx].data[:, 0]))*frac for frac in fractions]
brain = sim.source_data[idx].plot(**plot_params, clim=dict(kind='value', pos_lims=pos_lims))
brain.add_text(0.1, 0.9, f'Ground Truth', 'title')
# brain = sim.source_data[idx].plot()

model_selection = model_names_tmp
# Plot predicted sources
for model_name, prediction in zip(model_names_tmp, predictions):
    
    # if not any([model_name.lower() in model_select.lower() for model_select in model_selection]):
    #     continue
    # error = util.batch_nmse(sim.source_data[idx].data, prediction[idx].data)
    # r = util.batch_corr(sim.source_data[idx].data, prediction[idx].data)
    pos_lims = [np.max(np.abs(prediction[idx].data[:, 0]))*frac for frac in fractions]
    brain = prediction[idx].plot(**plot_params, clim=dict(kind='value', pos_lims=pos_lims))
    title = f'{model_name}'
    print(title)
    brain.add_text(0.1, 0.9, title, 'title')
    # brain = prediction[idx].plot()


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Simulating data based on sparse patches.


100%|██████████| 2/2 [00:00<00:00, 190.85it/s]
100%|██████████| 2/2 [00:00<00:00, 114.35it/s]


source data shape:  (1284, 200) (1284, 200)


100%|██████████| 2/2 [00:00<00:00, 43.47it/s]


interpolating for convdip...


2it [00:00,  6.62it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

  epochs.set_eeg_reference(projection=True, verbose=verbose)#.apply_baseline(baseline=baseline)
  epochs.set_eeg_reference(projection=True, verbose=verbose)#.apply_baseline(baseline=baseline)


  0%|          | 0/2 [00:00<?, ?it/s]

  epochs.set_eeg_reference(projection=True, verbose=verbose)#.apply_baseline(baseline=baseline)
  epochs.set_eeg_reference(projection=True, verbose=verbose)#.apply_baseline(baseline=baseline)


  0%|          | 0/2 [00:00<?, ?it/s]

  epochs.set_eeg_reference(projection=True, verbose=verbose)#.apply_baseline(baseline=baseline)
  data_cov = mne.compute_raw_covariance(raw, tmin=tmin,
  epochs.set_eeg_reference(projection=True, verbose=verbose)#.apply_baseline(baseline=baseline)
  data_cov = mne.compute_raw_covariance(raw, tmin=tmin,


LSTM
Fully-Connected
ConvDip
eLORETA
MNE
LCMV


# Plot Evaluation Metrics

## Boxplot Overview 
### All Sources

In [None]:
import seaborn as sns; sns.set(style='whitegrid', font_scale=1.6, font='helvetica')
%matplotlib qt


df_aio["aucs_combined"] *= 100
cols = ['mean_localization_errors', 'aucs_combined', 'nmses']
metric_names = ["Mean Localization Error [mm]", "Area Under the Curve [%]", "Normalized Mean Squared Error"]
ylims = [[0, None], [0, None], [0, 0.15]]
for y, metric_name, ylim in zip(cols, metric_names, ylims):
    plt.figure(figsize=(13, 7))
    sns.boxplot(data=df_aio, x='method', y=y)
    plt.title(metric_name)
    plt.xlabel("Inverse Solution")
    plt.ylabel(metric_name)
    plt.ylim(ylim)
util.multipage(r'C:\Users\lukas\Sync\lstm_inverse_problem\figures\results\boxplot_overview_allsources.pdf', png=True)

### Single Sources

In [None]:
import seaborn as sns; sns.set(style='whitegrid', font_scale=1.6, font='helvetica')
%matplotlib qt

idc_single_source = np.argwhere(simulation_info.number_of_sources.values==1)[:, 0]
df_single = df_aio[df_aio.sample_id.isin(idc_single_source)]
# df_single["aucs_combined"] *= 100
cols = ['mean_localization_errors', 'aucs_combined', 'nmses']  # df_single.iloc[:, 0:6].columns
metric_names = ["Mean Localization Error [mm]", "Area Under the Curve [%]", "Normalized Mean Squared Error"]
ylims = [[0, None], [0, None], [0, 0.15]]
for y, metric_name, ylim in zip(cols, metric_names, ylims):
    plt.figure(figsize=(13, 7))
    sns.boxplot(data=df_single, x='method', y=y)
    # sns.swarmplot(data=df_single, x='method', y=y)
    plt.title(metric_name)
    plt.xlabel("Inverse Solution")
    plt.ylabel(metric_name)
    plt.ylim(ylim)
util.multipage(r'C:\Users\lukas\Sync\lstm_inverse_problem\figures\results\boxplot_overview_singlesource.pdf', png=True)

# Tables

## All Source

In [None]:
df_of_interest = df_aio

methods = set(df_of_interest.method.values)
columns = ['method', 'mean_localization_errors', 'aucs_combined', 'nmses']  # df_aio.iloc[:, 0:6].columns
column_names = ["Inverse Solution", "MLE [mm] (SD)", "AUC [%] (SD)", "nMSE (SD)"]
decimals = [None, 2, 2, 4]
scalers = [None, 1, 100, 1]
table = pd.DataFrame(columns=column_names)
for i, method in enumerate(methods):
    row_dict = {column_names[0]:method}
    for j, (column, column_name, decimal, scaler) in enumerate(zip(columns[1:], column_names[1:], decimals[1:], scalers[1:])):
        values = df_of_interest[df_of_interest.method==method][column].values
        row_dict[column_name] = str(round(np.nanmedian(values)*scaler, decimal)) + ' (' + str(round(np.nanstd(values)*scaler, decimal)) + ')'
    table = table.append(row_dict, ignore_index=True)
table = table.sort_values("AUC [%] (SD)", ascending=False)
table

## Single Source

In [None]:
df_single = df_aio[df_aio.sample_id.isin(idc_single_source)]

df_of_interest = df_single

methods = set(df_of_interest.method.values)
columns = ['method', 'mean_localization_errors', 'aucs_combined', 'nmses']  # df_aio.iloc[:, 0:6].columns
column_names = ["Inverse Solution", "MLE [mm] (SD)", "AUC [%] (SD)", "nMSE (SD)"]
decimals = [None, 2, 2, 4]
scalers = [None, 1, 100, 1]

table = pd.DataFrame(columns=column_names)
for i, method in enumerate(methods):
    row_dict = {column_names[0]:method}
    for j, (column, column_name, decimal, scaler) in enumerate(zip(columns[1:], column_names[1:], decimals[1:], scalers[1:])):
        values = df_of_interest[df_of_interest.method==method][column].values
        row_dict[column_name] = str(round(np.nanmedian(values)*scaler, decimal)) + ' (' + str(round(np.nanstd(values)*scaler, decimal)) + ')'
    table = table.append(row_dict, ignore_index=True)
table

# Quadratic Scatter

In [None]:
import seaborn as sns; sns.set(style='whitegrid', font_scale=1.2, font='helvetica')
%matplotlib qt
methods_of_interest = ['LSTM Standard', 'Dense Standard']
# df_select = df_aio[df_aio['method'].str.contains('|'.join(methods_of_interest))]

cols = df_aio.iloc[:, 0:6].columns
for method_name in cols:

    vals_A = df_aio[df_aio['method'].str.contains(methods_of_interest[0])][method_name].values
    vals_B = df_aio[df_aio['method'].str.contains(methods_of_interest[1])][method_name].values
    d = {methods_of_interest[0]: vals_A, methods_of_interest[1]: vals_B,} 
    df_tmp = pd.DataFrame(d)
    plt.figure(figsize=(10, 10))
    ax = sns.scatterplot(data=df_tmp, x=methods_of_interest[0], y=methods_of_interest[1])


    xlim, ylim = (plt.xlim(), plt.ylim())
    plt.plot(xlim, ylim, '--k')
    xlim = (xlim[0]*0.95, xlim[1]*1.05)
    ylim = (ylim[0]*0.95, ylim[1]*1.05)
    plt.ylim(ylim)
    plt.xlim(xlim)
    
    # Title
    prop_higher = np.sum(vals_B > vals_A) / len(vals_A)
    cohens_d = (np.nanmean(vals_A) - np.nanmean(vals_B)) / np.mean([np.nanstd(vals_A), np.nanstd(vals_B)])
    median_diff = np.abs(np.nanmedian(vals_A-vals_B))
    method_name_title = method_name.replace('_', ' ').title()
    title = f'{method_name_title} ({methods_of_interest[1]} higher in {100*prop_higher:.1f} %)\nmedian_difference: {median_diff}\ncohens d: {abs(cohens_d):.2f}'
    plt.title(title)

    plt.gca().set_aspect('equal', adjustable='box')
    del d, df_tmp

# Dependence on Noise

In [None]:
import pandas as pd
%matplotlib qt
sns.set(font_scale=1.4, font='helvetica', style='whitegrid')
# pd.DataFrame( metrics , index=model_names)
target_column = 'target_snr'
binning = True
n_bins = 6


df = sim_test.simulation_info
if binning:
    minimum = np.min([np.min(arr) for arr in df[target_column].values])
    maximum = np.max([np.max(arr) for arr in df[target_column].values])
    
    bins = np.linspace(minimum, maximum*1.01, num=n_bins)
    bin_labels = [str(int(round(bins[i], 0))) + ' - ' + str(int(round(bins[i+1], 0))) for i in range(len(bins)-1)]
    new_target_column = 'bins ' + target_column
    df[new_target_column] = np.digitize(df[target_column].values, bins=bins)
    target_column = new_target_column
    
else:
    bins = list(set(df[target_column].values))
    bins[-1] *= 1.01
    bin_labels = [str(bins[i]) for i in range(len(bins))]


for i, model_name in enumerate(list(set(df_aio.method.values))):
    cols = df_aio[df_aio.method==model_name].iloc[:, 0:6].columns
    values = df_aio[df_aio.method==model_name].iloc[:, 0:6].values

    for metric_name, metric in zip(cols, values.T):
        col_name = model_name.replace(' ', '_') + '_' + metric_name.replace(' ', '_')
        df[col_name] = metric
dep_var_regex = target_column
dep_var_label = target_column.replace('_', ' ').title()
metric_names_nice = ["Mean Localization Error [mm]", "Area Under the Curve [%]", "Normalized Mean Squared Error"]
metric_names = ['mean_localization_errors', 'aucs_combined', 'nmses']
scalers = [1, 100, 1]
for metric_name, metric_name_nice, scaler in zip(metric_names,  metric_names_nice, scalers):
    df_temp = pd.concat((df.filter(regex=dep_var_regex), df.filter(regex='_'+metric_name)), axis=1).melt(dep_var_regex, var_name='cols', value_name='vals')
    df_temp.cols = [val.split('_')[0] for val in df_temp.cols.values]
    # df_temp.vals*=scaler
    g = sns.catplot(x=dep_var_regex, y='vals', hue='cols', capsize=.2, kind='point', data=df_temp, height=6, aspect=1.5)
    g.set(xticklabels=bin_labels, ylabel=metric_name_nice, xlabel=dep_var_label)
    g._legend.remove()
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.tight_layout()
    plt.title(metric_name_nice)

    # break
g = sns.catplot(x=dep_var_regex, y='vals', hue='cols', capsize=.2, kind='point', data=df_temp, height=6, aspect=1.5)
g.set(xticklabels=bin_labels, ylabel=metric_name_nice, xlabel=dep_var_label)
g._legend.remove()
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.ylim(plt.ylim()[0], plt.ylim()[1]/8)
plt.tight_layout()
plt.title(metric_name_nice)
util.multipage(f'C:/Users/lukas/Sync/lstm_inverse_problem/figures/results/dependence_{target_column}.pdf', png=True)

# Dependence on Eccentricity

In [None]:
import pandas as pd
%matplotlib qt
sns.set(font_scale=1.4, font='helvetica', style='whitegrid')
# pd.DataFrame( metrics , index=model_names)
target_column = 'eccentricity'
binning = True
n_bins = 4

eccentricity = np.zeros(df.shape[0])
for i, positions in enumerate(df.positions.values):
    eccentricity[i] = np.mean(np.sqrt((positions**2).sum(axis=1)))
df["eccentricity"] = eccentricity


df = sim_test.simulation_info
if binning:
    minimum = np.min([np.min(arr) for arr in df[target_column].values])
    maximum = np.max([np.max(arr) for arr in df[target_column].values])
    
    bins = np.linspace(minimum, maximum*1.01, num=n_bins)
    bin_labels = [str(int(round(bins[i], 0))) + ' - ' + str(int(round(bins[i+1], 0))) for i in range(len(bins)-1)]
    new_target_column = 'bins ' + target_column
    df[new_target_column] = np.digitize(df[target_column].values, bins=bins)
    target_column = new_target_column
    
else:
    bins = list(set(df[target_column].values))
    bins[-1] *= 1.01
    bin_labels = [str(bins[i]) for i in range(len(bins))]


for i, model_name in enumerate(list(set(df_aio.method.values))):
    cols = df_aio[df_aio.method==model_name].iloc[:, 0:6].columns
    values = df_aio[df_aio.method==model_name].iloc[:, 0:6].values

    for metric_name, metric in zip(cols, values.T):
        col_name = model_name.replace(' ', '_') + '_' + metric_name.replace(' ', '_')
        df[col_name] = metric
dep_var_regex = target_column
dep_var_label = target_column.replace('_', ' ').title()
metric_names_nice = ["Mean Localization Error [mm]", "Area Under the Curve [%]", "Normalized Mean Squared Error"]
metric_names = ['mean_localization_errors', 'aucs_combined', 'nmses']
scalers = [1, 100, 1]
for metric_name, metric_name_nice, scaler in zip(metric_names,  metric_names_nice, scalers):
    df_temp = pd.concat((df.filter(regex=dep_var_regex), df.filter(regex='_'+metric_name)), axis=1).melt(dep_var_regex, var_name='cols', value_name='vals')
    df_temp.cols = [val.split('_')[0] for val in df_temp.cols.values]
    # df_temp.vals*=scaler
    g = sns.catplot(x=dep_var_regex, y='vals', hue='cols', capsize=.2, kind='point', data=df_temp, height=6, aspect=1.5)
    g.set(xticklabels=bin_labels, ylabel=metric_name_nice, xlabel=dep_var_label)
    g._legend.remove()
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.tight_layout()
    plt.title(metric_name_nice)

    # break
g = sns.catplot(x=dep_var_regex, y='vals', hue='cols', capsize=.2, kind='point', data=df_temp, height=6, aspect=1.5)
g.set(xticklabels=bin_labels, ylabel=metric_name_nice, xlabel=dep_var_label)
g._legend.remove()
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.ylim(plt.ylim()[0], plt.ylim()[1]/8)
plt.tight_layout()
plt.title(metric_name_nice)
# util.multipage(f'C:/Users/lukas/Sync/lstm_inverse_problem/figures/results/dependence_{target_column}.pdf', png=True)

# Dependence on Duration

## Load Dependence on Duration

In [None]:
with open(f'results\\metrics_duration.pkl', 'rb') as f:
    [params, param_names] = pkl.load(f)

mles, nmses, mses, aucs_combined, aucs_close, aucs_far = params
params = [mles, nmses, aucs_combined]
param_names = ['mles', 'nmses', 'aucs_combined']
param_names_nice = ["Mean Localization Error [mm]", "Normalized Mean Squared Error", "Area Under the Curve [%]"]
params[-1]

## Plot Dependence On Duration

In [None]:
%matplotlib qt
sns.set(style='white', font_scale=1.6, font='helvetica')

ylims = [[None, None], [None, None], [None, None]]
scalers = [1, 1, 100]
linewidth = 2.5
x = np.arange(0, 201)[::5][::-1]
for param, name, name_nice, ylim, scaler in zip(params, param_names, param_names_nice, ylims, scalers):
    plt.figure(figsize=(12,6))

    for model_name in params[0].keys():
        if len(param[model_name])==1:
            plt.plot(x, np.array(param[model_name]*len(x))*scaler, label=model_name, linewidth=linewidth)
        else:
            plt.plot(x, np.array(param[model_name])*scaler, label=model_name, linewidth=linewidth)


    plt.xlabel('Number of available data points')
    plt.ylabel(name_nice)
    plt.ylim(ylim)
    plt.title(name_nice)
    plt.legend(loc=2, bbox_to_anchor=(1.05,1), borderaxespad=0)
    plt.tight_layout()

util.multipage(r'C:\Users\lukas\Sync\lstm_inverse_problem\figures\results\dependence_duration.pdf', png=True)

# Get interpolated Topomap for publication

In [None]:
from mne.viz.topomap import (_setup_interp, _make_head_outlines, _check_sphere, 
    _check_extrapolate)
from mne.channels.layout import _find_topomap_coords

model = models[0]
eeg, src = model._handle_data_input((sim,))
eeg_prep = np.swapaxes(eeg[0].get_data(), 1,2)
elec_pos = _find_topomap_coords(model.info, model.info.ch_names)
interpolator = model.make_interpolator(elec_pos, res=model.interp_channel_shape[0])
eeg_prep_interp = deepcopy(eeg_prep)
for i, sample in tqdm(enumerate(eeg_prep)):
    list_of_time_slices = []
    for time_slice in sample:
        time_slice_interp = interpolator.set_values(time_slice)()[::-1]
        list_of_time_slices.append(time_slice_interp)

sns.set(style='white')
tick_font_size = 50
plt.figure()
plt.imshow(np.mean(list_of_time_slices, axis=0)*8e6, cmap='RdBu_r')
plt.tick_params(left = False, right = False , labelleft = False ,
                labelbottom = False, bottom = False)
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=tick_font_size)

## Speed test

In [None]:
import time

def new_sim_params(sr=100, packages_per_second=20):
    package_size = int( round( sr / packages_per_second  ) )
    package_interval = package_size/sr

    n_chan = len(sim_test.eeg_data.ch_names)
    data_package = np.random.randn(n_chan, package_size)

    sim_data_package = Simulation(fwd, info, settings=dict(duration_of_trial=0.01*package_size)).simulate(1)
    print(f'performing predictions {packages_per_second} times per second')

    return sim_data_package, package_interval

packages_per_second = 50
sim_data_package, package_interval = new_sim_params(packages_per_second=packages_per_second)

while True:
    start = time.time()
    # stc = net_dense.predict(sim_data_package)
    stc = models[0].predict(sim_data_package)

    stop = time.time()
    diff = stop-start
    if stop-start > package_interval:
        print(f"took longer than expected: {diff} (instead of {package_interval})")
        print(f'decreasing package interval by one')
        packages_per_second -= 1
        sim_data_package, package_interval = new_sim_params(packages_per_second=packages_per_second)
        print(f'packages_per_second={packages_per_second}\n')
        continue
    time.sleep(package_interval-diff)

