In [None]:
import os
current_directory = os.getcwd()
import sys; sys.path.insert(0, current_directory)

In [None]:
import mne
import pickle as pkl
import os
import numpy as np
import matplotlib.pyplot as plt
import time
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()

#### =========================================================================
#### Section 1: Data Loading Setup
#### This section defines a helper function to load data from pickle files
#### and sets up the directory paths for loading the pre-computed EMD
#### (Earth Mover's Distance) results that are ready for plotting.
#### =========================================================================

def load_data(file_path):
    with open(file_path, 'rb') as f:
        data = pkl.load(f)
    return data

In [None]:
# Define the folder path for loading EMD results
folder_load = "Results"
folder_pathload = os.path.join(os.getcwd(), folder_load)

# Define the folder path for evaluated inverse solutions (not used here but kept for context)
folder_name = "Evaluated_Data"
folder_path = os.path.join(os.getcwd(), folder_name)

#### =========================================================================
#### Section 2: Main Plotting Loop
#### This is the primary execution block. It systematically iterates through
#### simulation parameters (SNR, source correlation, patch ranks) and,
#### crucially, different inverse models ('coarse-80', 'fine-80', 'fine-50').
#### This allows for the creation of plots that show how different types of
#### modeling errors affect the performance of various source localization methods.
#### =========================================================================

In [None]:
batch_size = 10 # number of monte-carlo repetitions
Patchranks_Full = [[1,2]]
Smoothness_order = 2
for snr_db in range(-5,5,5):
    for corr_coeff in [0.3, 0.5, 0.9]:
        for Patchranks in Patchranks_Full:
            inv_NAME = []
            # Define lists to store statistics for each modeling error condition
            all_medians_values = []
            all_std_devs_values = []
            for inv_name in ['coarse-80', 'fine-80', 'fine-50']:
                start_time = time.time()
                # Construct filename and load the pre-calculated EMD results for the current condition
                filename = f"EMD_Len_{inv_name}_Data_corr_{corr_coeff}_smooth_{Smoothness_order}_patchranks_{Patchranks}_snr_{snr_db}.pkl"
                file_path = os.path.join(folder_pathload, filename)
                EMD_results = load_data(file_path)

                # =====================================================
                # Subsection 2.1: Calculate Summary Statistics
                # This block processes the loaded EMD results. It calculates
                # the median (as a measure of central tendency) and the
                # standard error of the mean (for error bars) for each
                # source localization method.
                # =====================================================
                EMD_medians = {key: np.median(values) for key, values in EMD_results.items()}
                EMD_stderr = {key: np.std(values) / np.sqrt(len(values)) for key, values in EMD_results.items()}

                # Extract method names and their corresponding statistics
                method_names = list(EMD_medians.keys())
                medians_values = list(EMD_medians.values())
                std_devs_values = list( EMD_stderr.values())

                # Standardize method names for consistent labeling
                for im,aa in enumerate(method_names):
                    if method_names[im] == "Patch AP":
                        method_names[im] = "PATCH AP"
                    elif method_names[im] == "Patch RAP":
                        method_names[im] = "PATCH RAP"

                # =====================================================
                # Subsection 2.2: Reorder Data for Consistent Plotting
                # To ensure the bars for each method appear in the same
                # order across all plots, this block reorders the methods
                # and their statistics according to a predefined list.
                # =====================================================
                desired_order = ['dSPM', 'MNE', 'Convexity-Champagne','sLORETA','RAP-MUSIC','FLEX-MUSIC','PATCH RAP','AP','FLEX-AP',
                                 'PATCH AP']
                reordered_method_names, reordered_medians_values,reordered_std_devs_values = [],[],[]
                for method in desired_order:
                    if method in method_names:
                        index = method_names.index(method)
                        reordered_method_names.append(method)
                        reordered_medians_values.append(medians_values[index])
                        reordered_std_devs_values.append(std_devs_values[index])
                method_names = reordered_method_names
                medians_values = reordered_medians_values
                std_devs_values = reordered_std_devs_values

                # Append the ordered stats for the current modeling error to the main lists
                all_medians_values.append(medians_values)
                all_std_devs_values.append(std_devs_values)
                inv_NAME.append(inv_name)

            # =========================================================
            # Section 3: Generate and Save the Plot
            # This section uses Matplotlib to create a bar chart comparing
            # the performance (median EMD) of all localization methods
            # across different modeling error conditions. The resulting
            # plot is customized for clarity and saved as a high-resolution image.
            # =========================================================
            # Set plot properties
            bar_width = 0.35
            fwd_name = inv_name

            # Initialize the figure and axes
            fig, ax = plt.subplots(figsize=(18, 11))

            # Loop through each modeling error condition and plot a group of bars
            index = 0
            for ss, ssnr in enumerate(inv_NAME):
                print(ssnr)
                medians_values = [vals for vals in all_medians_values[ss]]
                std_devs_values = [stds for stds in all_std_devs_values[ss]]

                # Plot one bar for each method within the current group
                for ii, method_name in enumerate(method_names):
                    if ss == 0:
                        # Add a label only for the first group to create the legend correctly
                        ax.bar(index + ii * bar_width, medians_values[ii], bar_width, yerr=std_devs_values[ii], capsize=5, label=method_name)
                    else:
                        ax.bar(index + ii * bar_width, medians_values[ii], bar_width, yerr=std_devs_values[ii], capsize=5)
                index = index + 5 # Add spacing between the groups of bars

            # Configure plot labels, title, and ticks
            ax.set_xlabel('Modeling Errors', fontsize=24)
            ax.set_ylabel('Earth Movers Distance', fontsize=24)

            # Set the x-axis tick positions and labels to correspond to the modeling error conditions
            ax.set_xticks(np.arange(len(inv_NAME)) * 5 + (bar_width*len(method_names))/2 - bar_width/2)
            ax.set_xticklabels(inv_NAME)

            # Customize legend and tick font sizes for better readability
            ax.legend(fontsize=20)
            ax.yaxis.grid(True)
            ax.tick_params(axis='both', which='major', labelsize=20)

            # Move the legend outside the main plot area to avoid overlap
            ax.legend(loc='upper left', bbox_to_anchor=(1, 1))

            plt.show()

            # Create the save directory and save the figure
            folder_save = "Manuscript_Figures/Effect_of_MError"
            folder_save = os.path.join(os.getcwd(), folder_save)
            os.makedirs(folder_save, exist_ok=True) # Ensure the directory exists

            filename = f"LEN_Data_corr_{corr_coeff}_snr_db_{snr_db}_patchranks_{Patchranks}.png"
            figure_filename = os.path.join(folder_save, filename)
            print(figure_filename)
            fig.savefig(figure_filename, dpi=300, bbox_inches='tight', format='png')

            end_time = time.time()  # Record end time
            elapsed_time = end_time - start_time
            print(f"Elapsed time for corr_coeff={corr_coeff}, Smoothness_order={Smoothness_order}, Patchranks={Patchranks}, snr_db={snr_db}: {elapsed_time} seconds")