In [None]:
import sys; sys.path.insert(0, r'C:\Users\Lukas\OneDrive\Dokumente\projects\invert')
import mne
import pickle as pkl
from time import time
from scipy.spatial.distance import cdist
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from invert.evaluate import eval_mean_localization_error, eval_mean_localization_error_old
from tqdm.notebook import tqdm
import os
from time import sleep
from config import *

os.makedirs('D:/data/flex_ssm/results', exist_ok=True)
estimator = "mean"

In [None]:
def get_pos(fwd):
    pos_left = mne.vertex_to_mni(fwd["src"][0]["vertno"], 0, subject=subject, subjects_dir=subjects_dir, verbose=0)
    pos_right = mne.vertex_to_mni(fwd["src"][1]["vertno"], 1, subject=subject, subjects_dir=subjects_dir, verbose=0)
    pos = np.concatenate([pos_left, pos_right], axis=0)
    return pos

# Load Forward Models

In [None]:
fwds = {}
fullpath = os.path.join("D:/data/flex_ssm/", forward_models[0]["path_fwd"])
fwd_ico4 = mne.read_forward_solution(fullpath, verbose=0)
fwd_ico4 = mne.convert_forward_solution(fwd_ico4, force_fixed=True, surf_ori=True, use_cps=True)
fwd_ico4.subject = "sample"
pos_ico4 = get_pos(fwd_ico4)

distances_ico4 = cdist(pos_ico4, pos_ico4)
adjacency_ico4 = mne.spatial_src_adjacency(fwd_ico4["src"], verbose=0).toarray()

fullpath = os.path.join("D:/data/flex_ssm/", forward_models[1]["path_fwd"])
fwd_oct6 = mne.read_forward_solution(fullpath, verbose=0)
fwd_oct6 = mne.convert_forward_solution(fwd_oct6, force_fixed=True, surf_ori=True, use_cps=True)
fwd_oct6.subject = "sample"
pos_oct6 = get_pos(fwd_oct6)
distances_oct6 = cdist(pos_ico4, pos_oct6)
adjacency_oct6 = mne.spatial_src_adjacency(fwd_oct6["src"], verbose=0).toarray()

fullpath = os.path.join("D:/data/flex_ssm/", forward_models[0]["path_info"])
with open(fullpath, "rb") as f:
    info = pkl.load(f)

# Evaluate all files

In [None]:
%matplotlib qt
idx = 100

model = forward_models[1]
fwd_test = mne.read_forward_solution(model["path_fwd"], verbose=0)
fwd_test = mne.convert_forward_solution(fwd_test, force_fixed=True)
with open(model["path_info"], "rb") as f:
    info_test = pkl.load(f)
data_1 = fwd_test["sol"]["data"][:, idx][:, np.newaxis]

evoked = mne.EvokedArray(data_1, info_test)
evoked.plot_joint(title=model["name"])

model = forward_models[3]
fwd_test = mne.read_forward_solution(model["path_fwd"], verbose=0)
fwd_test = mne.convert_forward_solution(fwd_test, force_fixed=True)
with open(model["path_info"], "rb") as f:
    info_test = pkl.load(f)
data_2 = fwd_test["sol"]["data"][:, idx][:, np.newaxis]

evoked = mne.EvokedArray(data_1, info_test)
evoked.plot_joint(title=model["name"])

from scipy.stats import pearsonr
print(pearsonr(data_1[:, 0], data_2[:, 0])[0])

import matplotlib.pyplot as plt
plt.figure()
plt.scatter(data_1[:, 0], data_2[:, 0])

In [None]:
path_evaluation = "D:/data/flex_ssm/evaluation/"
eval_filenames = os.listdir(path_evaluation)
adjacency_true = adjacency_ico4

# eval_filenames = [f for f in eval_filenames if "non-greedy" in f]
for i, filename in enumerate(eval_filenames):
    # if not "figure-4" in filename:
    #     continue
    # print(filename)

    fullpath = os.path.join(path_evaluation, filename)
    figure = filename.split("_")[3]
    greedyness = fullpath.split("_")[-1].split(".")[0]
    model_error = fullpath.split("_")[-2]
    
    if "assumed" in fullpath:
        assumption = fullpath[:fullpath.find("assumed")-1].split("_")[-1]
        assumed = f"_{assumption}-assumed"
    else:
        assumed = ""
    
    fn = f"D:/data/flex_ssm/results/results_{figure}_{model_error}_{greedyness}{assumed}.pkl"
    print(fullpath, figure, greedyness, model_error)

    
    print("FN: ", fn)
    if not ".pkl" in filename or os.path.isfile(fn):
        print("\tis processed or not a regular file")
        continue
    # break
    with open(fullpath, "rb") as f:
        stc_dict, x_test, y_test, sim_info, proc_time_make, proc_time_apply = pkl.load(f)
    n_samples = len(y_test)
    print("\t", fn)
    if "fine" in filename.lower() or (not "coarse" in filename.lower() and not "fine" in filename.lower()):
        print("\t\tyep")
        adjacency_pred = adjacency_oct6
        distances = distances_oct6
    else:
        adjacency_pred = adjacency_ico4
        distances = distances_ico4
    

    results = []
    for solver_name in stc_dict.keys():
        print(solver_name)
        for i in tqdm(range(n_samples)):
            # y_pred = stc_dict[solver_name][i].data
            y_pred = stc_dict[solver_name][i].toarray()
            # y_true = y_test[i].T.toarray()
            y_true = y_test[i].toarray()
            # mle_dle = eval_mean_localization_error(abs(y_true).mean(axis=-1), abs(y_pred).mean(axis=-1), adjacency_true, adjacency_pred, distances, mode="dle")
            # mle_est = eval_mean_localization_error(abs(y_true).mean(axis=-1), abs(y_pred).mean(axis=-1), adjacency_true, adjacency_pred, distances, mode="est")
            # mle_true = eval_mean_localization_error(abs(y_true).mean(axis=-1), abs(y_pred).mean(axis=-1), adjacency_true, adjacency_pred, distances, mode="true")
            # mle_match = eval_mean_localization_error(abs(y_true).mean(axis=-1), abs(y_pred).mean(axis=-1), adjacency_true, adjacency_pred, distances, mode="match")
            mle_amir = eval_mean_localization_error(abs(y_true).mean(axis=-1), abs(y_pred).mean(axis=-1), adjacency_true, adjacency_pred, distances, mode="amir")
            result = dict(Method=solver_name, MLE_amir=mle_amir, Time_Make=proc_time_make[solver_name][i], Time_Apply=proc_time_apply[solver_name][i])
            # result = dict(Method=solver_name, MLE_match=mle_match, MLE_dle=mle_dle, MLE_est=mle_est, MLE_true=mle_true, MLE_amir=mle_amir, Time_Make=proc_time_make[solver_name][i], Time_Apply=proc_time_apply[solver_name][i])
            result.update(sim_info.iloc[i, :].to_dict())
            results.append(result)
            if i%100 == 0:
                print("\t ", i)
    del stc_dict, y_test, proc_time_make, proc_time_apply
    with open(fn, 'wb') as f:
        pkl.dump(results, f)

In [None]:
y_pred.shape, y_true.shape

# Plot Figure 1

In [None]:
%matplotlib qt
import pickle as pkl
import pandas as pd
import seaborn as sns
import os
import re

# filenames = os.listdir("results")
filenames = {
    'No Error': 'results_figure-1_Clean-Coarse_greedy.pkl',
    'Source Modelling Error': 'results_figure-1_Clean-Fine_greedy.pkl',
    'SME + Rotation Right 1°': 'results_figure-1_Rotation-Right-1_greedy.pkl',
    'SME + Rotation Right 2°': 'results_figure-1_Rotation-Right-2_greedy.pkl',
    'SME + Translation Post. 1mm': 'results_figure-1_Translation-Posterior-1_greedy.pkl',
    'SME + Translation Post. 2mm': 'results_figure-1_Translation-Posterior-2_greedy.pkl'
}
for title, filename in filenames.items():
    filename = os.path.join("results", filename)
    if not os.path.isfile(filename):
        print(f"{filename} does not exist")
        continue
    with open(filename, 'rb') as f:
        results = pkl.load(f)

    df = pd.DataFrame(results)

    sns.set_theme(style="whitegrid")
    import matplotlib.pyplot as plt
    fig = plt.figure()
    sns.barplot(hue="Method", x="inter_source_correlations", y="MLE", data=df, errorbar=("ci", 95), estimator=estimator)
    plt.ylim(0, 18)
    plt.title(title)
    # save figure
    savestring = re.sub('[+°.]', '', title).replace("  ", " ").replace(" ", "_")
    
    fig.savefig(f'figures/initial_results/Figure_1_{savestring}_{estimator}s.png', format='png', dpi=300)

# Plot Figure 12  - 3 sources

In [None]:
%matplotlib qt
import pickle as pkl
import pandas as pd
import seaborn as sns
import os
import re

# filenames = os.listdir("results")
filenames = {
    'No Error': 'results_figure-12_Clean-Coarse_greedy.pkl',
    'Source Modelling Error': 'results_figure-12_Clean-Fine_greedy.pkl',
    'SME + Rotation Right 1°': 'results_figure-12_Rotation-Right-1_greedy.pkl',
    'SME + Rotation Right 2°': 'results_figure-12_Rotation-Right-2_greedy.pkl',
    'SME + Translation Post. 1mm': 'results_figure-12_Translation-Posterior-1_greedy.pkl',
    'SME + Translation Post. 2mm': 'results_figure-12_Translation-Posterior-2_greedy.pkl'
}
for title, filename in filenames.items():
    filename = os.path.join("results", filename)
    if not os.path.isfile(filename):
        print(f"{filename} does not exist")
        continue
    with open(filename, 'rb') as f:
        results = pkl.load(f)

    df = pd.DataFrame(results)

    sns.set_theme(style="whitegrid")
    import matplotlib.pyplot as plt
    fig = plt.figure()
    sns.barplot(hue="Method", x="inter_source_correlations", y="MLE_dle", data=df, errorbar=("ci", 95), estimator=estimator)
    plt.ylim(0, 28)
    plt.title(title)
    # save figure
    savestring = re.sub('[+°.]', '', title).replace("  ", " ").replace(" ", "_")
    
    fig.savefig(f'figures/initial_results/Figure_12_{savestring}_{estimator}s.png', format='png', dpi=300)

# Plot Figure 13  - Equal  Magnitude Sources

In [None]:
%matplotlib qt
import pickle as pkl
import pandas as pd
import seaborn as sns
import os
import re

# filenames = os.listdir("results")
filenames = {
    'No Error': 'results_figure-13_Clean-Coarse_greedy.pkl',
    'Source Modelling Error': 'results_figure-13_Clean-Fine_greedy.pkl',
    'SME + Rotation Right 1°': 'results_figure-13_Rotation-Right-1_greedy.pkl',
    'SME + Rotation Right 2°': 'results_figure-13_Rotation-Right-2_greedy.pkl',
    'SME + Translation Post. 1mm': 'results_figure-13_Translation-Posterior-1_greedy.pkl',
    'SME + Translation Post. 2mm': 'results_figure-13_Translation-Posterior-2_greedy.pkl'
}
for title, filename in filenames.items():
    filename = os.path.join("results", filename)
    if not os.path.isfile(filename):
        print(f"{filename} does not exist")
        continue
    with open(filename, 'rb') as f:
        results = pkl.load(f)

    df = pd.DataFrame(results)

    sns.set_theme(style="whitegrid")
    import matplotlib.pyplot as plt
    fig = plt.figure()
    sns.barplot(hue="Method", x="inter_source_correlations", y="MLE_dle", data=df, errorbar=("ci", 95), estimator=estimator)
    plt.ylim(0, 18)
    plt.title(title)
    # save figure
    savestring = re.sub('[+°.]', '', title).replace("  ", " ").replace(" ", "_")
    
    fig.savefig(f'figures/initial_results/Figure_13_{savestring}_{estimator}s.png', format='png', dpi=300)

# Plot Figure 14 - White Spectrum Sources

In [None]:
%matplotlib qt
import pickle as pkl
import pandas as pd
import seaborn as sns
import os
import re

# filenames = os.listdir("results")
filenames = {
    'No Error': 'results_figure-14_Clean-Coarse_greedy.pkl',
    'Source Modelling Error': 'results_figure-14_Clean-Fine_greedy.pkl',
    'SME + Rotation Right 1°': 'results_figure-14_Rotation-Right-1_greedy.pkl',
    'SME + Rotation Right 2°': 'results_figure-14_Rotation-Right-2_greedy.pkl',
    'SME + Translation Post. 1mm': 'results_figure-14_Translation-Posterior-1_greedy.pkl',
    'SME + Translation Post. 2mm': 'results_figure-14_Translation-Posterior-2_greedy.pkl'
}
for title, filename in filenames.items():
    filename = os.path.join("results", filename)
    if not os.path.isfile(filename):
        print(f"{filename} does not exist")
        continue
    with open(filename, 'rb') as f:
        results = pkl.load(f)

    df = pd.DataFrame(results)

    sns.set_theme(style="whitegrid")
    import matplotlib.pyplot as plt
    fig = plt.figure()
    sns.barplot(hue="Method", x="inter_source_correlations", y="MLE_dle", data=df, errorbar=("ci", 95), estimator=estimator)
    plt.ylim(0, 18)
    plt.title(title)
    # save figure
    savestring = re.sub('[+°.]', '', title).replace("  ", " ").replace(" ", "_")
    
    fig.savefig(f'figures/initial_results/Figure_14_{savestring}_{estimator}s.png', format='png', dpi=300)

# Plot Figure 2

In [None]:
import pickle as pkl
import pandas as pd
import seaborn as sns
import os
import re
import matplotlib.pyplot as plt

# Setup the figure and subplots
fig, axes = plt.subplots(3, 3, figsize=(15, 10))  # Adjust the size as necessary
axes = axes.flatten()  # Flatten the 2D array of axes to simplify accessing them

filenames = {
    'No Error': 'results_figure-2_Clean-Coarse_greedy.pkl',
    'Source Modelling Error': 'results_figure-2_Clean-Fine_greedy.pkl',
    'SME + Rotation Right 1°': 'results_figure-2_Rotation-Right-1_greedy.pkl',
    'SME + Rotation Right 2°': 'results_figure-2_Rotation-Right-2_greedy.pkl',
    'SME + Translation Post. 1mm': 'results_figure-2_Translation-Posterior-1_greedy.pkl',
    'SME + Translation Post. 2mm': 'results_figure-2_Translation-Posterior-2_greedy.pkl',
    # Add additional filenames if needed to fill 3x3 grid
}

sns.set_theme(style="whitegrid")
estimator = "median"
plot_idx = 0
new_xticks = [-15, -10, -5, 0, 5]

for title, filename in filenames.items():
    filename = os.path.join("results", filename)
    if not os.path.isfile(filename):
        print(f"{filename} does not exist")
        continue
    with open(filename, 'rb') as f:
        results = pkl.load(f)
    
    # Remove the second SSM
    results = [r for r in results if not "Adaptive-Reg" in r["Method"]]
    df = pd.DataFrame(results)

    # Create each subplot
    ax = axes[plot_idx]
    sns.barplot(hue="Method", x="snr", y="MLE_match", data=df, errorbar=("ci", 95), estimator=estimator, ax=ax)
    ax.set_ylim(0, 10)
    ax.set_title(title)
    # change xtick labels
    ax.set_xticks(new_xticks)
    
    plot_idx += 1


# Adjust layout and save the entire figure
plt.tight_layout()
plt.savefig(f'{pth_results}/Figure_2_{estimator}.png', format='png', dpi=300)


In [None]:
# given some mne.io.Raw object, we could extract the first data points before the first event like so:
# raw = mne.io.read_raw_fif("path/to/file.fif")
# raw.crop(tmax=raw.times[raw.annotations.onset[0]])

# Plot Figure 3

In [None]:
%matplotlib qt
import pickle as pkl
import pandas as pd
import seaborn as sns
import os
import re

# filenames = os.listdir("results")
filenames = {
    'No Error': 'results_figure-3_Clean-Coarse_greedy.pkl',
    'Source Modelling Error': 'results_figure-3_Clean-Fine_greedy.pkl',
    'SME + Rotation Right 1°': 'results_figure-3_Rotation-Right-1_greedy.pkl',
    'SME + Rotation Right 2°': 'results_figure-3_Rotation-Right-2_greedy.pkl',
    'SME + Translation Post. 1mm': 'results_figure-3_Translation-Posterior-1_greedy.pkl',
    'SME + Translation Post. 2mm': 'results_figure-3_Translation-Posterior-2_greedy.pkl'
}
for title, filename in filenames.items():
    filename = os.path.join("results", filename)
    if not os.path.isfile(filename):
        print(f"{filename} does not exist")
        continue
    with open(filename, 'rb') as f:
        results = pkl.load(f)

    df = pd.DataFrame(results)

    sns.set_theme(style="whitegrid")
    import matplotlib.pyplot as plt
    fig = plt.figure()
    sns.barplot(hue="Method", x="n_timepoints", y="MLE", data=df, errorbar=("ci", 95), estimator=estimator)
    plt.ylim(0, 6)
    plt.title(title)
    # save figure
    savestring = re.sub('[+°.]', '', title).replace("  ", " ").replace(" ", "_")
    
    fig.savefig(f'figures/initial_results/Figure_3_{savestring}_{estimator}s.png', format='png', dpi=300)

# Plot Figure 4

In [None]:
%matplotlib qt
import pickle as pkl
import pandas as pd
import seaborn as sns
import os
import re
import matplotlib.pyplot as plt

# Assuming 'filenames_' is defined as in your snippet above
filenames_ = {
    "-1": {
        'No Error': 'results_figure-4_Clean-Coarse_greedy_-1-assumed.pkl',
        'Source Modelling Error': 'results_figure-4_Clean-Fine_greedy_-1-assumed.pkl',
        'SME + Rotation Right 1°': 'results_figure-4_Rotation-Right-1_greedy_-1-assumed.pkl',
        'SME + Rotation Right 2°': 'results_figure-4_Rotation-Right-2_greedy_-1-assumed.pkl',
        'SME + Translation Post. 1mm': 'results_figure-4_Translation-Posterior-1_greedy_-1-assumed.pkl',
        'SME + Translation Post. 2mm': 'results_figure-4_Translation-Posterior-2_greedy_-1-assumed.pkl'
    },
    "0": {
        'No Error': 'results_figure-4_Clean-Coarse_greedy_0-assumed.pkl',
        'Source Modelling Error': 'results_figure-4_Clean-Fine_greedy_0-assumed.pkl',
        'SME + Rotation Right 1°': 'results_figure-4_Rotation-Right-1_greedy_0-assumed.pkl',
        'SME + Rotation Right 2°': 'results_figure-4_Rotation-Right-2_greedy_0-assumed.pkl',
        'SME + Translation Post. 1mm': 'results_figure-4_Translation-Posterior-1_greedy_0-assumed.pkl',
        'SME + Translation Post. 2mm': 'results_figure-4_Translation-Posterior-2_greedy_0-assumed.pkl'
    },
    "1": {
        'No Error': 'results_figure-4_Clean-Coarse_greedy_1-assumed.pkl',
        'Source Modelling Error': 'results_figure-4_Clean-Fine_greedy_1-assumed.pkl',
        'SME + Rotation Right 1°': 'results_figure-4_Rotation-Right-1_greedy_1-assumed.pkl',
        'SME + Rotation Right 2°': 'results_figure-4_Rotation-Right-2_greedy_1-assumed.pkl',
        'SME + Translation Post. 1mm': 'results_figure-4_Translation-Posterior-1_greedy_1-assumed.pkl',
        'SME + Translation Post. 2mm': 'results_figure-4_Translation-Posterior-2_greedy_1-assumed.pkl'
    },
}

# Set the theme for seaborn
sns.set_theme(style="whitegrid")


# Row and column titles
row_titles = ['-1', '0', '1'] # Assumed sources
column_titles = [
    'No Error', 'Source Modelling Error', 'SME + Rotation Right 1°',
    'SME + Rotation Right 2°', 'SME + Translation Post. 1mm', 'SME + Translation Post. 2mm'
]

metrics_description = {
    "Dipole Localization Error": "MLE_dle",
    "Localization Error (True)": "MLE_true",
    "Localization Error (Estimated)": "MLE_est",
    "Matched Localization Error": "MLE_match"
    }

for suptitle, y_var in metrics_description.items():
    # Create a subplot grid: 3 rows x 6 columns
    fig, axes = plt.subplots(3, 6, figsize=(20, 10), sharex='col', sharey='row')
    fig.subplots_adjust(hspace=0.4, wspace=0.4) # Adjust space between plots
    # Iterate over each assumption and filenames
    for i, (assumption, filenames) in enumerate(filenames_.items()):
        if assumption == "-1":
            assumption_text = "1 source less"
        elif assumption == "0":
            assumption_text = "Correct number of sources"
        elif assumption == "1":
            assumption_text = "1 source more"

        for j, (title, filename) in enumerate(filenames.items()):
            ax = axes[i, j] # Current subplot axis
            filename = os.path.join("results", filename)
            if not os.path.isfile(filename):
                ax.text(0.5, 0.5, f"{filename} does not exist", ha='center')
                continue
            with open(filename, 'rb') as f:
                results = pkl.load(f)

            df = pd.DataFrame(results)

            # Plot on the specified subplot axis
            sns.barplot(
                hue="Method", x="inter_source_correlations", y=y_var, data=df, 
                errorbar=("ci", 95), estimator=estimator, ax=ax
            )

            # Set row and column titles
            if j == 0:
                ax.set_ylabel(assumption_text + "\n\n" + "Mean Localization Error [mm]")
            else:
                ax.set_ylabel('')
            
            if i == 2:
                ax.set_xlabel('Inter-source Correlations')
            else:
                ax.set_xlabel('')

            # Set the title for the first row and remove from others
            if i == 0:
                ax.set_title(column_titles[j])
            else:
                ax.set_title('')
            
            # Remove legend if not 0th column
            if j != 0:
                ax.get_legend().remove()

    plt.suptitle(suptitle)

    # Adjust the layout
    plt.tight_layout()

    # Save the figure
    plt.savefig(f'figures/initial_results/Figure_4_{y_var}_{estimator}s.png', format='png', dpi=300)

# Plot Figure 42

In [None]:
%matplotlib qt
import pickle as pkl
import pandas as pd
import seaborn as sns
import os
import re
import matplotlib.pyplot as plt

# Assuming 'filenames_' is defined as in your snippet above
filenames_ = {
    "-1": {
        'No Error': 'D:/data/flex_ssm/results/results_figure-42_Clean-Coarse_greedy_-1-assumed.pkl',
        'Source Modelling Error': 'D:/data/flex_ssm/results/results_figure-42_Clean-Fine_greedy_-1-assumed.pkl',
        'SME + Rotation Right 1°': 'D:/data/flex_ssm/results/results_figure-42_Rotation-Right-1_greedy_-1-assumed.pkl',
        'SME + Rotation Right 2°': 'D:/data/flex_ssm/results/results_figure-42_Rotation-Right-2_greedy_-1-assumed.pkl',
        'SME + Translation Post. 1mm': 'D:/data/flex_ssm/results/results_figure-42_Translation-Posterior-1_greedy_-1-assumed.pkl',
        'SME + Translation Post. 2mm': 'D:/data/flex_ssm/results/results_figure-42_Translation-Posterior-2_greedy_-1-assumed.pkl'
    },
    "0": {
        'No Error': 'D:/data/flex_ssm/results/results_figure-42_Clean-Coarse_greedy_0-assumed.pkl',
        'Source Modelling Error': 'D:/data/flex_ssm/results/results_figure-42_Clean-Fine_greedy_0-assumed.pkl',
        'SME + Rotation Right 1°': 'D:/data/flex_ssm/results/results_figure-42_Rotation-Right-1_greedy_0-assumed.pkl',
        'SME + Rotation Right 2°': 'D:/data/flex_ssm/results/results_figure-42_Rotation-Right-2_greedy_0-assumed.pkl',
        'SME + Translation Post. 1mm': 'D:/data/flex_ssm/results/results_figure-42_Translation-Posterior-1_greedy_0-assumed.pkl',
        'SME + Translation Post. 2mm': 'D:/data/flex_ssm/results/results_figure-42_Translation-Posterior-2_greedy_0-assumed.pkl'
    },
    "1": {
        'No Error': 'D:/data/flex_ssm/results/results_figure-42_Clean-Coarse_greedy_1-assumed.pkl',
        'Source Modelling Error': 'D:/data/flex_ssm/results/results_figure-42_Clean-Fine_greedy_1-assumed.pkl',
        'SME + Rotation Right 1°': 'D:/data/flex_ssm/results/results_figure-42_Rotation-Right-1_greedy_1-assumed.pkl',
        'SME + Rotation Right 2°': 'D:/data/flex_ssm/results/results_figure-42_Rotation-Right-2_greedy_1-assumed.pkl',
        'SME + Translation Post. 1mm': 'D:/data/flex_ssm/results/results_figure-42_Translation-Posterior-1_greedy_1-assumed.pkl',
        'SME + Translation Post. 2mm': 'D:/data/flex_ssm/results/results_figure-42_Translation-Posterior-2_greedy_1-assumed.pkl'
    },
}

# Set the theme for seaborn
sns.set_theme(style="whitegrid")


# Row and column titles
row_titles = ['-1', '0', '1'] # Assumed sources
column_titles = [
    'No Error', 'Source Modelling Error', 'SME + Rotation Right 1°',
    'SME + Rotation Right 2°', 'SME + Translation Post. 1mm', 'SME + Translation Post. 2mm'
]

metrics_description = {
    # "Dipole Localization Error": "MLE_dle",
    # "Localization Error (True)": "MLE_true",
    # "Localization Error (Estimated)": "MLE_est",
    # "Matched Localization Error": "MLE_match",
    "Amirs Matched Localization Error": "MLE_amir"
    }

for suptitle, y_var in metrics_description.items():
    # Create a subplot grid: 3 rows x 6 columns
    fig, axes = plt.subplots(3, 6, figsize=(20, 10), sharex='col', sharey='row')
    fig.subplots_adjust(hspace=0.4, wspace=0.4) # Adjust space between plots
    # Iterate over each assumption and filenames
    for i, (assumption, filenames) in enumerate(filenames_.items()):
        if assumption == "-1":
            assumption_text = "1 source less"
        elif assumption == "0":
            assumption_text = "Correct number of sources"
        elif assumption == "1":
            assumption_text = "1 source more"

        for j, (title, filename) in enumerate(filenames.items()):
            ax = axes[i, j] # Current subplot axis
            filename = os.path.join("results", filename)
            if not os.path.isfile(filename):
                ax.text(0.5, 0.5, f"{filename} does not exist", ha='center')
                continue
            with open(filename, 'rb') as f:
                results = pkl.load(f)

            df = pd.DataFrame(results)

            # Plot on the specified subplot axis
            sns.barplot(
                hue="Method", x="inter_source_correlations", y=y_var, data=df, 
                errorbar=("ci", 95), estimator=estimator, ax=ax
            )

            # Set row and column titles
            if j == 0:
                ax.set_ylabel(assumption_text + "\n\n" + "Mean Localization Error [mm]")
            else:
                ax.set_ylabel('')
            
            if i == 2:
                ax.set_xlabel('Inter-source Correlations')
            else:
                ax.set_xlabel('')

            # Set the title for the first row and remove from others
            if i == 0:
                ax.set_title(column_titles[j])
            else:
                ax.set_title('')
            
            # Remove legend if not 0th column
            if j != 0:
                ax.get_legend().remove()
                
    plt.suptitle(suptitle)

    # Adjust the layout
    plt.tight_layout()

    # Save the figure
    plt.savefig(f'D:/data/flex_ssm/figures/initial_results/Figure_42_{y_var}_{estimator}s.png', format='png', dpi=300)

# Figure 5

In [None]:
%matplotlib qt
import pickle as pkl
import pandas as pd
import seaborn as sns
import os
import re

# three strong colors and three light colors with same tone (seaborn)
colors = ["#4c72b0", "#55a868", "#c44e52", "#8172b2", "#ccb974", "#64b5cd"]

# filenames = os.listdir("results")
filenames = {
    'No Error': 'results_figure-5_Clean-Coarse_greedy.pkl',
    'Source Modelling Error': 'results_figure-5_Clean-Fine_greedy.pkl',
    'SME + Rotation Right 1°': 'results_figure-5_Rotation-Right-1_greedy.pkl',
    'SME + Rotation Right 2°': 'results_figure-5_Rotation-Right-2_greedy.pkl',
    'SME + Translation Post. 1mm': 'results_figure-5_Translation-Posterior-1_greedy.pkl',
    'SME + Translation Post. 2mm': 'results_figure-5_Translation-Posterior-2_greedy.pkl'
}
for title, filename in filenames.items():
    filename = os.path.join("results", filename)
    if not os.path.isfile(filename):
        print(f"{filename} does not exist")
        continue
    with open(filename, 'rb') as f:
        results = pkl.load(f)

    df = pd.DataFrame(results)
    df["n_orders"] = [ii[0] for ii in df["n_orders"].values]

    sns.set_theme(style="whitegrid")
    import matplotlib.pyplot as plt
    fig = plt.figure()
    sns.barplot(hue="Method", x="n_orders", y="MLE_dle", data=df, errorbar=("ci", 95), estimator=estimator, palette=colors)
    # sns.barplot(hue="Method", x="n_orders", y="MLE_dle", data=df, errorbar=("ci", 95), estimator="mean", palette=colors)
    plt.ylim(0, 12)
    plt.title(title)
    plt.ylabel("Mean Localization Error [mm]")

    # save figure
    savestring = re.sub('[+°.]', '', title).replace("  ", " ").replace(" ", "_")
    fig.savefig(f'figures/initial_results/Figure_5_{savestring}_{estimator}s.png', format='png', dpi=300)

In [None]:
import pickle as pkl
import mne
import numpy as np
from copy import deepcopy
from config import forward_models
pp = dict(surface="inflated", hemi="both", subject="sample", cortex="low_contrast")

fn = r"C:\Users\lukas\Dokumente\projects\flex_ssm\ssm_paper_analysis\evaluation\sim_and_preds_figure-1_Clean-Fine_greedy.pkl"
with open(fn, "rb") as f:
    stc, stc_dict, x_test, y_test, sim_info, proc_time_make, proc_time_apply = pkl.load(f)

fwd_coarse = mne.read_forward_solution(forward_models[0]["path_fwd"], verbose=0)
fwd_fine = mne.read_forward_solution(forward_models[1]["path_fwd"], verbose=0)
# convert to fixed
fwd_coarse = mne.convert_forward_solution(fwd_coarse, surf_ori=True, force_fixed=True, use_cps=True, verbose=0)
fwd_fine = mne.convert_forward_solution(fwd_fine, surf_ori=True, force_fixed=True, use_cps=True, verbose=0)
leadfield = deepcopy(fwd_fine["sol"]["data"])
leadfield /= np.linalg.norm(leadfield, axis=0)
with open(forward_models[0]["path_info"], "rb") as f:
    info = pkl.load(f)


subjects_dir = r"C:\Users\lukas\mne_data\MNE-sample-data\subjects"
subject = "sample"

# src_coarse = mne.setup_source_space(subject, spacing="ico4", surface='white',
#                                         subjects_dir=subjects_dir, add_dist=False,
#                                         n_jobs=-1, verbose=0)
# src_fine = mne.setup_source_space(subject, spacing="oct6", surface='white',
#                                         subjects_dir=subjects_dir, add_dist=False,
#                                         n_jobs=-1, verbose=0)

In [None]:
%matplotlib qt
idx = 5

stc = mne.SourceEstimate(y_test[idx].toarray().T, vertices=[fwd_coarse["src"][0]["vertno"], fwd_coarse["src"][1]["vertno"]], tmin=0, tstep=1, subject="sample")
stc.subject = "sample"
stc.plot(brain_kwargs=dict(title=f"Ground Truth sample {idx}"), **pp)
evoked = mne.EvokedArray(x_test[idx].T, info)
evoked.plot_joint(title="Ground Truth")

solver = "SSM"
J = stc_dict[solver][idx].toarray()
stc = mne.SourceEstimate(J, vertices=[fwd_fine["src"][0]["vertno"], fwd_fine["src"][1]["vertno"]], tmin=0, tstep=1, subject="sample")
stc.subject = "sample"
stc.plot(brain_kwargs=dict(title=f"{solver} sample {idx}"), **pp)
x_hat = leadfield @ J
evoked = mne.EvokedArray(x_hat, info)
evoked.plot_joint(title=solver)

solver = "AP"
J = stc_dict[solver][idx].toarray()
stc = mne.SourceEstimate(J, vertices=[fwd_fine["src"][0]["vertno"], fwd_fine["src"][1]["vertno"]], tmin=0, tstep=1, subject="sample")
stc.subject = "sample"
stc.plot(brain_kwargs=dict(title=f"{solver} sample {idx}"), **pp)
x_hat = leadfield @ J
evoked = mne.EvokedArray(x_hat, info)
evoked.plot_joint(title=solver)

In [None]:
import numpy as np
from scipy.optimize import linear_sum_assignment

def localization_error_metric(true_positions, estimated_positions):
    """
    Calculate the localization error between the estimated and true dipole locations.
    
    Parameters
    ----------
    true_positions : np.ndarray, shape (n_true_dipoles, 3)
        The true positions of the dipoles.
    estimated_positions : np.ndarray, shape (n_estimated_dipoles, 3)
        The estimated positions of the dipoles.
    
    Returns
    -------
    mean_distance : float
        The sum of distances between the true and estimated dipole positions.
    """
    
    # Calculate the distance matrix between all pairs of true and estimated positions
    distance_matrix = np.linalg.norm(true_positions[:, np.newaxis] - estimated_positions, axis=2)
    
    # Solve the assignment problem (i.e., find the matching with minimum total distance)
    true_indices, estimated_indices = linear_sum_assignment(distance_matrix)
    
    # Calculate the sum of distances for the optimal assignment
    mean_distance = distance_matrix[true_indices, estimated_indices].mean()
    
    return mean_distance
    
# Example usage:
true_dipole_positions = np.array(
    [
        [-20, 31, 19],
        [24, 29, 21],
        # [28, -20, 21],
    ]
)

estimated_dipole_positions = np.array(
    [
        [-19, 25, 19],
        [24, 27, 22],
        # [30, -20, 22],
        [100, 100,100]
    ]
)
error = localization_error_metric(true_dipole_positions, estimated_dipole_positions)
print("Total localization error:", error)
# print("Matching estimated dipoles to true dipoles:", matches)


In [None]:
from scipy.spatial.distance import cdist
cdist(true_dipole_positions, estimated_dipole_positions)