In [1]:
%matplotlib qt
%load_ext autoreload
%autoreload 2

import sys; sys.path.insert(0, '../invert')
import mne
import pickle as pkl
from time import time
from scipy.spatial.distance import cdist
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


from invert.forward import get_info, create_forward_model
from invert.solvers.esinet import generator
from invert.util import pos_from_forward

pp = dict(surface='white', hemi='both', verbose=0)

# Load Source Estimates

In [2]:
sim_type = "extended"
fn = f"evaluation/sim_and_preds_{sim_type}.pkl"
with open(fn, 'rb') as f:
    stc_dict, x_test, y_test, sim_info, proc_time_make, proc_time_apply = pkl.load(f)

fwd = mne.read_forward_solution("forward_model/64ch_ico3-fwd.fif", verbose=0)
fwd = mne.convert_forward_solution(fwd, force_fixed=True)
pos = pos_from_forward(fwd)
source_model = fwd['src']
vertices = [source_model[0]['vertno'], source_model[1]['vertno']]
distances = cdist(pos, pos)
argsorted_distance_matrix = np.argsort(distances, axis=1)

    No patch info available. The standard source space normals will be employed in the rotation to the local surface coordinates....
    Changing to fixed-orientation forward solution with surface-based source orientations...
    [done]


# Calculate Results

In [6]:
from invert.evaluate import evaluate_all

n_samples = x_test.shape[0]
results = []
for solver_name in stc_dict.keys():
    print(solver_name)
    for i in range(n_samples):
        y_pred = stc_dict[solver_name][i].data
        y_true = y_test[i].T
        result = evaluate_all(y_true, y_pred, pos, argsorted_distance_matrix, distances)
        result["Method"] = solver_name
        result["Time Make"] = proc_time_make[solver_name][i]
        result["Time Apply"] = proc_time_apply[solver_name][i]
        
        results.append(result)

fn = f"results/results_{sim_type}.pkl"
with open(fn, 'wb') as f:
    pkl.dump(results, f)

FLEX-MUSIC


  y_est_normed = y_est / np.max(np.abs(y_est))
  distribution2 = distribution2 / np.sum(distribution2)
  y_scaled = y / np.linalg.norm(y, axis=0)
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  y_est_normed = y_est / np.max(np.abs(y_est))
  distribution2 = distribution2 / np.sum(distribution2)
  y_scaled = y / np.linalg.norm(y, axis=0)
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  y_est_normed = y_est / np.max(np.abs(y_est))
  distribution2 = distribution2 / np.sum(distribution2)
  y_scaled = y / np.linalg.norm(y, axis=0)
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  y_est_normed = y_est / np.max(np.abs(y_est))
  distribution2 = distribution2 / np.sum(distribution2)
  y_scaled = y / np.linalg.norm(y, axis=0)
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  y_est_normed = y_est / np.max(np.abs(y_est))
  distribution2 = distribution2 / np.sum(distribution2)
  y_scaled = y / np.linal

TRAP-MUSIC


  y_est_normed = y_est / np.max(np.abs(y_est))
  distribution2 = distribution2 / np.sum(distribution2)
  y_scaled = y / np.linalg.norm(y, axis=0)
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  y_est_normed = y_est / np.max(np.abs(y_est))
  distribution2 = distribution2 / np.sum(distribution2)
  y_scaled = y / np.linalg.norm(y, axis=0)
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  y_est_normed = y_est / np.max(np.abs(y_est))
  distribution2 = distribution2 / np.sum(distribution2)
  y_scaled = y / np.linalg.norm(y, axis=0)
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  y_est_normed = y_est / np.max(np.abs(y_est))
  distribution2 = distribution2 / np.sum(distribution2)
  y_scaled = y / np.linalg.norm(y, axis=0)
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  y_est_normed = y_est / np.max(np.abs(y_est))
  distribution2 = distribution2 / np.sum(distribution2)
  y_scaled = y / np.linal

eLORETA
Convexity Champagne


  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,


MCMV


# Load Results

In [14]:
sim_type = "single"
fn = f"results/results_{sim_type}.pkl"
with open(fn, 'rb') as f:
    results = pkl.load(f)

# Plot

In [6]:
sns.set(style="whitegrid", font_scale=1.)

tick_params = dict(
    axis='y',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    left=True,      # ticks along the bottom edge are off
    right=False,         # ticks along the top edge are off
    labelbottom=False
)

medianprops = {
    "linewidth": 2,
    "linestyle": "dashed"
    }
order = ['FLEX-MUSIC', 'TRAP-MUSIC', 'Convexity Champagne', 'eLORETA', "MCMV"]
df = pd.DataFrame(results)
# df["Sparsity"] = [val[0] for val in df["Sparsity"].values]

plt.figure()
sns.boxplot(data=df, x="Method", y="AUC", order=order, medianprops=medianprops)
plt.axhline(y=0.5, color='grey', linestyle='-.')
plt.title("Area under ROC curve")
plt.ylim(-0.05, 1.05)
plt.tick_params(**tick_params)

plt.figure()
sns.boxplot(data=df, x="Method", y="Mean_Squared_Error", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)

plt.figure()
g = sns.boxplot(data=df, x="Method", y="Normalized_Mean_Squared_Error", order=order, medianprops=medianprops)
g.set_yscale("log")
plt.tick_params(**tick_params)

plt.figure()
sns.boxplot(data=df, x="Method", y="Mean_Localization_Error", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)
plt.ylim(-3, 43)

plt.figure()
sns.boxplot(data=df, x="Method", y="Sparsity", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)


plt.figure()
g = sns.boxplot(data=df, x="Method", y="Time", order=order, medianprops=medianprops)
g.set_yscale("log")
plt.tick_params(**tick_params)


# Single and extended sources in groups

In [3]:
sns.set(style="whitegrid", font_scale=1.)

tick_params = dict(
    axis='y',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    left=True,      # ticks along the bottom edge are off
    right=False,         # ticks along the top edge are off
    labelbottom=False
)

medianprops = {
    "linewidth": 2,
    "linestyle": "dashed"
    }
order = ['FLEX-MUSIC', 'TRAP-MUSIC', 'Convexity Champagne', 'MCMV', 'eLORETA']


sim_type = "single"
fn = f"results/results_{sim_type}.pkl"
with open(fn, 'rb') as f:
    results_single = pkl.load(f)

sim_type = "extended"
fn = f"results/results_{sim_type}.pkl"
with open(fn, 'rb') as f:
    results_extended = pkl.load(f)

df_single = pd.DataFrame(results_single)
df_single["Source Extend"] = "Single"
df_extended = pd.DataFrame(results_extended)
df_extended["Source Extend"] = "Extended"

df = pd.concat([df_single, df_extended])
# df["Sparsity"] = [val[0] for val in df["Sparsity"].values]

plt.figure(figsize=(10,7))
sns.boxplot(data=df, x="Method", y="AUC", hue="Source Extend", order=order, medianprops=medianprops)
plt.axhline(y=0.5, color='grey', linestyle='-.')
plt.title("Area under ROC curve")
plt.ylim(-0.05, 1.05)
plt.tick_params(**tick_params)

plt.figure(figsize=(10,7))
g = sns.boxplot(data=df, x="Method", y="Corr", hue="Source Extend", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)
plt.ylim(-0.05, 1.05)

plt.figure(figsize=(10,7))
g = sns.boxplot(data=df, x="Method", y="EMD", hue="Source Extend", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)

plt.figure(figsize=(10,7))
sns.boxplot(data=df, x="Method", y="Mean_Squared_Error", hue="Source Extend", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)

plt.figure(figsize=(10,7))
g = sns.boxplot(data=df, x="Method", y="Normalized_Mean_Squared_Error", hue="Source Extend", order=order, medianprops=medianprops)
g.set_yscale("log")
plt.tick_params(**tick_params)

plt.figure(figsize=(10,7))
sns.boxplot(data=df, x="Method", y="Mean_Localization_Error", hue="Source Extend", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)
plt.ylim(-3, 43)

plt.figure(figsize=(10,7))
sns.boxplot(data=df, x="Method", y="Sparsity_pred", hue="Source Extend", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)

plt.figure(figsize=(10,7))
g = sns.boxplot(data=df, x="Method", y="Time", hue="Source Extend", order=order, medianprops=medianprops)
g.set_yscale("log")
plt.tick_params(**tick_params)


In [10]:
from scipy.stats import pearsonr
method = "FLEX-MUSIC"
df_temp = df[df.Method==method]

a = df_temp.Sparsity_true.values
b = df_temp.Sparsity_pred.values
a[np.isnan(a)] = np.nanmedian(a)
b[np.isnan(b)] = np.nanmedian(a)
r, p = pearsonr(a, b)
print(f"{method} sparsity: r = {r:.2f}")

plt.figure()
sns.scatterplot(data=df_temp, y="Sparsity_pred", x="Sparsity_true")
maxval = np.max( [plt.ylim()[1], plt.xlim()[1]])
plt.ylim(0,  maxval)
plt.xlim(0, maxval)
plt.plot([0, maxval], [0, maxval], 'k')


method = "Convexity Champagne"
df_temp = df[df.Method==method]

a = df_temp.Sparsity_true.values
b = df_temp.Sparsity_pred.values
a[np.isnan(a)] = np.nanmedian(a)
b[np.isnan(b)] = np.nanmedian(a)
r, p = pearsonr(a, b)
print(f"{method} sparsity: r = {r:.2f}")

plt.figure()
sns.scatterplot(data=df_temp, y="Sparsity_pred", x="Sparsity_true")
maxval = np.max( [plt.ylim()[1], plt.xlim()[1]])
plt.ylim(0,  maxval)
plt.xlim(0, maxval)
plt.plot([0, maxval], [0, maxval], 'k')

method = "FLEX-MUSIC"
df_temp = df[df.Method==method]

r, p = pearsonr(df_temp.Active_True.values, df_temp.Active_Pred.values)
print(f"{method} Active Dipoles: r = {r:.2f}")

plt.figure()
sns.scatterplot(data=df_temp, y="Active_Pred", x="Active_True")
maxval = np.max( [plt.ylim()[1], plt.xlim()[1]])
plt.ylim(0,  maxval)
plt.xlim(0, maxval)
plt.plot([0, maxval], [0, maxval], 'k')


method = "Convexity Champagne"
df_temp = df[df.Method==method]

r, p = pearsonr(df_temp.Active_True.values, df_temp.Active_Pred.values)
print(f"{method} Active Dipoles: r = {r:.2f}")

plt.figure()
sns.scatterplot(data=df_temp, y="Active_Pred", x="Active_True")
maxval = np.max( [plt.ylim()[1], plt.xlim()[1]])
plt.ylim(0,  maxval)
plt.xlim(0, maxval)
plt.plot([0, maxval], [0, maxval], 'k')


FLEX-MUSIC sparsity: r = 0.80
Convexity Champagne sparsity: r = 0.17
FLEX-MUSIC Active Dipoles: r = 0.83
Convexity Champagne Active Dipoles: r = 0.17


[<matplotlib.lines.Line2D at 0x22397b9b6a0>]

In [8]:
df.head()

Unnamed: 0,Mean_Squared_Error,Normalized_Mean_Squared_Error,Mean_Localization_Error,AUC,Corr,EMD,Sparsity_pred,Sparsity_true,Active_True,Active_Pred,Method,Time,Source Extend
0,1.984523e-08,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.000779,0.000779,FLEX-MUSIC,0.427364,Single
1,6.161808e-08,0.000174,0.0,1.0,0.920804,7268.832897,1.710411,1.722149,0.002336,0.002336,FLEX-MUSIC,0.466751,Single
2,6.454046e-07,0.001269,1.93978,0.997959,0.692695,62176.801113,3.381233,2.512096,0.005452,0.014798,FLEX-MUSIC,0.425861,Single
3,6.913167e-08,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.000779,0.000779,FLEX-MUSIC,0.449303,Single
4,1.874573e-07,0.000182,0.0,1.0,0.921881,8764.065738,1.955567,1.912822,0.003115,0.003115,FLEX-MUSIC,0.485701,Single


# Paper Plots 1

In [33]:
sns.set(style="whitegrid", font_scale=1.)

tick_params = dict(
    axis='y',          # changes apply to the y-axis
    which='both',      # both major and minor ticks are affected
    left=True,      # ticks along the bottom edge are off
    right=False,         # ticks along the top edge are off
    labelbottom=False
)

medianprops = {
    "linewidth": 2,
    "linestyle": "dashed"
    }


sim_type = "single"
fn = f"results/results_{sim_type}.pkl"
with open(fn, 'rb') as f:
    results_single = pkl.load(f)

sim_type = "extended"
fn = f"results/results_{sim_type}.pkl"
with open(fn, 'rb') as f:
    results_extended = pkl.load(f)

df_single = pd.DataFrame(results_single)
df_single["Source Extend"] = "Single"
df_extended = pd.DataFrame(results_extended)
df_extended["Source Extend"] = "Extended"

df = pd.concat([df_single, df_extended])
df.Method[df.Method=="Convexity Champagne"] = "Champagne"
df.Method[df.Method=="FLEX-MUSIC"] = "FLEX\nMUSIC"
df.Method[df.Method=="TRAP-MUSIC"] = "TRAP\nMUSIC"
df.Method[df.Method=="eLORETA"] = "eLOR"
df.rename(columns = {col: col.replace("-", " ").replace("_", " ") for col in df.columns}, inplace = True)
order = ['FLEX\nMUSIC', 'TRAP\nMUSIC', 'Champagne', 'MCMV', 'eLOR']

fig1 = plt.figure(figsize=(16,5))
plt.subplot(131)
sns.boxplot(data=df, x="Method", y="Mean Localization Error", hue="Source Extend", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)
plt.ylim(-3, 43)
plt.ylabel("Mean Localization Error [mm]")
plt.gca().get_legend().remove()

plt.subplot(132)
g = sns.boxplot(data=df, x="Method", y="EMD", hue="Source Extend", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)
plt.gca().get_legend().remove()

plt.subplot(133)
sns.boxplot(data=df, x="Method", y="Mean Squared Error", hue="Source Extend", order=order, medianprops=medianprops)
plt.ylim(-0.05e-5, 1e-5)
plt.tick_params(**tick_params)


plt.tight_layout(pad=2)
df.rename(columns={"Sparsity pred": "Sparsity"}, inplace=True)

fig2 = plt.figure(figsize=(10,5))
plt.subplot(121)
sns.boxplot(data=df, x="Method", y="Sparsity", hue="Source Extend", order=order, medianprops=medianprops)
plt.tick_params(**tick_params)
plt.gca().get_legend().remove()

plt.subplot(122)
g = sns.boxplot(data=df, x="Method", y="Time", hue="Source Extend", order=order, medianprops=medianprops)
g.set_yscale("log")
plt.ylabel("Computation Time [s]")
plt.tick_params(**tick_params)

plt.tight_layout(pad=2)

from scipy.stats import pearsonr
sns.set(style="white", font_scale=1.)
tick_params = dict(
    axis='both',          # changes apply to the y-axis
    which='both',      # both major and minor ticks are affected
    direction="inout",
    left=True,      # ticks along the bottom edge are off
    bottom=True,
    right=False,         # ticks along the top edge are off
    labelbottom=True
)

fig3 = plt.figure(figsize=(8, 6))
for i, method in enumerate(order):
    print(method)
    plt.subplot(2,3,i+1)
    df_temp = df[df.Method==method]

    a = df_temp["Sparsity true"].values
    b = df_temp["Sparsity"].values
    nans = (np.isnan(a) | np.isnan(b))
    a[nans] = np.nanmedian(a)
    b[nans] = np.nanmedian(b)
    r, p = pearsonr(a, b)
    # print(f"{method} sparsity: r = {r:.2f}")

    
    sns.scatterplot(data=df_temp, y="Sparsity", x="Sparsity true")
    plt.gca().set_aspect('equal', adjustable='box')

    

    maxval = np.max( [plt.ylim()[1], plt.xlim()[1]])
    plt.ylim(0,  maxval)
    plt.xlim(0, maxval)
    
    
    plt.plot([0, maxval], [0, maxval], 'k')
    plt.ylabel("Predicted Sparsity")
    plt.xlabel("True Sparsity")
    plt.title(method)
    r_text = f"r = {r:.2f}"
    if p<0.001:
        r_text += " ***"
    elif p<0.01:
        r_text += " **"
    elif p<0.05:
        r_text += " *"
    plt.text(maxval/2.6, maxval/1.2, r_text)
    xticks = plt.xticks()[0]
    yticks = plt.yticks()[0]
    longest_ticks = xticks if len(xticks)>len(yticks) else yticks
    plt.yticks(longest_ticks)
    plt.xticks(longest_ticks)
    plt.ylim(0,  maxval)
    plt.xlim(0, maxval)
    plt.tick_params(**tick_params)

    
plt.tight_layout(pad=2)

fig1.savefig("figures/accuracy.png", dpi=600)
fig2.savefig("figures/sparsity_time.png", dpi=600)
fig3.savefig("figures/extend_estimation.png", dpi=600)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.Method[df.Method=="Convexity Champagne"] = "Champagne"
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.Method[df.Method=="FLEX-MUSIC"] = "FLEX\nMUSIC"
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.Method[df.Method=="TRAP-MUSIC"] = "TRAP\nMUSIC"
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.Metho

FLEX
MUSIC
TRAP
MUSIC
Champagne
MCMV
eLOR


# Table

In [31]:
df_table = df.copy()

df_table.Method[df.Method=="FLEX\nMUSIC"] = "FLEX-MUSIC"
df_table.Method[df.Method=="TRAP\nMUSIC"] = "TRAP-MUSIC"

df_table = df_table.groupby("Method").median()
df_table = df_table[["Mean Localization Error", "EMD", "Mean Squared Error", "Sparsity", "Time"]]
for col in df_table.columns:
    if not "Squared" in col:
        df_table[col] = df_table[col].round(2)

print(df_table.to_latex())

\begin{tabular}{lrrrrr}
\toprule
{} &  Mean Localization Error &        EMD &  Mean Squared Error &  Sparsity &  Time \\
Method     &                          &            &                     &           &       \\
\midrule
Champagne  &                     7.73 &  123775.42 &        4.674609e-07 &      6.29 &  2.66 \\
FLEX-MUSIC &                     2.11 &   66291.31 &        4.147606e-07 &      2.36 &  0.35 \\
MCMV       &                    16.25 &  184793.86 &        9.648467e-07 &     30.04 &  1.82 \\
TRAP-MUSIC &                     5.50 &  125233.79 &        5.371027e-07 &      1.73 &  0.32 \\
eLOR       &                    19.05 &  185186.10 &        8.155678e-07 &     28.29 &  0.32 \\
\bottomrule
\end{tabular}



A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_table.Method[df.Method=="FLEX\nMUSIC"] = "FLEX-MUSIC"
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_table.Method[df.Method=="TRAP\nMUSIC"] = "TRAP-MUSIC"
  print(df_table.to_latex())


# Plot brains

In [184]:
import pickle as pkl
import mne
import sys; sys.path.insert(0, '../invert')
from invert.util import pos_from_forward
from copy import deepcopy
import pandas as pd

clim=dict(kind="value", pos_lims=(0.0, 0.01, 1))
pp = dict(surface='inflated', hemi='both', background="white", verbose=0, colorbar=False, time_viewer=False)

sim_type = "extended"
fn = f"evaluation/sim_and_preds_{sim_type}.pkl"
with open(fn, 'rb') as f:
    stc_dict_ext, x_test_ext, y_test_ext, sim_info_ext, _, _ = pkl.load(f)

sim_type = "single"
fn = f"evaluation/sim_and_preds_{sim_type}.pkl"
with open(fn, 'rb') as f:
    stc_dict_sing, x_test_sing, y_test_sing, sim_info_sing, _, _ = pkl.load(f)

with open("forward_model/64ch_info.pkl", "rb") as f:
    info = pkl.load(f)

# Combine
stc_dict = deepcopy(stc_dict_sing)
for key, value in stc_dict_ext.items():
    stc_dict[key].extend(value)
x_test = np.concatenate([x_test_sing, x_test_ext], axis=0)
y_test = np.concatenate([y_test_sing, y_test_ext], axis=0)
sim_info = pd.concat([sim_info_sing, sim_info_ext])


fwd = mne.read_forward_solution("forward_model/64ch_ico3-fwd.fif", verbose=0)
fwd = mne.convert_forward_solution(fwd, force_fixed=True)
pos = pos_from_forward(fwd)
source_model = fwd['src']
vertices = [source_model[0]['vertno'], source_model[1]['vertno']]

samples = [23, 36, 508, 514]
imgs = []
colorbars = []
names = []
for sample in samples:
    
    tmin = 0
    tstep = 1/1000
    subject = "fsaverage"

    evoked = mne.EvokedArray(x_test[sample].T, info, tmin=0)
    # evoked.plot_joint()
    stc = mne.SourceEstimate(y_test[sample].T, vertices, tmin=tmin, tstep=tstep, 
                            subject=subject, verbose=0)
    first_sample = stc.data[:, 0]
    first_sample /= np.max(abs(first_sample))
    stc.data = np.tile(first_sample, (20,1)).T
    brain = stc.plot(**pp, brain_kwargs=dict(title="True"), clim=clim)
    pp["colorbar"] = True
    brain_cb = stc_list[sample].plot(**pp, brain_kwargs=dict(title=solver), clim=clim)
    pp["colorbar"] = False

    img = brain.screenshot()
    colorbar = brain_cb.screenshot()
    brain.close()
    brain_cb.close()
    imgs.append( img )
    colorbars.append( colorbar )
    names.append("Ground Truth")
    
    for solver, stc_list in stc_dict.items():
        first_sample = stc_list[sample].data[:, 0]
        first_sample /= np.max(abs(first_sample))
        stc_list[sample].data = np.tile(first_sample, (20,1)).T
        if solver == "eLORETA" or solver == "MCMV":
            brain = stc_list[sample].plot(**pp, brain_kwargs=dict(title=solver), clim=dict(kind="value", pos_lims=(0.2, 0.5, 1)))
            pp["colorbar"] = True
            brain_cb = stc_list[sample].plot(**pp, brain_kwargs=dict(title=solver), clim=dict(kind="value", pos_lims=(0.2, 0.5, 1)))
            pp["colorbar"] = False
        else:
            brain = stc_list[sample].plot(**pp, brain_kwargs=dict(title=solver), clim=clim)
            pp["colorbar"] = True
            brain_cb = stc_list[sample].plot(**pp, brain_kwargs=dict(title=solver), clim=clim)
            pp["colorbar"] = False
            
        img = brain.screenshot()
        colorbar = brain_cb.screenshot()
        
        brain.close()
        brain_cb.close()

        imgs.append( img )
        colorbars.append(colorbar)
        names.append(solver)

    No patch info available. The standard source space normals will be employed in the rotation to the local surface coordinates....
    Changing to fixed-orientation forward solution with surface-based source orientations...
    [done]


## Final Plot and Save

In [195]:
%matplotlib qt
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(20,8))
for i, (name, img) in enumerate(zip(names, imgs)):
    plt.subplot(int(len(imgs)/6)+1,6,i+1)
    plt.imshow(img[122:691, :])
    plt.axis('off')
    if i<6:
        if "Champagne" in name:
            plt.title("Champagne", fontsize=18)
        else:
            plt.title(name, fontsize=18)
for j, colorbar in enumerate(colorbars[:6]):
    plt.subplot(int(len(imgs)/6)+1,6,i+j+2)
    plt.imshow(colorbar[731:761, :])
    plt.axis('off')
plt.tight_layout()
fig.savefig("figures/brains.png", dpi=600)


'Ground Truth'

## Helps to decide for the right samples:

In [141]:
sim_info[sim_info.n_sources==3].head(25)

Unnamed: 0,n_sources,amplitudes,snr
6,3,"[0.7807408956566793, 0.7782117061516455, 0.544...",21.766431
9,3,"[0.6414994226703741, 0.7819144808592774, 0.138...",46.935714
25,3,"[0.339622791244461, 0.9859668040487106, 0.0735...",16.049346
40,3,"[0.2525028471127769, 0.8796421349726212, 0.583...",94.770972
43,3,"[0.9514130873655432, 0.1143115201823572, 0.383...",77.249802
53,3,"[0.5998499648461517, 0.4349899475082236, 0.796...",77.374201
63,3,"[0.24392823657689203, 0.9582270089814386, 0.10...",81.931384
76,3,"[0.4829919480971907, 0.45328753837912167, 0.61...",84.380941
77,3,"[0.3695619633879521, 0.9033622108169519, 0.449...",59.65226
82,3,"[0.9468621389452988, 0.38910627588799757, 0.03...",51.493159


In [148]:
clim=dict(kind="percent", pos_lims=(0, 0, 100))
pp["clim"] = clim
# for i in range(510, 520):
stc_dict["FLEX-MUSIC"][514].plot(**pp)

<mne.viz._brain._brain.Brain at 0x1f3c63f4e80>