In [1]:
import sys; sys.path.insert(0, r'../../invert')
import mne
import pickle as pkl
from time import time
from scipy.spatial.distance import cdist
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from invert.evaluate import eval_mean_localization_error
from tqdm.notebook import tqdm
import os
from time import sleep
from config import *

os.makedirs('results/', exist_ok=True)
estimator = "mean"
errorbar = ("ci", 95)
subject = "fsaverage"
rho = r"$\rho$"


In [2]:
def get_pos(fwd):
    pos_left = mne.vertex_to_mni(fwd["src"][0]["vertno"], 0, subject=subject, subjects_dir=subjects_dir, verbose=0)
    pos_right = mne.vertex_to_mni(fwd["src"][1]["vertno"], 1, subject=subject, subjects_dir=subjects_dir, verbose=0)
    pos = np.concatenate([pos_left, pos_right], axis=0)
    return pos

# Load Forward Models

In [17]:
base_path = r"D:\data\flex_ssm"

fwds = {}
fullpath = os.path.join(base_path, "forward_models", "fsaverage_biosemi256_Clean_coarse-fwd.fif")
fwd_ico4 = mne.read_forward_solution(fullpath, verbose=0)
fwd_ico4 = mne.convert_forward_solution(fwd_ico4, force_fixed=True, surf_ori=True, use_cps=True)
fwd_ico4.subject = subject
pos_ico4 = get_pos(fwd_ico4)

distances_ico4 = cdist(pos_ico4, pos_ico4)
adjacency_ico4 = mne.spatial_src_adjacency(fwd_ico4["src"], verbose=0)

fullpath = os.path.join(base_path, "forward_models", "fsaverage_biosemi256_Clean_fine-fwd.fif")

fwd_oct6 = mne.read_forward_solution(fullpath, verbose=0)
fwd_oct6 = mne.convert_forward_solution(fwd_oct6, force_fixed=True, surf_ori=True, use_cps=True)
fwd_oct6.subject = subject
pos_oct6 = get_pos(fwd_oct6)
distances_oct6 = cdist(pos_ico4, pos_oct6)
adjacency_oct6 = mne.spatial_src_adjacency(fwd_oct6["src"], verbose=0)

fullpath = os.path.join(base_path, "forward_models", "info_biosemi256.pkl")
with open(fullpath, "rb") as f:
    info = pkl.load(f)

    No patch info available. The standard source space normals will be employed in the rotation to the local surface coordinates....
    Changing to fixed-orientation forward solution with surface-based source orientations...
    [done]
    No patch info available. The standard source space normals will be employed in the rotation to the local surface coordinates....
    Changing to fixed-orientation forward solution with surface-based source orientations...
    [done]


# Evaluate all files

In [28]:
base_path = r"D:\data\flex_ssm"
path_evaluation = os.path.join(base_path, "predictions/")
path_results = os.path.join(base_path, "results/")
eval_filenames = os.listdir(path_evaluation)
adjacency_true = adjacency_ico4
pos_true = pos_ico4
eval_filenames = [f for f in eval_filenames if "non-greedy" in f]
for i, filename in enumerate(eval_filenames):
    fullpath = os.path.join(path_evaluation, filename)
    fn_results = os.path.join(path_results, filename.replace("sim_and_preds_", "results_")).replace(".pkl", "-2.pkl")

    print("FN: ", fn_results)
    if not ".pkl" in filename or os.path.isfile(fn_results):
        print("\tis processed or not a regular file")
        continue
    # break
    with open(fullpath, "rb") as f:
        stc_dict, _, y_test, sim_info, proc_time_make, proc_time_apply = pkl.load(f)
    
    if type(y_test) == list and type(y_test[0]) == list:
        y_test_batch = []

        for y1 in y_test:
            for y2 in y1:
                y_test_batch.append(y2)
        y_test = y_test_batch

    n_samples = len(y_test)
    print("\t", fn_results)
    if "fine" in filename.lower() or (not "coarse" in filename.lower() and not "fine" in filename.lower()):
        print("\t\tyep")
        adjacency_pred = adjacency_oct6
        distances = distances_oct6
        pos_pred = pos_oct6
    else:
        adjacency_pred = adjacency_ico4
        distances = distances_ico4
        pos_pred = pos_ico4
    

    results = []
    for solver_name in stc_dict.keys():
        print(solver_name)
        for i in tqdm(range(n_samples)):
            # y_pred = stc_dict[solver_name][i].data
            y_pred = stc_dict[solver_name][i].toarray()
            # y_true = y_test[i].T.toarray()
            y_true = y_test[i].toarray().T
            if "non-greedy" in filename:
                mle_match = eval_mean_localization_error(
                    abs(y_true).mean(axis=-1), 
                    abs(y_pred).mean(axis=-1), 
                    adjacency_true, adjacency_pred, 
                    pos_true, pos_pred, distances,
                    mode="match", threshold=0.75, max_maxima=2, max_iter=0)
            else:
                mle_match = eval_mean_localization_error(
                    abs(y_true).mean(axis=-1), 
                    abs(y_pred).mean(axis=-1), 
                    adjacency_true, adjacency_pred, 
                    pos_true, pos_pred, distances,
                    mode="match")

            result = dict(Method=solver_name, mle_match=mle_match, Time_Make=proc_time_make[solver_name][i], Time_Apply=proc_time_apply[solver_name][i])
            
            result.update(sim_info.iloc[i, :].to_dict())
            results.append(result)
            # if i%100 == 0:
            #     print("\t ", i)
    del stc_dict, y_test, proc_time_make, proc_time_apply
    with open(fn_results, 'wb') as f:
        pkl.dump(results, f)

FN:  D:\data\flex_ssm\results/results_figure-1_Clean-Coarse_non-greedy-2.pkl
	 D:\data\flex_ssm\results/results_figure-1_Clean-Coarse_non-greedy-2.pkl
MCMV


  0%|          | 0/2800 [00:00<?, ?it/s]

Champagne


  0%|          | 0/2800 [00:00<?, ?it/s]

FN:  D:\data\flex_ssm\results/results_figure-1_Clean-Fine_non-greedy-2.pkl
	 D:\data\flex_ssm\results/results_figure-1_Clean-Fine_non-greedy-2.pkl
		yep
MCMV


  0%|          | 0/2800 [00:00<?, ?it/s]

Champagne


  0%|          | 0/2800 [00:00<?, ?it/s]

FN:  D:\data\flex_ssm\results/results_figure-1_Ratio-30-Fine_non-greedy-2.pkl
	 D:\data\flex_ssm\results/results_figure-1_Ratio-30-Fine_non-greedy-2.pkl
		yep
MCMV


  0%|          | 0/2800 [00:00<?, ?it/s]

KeyboardInterrupt: 

# Plot Figure 1

In [16]:
%matplotlib qt
import matplotlib.pyplot as plt
import pickle as pkl
import pandas as pd
import seaborn as sns
import os
import re
sns.set_theme(style="ticks", font_scale=1.2)

# filenames = os.listdir("results")
filenames = {
    'No Forward Error': 'results_figure-1_Clean Coarse.pkl',
    'Source Modelling Error': 'results_figure-1_Clean Fine.pkl',
    'Source Modelling\n + Conductivity Error (1:50)': 'results_figure-1_Ratio-50 Fine.pkl',
    'Source Modelling\n + Conductivity Error (1:30)': 'results_figure-1_Ratio-30 Fine.pkl',
}

n_sources_range = (2, 3)
keep_isc = (0.1, 0.5, 0.9, 0.95)
ylim = (0, 30)
hue_order = ["SSM", "AP", "RAP"]
row_titles = ("A", "B")
fig, axs = plt.subplots(2, len(filenames), figsize=(14, 8), sharey=True, sharex=True)

for i_row, n_sources in enumerate(n_sources_range):
    for i_col, (title, filename) in enumerate(filenames.items()):        
        filename = os.path.join(base_path, "results", filename)
        
        if not os.path.isfile(filename):
            print(f"{filename} does not exist")
            continue
        with open(filename, 'rb') as f:
            results = pkl.load(f)

        df = pd.DataFrame(results)
        df_not_ssm = df[~df["Method"].str.contains("SSM")]
        # more elegant
        df_ssm = df[df["Method"].str.contains("SSM")]
        # Pick the sub type of SSM with the lowest mean mle_match
        df_temp = df_ssm.groupby("Method").mean(numeric_only=True).sort_values("mle_match").reset_index()  # Sort by mean mle_match
        best_ssm = df_temp.iloc[0]["Method"]  # Pick the best SSM
        df_ssm = df_ssm[df_ssm["Method"] == best_ssm]  # Pick the best SSM
        # rename the best SSM to SSM
        df_ssm["Method"] = "SSM"
        print(f"Best SSM: {best_ssm}")

        # Concatenate the two dataframes
        df = pd.concat([df_ssm, df_not_ssm])

        # Select certain ISCS
        df = df[df["inter_source_correlations"].isin(keep_isc)]
        
        ax = axs[i_row, i_col]
        sns.barplot(
            hue="Method", 
            x="inter_source_correlations", 
            y="mle_match", 
            errorbar=errorbar, 
            estimator=estimator, 
            # hue_order=hue_order,
            ax=ax,
            data=df)
        ax.set_ylim(*ylim)

        if i_row == 0:
            ax.set_title(title)
            ax.set_xlabel("")
        else:
            ax.set_title("")
            ax.set_xlabel(f"Inter-source correlation")
        if i_col == 0:
            row_title = row_titles[i_row]
            ax.text(-0.3, 1.1, row_title, transform=ax.transAxes, fontsize=24, fontweight='bold', va='top', ha='right')
            ax.set_ylabel(f"{n_sources} sources\n\nMLE (mm)")
        else:
            ax.set_ylabel("")
        # Remove legend
        # if (i_row < 1 and i_col < 3):
        ax.get_legend().remove()
        # else:
        # legend = ax.get_legend()
    #     break
    # break

# legend is horizontal bar below the figure:
handles, labels = ax.get_legend_handles_labels()
fig.subplots_adjust(bottom=0.2)
fig.legend(handles[:3], labels[:3], loc='upper center', bbox_to_anchor=(0.5, 0.12), ncol=3)

# save figure
# for dpi in (150, 300):
#     fig.savefig(f'figures/biosemi256/Figure_1_{estimator}s_{dpi}.png', format='png', dpi=dpi)

Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2


<matplotlib.legend.Legend at 0x235a8728b80>

## ISBI

In [61]:
%matplotlib qt
import matplotlib.pyplot as plt
import pickle as pkl
import pandas as pd
import seaborn as sns
import os
import re
sns.set_theme(style="ticks", font_scale=1.2)

filenames = {
    'No Forward Error': 'results_figure-1_Clean Coarse.pkl',
    'Source Modelling Error': 'results_figure-1_Clean Fine.pkl',
    'Source Modelling\n + Conductivity Error (1:50)': 'results_figure-1_Ratio-50 Fine.pkl',
    'Source Modelling\n + Conductivity Error (1:30)': 'results_figure-1_Ratio-30 Fine.pkl',
}
best_ssm = "SSM_-3"
n_sources_range = (2, 3)
keep_isc = (0.1, 0.5, 0.9, 0.95)
ylim = (0, 21)
hue_order = ["SSM", "AP", "RAP", "MCMV", "Champagne"]
col_titles = ("A", "B")
fig, axs = plt.subplots(4, 2, figsize=(9, 14), sharey=False, sharex=False)

for i_col, n_sources in enumerate(n_sources_range):
    for i_row, (title, filename) in enumerate(filenames.items()):        
        filename = os.path.join(base_path, "results", filename)
        print(filename)
        
        if not os.path.isfile(filename):
            print(f"{filename} does not exist")
            continue
        with open(filename, 'rb') as f:
            results = pkl.load(f)

        # filename_non_greedy = filename.split(".")[0].replace(" ", "-") + "_non-greedy.pkl"
        # with open(filename_non_greedy, 'rb') as f:
        #     results_non_greedy = pkl.load(f)
        # results.extend(results_non_greedy)


        df = pd.DataFrame(results)
        df_not_ssm = df[~df["Method"].str.contains("SSM")]
        df_ssm = df[df["Method"].str.contains("SSM")]
        # df_temp = df_ssm.groupby("Method").mean(numeric_only=True).sort_values("mle_match").reset_index()
        # best_ssm = df_temp.iloc[0]["Method"]
        df_ssm = df_ssm[df_ssm["Method"] == best_ssm]
        df_ssm["Method"] = "SSM"
        print(f"Best SSM: {best_ssm}")

        df = pd.concat([df_ssm, df_not_ssm])
        df = df[df["inter_source_correlations"].isin(keep_isc)]
        
        ax = axs[i_row, i_col]
        sns.barplot(
            hue="Method", 
            x="inter_source_correlations", 
            y="mle_match", 
            errorbar=errorbar, 
            estimator=estimator, 
            ax=ax,
            data=df)
        ax.set_ylim(*ylim)

        if i_row == 0:
            ax.set_title(f"{n_sources} sources")
        ax.set_xlabel("")
        
        # ax.set_xlabel("Inter-source correlation")

        if i_row == 0:
            col_title = col_titles[i_col]
            ax.text(0.5, 1.4, col_title, transform=ax.transAxes, fontsize=24, fontweight='bold', va='top', ha='right')
        if i_col == 0:
            ax.set_ylabel(f"{title}\n\nMLE (mm)", fontsize=12)
        else:
            ax.set_ylabel("")
        ax.get_legend().remove()

# Set the x-label for the bottom row
# for row in range(axs.shape[0]):
for ax in axs[-1, :]:
    ax.set_xlabel("Inter-source correlation")

# legend is horizontal bar below the figure:
handles, labels = ax.get_legend_handles_labels()
# Increase left margin and adjust other parameters as neede
fig.subplots_adjust(left=0.15, bottom=0.125, right=0.95, top=0.9, hspace=0.3)
fig.legend(handles, hue_order, loc='upper center', bbox_to_anchor=(0.55, 0.04), ncol=len(hue_order))
plt.tight_layout(pad=2)
# save figure
for dpi in (150, 300):
    fig.savefig(os.path.join(base_path, "figures", "biosemi256", f"Figure_1_{estimator}s_transposed_{dpi}.png"), format='png', dpi=dpi)

D:\data\flex_ssm\results\results_figure-1_Clean Coarse.pkl
Best SSM: SSM_-3
D:\data\flex_ssm\results\results_figure-1_Clean Fine.pkl
Best SSM: SSM_-3
D:\data\flex_ssm\results\results_figure-1_Ratio-50 Fine.pkl
Best SSM: SSM_-3
D:\data\flex_ssm\results\results_figure-1_Ratio-30 Fine.pkl
Best SSM: SSM_-3
D:\data\flex_ssm\results\results_figure-1_Clean Coarse.pkl
Best SSM: SSM_-3
D:\data\flex_ssm\results\results_figure-1_Clean Fine.pkl
Best SSM: SSM_-3
D:\data\flex_ssm\results\results_figure-1_Ratio-50 Fine.pkl
Best SSM: SSM_-3
D:\data\flex_ssm\results\results_figure-1_Ratio-30 Fine.pkl
Best SSM: SSM_-3


# Plot Figure 2 - Noise Color Modes

In [28]:
%matplotlib qt
import matplotlib.pyplot as plt
import pickle as pkl
import pandas as pd
import seaborn as sns
import os
import re

sns.set_theme(style="ticks", font_scale=1.1)

filenames = {
    'No Forward Error': os.path.join(base_path, "results", 'results_figure-2_Clean Coarse.pkl'),
    'Source Modelling Error': os.path.join(base_path, "results", 'results_figure-2_Clean Fine.pkl'),
    'Source Modelling\n + Conductivity Error (1:50)': os.path.join(base_path, "results", 'results_figure-2_Ratio-50 Fine.pkl'),
    'Source Modelling\n + Conductivity Error (1:30)': os.path.join(base_path, "results", 'results_figure-2_Ratio-30 Fine.pkl'),
}

noise_color_modes = ('diagonal', 'banded', 'cholesky')
noise_color_modes_texts = ("Diagonal", "Banded", "Cholesky")
row_labels = ("A", "B", "C")
xlabels = ("SD of noise power", "Noise color coefficient", "Noise color coefficient")

for isc in (0.5, 0.9):
    for i_fig, snr in enumerate((-5, 0, 5)):
        fig, axs = plt.subplots(len(noise_color_modes), len(filenames), figsize=(14, 12), sharey=False, sharex=False)
        
        # Share y-axes for the first 3 rows (0 to 2)
        for row in range(3):
            for col in range(1, len(filenames)):
                axs[row, col].get_shared_y_axes().join(axs[row, col], axs[row, 0])

        # # Share y-axes for the last row (3)
        # for col in range(1, len(filenames)):
        #     axs[-1, col].get_shared_y_axes().join(axs[-1, col], axs[-1, 0])

        fig.suptitle(f"SNR: {snr} dB, {rho}={isc}", fontsize=16)
        for i_row, (noise_color_mode, noise_color_modes_text) in enumerate(zip(noise_color_modes, noise_color_modes_texts)):
            for i_col, (title, filename) in enumerate(filenames.items()):        
                if not os.path.isfile(filename):
                    print(f"{filename} does not exist")
                    continue
                with open(filename, 'rb') as f:
                    results = pkl.load(f)

                df = pd.DataFrame(results)
                # Select correlation mode
                df["correlation_mode"] = df["correlation_mode"].astype(str)
                df = df[df["correlation_mode"] == noise_color_mode]
                df["correlation_mode"] = noise_color_modes_text
                # Select ISC
                df = df[df["inter_source_correlations"] == isc]
                # Select SNR
                df = df[df["snr"] == snr]
                
                # Select Reguarlization for SSM
                df_not_ssm = df[~df["Method"].str.contains("SSM")]
                df_ssm = df[df["Method"].str.contains("SSM")]
                # Pick the sub type of SSM with the lowest mean mle_match
                df_temp = df_ssm.groupby("Method").mean(numeric_only=True).sort_values("mle_match").reset_index()  # Sort by mean mle_match
                best_ssm = df_temp.iloc[0]["Method"]  # Pick the best SSM
                df_ssm = df_ssm[df_ssm["Method"] == best_ssm]  # Pick the best SSM
                # rename the best SSM to SSM
                df_ssm["Method"] = "SSM"
                print(f"Best SSM: {best_ssm}")

                # Concatenate the two dataframes
                df = pd.concat([df_ssm, df_not_ssm])
                
                ax = axs[i_row, i_col]
                sns.barplot(
                    hue="Method", 
                    x="noise_color_coeff", 
                    y="mle_match", 
                    errorbar=errorbar, 
                    estimator=estimator, 
                    ax=ax,
                    data=df)
                ax.set_xlabel(xlabels[i_row])
                if i_row == 0:
                    ax.set_title(title)
                elif i_row == 1:
                    pass
                else:
                    ax.set_title("")
                    
                if i_col == 0:
                    ax.set_ylabel(f"{noise_color_modes_text}\n\nMLE (mm)")
                    row_label = row_labels[i_row]
                    ax.text(-0.17, 1.1, row_label, transform=ax.transAxes, fontsize=24, va='top', ha='right')
                else:
                    ax.set_ylabel("")
                
                # Remove legend
                ax.get_legend().remove()
            #     break
            # break

        # legend is horizontal bar below the figure:
        handles, labels = ax.get_legend_handles_labels()
        # increase vertical space between the plots


        fig.subplots_adjust(bottom=0.2, hspace=0.35)
        fig.legend(handles[:3], labels[:3], loc='upper center', bbox_to_anchor=(0.5, 0.15), ncol=3)

        # save figure
        # for dpi in (150, 300):
        #     fig.savefig(f'figures/biosemi256/Figure_2_{snr}dB_rho{isc}_{estimator}s_{dpi}.png', format='png', dpi=dpi)

  axs[row, col].get_shared_y_axes().join(axs[row, col], axs[row, 0])


Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-1
Best SSM: SSM_-1
Best SSM: SSM_-1
Best SSM: SSM_-1
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3


  axs[row, col].get_shared_y_axes().join(axs[row, col], axs[row, 0])


Best SSM: SSM_-3
Best SSM: SSM_-1
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-1
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3


  axs[row, col].get_shared_y_axes().join(axs[row, col], axs[row, 0])


Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3


  axs[row, col].get_shared_y_axes().join(axs[row, col], axs[row, 0])


Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-1
Best SSM: SSM_-1
Best SSM: SSM_-1
Best SSM: SSM_-3
Best SSM: SSM_-2
Best SSM: SSM_-3
Best SSM: SSM_-2


  axs[row, col].get_shared_y_axes().join(axs[row, col], axs[row, 0])


Best SSM: SSM_-2
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3


  axs[row, col].get_shared_y_axes().join(axs[row, col], axs[row, 0])


Best SSM: SSM_-3
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3


## ISBI

In [63]:
%matplotlib qt
import matplotlib.pyplot as plt
import pickle as pkl
import pandas as pd
import seaborn as sns
import os
import re

sns.set_theme(style="ticks", font_scale=1.1)
best_ssm = "SSM_-3"
filenames = {
    'No Forward Error': os.path.join(base_path, "results", 'results_figure-2_Clean Coarse.pkl'),
    'Source Modelling Error': os.path.join(base_path, "results", 'results_figure-2_Clean Fine.pkl'),
    'Source Modelling\n + Conductivity Error (1:50)': os.path.join(base_path, "results", 'results_figure-2_Ratio-50 Fine.pkl'),
    'Source Modelling\n + Conductivity Error (1:30)': os.path.join(base_path, "results", 'results_figure-2_Ratio-30 Fine.pkl'),
}

noise_color_modes = ('diagonal', 'banded', 'cholesky')
noise_color_modes_texts = ("Diagonal", "Banded", "Cholesky")
col_labels = ("A", "B", "C")
xlabels = ("SD of noise power", "Noise color coefficient", "Noise color coefficient")
ylims = ((0, 18), (0, 18), (0, 60))
for isc in (0.9,):  # (0.5, 0.9):
    for i_fig, snr in enumerate((0, )):  # enumerate((-5, 0, 5)):
        fig, axs = plt.subplots(len(filenames), len(noise_color_modes), figsize=(8, 12), sharey=False, sharex=False)
        
        # Share y-axes for each row
        # for row in range(len(filenames)):
        #     for col in range(1, len(noise_color_modes)):
        #         axs[row, col].get_shared_y_axes().join(axs[row, col], axs[row, 0])

        # fig.suptitle(f"SNR: {snr} dB, {rho}={isc}", fontsize=16)
        for i_col, (noise_color_mode, noise_color_modes_text) in enumerate(zip(noise_color_modes, noise_color_modes_texts)):
            for i_row, (title, filename) in enumerate(filenames.items()):        
                if not os.path.isfile(filename):
                    print(f"{filename} does not exist")
                    continue
                with open(filename, 'rb') as f:
                    results = pkl.load(f)
                
                # filename_non_greedy = filename.split(".")[0].replace(" ", "-") + "_non-greedy.pkl"
                # with open(filename_non_greedy, 'rb') as f:
                #     results_non_greedy = pkl.load(f)
                # results.extend(results_non_greedy)

                df = pd.DataFrame(results)
                # Select correlation mode
                df["correlation_mode"] = df["correlation_mode"].astype(str)
                df = df[df["correlation_mode"] == noise_color_mode]
                df["correlation_mode"] = noise_color_modes_text
                # Select ISC
                df = df[df["inter_source_correlations"] == isc]
                # Select SNR
                df = df[df["snr"] == snr]
                
                # Select Regularization for SSM
                df_not_ssm = df[~df["Method"].str.contains("SSM")]
                df_ssm = df[df["Method"].str.contains("SSM")]
                # Pick the sub type of SSM with the lowest mean mle_match
                # df_temp = df_ssm.groupby("Method").mean(numeric_only=True).sort_values("mle_match").reset_index()  # Sort by mean mle_match
                # best_ssm = df_temp.iloc[0]["Method"]  # Pick the best SSM
                df_ssm = df_ssm[df_ssm["Method"] == best_ssm]  # Pick the best SSM
                # rename the best SSM to SSM
                df_ssm["Method"] = "SSM"
                print(f"Best SSM: {best_ssm}")

                # Concatenate the two dataframes
                df = pd.concat([df_ssm, df_not_ssm])
                
                ax = axs[i_row, i_col]
                sns.barplot(
                    hue="Method", 
                    x="noise_color_coeff", 
                    y="mle_match", 
                    errorbar=errorbar, 
                    estimator=estimator, 
                    ax=ax,
                    data=df)
                
                if i_row == 3:
                    ax.set_xlabel(xlabels[i_col])
                else:
                    ax.set_xlabel("")
                if i_row == 0:
                    ax.set_title(f"{noise_color_modes_text}")
                    ax.text(0.5, 1.4, col_labels[i_col], transform=ax.transAxes, fontsize=18, fontweight='bold', va='top', ha='right')
                else:
                    ax.set_title("")
                    
                if i_col == 0:
                    ax.set_ylabel(f"{title}\n\nMLE (mm)", fontsize=12)
                else:
                    ax.set_ylabel("")
                ax.set_ylim(ylims[i_col])
                # Remove legend
                ax.get_legend().remove()

        # legend is horizontal bar below the figure:
        handles, labels = ax.get_legend_handles_labels()
        # increase vertical space between the plots

        fig.subplots_adjust(bottom=0.1, hspace=0.25, wspace=0.25, left=0.15)
        fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.05), ncol=3)

        # save figure
        for dpi in (150, 300):
            fig.savefig(os.path.join(base_path, "figures", "biosemi256", f"Figure_2_{snr}dB_rho{isc}_{estimator}s_{dpi}_tranposed.png"), format='png', dpi=dpi)

Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3


In [11]:
os.path.join(base_path, "figures", "biosemi256", f"Figure_2_{snr}dB_rho{isc}_{estimator}s_{dpi}.png")

'D:\\data\\flex_ssm\\figures\\biosemi256\\Figure_2_0dB_rho0.9_means_300.png'

# Figure 3 - Time Points

In [35]:
%matplotlib qt
import matplotlib.pyplot as plt
import pickle as pkl
import pandas as pd
import seaborn as sns
import os
import re
sns.set_theme(style="ticks", font_scale=1.2)

# filenames = os.listdir("results")
filenames = {
    'No Forward Error': 'results_figure-3_Clean Coarse.pkl',
    'Source Modelling Error': 'results_figure-3_Clean Fine.pkl',
    'Source Modelling\n + Conductivity Error (1:50)': 'results_figure-3_Ratio-50 Fine.pkl',
    'Source Modelling\n + Conductivity Error (1:30)': 'results_figure-3_Ratio-30 Fine.pkl',
}

isc_range = (0.5, 0.9, 0.95)
row_labels = ("A", "B", "C")
fig, axs = plt.subplots(len(isc_range), len(filenames), figsize=(14, 8), sharey=True, sharex=True)

for i_row, isc in enumerate(isc_range):
    for i_col, (title, filename) in enumerate(filenames.items()):        
        filename = os.path.join("results", filename)
        if not os.path.isfile(filename):
            print(f"{filename} does not exist")
            continue
        with open(filename, 'rb') as f:
            results = pkl.load(f)

        df = pd.DataFrame(results)
        df = df[df["inter_source_correlations"] == isc]

        df_not_ssm = df[~df["Method"].str.contains("SSM")]
        df_ssm = df[df["Method"].str.contains("SSM")]
        # Pick the sub type of SSM with the lowest mean mle_match
        df_temp = df_ssm.groupby("Method").mean(numeric_only=True).sort_values("mle_match").reset_index()  # Sort by mean mle_match
        best_ssm = df_temp.iloc[0]["Method"]  # Pick the best SSM
        df_ssm = df_ssm[df_ssm["Method"] == best_ssm]  # Pick the best SSM
        # rename the best SSM to SSM
        df_ssm["Method"] = "SSM"
        print(f"Best SSM: {best_ssm}")

        # Concatenate the two dataframes
        df = pd.concat([df_ssm, df_not_ssm])
        
        ax = axs[i_row, i_col]
        sns.barplot(
            hue="Method", 
            x="n_timepoints", 
            y="mle_match", 
            errorbar=errorbar, 
            estimator=estimator, 
            ax=ax,
            data=df)
        # ax.set_ylim(*ylim)

        if i_row == 0:
            ax.set_title(title)
            ax.set_xlabel("")
        else:
            ax.set_title("")
            ax.set_xlabel(f"No. of Time Samples")
        if i_col == 0:
            row_label = row_labels[i_row]
            ax.text(-0.25, 1.1, row_label, transform=ax.transAxes, fontsize=24, fontweight='bold', va='top', ha='right')
            ax.set_ylabel(f"{rho}={isc}\n\nMLE (mm)")
        else:
            ax.set_ylabel("")
        # Remove legend
        ax.get_legend().remove()
    #     break
    # break        

# legend is horizontal bar below the figure:
handles, labels = ax.get_legend_handles_labels()
fig.subplots_adjust(bottom=0.2)
fig.legend(handles[:3], labels[:3], loc='upper center', bbox_to_anchor=(0.5, 0.12), ncol=3)

# save figure
for dpi in (150, 300):
    fig.savefig(f'figures/biosemi256/Figure_3_{estimator}s_{dpi}.png', format='png', dpi=dpi)

Best SSM: SSM_-3
Best SSM: SSM_-1
Best SSM: SSM_-3
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2


# Figure 4 - Misestimation

In [36]:
%matplotlib qt
import matplotlib.pyplot as plt
import pickle as pkl
import pandas as pd
import seaborn as sns
import os
import re
sns.set_theme(style="ticks", font_scale=1.2)

# filenames = os.listdir("results")
filenames = {
    'No Forward Error': 'results_figure-1_Clean Coarse-XXX-assumed.pkl',
    'Source Modelling Error': 'results_figure-1_Clean Fine-XXX-assumed.pkl',
    'Source Modelling\n + Conductivity Error (1:50)': 'results_figure-1_Ratio-50 Fine-XXX-assumed.pkl',
    'Source Modelling\n + Conductivity Error (1:30)': 'results_figure-1_Ratio-30 Fine-XXX-assumed.pkl',
}

keep_isc = (0.1, 0.5, 0.9, 0.95)
n_sources_assumed_range = (1,2,3,4,5)
row_titles = ("A", "B", "C", "D", "E")
fig, axs = plt.subplots(len(n_sources_assumed_range), len(filenames), figsize=(13, 10), sharey=True, sharex=True)
n_hat = r"$\hat{n}$"
for i_row, n_sources in enumerate(n_sources_assumed_range):
    row_title = row_titles[i_row]
    for i_col, (title, filename) in enumerate(filenames.items()):        
        filename = os.path.join("results", filename)
        filename = filename.replace("XXX", str(n_sources))
        if not os.path.isfile(filename):
            print(f"{filename} does not exist")
            continue
        with open(filename, 'rb') as f:
            results = pkl.load(f)

        df = pd.DataFrame(results)

        df_not_ssm = df[~df["Method"].str.contains("SSM")]
        df_ssm = df[df["Method"].str.contains("SSM")]
        # Pick the sub type of SSM with the lowest mean mle_match
        df_temp = df_ssm.groupby("Method").mean(numeric_only=True).sort_values("mle_match").reset_index()  # Sort by mean mle_match
        best_ssm = df_temp.iloc[0]["Method"]  # Pick the best SSM
        df_ssm = df_ssm[df_ssm["Method"] == best_ssm]  # Pick the best SSM
        # rename the best SSM to SSM
        df_ssm["Method"] = "SSM"
        print(f"Best SSM: {best_ssm}")

        # Concatenate the two dataframes
        df = pd.concat([df_ssm, df_not_ssm])
        
        ax = axs[i_row, i_col]
        sns.barplot(
            hue="Method", 
            x="inter_source_correlations", 
            y="mle_match", 
            errorbar=errorbar, 
            estimator=estimator, 
            ax=ax,
            data=df)
        # ax.set_ylim(*ylim)

        if i_row == 0:
            ax.set_title(title)
            ax.set_xlabel("")
        else:
            ax.set_title("")
            ax.set_xlabel(f"{rho}")
        if i_col == 0:
            ax.text(-0.3, 1.1, row_title, transform=ax.transAxes, fontsize=18, fontweight='bold', va='top', ha='right')
            if i_row == 2:
                ax.set_ylabel(f"Sources Assumed\n{n_hat}={n_sources} (correct)\n\nMLE (mm)")
            else:
                ax.set_ylabel(f"{n_hat}={n_sources}\n\nMLE (mm)")
        else:
            ax.set_ylabel("")
        # Remove legend
        ax.get_legend().remove()
    #     break
    # break
        
# legend is horizontal bar below the figure:
handles, labels = ax.get_legend_handles_labels()
fig.subplots_adjust(bottom=0.2)
fig.legend(handles[:3], labels[:3], loc='upper center', bbox_to_anchor=(0.5, 0.12), ncol=3)

# save figure
for dpi in (150, 300):
    fig.savefig(f'figures/biosemi256/Figure_1-misestimation_{estimator}s_{dpi}.png', format='png', dpi=dpi)

Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-2
Best SSM: SSM_-1
Best SSM: SSM_-1
Best SSM: SSM_-1


# Figure App.1 - Misestimation with Cholesky Noise

In [52]:
%matplotlib qt
import matplotlib.pyplot as plt
import pickle as pkl
import pandas as pd
import seaborn as sns
import os
import re
sns.set_theme(style="ticks", font_scale=1.2)

filenames = {
    'No Forward Error': 'results_figure-12_Clean Coarse-XXX-assumed.pkl',
    'Source Modelling Error': 'results_figure-12_Clean Fine-XXX-assumed.pkl',
    'Source Modelling\n + Conductivity Error (1:50)': 'results_figure-12_Ratio-50 Fine-XXX-assumed.pkl',
    'Source Modelling\n + Conductivity Error (1:30)': 'results_figure-12_Ratio-30 Fine-XXX-assumed.pkl',
}

keep_isc = (0.1, 0.5, 0.9, 0.95)
n_sources_assumed_range = (1,2,3,4,5)
row_titles = ("A", "B", "C", "D", "E")
fig, axs = plt.subplots(len(n_sources_assumed_range), len(filenames), figsize=(13, 10), sharey=True, sharex=True)
n_hat = r"$\hat{n}$"
for i_row, n_sources in enumerate(n_sources_assumed_range):
    row_title = row_titles[i_row]
    for i_col, (title, filename) in enumerate(filenames.items()):        
        filename = os.path.join("results", filename)
        filename = filename.replace("XXX", str(n_sources))
        if not os.path.isfile(filename):
            print(f"{filename} does not exist")
            continue
        with open(filename, 'rb') as f:
            results = pkl.load(f)

        df = pd.DataFrame(results)

        df_not_ssm = df[~df["Method"].str.contains("SSM")]
        df_ssm = df[df["Method"].str.contains("SSM")]
        # Pick the sub type of SSM with the lowest mean mle_match
        df_temp = df_ssm.groupby("Method").mean(numeric_only=True).sort_values("mle_match").reset_index()  # Sort by mean mle_match
        best_ssm = df_temp.iloc[0]["Method"]  # Pick the best SSM
        df_ssm = df_ssm[df_ssm["Method"] == best_ssm]  # Pick the best SSM
        # rename the best SSM to SSM
        df_ssm["Method"] = "SSM"
        print(f"Best SSM: {best_ssm}")

        # Concatenate the two dataframes
        df = pd.concat([df_ssm, df_not_ssm])
        
        ax = axs[i_row, i_col]
        sns.barplot(
            hue="Method", 
            x="inter_source_correlations", 
            y="mle_match", 
            errorbar=errorbar, 
            estimator=estimator, 
            ax=ax,
            data=df)
        # ax.set_ylim(*ylim)

        if i_row == 0:
            ax.set_title(title)
            ax.set_xlabel("")
        else:
            ax.set_title("")
            ax.set_xlabel(f"{rho}")
        if i_col == 0:
            ax.text(-0.3, 1.1, row_title, transform=ax.transAxes, fontsize=18, fontweight='bold', va='top', ha='right')
            if i_row == 2:
                ax.set_ylabel(f"Sources Assumed\n{n_hat}={n_sources} (correct)\n\nMLE (mm)")
            else:
                ax.set_ylabel(f"{n_hat}={n_sources}\n\nMLE (mm)")
        else:
            ax.set_ylabel("")
        # Remove legend
        ax.get_legend().remove()
    #     break
    # break

# legend is horizontal bar below the figure:
handles, labels = ax.get_legend_handles_labels()
fig.subplots_adjust(bottom=0.2)
fig.legend(handles[:3], labels[:3], loc='upper center', bbox_to_anchor=(0.5, 0.12), ncol=3)

# save figure
for dpi in (150, 300):
    fig.savefig(f'figures/biosemi256/Figure_12-misestimation_{estimator}s_{dpi}.png', format='png', dpi=dpi)

Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-3
Best SSM: SSM_-2
Best SSM: SSM_-3
Best SSM: SSM_-3
