In [1]:
%matplotlib qt
%load_ext autoreload
%autoreload 2

import sys; sys.path.insert(0, '../invert')
import mne
import pickle as pkl
from time import time
from scipy.spatial.distance import cdist
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


from invert.forward import get_info, create_forward_model
from invert.solvers.esinet import generator
from invert.util import pos_from_forward

pp = dict(surface='white', hemi='both', verbose=0)

# Load Source Estimates

In [15]:
sim_type = "extended"
fn = f"evaluation/sim_and_preds_{sim_type}.pkl"
with open(fn, 'rb') as f:
    stc_dict, x_test, y_test, sim_info, proc_time_make, proc_time_apply = pkl.load(f)

fwd = mne.read_forward_solution("forward_model/64ch_ico3-fwd.fif", verbose=0)
fwd = mne.convert_forward_solution(fwd, force_fixed=True)
pos = pos_from_forward(fwd)
source_model = fwd['src']
vertices = [source_model[0]['vertno'], source_model[1]['vertno']]
argsorted_distance_matrix = np.argsort(cdist(pos, pos), axis=1)

    No patch info available. The standard source space normals will be employed in the rotation to the local surface coordinates....
    Changing to fixed-orientation forward solution with surface-based source orientations...
    [done]


# Calculate Results

In [16]:
from invert.evaluate import evaluate_all

n_samples = x_test.shape[0]
results = []
for solver_name in stc_dict.keys():
    for i in range(n_samples):
        y_pred = stc_dict[solver_name][i].data[np.newaxis]
        y_true = y_test[i].T[np.newaxis]
        result = evaluate_all(y_true, y_pred, pos, argsorted_distance_matrix)
        result["Method"] = solver_name
        result["Time"] = proc_time[solver_name][i]
        results.append(result)

fn = f"results/results_{sim_type}.pkl"
with open(fn, 'wb') as f:
    pkl.dump(results, f)

  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, o

# Load Results

In [14]:
sim_type = "single"
fn = f"results/results_{sim_type}.pkl"
with open(fn, 'rb') as f:
    results = pkl.load(f)

{'FLEX-MUSIC', 'MCE', 'MM Champagne', 'TRAP-MUSIC', 'eLORETA'}

In [28]:
sns.set(style="whitegrid", font_scale=1.)

tick_params = dict(
    axis='y',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    left=True,      # ticks along the bottom edge are off
    right=False,         # ticks along the top edge are off
    labelbottom=False
)

medianprops = {
    "linewidth": 2,
    "linestyle": "dashed"
    }
order = ['FLEX-MUSIC', 'TRAP-MUSIC', 'MM Champagne', 'MCE', 'eLORETA']
df = pd.DataFrame(results)
# df["Sparsity"] = [val[0] for val in df["Sparsity"].values]

plt.figure()
sns.boxplot(data=df, x="Method", y="AUC", order=order, medianprops=medianprops)
plt.axhline(y=0.5, color='grey', linestyle='-.')
plt.title("Area under ROC curve")
plt.ylim(-0.05, 1.05)
plt.tick_params(**tick_params)

plt.figure()
sns.boxplot(data=df, x="Method", y="Mean_Squared_Error", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)

plt.figure()
g = sns.boxplot(data=df, x="Method", y="Normalized_Mean_Squared_Error", order=order, medianprops=medianprops)
g.set_yscale("log")
plt.tick_params(**tick_params)

plt.figure()
sns.boxplot(data=df, x="Method", y="Mean_Localization_Error", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)
plt.ylim(-3, 43)

plt.figure()
sns.boxplot(data=df, x="Method", y="Sparsity", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)


plt.figure()
g = sns.boxplot(data=df, x="Method", y="Time", order=order, medianprops=medianprops)
g.set_yscale("log")
plt.tick_params(**tick_params)


# Single and extended sources in groups

In [32]:
sns.set(style="whitegrid", font_scale=1.)

tick_params = dict(
    axis='y',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    left=True,      # ticks along the bottom edge are off
    right=False,         # ticks along the top edge are off
    labelbottom=False
)

medianprops = {
    "linewidth": 2,
    "linestyle": "dashed"
    }
order = ['FLEX-MUSIC', 'TRAP-MUSIC', 'MM Champagne', 'MCE', 'eLORETA']


sim_type = "single"
fn = f"results/results_{sim_type}.pkl"
with open(fn, 'rb') as f:
    results_single = pkl.load(f)

sim_type = "extended"
fn = f"results/results_{sim_type}.pkl"
with open(fn, 'rb') as f:
    results_extended = pkl.load(f)

df_single = pd.DataFrame(results_single)
df_single["Source Extend"] = "Single"
df_extended = pd.DataFrame(results_extended)
df_extended["Source Extend"] = "Extended"

df = pd.concat([df_single, df_extended])
# df["Sparsity"] = [val[0] for val in df["Sparsity"].values]

plt.figure(figsize=(10,7))
sns.boxplot(data=df, x="Method", y="AUC", hue="Source Extend", order=order, medianprops=medianprops)
plt.axhline(y=0.5, color='grey', linestyle='-.')
plt.title("Area under ROC curve")
plt.ylim(-0.05, 1.05)
plt.tick_params(**tick_params)

plt.figure()
sns.boxplot(data=df, x="Method", y="Mean_Squared_Error", hue="Source Extend", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)

plt.figure()
g = sns.boxplot(data=df, x="Method", y="Normalized_Mean_Squared_Error", hue="Source Extend", order=order, medianprops=medianprops)
g.set_yscale("log")
plt.tick_params(**tick_params)

plt.figure()
sns.boxplot(data=df, x="Method", y="Mean_Localization_Error", hue="Source Extend", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)
plt.ylim(-3, 43)

plt.figure()
sns.boxplot(data=df, x="Method", y="Sparsity", hue="Source Extend", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)


plt.figure()
g = sns.boxplot(data=df, x="Method", y="Time", hue="Source Extend", order=order, medianprops=medianprops)
g.set_yscale("log")
plt.tick_params(**tick_params)


# Plot an example

In [9]:
pp = dict(surface='white', hemi='both', verbose=0)
sample = 0

tmin = 0
tstep = 1/1000
subject = "fsaverage"
stc = mne.SourceEstimate(y_test[0].T, vertices, tmin=tmin, tstep=tstep, 
                        subject=subject, verbose=0)
stc.plot(**pp, brain_kwargs=dict(title="True"))

# for solver, stc_list in stc_dict.items():
#     stc_list[sample].plot(**pp, brain_kwargs=dict(title=solver))
# display(sim_info.iloc[sample])

<mne.viz._brain._brain.Brain at 0x2ab63917cd0>