In [None]:
# Import general libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import glob
from pathlib import Path
import pickle
import json
import cmasher as cmr
import warnings
from matplotlib.gridspec import GridSpec
import gc
import re
import math
from scipy.stats import gaussian_kde
import seaborn as sns
import matplotlib.patches as patches
from scipy.stats import mannwhitneyu

In [None]:
path_experimental_obs = f"C:\\Users\\franc\\Documents\\GitHub\\SBI_motor_neuron_behavior\\_experimental_data"
csv_experimental_obs = f"experimental_data_dataframe_reorganized_and_filtered_dir_inhibited_persp_other_MUs_as_ref.csv"

# To be changed to the folders where you created (or downloaded from the repository) the posterior-predicted simulations
path_posterior_sampled_obs_single_muscle = f"C:\\Users\\franc\\Documents\\GitHub\\SBI_motor_neuron_behavior\\$$$_Simulation_batch_single_muscle\\Z_inference\\inference_model_2\\posterior_predictive_checks_subjects_grouped"
path_subfolder_highest_likelihood_posterior_sample_obs_single_muscle = f"_highest_likelihood_sims"

path_posterior_sampled_obs_paired_muscles = f"C:\\Users\\franc\\Documents\\GitHub\\SBI_motor_neuron_behavior\\$$$_Simulation_batches_muscle_pairs\\Z_inference\\inference_model_0\\posterior_predictive_checks_subjects_grouped"
path_subfolder_highest_likelihood_posterior_sample_obs_paired_muscles = f"_highest_likelihood_sims"

# To be changed to the folder in which the general analyses (.csv) have been saved for the training data (prior-sampled simulations)
path_prior_samples_obs_single_muscle = f"C:\\Users\\franc\\Documents\\GitHub\\SBI_motor_neuron_behavior\\$$$_Simulation_batch_single_muscle" # those simulations serve as the "null" predictions of the neuron model (random sampling of the parameter space, without inference)
path_prior_samples_obs_paired_muscles = f"C:\\Users\\franc\\Documents\\GitHub\\SBI_motor_neuron_behavior\\$$$_Simulation_batches_muscle_pairs" # those simulations serve as the "null" predictions of the neuron model (random sampling of the parameter space, without inference)
csv_prior_samples_obs = f"___general_analysis_of_simulations.csv" # those simulations serve as the "null" predictions of the neuron model (random sampling of the parameter space, without inference)

path_to_save_into = f"validation_posterior_predictive_checks"
os.makedirs(path_to_save_into, exist_ok=True)

observations = [
    "trough_area",
    "peak_height",
    "firing_rate",
    "IPSP_delay"
]

perspective = 'other_MUs_as_ref' # choose from {'other_MUs_as_ref','MU_as_ref','combined'}
direction   = 'inhibited' # choose from {'inhibited','inhibiting'}

pool_pair_name = {
    "within_muscle": ["pool_0<->pool_0"],
    "between_muscles": ["pool_0<->pool_1"] # The loaded data frames have already made the duplications to get both AA->BB and BB->AA, so just need to select a the baseline direction (from AA to BB) here
} 

summarize_over = [
    # "subject", # comment out if SBI on pooled subjects
    "muscle_pair",
    "intensity",
    "sim_idx", # used only for the simulations
]

# Filtering
min_r2_for_baseline_curve_fit = {"simulation":0.1,
                                 "experiment": 0.1}
min_r2_for_overall_curve_fit = {"simulation":0.75,
                                 "experiment": 0.75}
min_nb_spikes = {"simulation":10_000,
                "experiment": 5_000}

colors_dict = {
  "VL<->VL": "#D62728",
  "VL<->VM": "#FF9201",
  "VM<->VL": "#FF9201",
  "VM<->VM": "#FFC400",
  "TA<->TA": "#00C71B",
  "FDI<->FDI": "#14BFA8",
  "GM<->GM": "#2489DC",
  "GM<->SOL": "#7D74EC",
  "SOL<->GM": "#7D74EC",
  "SOL<->SOL": "#BB86ED",

  "VL": "#D62728",
  "VM": "#FFC400",
  "TA": "#00C71B",
  "FDI": "#14BFA8",
  "GM": "#2489DC",
  "SOL": "#BB86ED",
}


In [None]:
def load_analyses(
    base_folder,
    filename="analysis_output.pkl",
    params_file="sim_parameters.json",
    best_subfolder_name="_highest_likelihood_sims",
):
    """
    Walk each condition subfolder of `base_folder`, loading every
    `filename` it finds in its child subfolders into two dicts:
      - all_sims: keyed by "<condition>_sim<N>"
      - best_sims: keyed by "<condition>_simHighestLikelihood"

    Returns
    -------
    all_sims : dict[str, dict]
       mapping each "<condition>_sim<N>" → loaded pickle dict (with 'sim_parameters' added)
    best_sims : dict[str, dict]
       mapping each "<condition>_simHighestLikelihood" → loaded pickle dict
    num_conditions : int
       number of immediate subdirectories (conditions) found in `base_folder`
    """
    base = Path(base_folder)

    # find all immediate subdirectories
    condition_dirs = [d for d in base.iterdir() if d.is_dir()]
    num_conditions = len(condition_dirs)

    all_sims = {}
    best_sims = {}

    for cond_dir in condition_dirs:
        condition = cond_dir.name
        if cond_dir.name == best_subfolder_name:
            continue

        # 1) load all the regular sims
        sim_dirs = [d for d in cond_dir.iterdir() if (d / filename).is_file()]
        for idx, sim_dir in enumerate(sorted(sim_dirs), start=1):
            key = f"{condition}_sim{idx}"
            pkl_path = sim_dir / filename

            # load pickle
            with open(pkl_path, "rb") as f:
                data = pickle.load(f)

            # load params JSON if present
            params_path = sim_dir / params_file
            if params_path.exists():
                with open(params_path, "r", encoding="utf-8") as jf:
                    data["sim_parameters"] = json.load(jf)
            else:
                data["sim_parameters"] = None
            data['condition'] = condition
            data['sim_type'] = 'posterior_sample'

            all_sims[key] = data

    # 2) load the best‐likelihood sim(s) from the special subfolder
    best_dir = Path(os.path.join(base_folder, best_subfolder_name))
    if best_dir.is_dir():
        num_conditions -= 1
        best_dirs_current_cond = [d for d in best_dir.iterdir() if d.is_dir()]
        for best_dir_current_cond in best_dirs_current_cond:
            condition = best_dir_current_cond.name
            sim_dirs = [d for d in best_dir_current_cond.iterdir() if (d / filename).is_file()]
            for idx, sim_dir in enumerate(sorted(sim_dirs), start=1):
                # print(sim_dir)
                for pkl_path in sim_dir.glob(f"*{filename}"):
                    key = f"{condition}_simHighestLikelihood"
                    with open(pkl_path, "rb") as f:
                        data = pickle.load(f)
                    params_path = pkl_path.parent / params_file
                    if params_path.exists():
                        with open(params_path, "r", encoding="utf-8") as jf:
                            data["sim_parameters"] = json.load(jf)
                    else:
                        data["sim_parameters"] = None
                    data['condition'] = condition
                    data['sim_type'] = 'posterior_highest_likelihood_sample'
                    best_sims[key] = data

    return all_sims, best_sims, num_conditions


In [None]:
### LOAD SIMULATED OBSERVATIONS 
all_sims_single_muscle, best_sims_single_muscles, n_conditions = load_analyses(path_posterior_sampled_obs_single_muscle,
                                    best_subfolder_name=path_subfolder_highest_likelihood_posterior_sample_obs_single_muscle)
print("SINGLE MUSCLE SIMULATIONS:")
print(f"    Simulations from posterior samples: {len(all_sims_single_muscle)} entries, from {n_conditions} conditions")
print(f"    Simulations from single posterior sample wih highest likelihood: {len(best_sims_single_muscles)} entries (one for each of the {n_conditions} conditions)")

all_sims_muscle_pair, best_sims_muscle_pair, n_conditions = load_analyses(path_posterior_sampled_obs_paired_muscles,
                                    best_subfolder_name=path_subfolder_highest_likelihood_posterior_sample_obs_paired_muscles)
print("MUSCLE PAIRS SIMULATIONS:")
print(f"    Simulations from posterior samples: {len(all_sims_muscle_pair)} entries, from {n_conditions} conditions")
print(f"    Simulations from single posterior sample wih highest likelihood: {len(best_sims_muscle_pair)} entries (one for each of the {n_conditions} conditions)")

# Merge into one big dict with all sims
all_sims = all_sims_single_muscle.copy()
all_sims.update(all_sims_muscle_pair)
best_sims = best_sims_single_muscles.copy()
best_sims.update(best_sims_muscle_pair)

In [None]:
print("POSTERIOR-PREDICTIVE CHECKS - CONDITIONS LOADED (TOTAL):")
print(best_sims.keys())

In [None]:
### TURN THE SIMULATED OBSERVATIONS INTO A DATAFRAME (1 row = 1 motor unit)

# pre-compile
_subject_re     = re.compile(r'^([A-Za-z]{4})')
_muscle_pair_re = re.compile(r'([A-Za-z]+-[A-Za-z]+)')
_intensity_re   = re.compile(r'_(\d+)$')
_sim_idx_re     = re.compile(r"_sim(\d+)$")

def sims_dict_to_dataframe(
    sims_dict,
    perspective,
    direction,
    observations,
    pool_pair_name={  # mapping of case -> allowed pool pairs
        "within_muscle": ["pool_0<->pool_0"],
        "between_muscles": ["pool_0<->pool_1"],
    },
):
    # which cross-histogram branch to read
    key_to_load = "inhibited" if perspective == "other_MUs_as_ref" else "inhibiting"
    subkey_to_load = (
        "forward" if (direction == "inhibited") ^ (perspective == "other_MUs_as_ref") else "backward"
    )

    rows = []

    for sim_key, sim_data in sims_dict.items():
        # sim index
        m = _sim_idx_re.search(sim_key)
        sim_idx = int(m.group(1)) if m else None

        condition = sim_data["condition"]

        # subject (fallback if missing)
        m_sub = _subject_re.match(condition)
        subject = m_sub.group(1) if m_sub else "pooled_subjects"

        # muscle pair like 'AA-BB' -> 'AA<->BB'
        m_pair = _muscle_pair_re.search(condition)
        muscle_pair = m_pair.group(1) if m_pair else None  # 'AA-BB'
        muscle_pair_arrow = muscle_pair.replace("-", "<->") if muscle_pair else None  # 'AA<->BB'

        # intensity
        m_int = _intensity_re.search(condition)
        intensity = int(m_int.group(1)) if m_int else None

        # decide within vs between
        if muscle_pair_arrow and "<->" in muscle_pair_arrow:
            a, b = muscle_pair_arrow.split("<->")
            case = "within_muscle" if a == b else "between_muscles"
        else:
            # fallback: infer from keys present in Cross_histograms
            ch_keys = set(sim_data.get("Cross_histograms", {}).keys())
            has_between = ("pool_0<->pool_1" in ch_keys) or ("pool_1<->pool_0" in ch_keys)
            case = "between_muscles" if has_between else "within_muscle"

        allowed_pairs = pool_pair_name[case]

        cross_all = sim_data.get("Cross_histograms", {})

        # iterate only allowed pool-pairs
        for pool_pair in allowed_pairs:
            ch_dict = cross_all.get(pool_pair)
            if not ch_dict:
                continue

            for mn_idx, hist_data in ch_dict.items():
                row = {
                    "sim_idx":     sim_idx,
                    "mn_idx":      mn_idx,
                    "condition":   condition,
                    "muscle_pair": muscle_pair_arrow,  # 'AA<->BB'
                    "pool_pair":   pool_pair,          # 'pool_0<->pool_1' etc.
                    "intensity":   intensity,
                    "subject":     subject,
                    "perspective": perspective,
                    "direction":   direction,
                    "r2_full":     hist_data[key_to_load]["r2_full"],
                    "r2_baseline": hist_data[key_to_load]["r2_base"],
                    "n_spikes":    hist_data[key_to_load]["n_spikes"],
                }

                # observations
                for obs in observations:
                    if obs == "trough_area":
                        val = hist_data[key_to_load][subkey_to_load]["raw_area"] * -100.0  # %
                    elif obs == "peak_height":
                        val = hist_data[key_to_load]["sync_height"] * 100.0               # %
                    elif obs == "firing_rate":
                        val = sim_data["Firing_rates"]["MN"]["mean"][f"MN_{mn_idx}"]
                    elif obs == "IPSP_delay":
                        val = hist_data[key_to_load][f"delay_{subkey_to_load}_IPSP"] * 1000.0  # ms
                    elif obs == "hist_plateau_duration":
                        val = hist_data[key_to_load][obs] * 1000.0  # ms
                    elif obs == "proportion_of_prob_within_plateau_duration":
                        val = hist_data[key_to_load][obs] * 100.0   # %
                    else:
                        raise KeyError(f"Unrecognized observation '{obs}'")
                    row[obs] = val

                rows.append(row)

    return pd.DataFrame(rows)

# build both DataFrames
df_obs_simulated_from_posterior_samples = sims_dict_to_dataframe(
    all_sims,
    perspective,
    direction,
    observations,
    pool_pair_name
)

df_obs_simulated_from_posterior_sample_highest_likelihood = sims_dict_to_dataframe(
    best_sims,
    perspective,
    direction,
    observations,
    pool_pair_name
)

In [None]:
### LOAD EXPERIMENTAL OBSERVATIONS
df_obs_experiment = pd.read_csv(f"{path_experimental_obs}\\{csv_experimental_obs}")
# Define a rename mapping: {old_name: new_name, …}
rename_map = {
    "MU_idx": "mn_idx",
    "firing_rates_mean": "firing_rate",
    "raw_area": "trough_area",
    "sync_height": "peak_height",
    "IPSP_timing_of_trough": "IPSP_delay",
    "proportion_of_prob_within_plateau_duration": "proportion_of_prob_within_plateau_duration", # same name
    "hist_plateau_duration": "hist_plateau_duration", # same name
    "r2_base": "r2_baseline"
}

# Apply the renaming (only the keys in rename_map get renamed)
df_obs_experiment = df_obs_experiment.rename(columns=rename_map)

# Define exactly which columns should be kept
keep_cols = list(df_obs_simulated_from_posterior_samples.columns)
# select only those columns (drops everything else)
df_obs_experiment = df_obs_experiment.filter(items=keep_cols)

df_obs_experiment['trough_area'] *= -100 # turn to positive %
df_obs_experiment['peak_height'] *= 100 # turn to %
df_obs_experiment['IPSP_delay'] *= 1000 # turn to ms
# df_obs_experiment['hist_plateau_duration'] *= 1000 # turn to ms
# df_obs_experiment['proportion_of_prob_within_plateau_duration'] *= 100 # turn to %

In [None]:
### LOOP THROUGH THE EXPERIMENTAL DATA .PKL FILES TO LOAD INDIVIDUAL CROSS-HISTOGRAM ANALYSES RESULTS
# (Used only to plot the cross-histograms)
exp_analysis_results_dict = {} # folder to look into = path_experimental_obs
for pkl_path in Path(path_experimental_obs).glob("*.pkl"):
    key = pkl_path.stem               # e.g. "subject1_cross_histogram"
    with open(pkl_path, "rb") as f:
        exp_analysis_results_dict[key] = pickle.load(f)


In [None]:
### LOAD SIMULATED PRIOR OBSERVATIONS (RANDOM SAMPLING OF PARAMETER SPACE - USED TO TRAIN THE SBI NETWORK)
# First load single muscle case
df_prior_single = pd.read_csv(Path(path_prior_samples_obs_single_muscle) / csv_prior_samples_obs)
# Union of columns is taken automatically; missing cols become NaN
df_prior_pairs  = pd.read_csv(Path(path_prior_samples_obs_paired_muscles) / csv_prior_samples_obs)
# something something pd.read_csv(f"{path_prior_samples_obs_paired_muscles}\\{csv_prior_samples_obs}")
df_obs_simulated_from_prior = pd.concat([df_prior_single, df_prior_pairs], ignore_index=True, sort=False)# And append the muscle pairs case

# pull out the digits after "_sim" at the end of the string
df_obs_simulated_from_prior['sim_name'] = (
    df_obs_simulated_from_prior['sim_name']
      .str
      .extract(r'output_(\d+)$')      # returns a DataFrame with one column
      .astype(float)               # or .astype(int) if you know they all match
      .iloc[:, 0]                  # turn it into a Series
)
# Define a rename mapping: {old_name: new_name, …}
rename_map = {
    "sim_name": "sim_idx",
    "MN_index": "mn_idx",
    "Firing_rates_mean": "firing_rate",
    "sync_height": "peak_height",
    # "proportion_of_prob_within_plateau_duration": "proportion_of_prob_within_plateau_duration", # same name
    # "hist_plateau_duration": "hist_plateau_duration", # same name
    "r2_base": "r2_baseline"
}
if direction == 'inhibited':
    if perspective == 'other_MUs_as_ref':
        rename_map["raw_area_fwd"] = 'trough_area'
        rename_map["delay_forward_IPSP"] = 'IPSP_delay'
    elif perspective == 'MU_as_ref':
        rename_map["raw_area_bwd"] = 'trough_area'
        rename_map["delay_backward_IPSP"] = 'IPSP_delay'
elif direction == 'inhibiting':
    if perspective == 'other_MUs_as_ref':
        rename_map["raw_area_bwd"] = 'trough_area'
        rename_map["delay_backward_IPSP"] = 'IPSP_delay'
    elif perspective == 'MU_as_ref':
        rename_map["raw_area_fwd"] = 'trough_area'
        rename_map["delay_forward_IPSP"] = 'IPSP_delay'

# Apply the renaming (only the keys in rename_map get renamed)
df_obs_simulated_from_prior = df_obs_simulated_from_prior.rename(columns=rename_map)

# Define exactly which columns should be kept
keep_cols = list(df_obs_simulated_from_posterior_samples.columns)
# select only those columns (drops everything else)
df_obs_simulated_from_prior = df_obs_simulated_from_prior.filter(items=keep_cols)

df_obs_simulated_from_prior['trough_area'] *= -100 # turn to positive %
df_obs_simulated_from_prior['peak_height'] *= 100 # turn to positive %
df_obs_simulated_from_prior['IPSP_delay'] *= 1000 # turn to ms
# df_obs_simulated_from_prior['hist_plateau_duration'] *= 1000 # turn to ms
# df_obs_simulated_from_prior['proportion_of_prob_within_plateau_duration'] *= 100 # turn to %

In [None]:
# Filter all data frames
df_obs_simulated_from_posterior_samples = df_obs_simulated_from_posterior_samples[df_obs_simulated_from_posterior_samples['r2_baseline']>min_r2_for_baseline_curve_fit['simulation']]
df_obs_simulated_from_posterior_samples = df_obs_simulated_from_posterior_samples[df_obs_simulated_from_posterior_samples['r2_full']>min_r2_for_overall_curve_fit['simulation']]
df_obs_simulated_from_posterior_samples = df_obs_simulated_from_posterior_samples[df_obs_simulated_from_posterior_samples['n_spikes']>min_nb_spikes['simulation']]
df_obs_simulated_from_posterior_samples = df_obs_simulated_from_posterior_samples[df_obs_simulated_from_posterior_samples['perspective']==perspective]
df_obs_simulated_from_posterior_samples = df_obs_simulated_from_posterior_samples[df_obs_simulated_from_posterior_samples['direction']==direction]
df_obs_simulated_from_posterior_sample_highest_likelihood = df_obs_simulated_from_posterior_sample_highest_likelihood[df_obs_simulated_from_posterior_sample_highest_likelihood['r2_baseline']>min_r2_for_baseline_curve_fit['simulation']]
df_obs_simulated_from_posterior_sample_highest_likelihood = df_obs_simulated_from_posterior_sample_highest_likelihood[df_obs_simulated_from_posterior_sample_highest_likelihood['r2_full']>min_r2_for_overall_curve_fit['simulation']]
df_obs_simulated_from_posterior_sample_highest_likelihood = df_obs_simulated_from_posterior_sample_highest_likelihood[df_obs_simulated_from_posterior_sample_highest_likelihood['n_spikes']>min_nb_spikes['simulation']]
df_obs_simulated_from_posterior_sample_highest_likelihood = df_obs_simulated_from_posterior_sample_highest_likelihood[df_obs_simulated_from_posterior_sample_highest_likelihood['perspective']==perspective]
df_obs_simulated_from_posterior_sample_highest_likelihood = df_obs_simulated_from_posterior_sample_highest_likelihood[df_obs_simulated_from_posterior_sample_highest_likelihood['direction']==direction]
df_obs_experiment = df_obs_experiment[df_obs_experiment['r2_baseline']>min_r2_for_baseline_curve_fit['experiment']]
df_obs_experiment = df_obs_experiment[df_obs_experiment['r2_full']>min_r2_for_overall_curve_fit['experiment']]
df_obs_experiment = df_obs_experiment[df_obs_experiment['n_spikes']>min_nb_spikes['experiment']]
df_obs_experiment = df_obs_experiment[df_obs_experiment['perspective']==perspective]
df_obs_experiment = df_obs_experiment[df_obs_experiment['direction']==direction]
df_obs_simulated_from_prior = df_obs_simulated_from_prior[df_obs_simulated_from_prior['r2_baseline']>min_r2_for_baseline_curve_fit['simulation']]
df_obs_simulated_from_prior = df_obs_simulated_from_prior[df_obs_simulated_from_prior['r2_full']>min_r2_for_overall_curve_fit['simulation']]
df_obs_simulated_from_prior = df_obs_simulated_from_prior[df_obs_simulated_from_prior['n_spikes']>min_nb_spikes['simulation']]
df_obs_simulated_from_prior = df_obs_simulated_from_prior[df_obs_simulated_from_prior['perspective']==perspective]
df_obs_simulated_from_prior = df_obs_simulated_from_prior[df_obs_simulated_from_prior['direction']==direction]

In [None]:
df_obs_simulated_from_prior

In [None]:
# Standardize all the 'observations' values according to the prior (simulated data from uniform prior) data mean and std
# compute means and stds on the *prior* data only
exp_stats = {}
for obs in observations:
    μ = df_obs_simulated_from_prior[obs].mean()
    σ = df_obs_simulated_from_prior[obs].std(ddof=0)        # population‐style std, or ddof=1 for sample‐std
    exp_stats[obs] = {"mean": μ, "std": σ}
# replace in-place:
for obs in observations:
    df_obs_experiment[obs] = (df_obs_experiment[obs] - exp_stats[obs]["mean"]) / exp_stats[obs]["std"]
    df_obs_simulated_from_posterior_samples[obs] = (df_obs_simulated_from_posterior_samples[obs] - exp_stats[obs]["mean"]) / exp_stats[obs]["std"]
    df_obs_simulated_from_posterior_sample_highest_likelihood[obs] = (df_obs_simulated_from_posterior_sample_highest_likelihood[obs] - exp_stats[obs]["mean"]) / exp_stats[obs]["std"]
    df_obs_simulated_from_prior[obs] = (df_obs_simulated_from_prior[obs] - exp_stats[obs]["mean"]) / exp_stats[obs]["std"]

In [None]:
# Build summaries...
def iqr(series: pd.Series) -> float:
    """Interquartile range = Q3 - Q1."""
    return float(series.quantile(0.75) - series.quantile(0.25))
iqr.__name__ = "iqr"  # ensures a clean column suffix in groupby.agg

def summarize_dataframe(df, summarize_over, observations):
    """
    Group df by columns in summarize_over that exist.
    For each obs in `observations`, compute mean, std, median, iqr.
    For other columns, take the first value.
    Returns a DataFrame with flat columns: obs_mean, obs_std, obs_median, obs_iqr.
    """
    # 1) pick only grouping columns that exist
    group_cols = [c for c in summarize_over if c in df.columns]
    if not group_cols:
        raise ValueError(f"None of {summarize_over} found in DataFrame columns.")

    # 2) "other" columns (neither grouping nor summarized)
    other_cols = [c for c in df.columns if c not in group_cols and c not in observations]

    # 2a) optional sanity check for constancy within groups
    for col in other_cols:
        nunique = df.groupby(group_cols)[col].nunique()
        if (nunique > 1).any():
            print(f"    Warning: column '{col}' has >1 distinct values in some groups.")

    # 3) aggregation dict
    agg_stats = ['mean', 'std', 'median', iqr]
    agg_dict = {c: agg_stats for c in observations if c in df.columns}
    agg_dict.update({c: 'first' for c in other_cols})

    # 4) group & agg
    summary = df.groupby(group_cols, as_index=False).agg(agg_dict)

    # 5) flatten column MultiIndex -> single level names
    new_cols = []
    for col, func in summary.columns:
        if func == "":
            # groupers come through with empty func
            new_cols.append(col)
        elif func == "first":
            new_cols.append(col)
        else:
            # 'mean'|'std'|'median'|'iqr'
            new_cols.append(f"{col}_{func}")
    summary.columns = new_cols
    return summary


# For each simulation (simulated observations - whether from prior or from posterior)
print('Summarizing data frame of simulated observations from inferred posterior')
df_summary_obs_simulated_from_posterior = summarize_dataframe(df_obs_simulated_from_posterior_samples,
                                                                summarize_over,
                                                                observations)
print('Summarizing data frame of simulated observations from prior')
df_summary_obs_simulated_from_prior = summarize_dataframe(df_obs_simulated_from_prior,
                                                          summarize_over,
                                                          observations)
print('Summarizing data frame of simulated observations from inferred posterior - specifically the mode of the posterior (set of parameters with highest likelihood)')
df_obs_simulated_from_posterior_sample_highest_likelihood_temp = df_obs_simulated_from_posterior_sample_highest_likelihood.copy()
df_obs_simulated_from_posterior_sample_highest_likelihood_temp['sim_idx'] = 0 # to allow grouping
df_summary_obs_simulated_highest_likelihood = summarize_dataframe(df_obs_simulated_from_posterior_sample_highest_likelihood_temp,
                                                          summarize_over,
                                                          observations)
# For each condition
print('Summarizing data frame of experimental observations')
df_summary_obs_experiment = summarize_dataframe(df_obs_experiment,
                                                summarize_over,
                                                observations)
# From there, compute the groups to iterate over
subjects = list(np.unique(df_summary_obs_experiment['subject']))
muscle_pairs = list(np.unique(df_summary_obs_experiment['muscle_pair']))
intensities = list(np.unique(df_summary_obs_experiment['intensity']))

# Posterior predictive checks - accuracy and correlation between posterior-predicted features and experimentally observed features, across conditions

In [None]:
from sklearn.decomposition import PCA
import numpy as np

observations = [
    "trough_area",
    "peak_height",
    "firing_rate",
    "IPSP_delay",
]
summary_stats = ['mean', 'std', 'median', 'iqr']
# 4 features × 4 stats = 16 columns
feat_cols = [f"{obs}_{stat}" for obs in observations for stat in summary_stats]
print("feat_cols:", feat_cols)

# Make sure the prior summary DF exists and has the 16 feature columns
df_prior_for_pca = df_summary_obs_simulated_from_prior.copy()

# Drop rows with NaNs in features
df_prior_for_pca = df_prior_for_pca.dropna(subset=feat_cols)
print(f"PCA training rows after dropping NaNs: {len(df_prior_for_pca)}")

X_prior = df_prior_for_pca[feat_cols].to_numpy(dtype=float)

# Fit PCA on these (no extra whitening – your features are already z-scored to the prior)
pca = PCA(n_components=min(len(feat_cols), X_prior.shape[0]))
pca.fit(X_prior)

print("PCA fitted. Explained variance ratio (first few):", pca.explained_variance_ratio_[:5])


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, to_rgb
from scipy.stats import pearsonr, spearmanr

# ===========================
# Config
# ===========================
EXCLUDE_PAIRS = ["VL<->VM", "GM<->SOL"]   # same as before
GROUP_COLS = ["muscle_pair", "intensity"]

SUMMARY_STAT        = "mean"   # which summary stat column to use for the feature rows
POST_SUMMARY        = "mean"   # "mode" (highest-likelihood) or "mean" (posterior mean) for curves + metrics
STANDARDIZATION_MODE = "exp"  # "prior" = relative to prior SD; "exp" = relative to across-condition SD
NBINS_Y             = 30
VSCALE_QUANTILE     = 0.999
FIG_W               = 10
FIG_H_PER_ROW       = 3.0

# base colors for each feature & for PC1 row
feature_colors = {
    "trough_area": "blue",
    "peak_height": "red",
    "firing_rate": "orange",
    "IPSP_delay": "purple",
}
PC1_COLOR = "green"

# Need PCA + full feature list for PC1 row
assert "feat_cols" in globals(), "This cell expects a global 'feat_cols' list (all 16 feature columns)."
assert "pca" in globals(), "This cell expects a fitted sklearn PCA instance called 'pca'."
assert "exp_stats" in globals(), "This cell expects a global 'exp_stats' dict with prior mean/std for each observation."


def white_to_color_cmap(color_hex, steps=256):
    rgb = to_rgb(color_hex)
    return LinearSegmentedColormap.from_list("white_to_color", [(1, 1, 1), rgb], N=steps)


def rmse(a, b):
    a = np.asarray(a, dtype=float); b = np.asarray(b, dtype=float)
    return float(np.sqrt(np.mean((a - b)**2)))


def r2(pred, true):
    pred = np.asarray(pred, dtype=float); true = np.asarray(true, dtype=float)
    ss_res = np.sum((pred - true)**2)
    ss_tot = np.sum((true - true.mean())**2)
    return float(1.0 - ss_res/ss_tot) if ss_tot > 0 else np.nan


def filter_and_dropna_single(df, col):
    out = df.copy()
    if "muscle_pair" in out.columns:
        out = out[~out["muscle_pair"].isin(EXCLUDE_PAIRS)]
    out = out.dropna(subset=[col])
    return out


def filter_and_dropna_multifeat(df, cols):
    out = df.copy()
    if "muscle_pair" in out.columns:
        out = out[~out["muscle_pair"].isin(EXCLUDE_PAIRS)]
    out[cols] = out[cols].apply(pd.to_numeric, errors="coerce")
    out = out.dropna(subset=cols)
    return out


def make_transform_from_y_true(y_true_z, mode):
    """
    Build a transform v -> v_std for the chosen standardization mode.
    y_true_z are the (already prior-z-scored) experimental values across conditions.
    """
    y_true_z = np.asarray(y_true_z, dtype=float)
    if mode == "exp":
        mu = float(y_true_z.mean())
        sd = float(y_true_z.std(ddof=0))
        if sd <= 0:
            sd = 1.0
        def transform(v):
            return (np.asarray(v, dtype=float) - mu) / sd
    else:  # "prior" or anything else: identity
        mu = 0.0
        sd = 1.0
        def transform(v):
            return np.asarray(v, dtype=float)
    return transform, mu, sd


if STANDARDIZATION_MODE == "prior":
    std_label = "relative to prior SD"
else:
    std_label = "relative to across-condition SD"

# ---------------------------
# Row specification: 4 features + 1 PC1 row
# ---------------------------
row_specs = [{"kind": "feature", "feat": feat} for feat in observations]
row_specs.append({"kind": "pc1", "feat": "PC1"})

n_rows = len(row_specs)
fig, axes = plt.subplots(
    n_rows, 1,
    figsize=(FIG_W, max(FIG_H_PER_ROW * n_rows, 3.0)),
    squeeze=False
)

y_ranges = [None] * n_rows
density_vals_for_vscale = []

# ===========================
# First pass: compute y-range and gather density values for shared v-scale
# ===========================
for row_idx, spec in enumerate(row_specs):
    kind = spec["kind"]

    if kind == "feature":
        feat = spec["feat"]
        col_name = f"{feat}_{SUMMARY_STAT}"

        df_exp  = filter_and_dropna_single(df_summary_obs_experiment, col_name)
        df_post = filter_and_dropna_single(df_summary_obs_simulated_from_posterior, col_name)
        df_best = filter_and_dropna_single(df_summary_obs_simulated_from_posterior, col_name)

        # build maps: cond -> values
        exp_map = {k: float(g.iloc[0][col_name])
                   for k, g in df_exp.groupby(GROUP_COLS, dropna=False)}
        post_map = {}
        for k, g in df_post.groupby(GROUP_COLS, dropna=False):
            vals = g[col_name].to_numpy(dtype=float)
            if vals.size > 0:
                post_map[k] = vals
        best_map = {k: float(g.iloc[0][col_name])
                    for k, g in df_best.groupby(GROUP_COLS, dropna=False)}

        conds = sorted(set(exp_map.keys()) & set(post_map.keys()) & set(best_map.keys()))
        if not conds:
            spec["skip"] = True
            continue

        # these are in prior-z space
        y_true_z = np.array([exp_map[c] for c in conds], dtype=float)
        y_best_z = np.array([best_map[c] for c in conds], dtype=float)
        post_vals_all_z = np.concatenate([post_map[c] for c in conds])

        # transform for chosen standardization mode
        transform, mu_z, sd_z = make_transform_from_y_true(y_true_z, STANDARDIZATION_MODE)
        spec["mu_z"] = mu_z
        spec["sd_z"] = sd_z

        y_true_std = transform(y_true_z)
        y_best_std = transform(y_best_z)
        post_vals_all_std = transform(post_vals_all_z)

        # robust y-range in standardized plotting units
        all_vals = np.concatenate([y_true_std, y_best_std, post_vals_all_std])
        q_low, q_high = np.quantile(all_vals, [0.01, 0.99])
        span = q_high - q_low if q_high > 0 else 1.0
        y_min = q_low - 0.1 * span
        y_max = q_high + 0.1 * span
        y_ranges[row_idx] = (y_min, y_max)

        # density image (unsorted) for v-scale, in standardized units
        y_edges = np.linspace(y_min, y_max, NBINS_Y + 1)
        img_unsorted = np.zeros((NBINS_Y, len(conds)), dtype=float)
        for j, cond in enumerate(conds):
            vals_std = transform(post_map[cond])
            vals_std = np.clip(vals_std, y_min, y_max)
            counts, _ = np.histogram(vals_std, bins=y_edges)
            col_pdf = counts.astype(float)
            if col_pdf.sum() > 0:
                col_pdf /= col_pdf.sum()
            img_unsorted[:, j] = col_pdf

        density_vals_for_vscale.append(img_unsorted.ravel())

    elif kind == "pc1":
        # PC1 over full 16-D feature space
        df_exp_pc  = filter_and_dropna_multifeat(df_summary_obs_experiment, feat_cols)
        df_post_pc = filter_and_dropna_multifeat(df_summary_obs_simulated_from_posterior, feat_cols)
        df_best_pc = filter_and_dropna_multifeat(df_summary_obs_simulated_from_posterior, feat_cols)

        exp_map_pc = {k: g.iloc[0][feat_cols].to_numpy(dtype=float)
                      for k, g in df_exp_pc.groupby(GROUP_COLS, dropna=False)}
        post_map_pc = {}
        for k, g in df_post_pc.groupby(GROUP_COLS, dropna=False):
            X = g[feat_cols].to_numpy(dtype=float)
            if X.shape[0] > 0:
                post_map_pc[k] = X
        best_map_pc = {k: g.iloc[0][feat_cols].to_numpy(dtype=float)
                       for k, g in df_best_pc.groupby(GROUP_COLS, dropna=False)}

        conds = sorted(set(exp_map_pc.keys()) & set(post_map_pc.keys()) & set(best_map_pc.keys()))
        if not conds:
            spec["skip"] = True
            continue

        pc_true_z_list = []
        pc_best_z_list = []
        post_pc_vals_all_z = []

        for cond in conds:
            vec_exp  = exp_map_pc[cond]
            vec_best = best_map_pc[cond]
            X_post   = post_map_pc[cond]

            pc_true_z = pca.transform(vec_exp[None, :])[:, 0][0]
            pc_best_z = pca.transform(vec_best[None, :])[:, 0][0]
            pcs_post_z = pca.transform(X_post)[:, 0]

            pc_true_z_list.append(pc_true_z)
            pc_best_z_list.append(pc_best_z)
            post_pc_vals_all_z.append(pcs_post_z)

        pc_true_z = np.array(pc_true_z_list, dtype=float)
        pc_best_z = np.array(pc_best_z_list, dtype=float)
        post_pc_vals_all_z = np.concatenate(post_pc_vals_all_z)

        # standardize PC1 according to mode
        transform, mu_z, sd_z = make_transform_from_y_true(pc_true_z, STANDARDIZATION_MODE)
        spec["mu_z"] = mu_z
        spec["sd_z"] = sd_z

        pc_true_std = transform(pc_true_z)
        pc_best_std = transform(pc_best_z)
        post_pc_vals_all_std = transform(post_pc_vals_all_z)

        all_vals = np.concatenate([pc_true_std, pc_best_std, post_pc_vals_all_std])
        q_low, q_high = np.quantile(all_vals, [0.01, 0.99])
        span = q_high - q_low if q_high > 0 else 1.0
        y_min = q_low - 0.1 * span
        y_max = q_high + 0.1 * span
        y_ranges[row_idx] = (y_min, y_max)

        y_edges = np.linspace(y_min, y_max, NBINS_Y + 1)
        img_unsorted = np.zeros((NBINS_Y, len(conds)), dtype=float)
        for j, cond in enumerate(conds):
            pcs_post_z = pca.transform(post_map_pc[cond])[:, 0]
            pcs_post_std = transform(pcs_post_z)
            pcs_post_std = np.clip(pcs_post_std, y_min, y_max)
            counts, _ = np.histogram(pcs_post_std, bins=y_edges)
            col_pdf = counts.astype(float)
            if col_pdf.sum() > 0:
                col_pdf /= col_pdf.sum()
            img_unsorted[:, j] = col_pdf

        density_vals_for_vscale.append(img_unsorted.ravel())

# Shared v-scale
if density_vals_for_vscale:
    all_density_vals = np.concatenate(density_vals_for_vscale)
else:
    all_density_vals = np.array([1.0], dtype=float)

vmin = 0.0
vmax = float(np.quantile(all_density_vals, VSCALE_QUANTILE))
if vmax <= 0:
    vmax = 1.0

# ===========================
# Second pass: plotting with shared v-scale
# ===========================
last_im = None

for row_idx, spec in enumerate(row_specs):
    ax = axes[row_idx, 0]
    kind = spec["kind"]

    if spec.get("skip", False):
        ax.axis("off")
        continue

    y_min, y_max = y_ranges[row_idx]
    x_left = -0.5
    mu_z = spec.get("mu_z", 0.0)
    sd_z = spec.get("sd_z", 1.0)

    def transform(v):
        return (np.asarray(v, dtype=float) - mu_z) / sd_z

    if kind == "feature":
        feat = spec["feat"]
        col_name = f"{feat}_{SUMMARY_STAT}"

        df_exp  = filter_and_dropna_single(df_summary_obs_experiment, col_name)
        df_post = filter_and_dropna_single(df_summary_obs_simulated_from_posterior, col_name)
        df_best = filter_and_dropna_single(df_summary_obs_simulated_from_posterior, col_name)

        exp_map = {k: float(g.iloc[0][col_name])
                   for k, g in df_exp.groupby(GROUP_COLS, dropna=False)}
        post_map = {}
        for k, g in df_post.groupby(GROUP_COLS, dropna=False):
            vals = g[col_name].to_numpy(dtype=float)
            if vals.size > 0:
                post_map[k] = vals
        best_map = {k: float(g.iloc[0][col_name])
                    for k, g in df_best.groupby(GROUP_COLS, dropna=False)}

        conds = sorted(set(exp_map.keys()) & set(post_map.keys()) & set(best_map.keys()))
        if not conds:
            ax.axis("off")
            continue

        # prior-z values (as stored in the DF)
        y_true_z = np.array([exp_map[c] for c in conds], dtype=float)
        y_best_z = np.array([best_map[c] for c in conds], dtype=float)
        y_post_mean_z = np.array([post_map[c].mean() for c in conds], dtype=float)

        # choose which posterior summary to show (still in z-prior)
        if POST_SUMMARY.lower() == "mean":
            y_pred_z = y_post_mean_z
            pred_label = "Posterior mean"
        else:
            y_pred_z = y_best_z
            pred_label = "Posterior mode"

        # standardized values for metrics & plotting, according to STANDARDIZATION_MODE
        y_true_std = transform(y_true_z)
        y_pred_std = transform(y_pred_z)

        # --- standardized metrics ---
        rmse_pred = rmse(y_pred_std, y_true_std)
        r2_pred   = r2(y_pred_std, y_true_std)
        if len(y_true_std) >= 2:
            pear_r, _    = pearsonr(y_pred_std, y_true_std)
            spear_rho, _ = spearmanr(y_pred_std, y_true_std)
            r2_cal = pear_r**2  # linear R²
        else:
            pear_r = spear_rho = np.nan
            r2_cal = np.nan

        # --- unstandardized differences (absolute) in original units ---
        sigma_raw = exp_stats[feat]["std"]  # std used for z-scoring this feature
        diff_raw  = (y_pred_z - y_true_z) * sigma_raw        # predicted − experimental, in raw units
        abs_diff_raw = np.abs(diff_raw)

        mean_abs_diff   = float(np.mean(abs_diff_raw))
        sd_abs_diff     = float(np.std(abs_diff_raw, ddof=0))
        median_abs_diff = float(np.median(abs_diff_raw))
        q1_abs, q3_abs  = [float(q) for q in np.quantile(abs_diff_raw, [0.25, 0.75])]
        min_abs_diff    = float(np.min(abs_diff_raw))
        max_abs_diff    = float(np.max(abs_diff_raw))

        # sorting by experimental value in standardized units
        order = np.argsort(y_true_std)
        conds_sorted = [conds[i] for i in order]
        y_true_sorted = y_true_std[order]
        y_pred_sorted = y_pred_std[order]

        # density image in sorted order (standardized units)
        y_edges = np.linspace(y_min, y_max, NBINS_Y + 1)
        img_sorted = np.zeros((NBINS_Y, len(conds_sorted)), dtype=float)
        for j, cond in enumerate(conds_sorted):
            vals_z = post_map[cond]
            vals_std = transform(vals_z)
            vals_std = np.clip(vals_std, y_min, y_max)
            counts, _ = np.histogram(vals_std, bins=y_edges)
            col_pdf = counts.astype(float)
            if col_pdf.sum() > 0:
                col_pdf /= col_pdf.sum()
            img_sorted[:, j] = col_pdf

        base_color = feature_colors.get(feat, "#2a6f97")
        cmap = white_to_color_cmap(base_color, steps=256)

        x_right = len(conds_sorted) - 0.5
        im = ax.imshow(
            img_sorted,
            aspect="auto",
            origin="lower",
            extent=[x_left, x_right, y_min, y_max],
            cmap=cmap,
            vmin=vmin,
            vmax=vmax,
            interpolation=None,
        )
        last_im = im

        x = np.arange(len(conds_sorted))
        # experiment (black) and posterior summary (colored)
        ax.plot(x, y_true_sorted, color="black", lw=1.4, marker="o", markersize=3,
                label="Experiment")
        ax.plot(x, y_pred_sorted, color=base_color, lw=1.2, marker="o", markersize=3,
                alpha=0.9, label=pred_label)

        tick_labels = [f"{mp}_{inten}" for (mp, inten) in conds_sorted]
        ax.set_xticks(x)
        ax.set_xticklabels(tick_labels, rotation=90, ha="right", fontsize=7)

        ax.set_xlim(x_left, x_right)
        ax.set_ylabel(f"{feat} ({SUMMARY_STAT})")

        if row_idx == 0:
            ax.set_title(
                "Posterior predictive accuracy across conditions\n"
                f"(background: posterior density; lines: experiment vs posterior {POST_SUMMARY})"
            )
            ax.legend(loc="upper right", fontsize=8, frameon=False)

        # standardized metrics (left box)
        txt_std = (f"Posterior {POST_SUMMARY} vs experiment\n"
                   f"({std_label}):\n"
                   f"  RMSE        = {rmse_pred:.3f}\n"
                   f"  R²          = {r2_pred:.3f}\n"
                   f"  R² (linear) = {r2_cal:.3f}\n"
                   f"  Pearson r   = {pear_r:.3f}\n"
                   f"  Spearman ρ  = {spear_rho:.3f}")
        ax.text(0.01, 0.99, txt_std, transform=ax.transAxes,
                va="top", ha="left", fontsize=8, color="black",
                bbox=dict(facecolor="white", alpha=0.75, edgecolor="none"))

        # unstandardized absolute differences (right box)
        txt_raw = (f"|pred − exp| in raw units:\n"
                   f"  mean±SD = {mean_abs_diff:.3g} ± {sd_abs_diff:.3g}\n"
                   f"  median [Q1–Q3] = {median_abs_diff:.3g} "
                   f"[{q1_abs:.3g}–{q3_abs:.3g}]\n"
                   f"  min–max = {min_abs_diff:.3g}–{max_abs_diff:.3g}")
        ax.text(0.52, 0.99, txt_raw, transform=ax.transAxes,
                va="top", ha="left", fontsize=8, color="black",
                bbox=dict(facecolor="white", alpha=0.75, edgecolor="none"))

    elif kind == "pc1":
        # PC1 row using full 16-D feature vectors
        df_exp_pc  = filter_and_dropna_multifeat(df_summary_obs_experiment, feat_cols)
        df_post_pc = filter_and_dropna_multifeat(df_summary_obs_simulated_from_posterior, feat_cols)
        df_best_pc = filter_and_dropna_multifeat(df_summary_obs_simulated_from_posterior, feat_cols)

        exp_map_pc = {k: g.iloc[0][feat_cols].to_numpy(dtype=float)
                      for k, g in df_exp_pc.groupby(GROUP_COLS, dropna=False)}
        post_map_pc = {}
        for k, g in df_post_pc.groupby(GROUP_COLS, dropna=False):
            X = g[feat_cols].to_numpy(dtype=float)
            if X.shape[0] > 0:
                post_map_pc[k] = X
        best_map_pc = {k: g.iloc[0][feat_cols].to_numpy(dtype=float)
                       for k, g in df_best_pc.groupby(GROUP_COLS, dropna=False)}

        conds = sorted(set(exp_map_pc.keys()) & set(post_map_pc.keys()) & set(best_map_pc.keys()))
        if not conds:
            ax.axis("off")
            continue

        pc_true_z = []
        pc_best_z = []
        pc_mean_post_z = []
        post_pc_per_cond_z = {}

        for cond in conds:
            vec_exp  = exp_map_pc[cond]
            vec_best = best_map_pc[cond]
            X_post   = post_map_pc[cond]

            pc_true_val_z = pca.transform(vec_exp[None, :])[:, 0][0]
            pc_best_val_z = pca.transform(vec_best[None, :])[:, 0][0]
            pcs_post_z    = pca.transform(X_post)[:, 0]

            pc_true_z.append(pc_true_val_z)
            pc_best_z.append(pc_best_val_z)
            pc_mean_post_z.append(pcs_post_z.mean())
            post_pc_per_cond_z[cond] = pcs_post_z

        pc_true_z = np.array(pc_true_z, dtype=float)
        pc_best_z = np.array(pc_best_z, dtype=float)
        pc_mean_post_z = np.array(pc_mean_post_z, dtype=float)

        # standardized PC1 according to mode
        pc_true_std = transform(pc_true_z)
        if POST_SUMMARY.lower() == "mean":
            pc_pred_z = pc_mean_post_z
            pred_label = "Posterior mean (PC1)"
        else:
            pc_pred_z = pc_best_z
            pred_label = "Posterior mode (PC1)"

        pc_pred_std = transform(pc_pred_z)

        rmse_pred = rmse(pc_pred_std, pc_true_std)
        r2_pred   = r2(pc_pred_std, pc_true_std)
        if len(pc_true_std) >= 2:
            pear_r, _    = pearsonr(pc_pred_std, pc_true_std)
            spear_rho, _ = spearmanr(pc_pred_std, pc_true_std)
            r2_cal = pear_r**2
        else:
            pear_r = spear_rho = np.nan
            r2_cal = np.nan

        # sort by experimental PC1 (standardized)
        order = np.argsort(pc_true_std)
        conds_sorted = [conds[i] for i in order]
        pc_true_sorted = pc_true_std[order]
        pc_pred_sorted = pc_pred_std[order]

        # density image in sorted order
        y_edges = np.linspace(y_min, y_max, NBINS_Y + 1)
        img_sorted = np.zeros((NBINS_Y, len(conds_sorted)), dtype=float)
        for j, cond in enumerate(conds_sorted):
            pcs_post_z = post_pc_per_cond_z[cond]
            pcs_post_std = transform(pcs_post_z)
            pcs_post_std = np.clip(pcs_post_std, y_min, y_max)
            counts, _ = np.histogram(pcs_post_std, bins=y_edges)
            col_pdf = counts.astype(float)
            if col_pdf.sum() > 0:
                col_pdf /= col_pdf.sum()
            img_sorted[:, j] = col_pdf

        x_right = len(conds_sorted) - 0.5
        cmap = white_to_color_cmap(PC1_COLOR, steps=256)
        im = ax.imshow(
            img_sorted,
            aspect="auto",
            origin="lower",
            extent=[x_left, x_right, y_min, y_max],
            cmap=cmap,
            vmin=vmin,
            vmax=vmax,
            interpolation="nearest",
        )
        last_im = im

        x = np.arange(len(conds_sorted))
        ax.plot(x, pc_true_sorted, color="black", lw=1.4, marker="o", markersize=3,
                label="Experiment (PC1)")
        ax.plot(x, pc_pred_sorted, color=PC1_COLOR, lw=1.2, marker="o", markersize=3,
                alpha=0.9, label=pred_label)

        tick_labels = [f"{mp}_{inten}" for (mp, inten) in conds_sorted]
        ax.set_xticks(x)
        ax.set_xticklabels(tick_labels, rotation=90, ha="right", fontsize=7)

        ax.set_xlim(x_left, x_right)
        ax.set_ylabel("PC1 score\n(all 4×4 features)")

        txt = (f"Posterior {POST_SUMMARY} vs experiment (PC1)\n"
               f"({std_label}):\n"
               f"  RMSE        = {rmse_pred:.3f}\n"
               f"  R²          = {r2_pred:.3f}\n"
               f"  R² (linear) = {r2_cal:.3f}\n"
               f"  Pearson r   = {pear_r:.3f}\n"
               f"  Spearman ρ  = {spear_rho:.3f}")
        ax.text(0.01, 0.99, txt, transform=ax.transAxes,
                va="top", ha="left", fontsize=8, color="black",
                bbox=dict(facecolor="white", alpha=0.75, edgecolor="none"))

# shared x-label on last row
axes[-1, 0].set_xlabel("Conditions (muscle_pair_intensity, sorted per row)")

# leave space on the right for colorbar
plt.tight_layout(rect=[0, 0, 0.9, 1.0])

# colorbar (shared)
if last_im is not None:
    cax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    cbar = fig.colorbar(last_im, cax=cax)
    cbar.set_label("Posterior column PDF (density)")

plt.show()


# Visual inspection of the experimental VS simulated synchronization cross-histograms
### + distance to experimental observations (in feature space) calculations

In [None]:
# For each observed metric:
# Mean +- std of experimental observations as vertical line + shaded area
# Histogram of distribution of means from simulated samples from posterior (posterior prediction)
# Histogram of distribution of means from simulated samples from prior (random, null predictions)
# Mean +- std of simulated observation from highest likelihood sample
distance_and_diff_dict = {}

os.makedirs(path_to_save_into, exist_ok=True)

for subject_i in subjects:
    if len(subjects)==1: # if only one subject (= atually not iterating over subjects)
        filtering_subjects = False
        subject_title = "pooled"
    else:
        filtering_subjects = True
        subject_title = subject_i
    for muscle_pair_i in muscle_pairs:
        for intensity_i in intensities:
            dict_storage_current_key = f"{subject_title}_{muscle_pair_i}_{intensity_i}"
            distance_and_diff_dict[dict_storage_current_key] = {}
            # Filter each data frame to get only the corresponding values #################
            temp_df_exp_obs = df_summary_obs_experiment.copy()
            if filtering_subjects and ('subject' in temp_df_exp_obs.columns):
                temp_df_exp_obs = temp_df_exp_obs[temp_df_exp_obs['subject']==subject_i]
            if 'muscle_pair' in temp_df_exp_obs.columns:
                temp_df_exp_obs = temp_df_exp_obs[temp_df_exp_obs['muscle_pair']==muscle_pair_i]
            if 'intensity' in temp_df_exp_obs.columns:
                temp_df_exp_obs = temp_df_exp_obs[temp_df_exp_obs['intensity']==intensity_i]
            #
            temp_df_exp_other_obs = df_obs_experiment.copy()
            if filtering_subjects and ('subject' in temp_df_exp_other_obs.columns):
                temp_df_exp_other_obs = temp_df_exp_other_obs[temp_df_exp_other_obs['subject']!=subject_i]
            if 'muscle_pair' in temp_df_exp_other_obs.columns:
                temp_df_exp_other_obs = temp_df_exp_other_obs[temp_df_exp_other_obs['muscle_pair']!=muscle_pair_i]
            if 'intensity' in temp_df_exp_other_obs.columns:
                temp_df_exp_other_obs = temp_df_exp_other_obs[temp_df_exp_other_obs['intensity']!=intensity_i]
            # Get the data frame for each different experimental condition
            temp_df_exp_obs_each_separately = {}
            for subject_j_iter, subject_j in enumerate(np.unique(df_obs_experiment['subject'].values)):
                if (not filtering_subjects):
                    subject_j_title = "pooled"
                    if (subject_j_iter > 0):
                        break
                else:
                    subject_j_title = subject_j
                for muscle_pair_j in np.unique(df_obs_experiment['muscle_pair'].values):
                    for intensity_j in np.unique(df_obs_experiment['intensity'].values):
                        temp_df_other_exp_cond = df_obs_experiment.copy()
                        if filtering_subjects:
                            temp_df_other_exp_cond = temp_df_other_exp_cond[
                                temp_df_other_exp_cond['subject']==subject_j]
                        #
                        temp_df_other_exp_cond = temp_df_other_exp_cond[
                            temp_df_other_exp_cond['muscle_pair']==muscle_pair_j]
                        temp_df_other_exp_cond = temp_df_other_exp_cond[
                            temp_df_other_exp_cond['intensity']==intensity_j]    
                        #                
                        temp_df_exp_obs_each_separately[
                            f"{subject_j_title}_{muscle_pair_j}_{intensity_j}"] = temp_df_other_exp_cond
            #
            temp_df_sim_obs_posterior_highest_likelihood = df_obs_simulated_from_posterior_sample_highest_likelihood.copy()
            if filtering_subjects:
                temp_df_sim_obs_posterior_highest_likelihood = temp_df_sim_obs_posterior_highest_likelihood[temp_df_sim_obs_posterior_highest_likelihood['subject']==subject_i]
            if 'muscle_pair' in temp_df_sim_obs_posterior_highest_likelihood.columns:
                temp_df_sim_obs_posterior_highest_likelihood = temp_df_sim_obs_posterior_highest_likelihood[temp_df_sim_obs_posterior_highest_likelihood['muscle_pair']==muscle_pair_i]
            if 'intensity' in temp_df_sim_obs_posterior_highest_likelihood.columns:
                temp_df_sim_obs_posterior_highest_likelihood = temp_df_sim_obs_posterior_highest_likelihood[temp_df_sim_obs_posterior_highest_likelihood['intensity']==intensity_i]
            #
            temp_df_sim_obs_posterior = df_summary_obs_simulated_from_posterior.copy()
            if filtering_subjects:
                temp_df_sim_obs_posterior = temp_df_sim_obs_posterior[temp_df_sim_obs_posterior['subject']==subject_i]
            if 'muscle_pair' in temp_df_sim_obs_posterior.columns:
                temp_df_sim_obs_posterior = temp_df_sim_obs_posterior[temp_df_sim_obs_posterior['muscle_pair']==muscle_pair_i]
            if 'intensity' in temp_df_sim_obs_posterior.columns:
                temp_df_sim_obs_posterior = temp_df_sim_obs_posterior[temp_df_sim_obs_posterior['intensity']==intensity_i]
            #
            temp_df_sim_obs_prior = df_summary_obs_simulated_from_prior.copy()
            if filtering_subjects:
                temp_df_sim_obs_prior = temp_df_sim_obs_prior[temp_df_sim_obs_prior['subject']==subject_i]
            if 'muscle_pair' in temp_df_sim_obs_prior.columns:
                temp_df_sim_obs_prior = temp_df_sim_obs_prior[temp_df_sim_obs_prior['muscle_pair']==muscle_pair_i]
            if 'intensity' in temp_df_sim_obs_prior.columns:
                temp_df_sim_obs_prior = temp_df_sim_obs_prior[temp_df_sim_obs_prior['intensity']==intensity_i]
            
            # Compute euclidean distance in observation space ##########################
            distance_and_diff_dict[dict_storage_current_key]['distance'] = {}
            exp_obs_vector = np.zeros(len(observations))
            sim_obs_posterior_highest_likelihood_vector = np.zeros(len(observations))
            sim_obs_posterior_mat = np.zeros((temp_df_sim_obs_posterior.shape[0], len(observations)))
            sim_obs_prior_mat = np.zeros((temp_df_sim_obs_prior.shape[0], len(observations)))
            exp_other_obs_vector = np.zeros((temp_df_exp_other_obs.shape[0], len(observations)))
            for i, obs_metric in enumerate(observations):
                exp_obs_vector[i] = temp_df_exp_obs[f"{obs_metric}_mean"].iloc[0]
                exp_other_obs_vector[:,i] = temp_df_exp_other_obs[f"{obs_metric}"]
                sim_obs_posterior_highest_likelihood_vector[i] = temp_df_sim_obs_posterior_highest_likelihood[f"{obs_metric}"].mean()
                sim_obs_posterior_mat[:,i] = temp_df_sim_obs_posterior[f"{obs_metric}_mean"]
                sim_obs_prior_mat[:,i] = temp_df_sim_obs_prior[f"{obs_metric}_mean"]
            # Euclidean distances, vectorized:
            euc_dist_posterior = np.linalg.norm(
                sim_obs_posterior_mat - exp_obs_vector[np.newaxis, :],
                axis=1)
            distance_and_diff_dict[dict_storage_current_key][
                'distance']['posterior'] = euc_dist_posterior
            euc_dist_posterior_highest_likelihood = np.linalg.norm(
                sim_obs_posterior_highest_likelihood_vector[np.newaxis, :] - exp_obs_vector[np.newaxis, :],
                axis=1)
            distance_and_diff_dict[dict_storage_current_key][
                'distance']['posterior_highest_likelihood'] = euc_dist_posterior_highest_likelihood
            euc_dist_prior = np.linalg.norm(
                sim_obs_prior_mat - exp_obs_vector[np.newaxis, :],
                axis=1)
            distance_and_diff_dict[dict_storage_current_key][
                'distance']['prior'] = euc_dist_prior
            euc_dist_other_experimental_conditions = np.linalg.norm(
                exp_other_obs_vector - exp_obs_vector[np.newaxis, :],
                axis=1)
            euc_dist_other_experimental_conditions = euc_dist_other_experimental_conditions[ # ensure no NaN values
                np.isfinite(euc_dist_other_experimental_conditions)]
            distance_and_diff_dict[dict_storage_current_key][
                'distance']['other_exp_pooled'] = euc_dist_other_experimental_conditions
            # Same, but for each experimental condition separately
            distance_and_diff_dict[dict_storage_current_key][
                'distance']['other_exp_separately'] = {}
            for exp_cond_i_idx, exp_cond_i_vals in temp_df_exp_obs_each_separately.items():
                temp_mat = np.zeros((exp_cond_i_vals.shape[0], len(observations)))
                for i, obs_metric in enumerate(observations):
                    temp_mat[:,i] = exp_cond_i_vals[f"{obs_metric}"]
                temp_euc_dist = np.linalg.norm(
                    temp_mat - exp_obs_vector[np.newaxis, :],
                    axis=1)
                distance_and_diff_dict[dict_storage_current_key][
                    'distance']['other_exp_separately'][exp_cond_i_idx] = temp_euc_dist

            # Start plotting ############################
            N = len(observations)  # =4
            n_grid_rows = math.floor(math.sqrt(N))  # =2
            n_grid_cols = math.ceil (math.sqrt(N))  # =2
            total_rows  = 1 + n_grid_rows           # =3

            fig = plt.figure(figsize=(16, 20))
            gs  = fig.add_gridspec(total_rows, n_grid_cols,
                                height_ratios=[1] + [1]*n_grid_rows,
                                hspace=0.4, wspace=0.3)
            
            # 1) main axes spanning the full width - euclidean distance in "observation/metric space"
            ax_main = fig.add_subplot(gs[0, :])
            # Define a common x‐axis grid spanning the full range
            all_data = np.concatenate([euc_dist_posterior, euc_dist_prior, euc_dist_other_experimental_conditions])
            x_min, x_max = 0, all_data.max() / 2 # all_data.min(), all_data.max() / 2 # dividing the max by two seems a good cutoff
            x_grid = np.linspace(x_min, x_max, 200)
            # Fit a KDE to each (automatic bandwidth)
            if euc_dist_posterior.shape[0] >= 1:
                kde_euc_dist_posterior = gaussian_kde(euc_dist_posterior)
                density_to_plot_posterior = kde_euc_dist_posterior(x_grid)
            else:
                density_to_plot_posterior = np.full((len(x_grid)),np.nan)
            if euc_dist_prior.shape[0] >= 1:
                kde_euc_dist_prior = gaussian_kde(euc_dist_prior)
                density_to_plot_prior = kde_euc_dist_prior(x_grid)
            else:
                density_to_plot_prior = np.full((len(x_grid)),np.nan)
            if euc_dist_other_experimental_conditions.shape[0] >= 1:
                kde_euc_dist_other_experimental_conditions = gaussian_kde(euc_dist_other_experimental_conditions)
                density_to_plot_other_exp = kde_euc_dist_other_experimental_conditions(x_grid)
            else:
                density_to_plot_other_exp = np.full((len(x_grid)),np.nan)
            # Actually plot
            ax_main.axvline(euc_dist_posterior_highest_likelihood[0], linewidth=2, color='blue', alpha=0.7, zorder=10,
                        label=f"Sim with highest likelihood\nDist = {euc_dist_posterior_highest_likelihood[0]:.2f}\n ")
            ax_main.plot(x_grid, density_to_plot_posterior,
                        color="#00A2FFFF", alpha=1, zorder=1, linewidth=3,
                        label=f"Distribution of distances to experimental data for\ndata simulated from posterior (= from inferred params)\nMean dist = {np.mean(euc_dist_posterior):.2f}±{np.std(euc_dist_posterior):.2f}\n ")
            ax_main.plot(x_grid, density_to_plot_prior,
                        color="#AA69FF", alpha=1, zorder=0, linewidth=3,
                        label=f"Distribution of distances to experimental data for\ndata simulated from prior (= random params, null model)\nMean dist = {np.mean(euc_dist_prior):.2f}±{np.std(euc_dist_prior):.2f}\n ")
            ax_main.plot(x_grid, density_to_plot_other_exp,
                        color="#FFC400", alpha=1, zorder=-1, linewidth=3,
                        label=f"Distribution of distances to experimental data for\nexperimental data from other conditions (each MN)\nMean dist = {np.nanmean(euc_dist_other_experimental_conditions):.2f}±{np.nanstd(euc_dist_other_experimental_conditions):.2f}\n ")
            #
            ax_main.legend(fontsize="small")
            ax_main.set_ylabel("Density")
            ax_main.set_xlabel("Euclidean distance in the (standardized) observation/metric space")
            ax_main.set_title(f"Subject: {subject_title}, Muscle pair: {muscle_pair_i}, Intensity: {intensity_i}%\n \nEuclidean distance in observation-space")

            # 2) detail axes in the grid below - Difference for each metric/observation
            detail_axes = []
            distance_and_diff_dict[dict_storage_current_key]['difference'] = {
                'posterior': {},
                'posterior_highest_likelihood': {},
                'prior': {},
                'other_exp_pooled': {},
                'other_exp_separately': {}
            }
            for i in range(N):
                # Create plot
                # row = 1 or 2, col = 0 or 1
                row = 1 + i // n_grid_cols
                col = i %  n_grid_cols
                ax = fig.add_subplot(gs[row, col])
                # Data to display
                obs_exp_mean = temp_df_exp_obs[f"{observations[i]}_mean"].iloc[0]
                obs_exp_std = temp_df_exp_obs[f"{observations[i]}_std"].iloc[0]
                obs_sim_posterior_highest_likelihood_mean = temp_df_sim_obs_posterior_highest_likelihood[f"{observations[i]}"].mean()
                obs_sim_posterior_highest_likelihood_std = temp_df_sim_obs_posterior_highest_likelihood[f"{observations[i]}"].std()
                abs_dist_posterior = np.abs(obs_exp_mean - temp_df_sim_obs_posterior[f"{observations[i]}_mean"].values)
                distance_and_diff_dict[dict_storage_current_key][
                    'difference']['posterior'][observations[i]] = abs_dist_posterior
                abs_dist_posterior_max_likelihood = np.abs(obs_exp_mean - temp_df_sim_obs_posterior_highest_likelihood[f"{observations[i]}"].values)
                distance_and_diff_dict[dict_storage_current_key][
                    'difference']['posterior_highest_likelihood'][observations[i]] = abs_dist_posterior_max_likelihood
                abs_dist_prior = np.abs(obs_exp_mean - temp_df_sim_obs_prior[f"{observations[i]}_mean"].values)
                distance_and_diff_dict[dict_storage_current_key][
                    'difference']['prior'][observations[i]] = abs_dist_prior
                abs_dist_other_exp_obs_pooled = np.abs(obs_exp_mean - temp_df_exp_other_obs[f"{observations[i]}"].values)
                distance_and_diff_dict[dict_storage_current_key][
                    'difference']['other_exp_pooled'][observations[i]] = abs_dist_other_exp_obs_pooled
                # Fill dict with difference for other experimental observations, separately
                for other_cond_i_idx, other_cond_i_df in temp_df_exp_obs_each_separately.items():
                    if other_cond_i_idx not in distance_and_diff_dict[dict_storage_current_key]['difference']['other_exp_separately'].keys():
                        distance_and_diff_dict[dict_storage_current_key]['difference']['other_exp_separately'][other_cond_i_idx] = {}
                    distance_and_diff_dict[dict_storage_current_key]['difference'][
                        'other_exp_separately'][other_cond_i_idx][observations[i]] = np.abs(
                        obs_exp_mean - other_cond_i_df[f"{observations[i]}"].values)
                # exp data
                ax.axvline(obs_exp_mean, linewidth=2, color='red', alpha=0.7, zorder=10,
                           label=f"Experimental observations (each MN)\nMean = {obs_exp_mean:.2f}±{obs_exp_std:.2f}\n ")
                ax.axvline(obs_exp_mean-obs_exp_std, linewidth=2, color='red', linestyle='--', alpha=0.3, zorder=10)
                ax.axvline(obs_exp_mean+obs_exp_std, linewidth=2, color='red', linestyle='--', alpha=0.3, zorder=10)
                ax.axvspan(xmin=obs_exp_mean-obs_exp_std, xmax=obs_exp_mean+obs_exp_std, color="red", alpha=0.1, zorder=-10)
                # sim data highest likelihood posterior
                ax.axvline(obs_sim_posterior_highest_likelihood_mean, linewidth=2, color='blue', alpha=0.7, zorder=10,
                           label=f"Simulated obs from highest-likelihood\nparameters from posterior (each MN)\nMean = {obs_sim_posterior_highest_likelihood_mean:.2f}±{obs_sim_posterior_highest_likelihood_std:.2f}\n ")
                ax.axvline(obs_sim_posterior_highest_likelihood_mean-obs_sim_posterior_highest_likelihood_std, linewidth=2, color='blue', linestyle='--', alpha=0.3, zorder=10)
                ax.axvline(obs_sim_posterior_highest_likelihood_mean+obs_sim_posterior_highest_likelihood_std, linewidth=2, color='blue', linestyle='--', alpha=0.3, zorder=10)
                ax.axvspan(xmin=obs_sim_posterior_highest_likelihood_mean-obs_sim_posterior_highest_likelihood_std, xmax=obs_sim_posterior_highest_likelihood_mean+obs_sim_posterior_highest_likelihood_std, color="blue", alpha=0.1, zorder=-10)
                # Compte the KDEs to dsiplay
                # Define a common x‐axis grid spanning the full range
                temp_obs_posterior = temp_df_sim_obs_posterior[f"{observations[i]}_mean"].values
                temp_obs_posterior = temp_obs_posterior[np.isfinite(temp_obs_posterior)]
                temp_obs_prior = temp_df_sim_obs_prior[f"{observations[i]}_mean"].values
                temp_obs_prior = temp_obs_prior[np.isfinite(temp_obs_prior)]
                temp_obs_others = temp_df_exp_other_obs[f"{observations[i]}"].values
                temp_obs_others = temp_obs_others[np.isfinite(temp_obs_others)]
                all_data = np.concatenate([temp_obs_posterior,
                                           temp_obs_prior,
                                           temp_obs_others])
                x_min, x_max = all_data.min(), all_data.max()
                x_grid = np.linspace(x_min, x_max, 200)
                # Fit a KDE to each (automatic bandwidth)
                if len(temp_obs_posterior) >= 1:
                    kde_sim_posterior = gaussian_kde(temp_obs_posterior)
                    density_to_plot_posterior = kde_sim_posterior(x_grid)
                else:
                    density_to_plot_posterior = np.full((len(x_grid)),np.nan)
                if len(temp_obs_prior) >= 1:
                    kde_sim_prior = gaussian_kde(temp_obs_prior)
                    density_to_plot_prior = kde_sim_prior(x_grid)
                else:
                    density_to_plot_prior = np.full((len(x_grid)),np.nan)
                if len(temp_obs_others) >= 1:
                    kde_other_experimental_conditions = gaussian_kde(temp_obs_others)
                    density_to_plot_other_exp = kde_other_experimental_conditions(x_grid)
                else:
                    density_to_plot_other_exp = np.full((len(x_grid)),np.nan)
                # sim data posterior
                ax.plot(x_grid, density_to_plot_posterior,
                        color="#00A2FFFF", alpha=1, zorder=1, linewidth=3,
                        label=f"Distribution of means for\ndata simulated from posterior\n(= from inferred params)\nMean diff = {np.mean(abs_dist_posterior):.2f}±{np.std(abs_dist_posterior):.2f}\n ")
                # sim data prior
                ax.plot(x_grid, density_to_plot_prior,
                        color="#AA69FF", alpha=1, zorder = 0, linewidth=3,
                        label=f"Distribution of means for\ndata simulated from prior\n(= random params, null model)\nMean diff = {np.mean(abs_dist_prior):.2f}±{np.std(abs_dist_prior):.2f}\n ")
                # distance to experimental data from other conditions
                ax.plot(x_grid, density_to_plot_other_exp,
                        color="#FFC400", alpha=1, zorder = -1, linewidth=3,
                        label=f"Distribution experimental data\nfrom other conditions (each MN)\nMean diff = {np.nanmean(abs_dist_other_exp_obs_pooled):.2f}±{np.nanstd(abs_dist_other_exp_obs_pooled):.2f}\n ")
                #
                ax.set_ylabel("Density")
                ax.set_xlabel(f"Standardized {observations[i]}\n(pooled experimental observations are transformed\nto have mean = 0, std = 1)")
                ax.legend(fontsize="small")
                ax.set_title(f"{observations[i]}")
                detail_axes.append(ax)

            plt.savefig(f"{path_to_save_into}\\{re.sub(r'[^0-9A-Za-z._-]+', '', str(muscle_pair_i))}_{intensity_i}_{subject_title}.png")
            plt.savefig(f"{path_to_save_into}\\{re.sub(r'[^0-9A-Za-z._-]+', '', str(muscle_pair_i))}_{intensity_i}_{subject_title}.svg")      
            plt.close()
            # plt.show()
            # plt.close()

            ### Display the direct histogram comparisons
            def sanitize(x): # Function to recreate the appropriate key string from the loop
                # 1) If it looks like a number, cast to float, then to int if possible
                try:
                    val = float(x)
                except (TypeError, ValueError):
                    s = str(x)
                else:
                    if val.is_integer():
                        return str(int(val))
                    return str(val)
                # 2) Otherwise it's a string: replace your arrows with hyphens
                s = re.sub(r'<->', '-', s)
                # 3) (Optional) drop anything except letters, digits, dash, or underscore
                s = re.sub(r'[^A-Za-z0-9\-_]', '', s)
                return s
            #
            if not filtering_subjects:
                subject_str = ''
            else:
                subject_str = f"{subject_i}_"
            if perspective == 'other_MUs_as_ref':
                perspective_name_for_cross_hist = 'inhibited'
            elif perspective == 'MU_as_ref':
                perspective_name_for_cross_hist = 'inhibiting'

            ### START THE FIGURE
            plt.figure(figsize=(6,12))
            ## SUBPLOT 1) Experimentally observed cross-histograms
            plt.subplot(2,1,1)
            # Loop through exp_analysis_results_dict and find matching
            temp_list_of_exp_MNs_cross_hist = []
            for exp_analysis_key, exp_analysis_result in exp_analysis_results_dict.items():
                if 'Cross_histograms' not in exp_analysis_result.keys():
                    continue
                # Check if correct subject - only if exp_analysis_key==True
                if filtering_subjects:
                    if subject_i not in exp_analysis_key:
                        continue
                # Check if correct intensity
                if str(np.round(intensity_i).astype(int)) not in exp_analysis_key:
                    continue
                # Check if correct muscle pair
                if muscle_pair_i not in exp_analysis_result['Cross_histograms']['cross_histograms']:
                    continue
                # Now load all relevant MNs in the list
                for exp_mn_idx, exp_mn_cross_hist in exp_analysis_result['Cross_histograms']['cross_histograms'][muscle_pair_i].items():
                    # Check if the histogram match the filtering criterion
                    n_spikes_temp = exp_analysis_result['Cross_histograms'][muscle_pair_i][exp_mn_idx][perspective_name_for_cross_hist]['n_spikes']
                    r2_full_temp = exp_analysis_result['Cross_histograms'][muscle_pair_i][exp_mn_idx][perspective_name_for_cross_hist]['r2_full']
                    r2_base_temp = exp_analysis_result['Cross_histograms'][muscle_pair_i][exp_mn_idx][perspective_name_for_cross_hist]['r2_base']
                    if (n_spikes_temp < min_nb_spikes["experiment"]) or (r2_full_temp < min_r2_for_overall_curve_fit["experiment"]) or (r2_base_temp < min_r2_for_baseline_curve_fit["experiment"]):
                        continue
                    temp_list_of_exp_MNs_cross_hist.append(exp_mn_cross_hist[perspective_name_for_cross_hist])
            if len(temp_list_of_exp_MNs_cross_hist) > 0:
                hist_samples_length = len(temp_list_of_exp_MNs_cross_hist[0])
                cross_hist_time = np.linspace(start=-200, stop=200, num=hist_samples_length)
                for mn_i in range(len(temp_list_of_exp_MNs_cross_hist)):
                        plt.plot(cross_hist_time, temp_list_of_exp_MNs_cross_hist[mn_i],
                            color=colors_dict[muscle_pair_i], alpha=0.05, linewidth=1)
                            # alpha=4/len(temp_list_of_exp_MNs_cross_hist), linewidth=1.5)
                # plot mean histogram over all MNs
                mean_exp_cross_hist = np.nanmean(np.array(temp_list_of_exp_MNs_cross_hist), axis=0)
                # plt.plot(cross_hist_time, mean_exp_cross_hist,
                #             color=colors_dict[muscle_pair_i], alpha=1, linewidth=3)
            plt.title(f"Experimental data\n(n = {len(temp_list_of_exp_MNs_cross_hist)} MNs)")
            plt.xlabel("Time (ms)")
            plt.ylabel("Firing probability")
            exp_subplot_ymax = plt.ylim()[1]
            ## SUBPLOT 2) Simulated cross-histograms from posterior estimated parameters
            plt.subplot(2,1,2)
            sim_subplot_ymax = 0
            key_to_load_best_sim = f"{subject_str}{sanitize(muscle_pair_i)}_{sanitize(intensity_i)}_simHighestLikelihood"
            #       # Continue further only if the key exists in the "best_sim" to load
            left, right = map(str.strip, muscle_pair_i.split('<->', 1))
            pool_pair_key = 'pool_0<->pool_0' if left == right else 'pool_0<->pool_1'
            if key_to_load_best_sim in best_sims.keys():
                best_sim_histograms = best_sims[key_to_load_best_sim]['Cross_histograms']['cross_histograms'][pool_pair_key]
                hist_samples_length = len(best_sim_histograms[0][perspective_name_for_cross_hist])
                cross_hist_time = np.linspace(start=-200, stop=200, num=hist_samples_length)
                sum_hist_for_mean = np.zeros_like(cross_hist_time)
                num_valid_mns = 0
                for mn_i in best_sim_histograms.keys():
                    # Only plot valid MNs (fit R² and spike nb meet the condtions)
                    n_spikes_temp = best_sims[key_to_load_best_sim]['Cross_histograms'][pool_pair_key][mn_i][perspective_name_for_cross_hist]['n_spikes']
                    r2_full_temp = best_sims[key_to_load_best_sim]['Cross_histograms'][pool_pair_key][mn_i][perspective_name_for_cross_hist]['r2_full']
                    r2_base_temp = best_sims[key_to_load_best_sim]['Cross_histograms'][pool_pair_key][mn_i][perspective_name_for_cross_hist]['r2_base']
                    if (n_spikes_temp < min_nb_spikes["simulation"]) or (r2_full_temp < min_r2_for_overall_curve_fit["simulation"]) or (r2_base_temp < min_r2_for_baseline_curve_fit["simulation"]):
                        continue
                    sum_hist_for_mean += best_sim_histograms[mn_i][perspective_name_for_cross_hist]
                    num_valid_mns += 1
                    # plot each MN's histogram
                    plt.plot(cross_hist_time, best_sim_histograms[mn_i][perspective_name_for_cross_hist],
                            color=colors_dict[muscle_pair_i], alpha=0.1, linewidth=1) # alpha=3/len(best_sim_histograms), linewidth=2)
                sim_subplot_ymax = plt.ylim()[1]
                # plot mean histogram over all MNs
                if num_valid_mns >= 1:
                    mean_exp_cross_hist = sum_hist_for_mean / num_valid_mns
                    # plt.plot(cross_hist_time, mean_exp_cross_hist,
                    #             color=colors_dict[muscle_pair_i], alpha=1, linewidth=3)
            plt.title(f"Simulation from posterior\n(inferred parameters /w highest likelihood)")
            plt.xlabel("Time (ms)")
            plt.ylabel("Firing probability")
            # Get same Y limits for both subplots
            ylim_to_select = np.max([exp_subplot_ymax,sim_subplot_ymax])
            plt.subplot(2,1,1)
            # plt.axvline(x=0, ymin=0, ymax=1, color='k', linestyle='--', linewidth=1.5, alpha=0.2, zorder=-10)
            plt.ylim(-0.0002, ylim_to_select)
            plt.subplot(2,1,2)
            # plt.axvline(x=0, ymin=0, ymax=1, color='k', linestyle='--', linewidth=1.5, alpha=0.2, zorder=-10)
            plt.ylim(-0.0002, ylim_to_select)
            # 
            plt.suptitle(f"Subject: {subject_title}, Muscle pair: {muscle_pair_i}, Intensity: {intensity_i}%\nCross-histogram comparisons")
            # plt.tight_layout()
            plt.savefig(f"{path_to_save_into}\\{re.sub(r'[^0-9A-Za-z._-]+', '', str(muscle_pair_i))}_{intensity_i}_{subject_title}_cross_histograms.png")
            plt.savefig(f"{path_to_save_into}\\{re.sub(r'[^0-9A-Za-z._-]+', '', str(muscle_pair_i))}_{intensity_i}_{subject_title}_cross_histograms.svg")
            plt.show()

            

# Example display of simulated data in 2D feature-subspace
(part of Fig. 4 C in paper)

In [None]:
def pairplot_grid_marginals(
    df, xcol, ycol,
    *,
    # data scaling / normalization (used for binning/KDE + default axis extents)
    x_range=None, y_range=None,            # (min, max) or None
    # explicit display bounds (override axis limits on all three panels)
    xlim=None, ylim=None,                  # (min, max) or None
    # title & colors
    main_title=None,
    x_color="#FF9100",
    y_color="#0077FF",
    base_color="#808080",
    # per-axis color scales (list of >=2 colors). If None → [base_color, x_color]/[base_color, y_color]
    colors_for_x_axis_colorscale=None,
    colors_for_y_axis_colorscale=None,
    # point color blending of per-axis colors
    blend_space="rgb",                     # "rgb" or "lab"
    point_size=18,
    alpha_pts=0.35,
    # marginals control
    marginal_mode="hist",                  # "hist" or "kde"
    bins=60,                               # used in hist mode
    n_grid=400,                            # used in *marginal* kde mode
    kde_boundary="reflect",                # "reflect" or "none"
    hist_max_count=None,                   # cap for BOTH marginal panels (ignored in KDE mode)
    # overlays (legacy)
    overlay_points=None,                   # dicts with x,y,color,label
    overlay_vlines=None,                   # list[(x, color)] on BR (X marginal)
    overlay_hlines=None,                   # list[(y, color)] on TL (Y marginal)
    # global scatter KDE overlay (legacy)
    scatter_kde=False,
    scatter_kde_levels=(0.2, 0.5, 0.8),
    scatter_kde_bw=None,
    scatter_kde_grid=200,
    scatter_kde_color="#222222",
    scatter_kde_lw=1.5,
    scatter_kde_alpha=0.9,
    scatter_kde_max_points=8000,
    scatter_kde_level_alphas=None,
    scatter_kde_alpha_min=0.3,
    scatter_kde_alpha_max=0.95,

    # ---- NEW: filtered per-condition mode ----
    dataset_mode="full_dataset",           # "full_dataset" | "experimental_condition_filter"
    condition_filters=None,                # dict[str -> list]; matched across keys, same length lists
    condition_colors=None,                 # list of colors, same len as lists in condition_filters
    condition_colors_scatter_kde=None,
    group_scatter_kde=False,               # draw per-group 2D KDE (only if experimental_condition_filter)
    group_kde_levels=(0.25, 0.5, 0.75),
    group_kde_alpha=0.9,
    group_kde_lw=2.0,

    # ---- NEW: overlay summary points from another DF (works in both modes) ----
    overlay_summary_df=None,               # DataFrame containing summary xcol/ycol
    overlay_summary_conditions=None,       # dict[str -> list] same shaping as condition_filters
    overlay_summary_colors=None,           # list of colors (same length)
    overlay_summary_marker="X",
    overlay_summary_size=90,
    overlay_summary_edge="k",
    overlay_summary_lw=1.0,

    # figure / IO
    figsize=(8.4, 8.4),
    savepath=None,
    csv_prefix=None
):
    import numpy as np, pandas as pd
    import matplotlib.pyplot as plt
    from matplotlib.gridspec import GridSpec

    # ---------- helpers ----------
    def _hex_to_rgb01(h: str):
        h = h.strip()
        if h.startswith("#"): h = h[1:]
        if len(h)==3: h = "".join([c*2 for c in h])
        return (int(h[0:2],16)/255.0, int(h[2:4],16)/255.0, int(h[4:6],16)/255.0)

    def _to_rgb01(c):
        if isinstance(c, str): return _hex_to_rgb01(c)
        if isinstance(c, (list, tuple)) and len(c)==3: return tuple(float(v) for v in c)
        raise ValueError("Colors must be hex strings like '#AABBCC' or RGB tuples in 0..1")

    def _interp_color_list(color_list, t):
        cols = np.array([_to_rgb01(c) for c in color_list], dtype=float)
        N = len(cols)
        if N < 2: raise ValueError("Provide at least two colors per axis colormap.")
        t = float(np.clip(t, 0.0, 1.0))
        if t <= 0:  return tuple(cols[0])
        if t >= 1:  return tuple(cols[-1])
        pos = t * (N - 1); i0 = int(np.floor(pos)); i1 = i0 + 1; frac = pos - i0
        return tuple((1-frac)*cols[i0] + frac*cols[i1])

    def _norm_to_unit(vals, vmin=None, vmax=None):
        vals = np.asarray(vals, float)
        lo = np.nanmin(vals) if vmin is None else float(vmin)
        hi = np.nanmax(vals) if vmax is None else float(vmax)
        if hi <= lo: return np.zeros_like(vals), (lo, hi)
        return ((vals - lo) / (hi - lo)), (lo, hi)

    def _gaussian_kde_1d(x, grid, bw=None):
        x = np.asarray(x, float); x = x[np.isfinite(x)]
        n = x.size
        if n < 2: return np.zeros_like(grid, dtype=float)
        std = np.std(x, ddof=1)
        if std <= 0:
            mu = float(x[0]); s = (grid.max() - grid.min()) * 1e-3 or 1e-6
            return np.exp(-0.5*((grid-mu)/s)**2)/(s*np.sqrt(2*np.pi))
        if bw is None: bw = std * n**(-1/5)  # Scott
        z = (grid[:,None] - x[None,:]) / bw
        dens = np.exp(-0.5*z*z).sum(axis=1) / (n * bw * np.sqrt(2*np.pi))
        return dens

    def _kde_bounded_reflect(x, grid, lo, hi, bw=None):
        x = np.asarray(x, float); x = x[np.isfinite(x)]
        if x.size == 0: return np.zeros_like(grid)
        x_aug = np.concatenate([x, 2*lo - x, 2*hi - x])  # reflect at bounds
        dens = _gaussian_kde_1d(x_aug, grid, bw=bw)
        area = np.trapz(dens, grid)
        if area > 0: dens = dens / area
        return dens

    def _save_pairplot_grid_csvs(basepath_no_ext, *, mode, x_label, y_label,
                                 x=None, y=None,
                                 x_grid=None, x_kde=None,
                                 y_grid=None, y_kde=None,
                                 x_hist=None, y_hist=None,
                                 overlay_points=None):
        if x is not None and y is not None:
            pd.DataFrame({x_label: x, y_label: y}).to_csv(basepath_no_ext + "_scatter.csv", index=False)
        if mode == "kde":
            if x_grid is not None and x_kde is not None:
                pd.DataFrame({"x": x_grid, "density": x_kde}).to_csv(basepath_no_ext + "_kde_x.csv", index=False)
            if y_grid is not None and y_kde is not None:
                pd.DataFrame({"y": y_grid, "density": y_kde}).to_csv(basepath_no_ext + "_kde_y.csv", index=False)
        elif mode == "hist":
            if x_hist is not None:
                counts, edges = x_hist
                pd.DataFrame({"bin_left": edges[:-1], "bin_right": edges[1:], "count": counts.astype(float)}).to_csv(
                    basepath_no_ext + "_hist_x.csv", index=False
                )
            if y_hist is not None:
                counts, edges = y_hist
                pd.DataFrame({"bin_left": edges[:-1], "bin_right": edges[1:], "count": counts.astype(float)}).to_csv(
                    basepath_no_ext + "_hist_y.csv", index=False
                )
        if overlay_points:
            pd.DataFrame([{
                x_label: float(p["x"]), y_label: float(p["y"]),
                "label": p.get("label",""), "color": p.get("color","")
            } for p in overlay_points]).to_csv(basepath_no_ext + "_overlay_points.csv", index=False)

    def _make_group_masks(df, cond_dict):
        """Return list of (mask, label) for each aligned condition across keys."""
        keys = list(cond_dict.keys())
        lens = {k: len(cond_dict[k]) for k in keys}
        if len(set(lens.values())) != 1:
            raise ValueError("All lists in condition_filters/overlay_summary_conditions must have the same length.")
        L = next(iter(lens.values()))
        out = []
        for i in range(L):
            m = np.ones(len(df), dtype=bool)
            parts = []
            for k in keys:
                v = cond_dict[k][i]
                m &= (df[k] == v)
                parts.append(f"{k}={v}")
            label = ", ".join(parts)
            out.append((m, label))
        return out

    # ---------- data & ranges ----------
    x_all = np.asarray(df[xcol], float)
    y_all = np.asarray(df[ycol], float)

    # ranges for binning/KDE
    xn, (xmin_data, xmax_data) = _norm_to_unit(x_all, *(x_range or (None, None)))
    yn, (ymin_data, ymax_data) = _norm_to_unit(y_all, *(y_range or (None, None)))

    # visual bounds
    xmin_plot, xmax_plot = (xlim if xlim is not None else (xmin_data, xmax_data))
    ymin_plot, ymax_plot = (ylim if ylim is not None else (ymin_data, ymax_data))

    # ---------- per-axis color scales (for FULL mode point coloring) ----------
    if colors_for_x_axis_colorscale is None:
        colors_for_x_axis_colorscale = [base_color, x_color]
    if colors_for_y_axis_colorscale is None:
        colors_for_y_axis_colorscale = [base_color, y_color]

    xn_c = np.clip(np.nan_to_num(xn, nan=0.0), 0.0, 1.0)
    yn_c = np.clip(np.nan_to_num(yn, nan=0.0), 0.0, 1.0)
    cx = np.array([_interp_color_list(colors_for_x_axis_colorscale, t) for t in xn_c], dtype=float)
    cy = np.array([_interp_color_list(colors_for_y_axis_colorscale, t) for t in yn_c], dtype=float)

    # blend per-point colors for full dataset mode
    def _blend_rgb(cx_arr, cy_arr, xn_c, yn_c):
        wy = yn_c / np.clip(xn_c + yn_c, 1e-12, None)  # Y contribution
        wx = 1.0 - wy
        if blend_space.lower() == "lab":
            try:
                from skimage.color import rgb2lab, lab2rgb
                cx_lab = rgb2lab(np.clip(cx_arr.reshape(-1,1,3), 0, 1)).reshape(-1,3)
                cy_lab = rgb2lab(np.clip(cy_arr.reshape(-1,1,3), 0, 1)).reshape(-1,3)
                lab_mix = (wx[:,None]*cx_lab + wy[:,None]*cy_lab)
                rgb = np.clip(lab2rgb(lab_mix.reshape(-1,1,3)).reshape(-1,3), 0, 1)
                return rgb
            except Exception:
                print("[blend_space='lab'] falling back to RGB blend.")
        return np.clip(wx[:,None]*cx_arr + wy[:,None]*cy_arr, 0, 1)

    rgb_full = _blend_rgb(cx, cy, xn_c, yn_c)

    # ---------- layout ----------
    fig = plt.figure(figsize=figsize, constrained_layout=False)
    gs  = GridSpec(2, 2, width_ratios=[1, 1], height_ratios=[1, 1],
                   wspace=0.07, hspace=0.07, figure=fig)
    ax_tl = fig.add_subplot(gs[0, 0])  # Y marginal
    ax_tr = fig.add_subplot(gs[0, 1])  # scatter
    ax_bl = fig.add_subplot(gs[1, 0])  # legend
    ax_br = fig.add_subplot(gs[1, 1])  # X marginal

    # ---------- plotting helpers ----------
    def _plot_hist(ax, data, lo, hi, color=None, label=None, vertical=False):
        counts, edges = np.histogram(data[np.isfinite(data)], bins=bins, range=(lo, hi))
        if vertical:
            # standard orientation (x = centers)
            centers = 0.5*(edges[:-1] + edges[1:])
            widths  = (edges[1:] - edges[:-1])
            if color is None:
                # gradient bars (full mode)
                t = np.clip((centers - lo) / (hi - lo + 1e-12), 0, 1)
                bar_colors = [ _interp_color_list(colors_for_y_axis_colorscale, ti) for ti in t ]
                ax.bar(centers, counts, width=widths, color=bar_colors, edgecolor="white", linewidth=1.5, alpha=0.8, align='center')
            else:
                # outline steps for groups
                ax.stairs(counts, edges, fill=False, color=color, linewidth=2.0, label=label)
        else:
            centers = 0.5*(edges[:-1] + edges[1:])
            widths  = (edges[1:] - edges[:-1])
            if color is None:
                t = np.clip((centers - lo) / (hi - lo + 1e-12), 0, 1)
                bar_colors = [ _interp_color_list(colors_for_x_axis_colorscale, ti) for ti in t ]
                ax.bar(centers, counts, width=widths, color=bar_colors, edgecolor="white", linewidth=1.5, alpha=0.8, align='center')
            else:
                ax.stairs(counts, edges, fill=False, color=color, linewidth=2.0, label=label)
        return counts, edges

    def _plot_kde(ax, data, lo, hi, color, label=None):
        grid = np.linspace(lo, hi, n_grid)
        if kde_boundary == "reflect":
            dens = _kde_bounded_reflect(data, grid, lo, hi, bw=None)
        else:
            dens = _gaussian_kde_1d(data, grid, bw=None)
            area = np.trapz(dens, grid)
            if area > 0: dens = dens/area
        ax.plot(grid, dens, color=color, lw=2.2, label=label)
        return grid, dens

    def _kde2d(ax, X, Y, xlim, ylim, levels, color, alpha=0.9, lw=2.0):
        if X.size < 2: return
        X = X[np.isfinite(X)]; Y = Y[np.isfinite(Y)]
        if X.size < 2: return
        gx = np.linspace(xlim[0], xlim[1], scatter_kde_grid)
        gy = np.linspace(ylim[0], ylim[1], scatter_kde_grid)
        XX, YY = np.meshgrid(gx, gy)
        sx = np.std(X, ddof=1); sy = np.std(Y, ddof=1)
        if sx <= 0 or sy <= 0:
            sx = (xlim[1] - xlim[0]) * 1e-3 or 1e-6
            sy = (ylim[1] - ylim[0]) * 1e-3 or 1e-6
        n = X.size; h_scott = n ** (-1/6)
        bx = sx * h_scott; by = sy * h_scott
        dens = np.zeros_like(XX, dtype=float)
        chunk = max(1, int(2e5 // XX.size))
        for start in range(0, n, chunk):
            end = min(n, start+chunk)
            dx = (XX[...,None] - X[None,None,start:end]) / bx
            dy = (YY[...,None] - Y[None,None,start:end]) / by
            dens += np.exp(-0.5*(dx*dx + dy*dy)).sum(axis=2)
        dens /= (n * (2*np.pi*bx*by))
        dmin, dmax = float(np.nanmin(dens)), float(np.nanmax(dens))
        if dmax <= dmin: return
        dens_norm = (dens - dmin) / (dmax - dmin)
        cs = ax.contour(XX, YY, dens_norm, levels=levels, colors=color, linewidths=lw, alpha=alpha)
        return cs

    # ---------- FULL vs FILTERED ----------
    show_filtered = (dataset_mode == "experimental_condition_filter" and condition_filters is not None)

    # --- TL (Y marginal) ---
    y_hist = y_hist_grid = y_hist_dens = None
    if not show_filtered:
        if marginal_mode == "kde":
            y_hist_grid, y_hist_dens = _plot_kde(ax_tl, y_all, ymin_data, ymax_data, color=colors_for_y_axis_colorscale[-1], label=None)
        else:
            y_hist = _plot_hist(ax_tl, y_all, ymin_data, ymax_data, color=None, label=None, vertical=True)
        ax_tl.set_xlim(ymin_plot, ymax_plot)
        if hist_max_count is not None and marginal_mode=="hist":
            ax_tl.set_ylim(0, float(hist_max_count))
        ax_tl.set_ylabel("count" if marginal_mode=="hist" else "density"); ax_tl.set_xlabel(ycol)
    else:
        # per-group overlays
        if condition_colors is None:
            raise ValueError("Provide condition_colors matching condition_filters lengths.")
        group_masks = _make_group_masks(df, condition_filters)
        for (mask, label), color in zip(group_masks, condition_colors):
            ys = y_all[mask]
            if ys.size == 0: continue
            if marginal_mode == "kde":
                _plot_kde(ax_tl, ys, ymin_data, ymax_data, color=color, label=label)
            else:
                _plot_hist(ax_tl, ys, ymin_data, ymax_data, color=color, label=label, vertical=True)
        ax_tl.set_xlim(ymin_plot, ymax_plot)
        if hist_max_count is not None and marginal_mode=="hist":
            ax_tl.set_ylim(0, float(hist_max_count))
        ax_tl.set_ylabel("count" if marginal_mode=="hist" else "density"); ax_tl.set_xlabel(ycol)

    try: ax_tl.set_box_aspect(1)
    except: pass
    if overlay_hlines:
        for yv, col in overlay_hlines:
            ax_tl.axvline(yv, color=col, lw=2.0)

    # --- BR (X marginal) ---
    x_hist = x_hist_grid = x_hist_dens = None
    if not show_filtered:
        if marginal_mode == "kde":
            x_hist_grid, x_hist_dens = _plot_kde(ax_br, x_all, xmin_data, xmax_data, color=colors_for_x_axis_colorscale[-1], label=None)
        else:
            x_hist = _plot_hist(ax_br, x_all, xmin_data, xmax_data, color=None, label=None, vertical=False)
        ax_br.set_xlim(xmin_plot, xmax_plot)
        if hist_max_count is not None and marginal_mode=="hist":
            ax_br.set_ylim(0, float(hist_max_count))
        ax_br.set_xlabel(xcol); ax_br.set_ylabel("count" if marginal_mode=="hist" else "density")
    else:
        if condition_colors is None:
            raise ValueError("Provide condition_colors matching condition_filters lengths.")
        group_masks = _make_group_masks(df, condition_filters)
        for (mask, label), color in zip(group_masks, condition_colors):
            xs = x_all[mask]
            if xs.size == 0: continue
            if marginal_mode == "kde":
                _plot_kde(ax_br, xs, xmin_data, xmax_data, color=color, label=label)
            else:
                _plot_hist(ax_br, xs, xmin_data, xmax_data, color=color, label=label, vertical=False)
        ax_br.set_xlim(xmin_plot, xmax_plot)
        if hist_max_count is not None and marginal_mode=="hist":
            ax_br.set_ylim(0, float(hist_max_count))
        ax_br.set_xlabel(xcol); ax_br.set_ylabel("count" if marginal_mode=="hist" else "density")

    try: ax_br.set_box_aspect(1)
    except: pass
    if overlay_vlines:
        for xv, col in overlay_vlines:
            ax_br.axvline(xv, color=col, lw=2.0)

    # --- TR (scatter) ---
    if not show_filtered:
        ax_tr.scatter(x_all, y_all, s=point_size, c=rgb_full, edgecolor='none', alpha=alpha_pts)
        # optional single global KDE (legacy)
        if scatter_kde:
            _kde2d(ax_tr, x_all, y_all, (xmin_plot, xmax_plot), (ymin_plot, ymax_plot),
                   levels=scatter_kde_levels, color=scatter_kde_color,
                   alpha=scatter_kde_alpha, lw=scatter_kde_lw)
    else:
        group_masks = _make_group_masks(df, condition_filters)
        n_groups = len(group_masks)
        # n_groups = number of condition combinations you’re plotting
        if condition_colors_scatter_kde is None:
            kde_cols = ["#000000"] * n_groups    # default = all black
        else:
            if len(condition_colors_scatter_kde) != n_groups:
                raise ValueError("condition_colors_scatter_kde must have the same length as the number of groups.")
            kde_cols = list(condition_colors_scatter_kde)
        i = -1
        for (mask, label), color in zip(group_masks, condition_colors):
            i += 1
            xs = x_all[mask]; ys = y_all[mask]
            if xs.size == 0: continue
            ax_tr.scatter(xs, ys, s=point_size, c=[color], edgecolor='none', alpha=alpha_pts, label=label)
            if group_scatter_kde:
                _kde2d(ax_tr, xs, ys, (xmin_plot, xmax_plot), (ymin_plot, ymax_plot),
                       levels=group_kde_levels, color=kde_cols[i], alpha=group_kde_alpha, lw=group_kde_lw)

    ax_tr.set_xlim(xmin_plot, xmax_plot); ax_tr.set_ylim(ymin_plot, ymax_plot)
    ax_tr.set_xlabel(xcol); ax_tr.set_ylabel(ycol)
    try: ax_tr.set_box_aspect(1)
    except: pass

    # --- overlay summary points/lines (works in both modes) ---
    def _overlay_summary(df_sum, conds, cols, marker, s, edge, lw):
        masks = _make_group_masks(df_sum, conds)
        for (m, _), col in zip(masks, cols):
            sub = df_sum[m]
            if sub.empty: continue
            # scatter markers
            ax_tr.scatter(sub[xcol], sub[ycol], s=s, marker=marker, color=col,
                          edgecolor=edge, linewidth=lw, zorder=5)
            # vlines and hlines for each row
            for _, row in sub.iterrows():
                xv = float(row[xcol]); yv = float(row[ycol])
                ax_br.axvline(xv, color=col, lw=2.0, alpha=0.9)
                ax_tl.axvline(yv, color=col, lw=2.0, alpha=0.9)

    if overlay_summary_df is not None and overlay_summary_conditions is not None and overlay_summary_colors is not None:
        _overlay_summary(overlay_summary_df, overlay_summary_conditions,
                         overlay_summary_colors, overlay_summary_marker,
                         overlay_summary_size, overlay_summary_edge, overlay_summary_lw)

    # --- legend panel ---
    ax_bl.axis("off")
    handles, labels = ax_tr.get_legend_handles_labels()
    if handles:
        ax_bl.legend(handles, labels, frameon=False, loc="center")

    if main_title:
        fig.suptitle(main_title, y=0.98, fontsize=12)

    # CSVs (unchanged: saves the full scatter data and the *single* marginal used in full mode)
    if csv_prefix:
        _save_pairplot_grid_csvs(
            csv_prefix, mode=marginal_mode, x_label=xcol, y_label=ycol,
            x=x_all, y=y_all,
            x_grid=None, x_kde=None, y_grid=None, y_kde=None,
            x_hist=None, y_hist=None,
            overlay_points=overlay_points
        )

    if savepath:
        plt.savefig(savepath, dpi=180)
    plt.show()


In [None]:
# %% FIGURE : posterior predictive simulations for selected conditions

FIGURE_param_x_axis = 'common_input_std'
FIGURE_param_y_axis = 'disynpatic_inhib_connections_desired_MN_MN'
FIGURE_observable_feature_x_axis = 'peak_height_mean'
FIGURE_observable_feature_y_axis = 'trough_area_mean'
FIGURE_high_freq_CI_scale_color_list = ["#FFE96F","#FF9A3B"] # x axis (both for parameter and observed features)
FIGURE_RI_strength_scale_color_list = [ "#70D7FF",  "#323CCF"] # y axis (both for parameter and observed features)
FIGURE_base_color = "#979797"
FIGURE_color_blend = "additive"
FIGURE_scatter_point_size = 30
FIGURE_alpha_scatter_experimental = 0.3
FIGURE_alpha_scatter_simulation = 0.1
colors_dict = {
  "VL<->VL": "#D62728",
  "VL<->VM": "#FF9201",
  "VM<->VL": "#FF9201",
  "VM<->VM": "#FFC400",
  "TA<->TA": "#00C71B",
  "FDI<->FDI": "#14BFA8",
  "GM<->GM": "#2489DC",
  "GM<->SOL": "#7D74EC",
  "SOL<->GM": "#7D74EC",
  "SOL<->SOL": "#BB86ED",

  "VL": "#D62728",
  "VM": "#FFC400",
  "TA": "#00C71B",
  "FDI": "#14BFA8",
  "GM": "#2489DC",
  "SOL": "#BB86ED",
}
FIGURE_sampled_experiment_example_conditions = {"muscle_pair": ["GM<->GM","VM<->VM","FDI<->FDI"],
                                                "intensity": [40,10,40]}
FIGURE_sampled_experiment_example_colors = ["#2489DC", "#FFC400", "#14BFA8"] # Select according to muscles in FIGURE_sampled_experiment_observed_features_vector_conditions
FIGURE_sampled_experiment_example_colors_darker = ["#1D56BE", "#F08800", "#008B79"]
# FIGURE_sampled_experiment_example_conditions = {"muscle_pair": ["FDI<->FDI","TA<->TA","SOL<->SOL","GM<->GM","VL<->VL","VM<->VM"],
#                                                 # "intensity": [10, 10, 10, 10, 10, 10]}
#                                                 "intensity": [40, 40, 40, 40, 40, 40]}
# FIGURE_sampled_experiment_example_colors = ["#14BFA8", "#00C71B", "#BB86ED", "#2489DC", "#D62728", "#FFC400", ] # Select according to muscles in FIGURE_sampled_experiment_observed_features_vector_conditions
# FIGURE_sampled_experiment_example_colors_darker = ["#008B79", "#1E7712", "#9936B8", "#1D56BE", "#9C0000", "#F08800", ]
FIGURE_kde_alpha_min_max = [1, 1]
FIG_SAVE_DIR = path_to_save_into


In [None]:
# %% FIGURE : posterior predictive simulations for selected conditions posterior predictive sims (per-condition), optional per-sim averaging
# ----- USER TOGGLES -----
AGGREGATE_SIM_MEANS = True   # True: average across MUs per simulation before plotting
UNSTANDARDIZE_FOR_PLOT = True  # True: convert standardized values back to original scales for the two axes
# Optional custom ranges for display/binning (None = auto from data below)
X_RANGE = None     # e.g. (0, 0.6)
Y_RANGE = None     # e.g. (0, 2.5)
XLIM    = None     # e.g. (-0.04, 0.64)
YLIM    = None     # e.g. (-0.10, 2.60)

# ----- Axes / inputs (same names you used before) -----
x_feat_mean = FIGURE_observable_feature_x_axis          # e.g. 'sync_height_mean'
y_feat_mean = FIGURE_observable_feature_y_axis          # e.g. 'inhibited_by_estimation_raw_mean'
x_feat_raw  = x_feat_mean.replace("_mean","")           # e.g. 'sync_height'
y_feat_raw  = y_feat_mean.replace("_mean","")           # e.g. 'inhibited_by_estimation_raw'

df_ppc_all   = df_obs_simulated_from_posterior_samples.copy()
df_hl_all    = df_obs_simulated_from_posterior_sample_highest_likelihood.copy()
df_hl_all['sim_idx'] = 0

# Sanity: needed columns
for need_col in (x_feat_raw, y_feat_raw, "muscle_pair", "intensity"):
    if need_col not in df_ppc_all.columns:
        raise KeyError(f"df_obs_simulated_from_posterior_samples missing column: {need_col}")
    if need_col not in df_hl_all.columns:
        raise KeyError(f"df_obs_simulated_from_posterior_sample_highest_likelihood missing column: {need_col}")

# Ensure 'intensity' comparable (cast to int when safe)
def _safe_cast_int_inplace(df, col="intensity"):
    if col in df.columns and pd.api.types.is_numeric_dtype(df[col]):
        arr = df[col].to_numpy(dtype=float, copy=False)
        if np.all(np.isfinite(arr)) and np.all(np.isclose(arr % 1, 0, atol=1e-9)):
            df[col] = arr.astype(int)

_safe_cast_int_inplace(df_ppc_all, "intensity")
_safe_cast_int_inplace(df_hl_all, "intensity")

# ----- Unstandardize (only the two axes) using exp_stats computed from PRIOR -----
def _unstandardize_cols(df, cols, stats):
    df = df.copy()
    for c in cols:
        if c not in stats:
            raise KeyError(f"exp_stats missing key for '{c}'")
        mu = float(stats[c]["mean"])
        sd = float(stats[c]["std"])
        df[c] = df[c] * sd + mu
    return df

if UNSTANDARDIZE_FOR_PLOT:
    df_ppc_all = _unstandardize_cols(df_ppc_all, [x_feat_raw, y_feat_raw], exp_stats)
    df_hl_all  = _unstandardize_cols(df_hl_all,  [x_feat_raw, y_feat_raw], exp_stats)

# ----- Optional: average across MUs per simulation (for main PPC scatter) -----
def _aggregate_per_sim(df, value_cols, prefer_group_order=None):
    """
    Average value_cols per simulation. We assemble group columns from typical keys if present.
    """
    df = df.copy()
    # typical keys that define a unique "simulation":
    default_order = ["subject", "muscle_pair", "intensity", "condition", "sim_idx"]
    if prefer_group_order is None: prefer_group_order = default_order
    group_cols = [c for c in prefer_group_order if c in df.columns]
    if not group_cols:
        # fallback: at least condition/muscle_pair/intensity if available
        for c in ["condition", "muscle_pair", "intensity"]:
            if c in df.columns and c not in group_cols:
                group_cols.append(c)
    # average only requested value cols
    g = df.groupby(group_cols, as_index=False)[value_cols].mean()
    return g

if AGGREGATE_SIM_MEANS:
    df_ppc_main = _aggregate_per_sim(df_ppc_all, [x_feat_raw, y_feat_raw])
else:
    df_ppc_main = df_ppc_all

# For overlays (highest-likelihood): collapse to a single point per condition
df_hl_overlay = _aggregate_per_sim(df_hl_all, [x_feat_raw, y_feat_raw])

# ----- Condition filters / colors -----
cond_filters = FIGURE_sampled_experiment_example_conditions  # e.g. {"muscle_pair": [...], "intensity":[...]}
cond_colors  = FIGURE_sampled_experiment_example_colors
try:
    kde_cols = FIGURE_sampled_experiment_example_colors_darker
except NameError:
    # simple darken fallback
    def _darken_hex(h, factor=0.65):
        h = h.lstrip("#")
        r,g,b = int(h[0:2],16), int(h[2:4],16), int(h[4:6],16)
        r = max(0, min(255, int(r*factor))); g = max(0, min(255, int(g*factor))); b = max(0, min(255, int(b*factor)))
        return f"#{r:02X}{g:02X}{b:02X}"
    kde_cols = [_darken_hex(c) for c in cond_colors]

# ----- Optional per-axis gradient scales (harmless here; used in "full_dataset" mode) -----
kwargs_axis_scales = {}
try:
    kwargs_axis_scales["colors_for_x_axis_colorscale"] = FIGURE_high_freq_CI_scale_color_list
    kwargs_axis_scales["colors_for_y_axis_colorscale"] = FIGURE_RI_strength_scale_color_list
except NameError:
    pass

# ----- Auto ranges if not provided -----
def _auto_range(df, col, pad=0.02):
    a = df[col].to_numpy(dtype=float)
    a = a[np.isfinite(a)]
    if a.size == 0: return (0.0, 1.0)
    lo, hi = np.min(a), np.max(a)
    if hi <= lo:
        return (lo - 0.5, hi + 0.5)
    span = hi - lo
    return (lo - pad*span, hi + pad*span)

x_range = X_RANGE if X_RANGE is not None else _auto_range(df_ppc_main, x_feat_raw, pad=0.00)
y_range = Y_RANGE if Y_RANGE is not None else _auto_range(df_ppc_main, y_feat_raw, pad=0.00)
xlim    = XLIM    if XLIM    is not None else _auto_range(df_ppc_main, x_feat_raw, pad=0.05)
ylim    = YLIM    if YLIM    is not None else _auto_range(df_ppc_main, y_feat_raw, pad=0.05)

# ----- Build overlay DF in the format expected by pairplot (x/y columns must match main’s xcol/ycol)
overlay_df = df_hl_overlay.rename(columns={x_feat_raw: x_feat_raw, y_feat_raw: y_feat_raw})

# ----- Plot -----
FIG_PPC_PATH  = os.path.join(FIG_SAVE_DIR, "_posterior_predictive_checks_features_examples.svg")
FIG_PPC_CSVP  = os.path.join(FIG_SAVE_DIR, "_posterior_predictive_checks_features_examples")

pairplot_grid_marginals(
    df_ppc_main, x_feat_raw, y_feat_raw,
    dataset_mode="experimental_condition_filter",
    main_title=("FIG 6 — Posterior predictive simulations"
                f" ({'per-simulation means' if AGGREGATE_SIM_MEANS else 'per-MU samples'})"),
    condition_filters=cond_filters,
    condition_colors=cond_colors,
    condition_colors_scatter_kde=kde_cols,


    # binning/KDE ranges and display limits
    # x_range=x_range, y_range=y_range,
    # xlim=xlim, ylim=ylim,
    # x_range=(0, 0.7), y_range=(0, 5),        # for binning
    # xlim=(-0.05, 0.75), ylim=(-0.2, 5.2),    # for display
    x_range=(0, 0.6), y_range=(0, 2.5),        # for binning
    xlim=(-0.04, 0.64), ylim=(-0.1, 2.6),    # for display

    # visuals
    point_size=FIGURE_scatter_point_size,
    alpha_pts=FIGURE_alpha_scatter_experimental * 1.5,

    # 1D KDE marginals (per group) + per-group 2D KDE on scatter
    marginal_mode="kde",
    n_grid=120,
    kde_boundary="none",
    group_scatter_kde=True,
    group_kde_levels=(0.30, 0.60, 0.90),
    group_kde_alpha=1.0,
    group_kde_lw=1.4,

    # Overlay the highest-likelihood simulation(s) as markers + v/h lines
    overlay_summary_df=overlay_df,
    overlay_summary_conditions=cond_filters,       # same selections
    overlay_summary_colors=cond_colors,
    overlay_summary_marker="X",
    overlay_summary_size=115,
    overlay_summary_edge="k",
    overlay_summary_lw=1.0,

    savepath=FIG_PPC_PATH,
    csv_prefix=FIG_PPC_CSVP,
    **kwargs_axis_scales
)

print("Saved:", FIG_PPC_PATH)


# Posterior Predictive Checks in the first 2 PCs of the full feature space (4 features x 4 summary statistics = 12d space)
### (even though lots of colinearity within the space)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgb
from tqdm.auto import tqdm

# ----------------------- CONFIG -----------------------
GROUP_COLS = ["muscle_pair", "intensity"]   # columns present in all dfs

# Visuals
colors_dict = {
  "VL<->VL": "#D62728", "VL<->VM": "#FF9201", "VM<->VL": "#FF9201", "VM<->VM": "#FFC400",
  "TA<->TA": "#00C71B", "FDI<->FDI": "#14BFA8", "GM<->GM": "#2489DC",
  "GM<->SOL": "#7D74EC", "SOL<->GM": "#7D74EC", "SOL<->SOL": "#BB86ED",
  "VL": "#D62728", "VM": "#FFC400", "TA": "#00C71B", "FDI": "#14BFA8",
  "GM": "#2489DC", "SOL": "#BB86ED",
}

# Whether to compute distances in whitened space (Mahalanobis via prior Σ)
WHITEN_FOR_METRICS = False

# ------------------------------------------------------

# Configure the summary stat names once so everything downstream stays in sync
SUMMARY_STATS = ["mean", "std", "median", "iqr"]
FEATURE_PREFIXES = ["trough_area", "peak_height", "firing_rate", "IPSP_delay"]

def pick_feature_columns(df: pd.DataFrame,
                         prefixes=FEATURE_PREFIXES,
                         stats=SUMMARY_STATS):
    """Return exactly <prefix>_<stat> columns in a consistent order."""
    cols = []
    for p in prefixes:
        for s in stats:
            name = f"{p}_{s}"
            if name not in df.columns:
                raise KeyError(f"Expected column '{name}' not found in DataFrame.")
            cols.append(name)
    return cols

# Select 16-D feature columns
feat_cols = pick_feature_columns(df_summary_obs_simulated_from_prior,
                                 prefixes=FEATURE_PREFIXES,
                                 stats=SUMMARY_STATS)

# Optional guard:
expected_n = len(FEATURE_PREFIXES) * len(SUMMARY_STATS)  # 4 * 4 = 16
assert len(feat_cols) == expected_n

def make_whitener_from_prior(df_prior_feats: pd.DataFrame):
    """Return (mean, W) so that x_whitened = (x-mean) @ W^T has identity covariance on prior."""
    X = df_prior_feats.to_numpy().astype(float)
    mu = X.mean(axis=0)
    Xc = X - mu
    # covariance (features x features)
    Σ = np.cov(Xc, rowvar=False)
    # eigen-decomp
    evals, evecs = np.linalg.eigh(Σ)
    # guard tiny/neg eigenvalues
    eps = 1e-12
    evals = np.clip(evals, eps, None)
    W = (evecs @ np.diag(1.0/np.sqrt(evals)) @ evecs.T)  # whitening transform on columns
    return mu, W

def whiten_rows(X: np.ndarray, mu: np.ndarray, W: np.ndarray):
    """Apply whitening to rows of X: (X - mu) @ W^T."""
    return (X - mu) @ W.T

def rmse(a, b): return float(np.sqrt(np.mean((a - b)**2)))
def r2_from_vectors(pred, true, baseline_mean=None):
    """
    R² over features for a *single* vector comparison:
      R² = 1 - SSE/SST, with SST vs baseline.
    For z-scored vs prior, baseline_mean=0 is reasonable.
    """
    pred = np.asarray(pred); true = np.asarray(true)
    sse = np.sum((pred - true)**2)
    if baseline_mean is None:
        baseline_mean = 0.0
    sst = np.sum((true - baseline_mean)**2)
    return float(1.0 - sse/max(sst, 1e-12))

def silverman_bandwidth_2d(X):
    """Silverman's rule-of-thumb bandwidth matrix factor (scalar) for 2D."""
    n, d = X.shape
    # scalar bandwidth h ~ n^{-1/(d+4)}
    h = (n ** (-1.0/(d+4.0)))
    return h

def kde2d_grid(points, bw_scale=1.0, grid_n=150, margin=0.1):
    """
    Simple 2D Gaussian KDE on a grid.
    Returns (Xgrid, Ygrid, Z) where Z is density. Uses isotropic h * cov^(1/2) scaling.
    """
    X = np.asarray(points, dtype=float)
    n, d = X.shape
    assert d == 2 and n >= 2

    xmin, ymin = X.min(axis=0)
    xmax, ymax = X.max(axis=0)
    dx, dy = xmax - xmin, ymax - ymin
    # margins
    xmin -= margin*dx; xmax += margin*dx
    ymin -= margin*dy; ymax += margin*dy

    xs = np.linspace(xmin, xmax, grid_n)
    ys = np.linspace(ymin, ymax, grid_n)
    Xg, Yg = np.meshgrid(xs, ys)

    # covariance & bandwidth
    Σ = np.cov(X, rowvar=False)
    # isotropic scalar h
    h0 = silverman_bandwidth_2d(X) * bw_scale
    # effective covariance for kernel
    Σk = Σ * (h0**2)
    # precompute inverse and norm factor
    invΣ = np.linalg.inv(Σk)
    norm = 1.0 / (2.0*np.pi*np.sqrt(np.linalg.det(Σk)))

    # evaluate density
    Z = np.zeros_like(Xg)
    for k in range(n):
        diff = np.stack([Xg - X[k,0], Yg - X[k,1]], axis=-1)  # (..., 2)
        # quadratic form
        q = diff @ invΣ
        q = q[...,0]*diff[...,0] + q[...,1]*diff[...,1]
        Z += np.exp(-0.5*q)
    Z *= (norm / n)
    return xs, ys, Z

def hdr_levels_from_density(Z, levels=(0.9, 0.6, 0.3)):
    """
    Given density grid Z, return *density* cutoffs whose superlevel sets
    contain given probability masses (highest-density regions).
    """
    z = Z.ravel()
    order = np.argsort(z)[::-1]
    z_sorted = z[order]
    csum = np.cumsum(z_sorted)
    total = csum[-1]
    # each cell area cancels out for threshold computation (uniform grid)
    cuts = []
    for alpha in levels:
        idx = np.searchsorted(csum, alpha*total)
        idx = np.clip(idx, 0, z_sorted.size-1)
        cuts.append(z_sorted[idx])
    return cuts  # density thresholds


In [None]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

def dropna_and_report(df: pd.DataFrame, feat_cols, name: str):
    """Drop rows with any NaN in feat_cols and report how many were removed."""
    n0 = len(df)
    df_clean = df.dropna(subset=feat_cols)
    n1 = len(df_clean)
    dropped = n0 - n1
    if dropped > 0:
        print(f"[{name}] Dropped {dropped} row(s) with NaNs in feature columns (kept {n1}/{n0}).")
    else:
        print(f"[{name}] No NaNs found in feature columns (kept {n1}/{n0}).")
    return df_clean

# Clean the PRIOR only (this is the training distribution for PCA)
df_prior_clean = dropna_and_report(df_summary_obs_simulated_from_prior, feat_cols, name="PRIOR")

# Fit PCA on the cleaned prior (16-D features)
X_prior = df_prior_clean[feat_cols].to_numpy()
pca = PCA(n_components=min(X_prior.shape[0], X_prior.shape[1]), svd_solver="full")
pca.fit(X_prior)

# Helper to project any dataframe onto first two PCs (no re-fitting)
def project_to_pcs(df, feat_cols, pca_obj, n=2):
    X = df[feat_cols].to_numpy()
    return pca_obj.transform(X)[:, :n]

# Cumulative variance explained plot (starts at 0 PCs)
evr = pca.explained_variance_ratio_
cum = np.concatenate([[0.0], np.cumsum(evr)])
xs  = np.arange(cum.size)  # 0..k

plt.figure(figsize=(6, 3.5))
plt.plot(xs, cum, lw=2)
plt.scatter(xs, cum, s=18)
plt.xticks(xs)
plt.ylim(0, 1.01)
plt.xlabel("# Principal Components")
plt.ylabel("Cumulative variance explained")
plt.title("PCA on prior-generated simulations (NaN rows dropped)")
plt.grid(alpha=0.25)
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

# ===========================
# Config
# ===========================
EXCLUDE_PAIRS = ["VL<->VM","GM<->SOL"]          # e.g., ["VL<->VL", "GM<->SOL"]
GROUP_COLS = ["muscle_pair", "intensity"]       # condition identity
assert 'feat_cols' in globals() and isinstance(feat_cols, (list, tuple)) and len(feat_cols) > 0

# ===========================
# Helpers
# ===========================
def rmse(a, b):
    a = np.asarray(a, dtype=float); b = np.asarray(b, dtype=float)
    return float(np.sqrt(np.mean((a - b)**2)))

def r2(pred, true):
    pred = np.asarray(pred, dtype=float); true = np.asarray(true, dtype=float)
    ss_res = np.sum((pred - true)**2)
    ss_tot = np.sum((true - true.mean())**2)
    return float(1.0 - ss_res/ss_tot) if ss_tot > 0 else np.nan

def coerce_numeric(df, cols):
    out = df.copy()
    out[cols] = out[cols].apply(pd.to_numeric, errors='coerce')
    return out

def clean_df(df, feat_cols, exclude_pairs):
    out = coerce_numeric(df, feat_cols)
    if "muscle_pair" in out.columns and exclude_pairs:
        out = out[~out["muscle_pair"].isin(exclude_pairs)]
    before = len(out)
    out = out.dropna(subset=feat_cols)
    dropped = before - len(out)
    if dropped > 0:
        print(f"[info] Dropped {dropped} rows with NaNs in features.")
    return out

def df_to_map_many(df, feat_cols, group_cols=GROUP_COLS):
    m = {}
    for k, g in df.groupby(group_cols, dropna=False):
        m[k] = g[feat_cols].to_numpy(dtype=float)
    return m

def df_to_map_one(df, feat_cols, group_cols=GROUP_COLS):
    m = {}
    for k, g in df.groupby(group_cols, dropna=False):
        if len(g) != 1:
            print(f"[warn] Expected 1 row for {k}, found {len(g)}. Taking the first.")
        m[k] = g.iloc[0][feat_cols].to_numpy(dtype=float)
    return m

def fmt(mu, sd, decimals=3):
    return f"{mu:.{decimals}f} ± {sd:.{decimals}f}"

def med_q1_q3(arr):
    arr = np.asarray(arr, dtype=float)
    return (float(np.nanmedian(arr)),
            float(np.nanpercentile(arr, 25)),
            float(np.nanpercentile(arr, 75)))

def fmt_med_iqr(med, q1, q3, decimals=3):
    return f"{med:.{decimals}f} [{q1:.{decimals}f}–{q3:.{decimals}f}]"

def fmt_min_max(vmin, vmax, decimals=3):
    return f"{vmin:.{decimals}f}, {vmax:.{decimals}f}"

def pool_concat(list_of_arrays):
    if len(list_of_arrays) == 0:
        return np.array([], dtype=float)
    return np.concatenate(list_of_arrays, axis=0)

# ===========================
# Clean inputs
# ===========================
df_post  = clean_df(df_summary_obs_simulated_from_posterior,      feat_cols, EXCLUDE_PAIRS)
df_best  = clean_df(df_summary_obs_simulated_highest_likelihood,  feat_cols, EXCLUDE_PAIRS)
df_exp   = clean_df(df_summary_obs_experiment,                    feat_cols, EXCLUDE_PAIRS)

# PRIOR: compare each experimental condition to ALL prior rows
df_prior = coerce_numeric(df_summary_obs_simulated_from_prior, feat_cols).dropna(subset=feat_cols)
prior_all = df_prior[feat_cols].to_numpy(dtype=float)   # (N_prior, F)

# Build maps
post_map  = df_to_map_many(df_post, feat_cols)   # many samples / condition
best_map  = df_to_map_one(df_best, feat_cols)    # one / condition
exp_map   = df_to_map_one(df_exp,  feat_cols)    # one / condition

conditions = sorted(exp_map.keys())

# ===========================
# Per-condition metrics + pooled accumulators
# ===========================
rows = []
missing = {'posterior': [], 'best': []}

pooled_rmse_post = []   # arrays (sample-level) across conditions
pooled_r2_post   = []
pooled_rmse_prior = []
pooled_r2_prior   = []
pooled_rmse_oo = []
pooled_r2_oo   = []
pooled_rmse_best = []   # scalars (one per condition)
pooled_r2_best   = []

for cond in tqdm(conditions, desc="Summarizing per-condition"):
    y_true = exp_map[cond]              # (F,)

    # ---- Posterior vs EXP
    if cond in post_map:
        Ypost = post_map[cond].astype(float, copy=False)  # (n_post, F)
        diffs = Ypost - y_true[None, :]
        rmse_post_samples = np.sqrt(np.mean(diffs**2, axis=1)).astype(float)
        sst = np.sum((y_true - y_true.mean())**2)
        if sst > 0:
            r2_post_samples = (1.0 - np.sum(diffs**2, axis=1) / sst).astype(float)
        else:
            r2_post_samples = np.full(Ypost.shape[0], np.nan, dtype=float)

        # μ±σ
        rmse_post_mu, rmse_post_sd = float(np.nanmean(rmse_post_samples)), float(np.nanstd(rmse_post_samples, ddof=0))
        r2_post_mu,   r2_post_sd   = float(np.nanmean(r2_post_samples)),   float(np.nanstd(r2_post_samples,   ddof=0))
        # med [Q1–Q3]
        rmse_post_med, rmse_post_q1, rmse_post_q3 = med_q1_q3(rmse_post_samples)
        r2_post_med,   r2_post_q1,   r2_post_q3   = med_q1_q3(r2_post_samples)

        pooled_rmse_post.append(rmse_post_samples)
        pooled_r2_post.append(r2_post_samples)
    else:
        rmse_post_mu = rmse_post_sd = np.nan
        r2_post_mu   = r2_post_sd   = np.nan
        rmse_post_med = rmse_post_q1 = rmse_post_q3 = np.nan
        r2_post_med   = r2_post_q1   = r2_post_q3   = np.nan
        missing['posterior'].append(cond)

    # ---- Best posterior (argmax logp) vs EXP
    if cond in best_map:
        y_best = best_map[cond].astype(float, copy=False)
        rmse_best_val = rmse(y_best, y_true)
        r2_best_val   = r2(y_best, y_true)
        pooled_rmse_best.append(rmse_best_val)
        pooled_r2_best.append(r2_best_val)
        # represent as degenerate median/IQR
        rmse_best_med = rmse_best_q1 = rmse_best_q3 = rmse_best_val
        r2_best_med   = r2_best_q1   = r2_best_q3   = r2_best_val
    else:
        rmse_best_val = r2_best_val = np.nan
        rmse_best_med = rmse_best_q1 = rmse_best_q3 = np.nan
        r2_best_med   = r2_best_q1   = r2_best_q3   = np.nan
        missing['best'].append(cond)

    # ---- Prior vs EXP
    diffs_pr = prior_all - y_true[None, :]
    rmse_pr_samples = np.sqrt(np.mean(diffs_pr**2, axis=1)).astype(float)
    sst = np.sum((y_true - y_true.mean())**2)
    if sst > 0:
        r2_pr_samples = (1.0 - np.sum(diffs_pr**2, axis=1) / sst).astype(float)
    else:
        r2_pr_samples = np.full(prior_all.shape[0], np.nan, dtype=float)

    rmse_prior_mu, rmse_prior_sd = float(np.nanmean(rmse_pr_samples)), float(np.nanstd(rmse_pr_samples, ddof=0))
    r2_prior_mu,   r2_prior_sd   = float(np.nanmean(r2_pr_samples)),   float(np.nanstd(r2_pr_samples,   ddof=0))
    rmse_prior_med, rmse_prior_q1, rmse_prior_q3 = med_q1_q3(rmse_pr_samples)
    r2_prior_med,   r2_prior_q1,   r2_prior_q3   = med_q1_q3(r2_pr_samples)

    # ---- EXP vs other EXP
    yjs = [exp_map[c] for c in conditions if c != cond]
    if yjs:
        yjs = np.vstack(yjs).astype(float, copy=False)
        diffs_oo = yjs - y_true[None, :]
        rmse_oo_samples = np.sqrt(np.mean(diffs_oo**2, axis=1)).astype(float)
        if sst > 0:
            r2_oo_samples = (1.0 - np.sum(diffs_oo**2, axis=1) / sst).astype(float)
        else:
            r2_oo_samples = np.full(len(yjs), np.nan, dtype=float)

        rmse_oo_mu, rmse_oo_sd = float(np.nanmean(rmse_oo_samples)), float(np.nanstd(rmse_oo_samples, ddof=0))
        r2_oo_mu,   r2_oo_sd   = float(np.nanmean(r2_oo_samples)),   float(np.nanstd(r2_oo_samples,   ddof=0))
        rmse_oo_med, rmse_oo_q1, rmse_oo_q3 = med_q1_q3(rmse_oo_samples)
        r2_oo_med,   r2_oo_q1,   r2_oo_q3   = med_q1_q3(r2_oo_samples)

        pooled_rmse_oo.append(rmse_oo_samples)
        pooled_r2_oo.append(r2_oo_samples)
    else:
        rmse_oo_mu = rmse_oo_sd = np.nan
        r2_oo_mu   = r2_oo_sd   = np.nan
        rmse_oo_med = rmse_oo_q1 = rmse_oo_q3 = np.nan
        r2_oo_med   = r2_oo_q1   = r2_oo_q3   = np.nan

    pair, inten = cond
    rows.append({
        "muscle_pair": pair,
        "intensity": inten,

        # Posterior (μ±σ and med[IQR])
        "RMSE_post (μ±σ)": fmt(rmse_post_mu, rmse_post_sd),
        "RMSE_post (med [Q1–Q3])": fmt_med_iqr(rmse_post_med, rmse_post_q1, rmse_post_q3),
        "R²_post (μ±σ)"  : fmt(r2_post_mu,   r2_post_sd),
        "R²_post (med [Q1–Q3])": fmt_med_iqr(r2_post_med, r2_post_q1, r2_post_q3),

        # Best (single value + degenerate med[IQR])
        "RMSE_best": f"{rmse_best_val:.3f}",
        "RMSE_best (med [Q1–Q3])": fmt_med_iqr(rmse_best_med, rmse_best_q1, rmse_best_q3),
        "R²_best":   f"{r2_best_val:.3f}",
        "R²_best (med [Q1–Q3])": fmt_med_iqr(r2_best_med, r2_best_q1, r2_best_q3),

        # Prior (μ±σ and med[IQR])
        "RMSE_prior (μ±σ)": fmt(rmse_prior_mu, rmse_prior_sd),
        "RMSE_prior (med [Q1–Q3])": fmt_med_iqr(rmse_prior_med, rmse_prior_q1, rmse_prior_q3),
        "R²_prior (μ±σ)"  : fmt(r2_prior_mu,   r2_prior_sd),
        "R²_prior (med [Q1–Q3])": fmt_med_iqr(r2_prior_med, r2_prior_q1, r2_prior_q3),

        # EXP vs others (μ±σ and med[IQR])
        "RMSE_exp-others (μ±σ)": fmt(rmse_oo_mu, rmse_oo_sd),
        "RMSE_exp-others (med [Q1–Q3])": fmt_med_iqr(rmse_oo_med, rmse_oo_q1, rmse_oo_q3),
        "R²_exp-others (μ±σ)"  : fmt(r2_oo_mu,   r2_oo_sd),
        "R²_exp-others (med [Q1–Q3])": fmt_med_iqr(r2_oo_med, r2_oo_q1, r2_oo_q3),

        # keep raw numbers for your "AVERAGE" row
        "_rmse_post_mu": rmse_post_mu, "_rmse_post_sd": rmse_post_sd,
        "_r2_post_mu":   r2_post_mu,   "_r2_post_sd":   r2_post_sd,
        "_rmse_best":    rmse_best_val, "_r2_best":     r2_best_val,
        "_rmse_prior_mu": rmse_prior_mu, "_rmse_prior_sd": rmse_prior_sd,
        "_r2_prior_mu":   r2_prior_mu,   "_r2_prior_sd":   r2_prior_sd,
        "_rmse_oo_mu": rmse_oo_mu, "_rmse_oo_sd": rmse_oo_sd,
        "_r2_oo_mu":   r2_oo_mu,   "_r2_oo_sd":   r2_oo_sd,
    })

summary_table = pd.DataFrame(rows).sort_values(by=["muscle_pair", "intensity"]).reset_index(drop=True)

# ===========================
# POOLED rows
# ===========================
pool_post_rmse  = pool_concat(pooled_rmse_post)
pool_post_r2    = pool_concat(pooled_r2_post)
pool_prior_rmse = pool_concat(pooled_rmse_prior)
pool_prior_r2   = pool_concat(pooled_r2_prior)
pool_oo_rmse    = pool_concat(pooled_rmse_oo)
pool_oo_r2      = pool_concat(pooled_r2_oo)
pool_best_rmse  = np.asarray(pooled_rmse_best, dtype=float) if pooled_rmse_best else np.array([], dtype=float)
pool_best_r2    = np.asarray(pooled_r2_best,   dtype=float) if pooled_r2_best   else np.array([], dtype=float)

# pooled means ± sd
row_pooled = {
    "muscle_pair": "— POOLED (samples) —",
    "intensity": "",
    "RMSE_post (μ±σ)": fmt(np.nanmean(pool_post_rmse),  np.nanstd(pool_post_rmse,  ddof=0)) if pool_post_rmse.size  else "nan ± nan",
    "R²_post (μ±σ)"  : fmt(np.nanmean(pool_post_r2),    np.nanstd(pool_post_r2,    ddof=0)) if pool_post_r2.size    else "nan ± nan",
    "RMSE_best"      : fmt(np.nanmean(pool_best_rmse),  np.nanstd(pool_best_rmse,  ddof=0)) if pool_best_rmse.size  else "nan ± nan",
    "R²_best"        : fmt(np.nanmean(pool_best_r2),    np.nanstd(pool_best_r2,    ddof=0)) if pool_best_r2.size    else "nan ± nan",
    "RMSE_prior (μ±σ)": fmt(np.nanmean(pool_prior_rmse), np.nanstd(pool_prior_rmse, ddof=0)) if pool_prior_rmse.size else "nan ± nan",
    "R²_prior (μ±σ)"  : fmt(np.nanmean(pool_prior_r2),   np.nanstd(pool_prior_r2,   ddof=0)) if pool_prior_r2.size   else "nan ± nan",
    "RMSE_exp-others (μ±σ)": fmt(np.nanmean(pool_oo_rmse), np.nanstd(pool_oo_rmse, ddof=0)) if pool_oo_rmse.size else "nan ± nan",
    "R²_exp-others (μ±σ)"  : fmt(np.nanmean(pool_oo_r2),   np.nanstd(pool_oo_r2,   ddof=0)) if pool_oo_r2.size   else "nan ± nan",
}

# pooled med[IQR]
row_pooled_median = {
    "muscle_pair": "— POOLED (med [Q1–Q3]) —",
    "intensity": "",
    "RMSE_post (med [Q1–Q3])": fmt_med_iqr(*med_q1_q3(pool_post_rmse))  if pool_post_rmse.size  else "nan [nan–nan]",
    "R²_post (med [Q1–Q3])"  : fmt_med_iqr(*med_q1_q3(pool_post_r2))    if pool_post_r2.size    else "nan [nan–nan]",
    "RMSE_best (med [Q1–Q3])": fmt_med_iqr(*med_q1_q3(pool_best_rmse))  if pool_best_rmse.size  else "nan [nan–nan]",
    "R²_best (med [Q1–Q3])"  : fmt_med_iqr(*med_q1_q3(pool_best_r2))    if pool_best_r2.size    else "nan [nan–nan]",
    "RMSE_prior (med [Q1–Q3])": fmt_med_iqr(*med_q1_q3(pool_prior_rmse)) if pool_prior_rmse.size else "nan [nan–nan]",
    "R²_prior (med [Q1–Q3])"  : fmt_med_iqr(*med_q1_q3(pool_prior_r2))   if pool_prior_r2.size   else "nan [nan–nan]",
    "RMSE_exp-others (med [Q1–Q3])": fmt_med_iqr(*med_q1_q3(pool_oo_rmse)) if pool_oo_rmse.size else "nan [nan–nan]",
    "R²_exp-others (med [Q1–Q3])"  : fmt_med_iqr(*med_q1_q3(pool_oo_r2))   if pool_oo_r2.size   else "nan [nan–nan]",
}

# pooled min/max (requested to appear below pooled means ± sd)
row_pooled_minmax = {
    "muscle_pair": "— POOLED (min/max) —",
    "intensity": "",
    "RMSE_post (min,max)": fmt_min_max(np.nanmin(pool_post_rmse),  np.nanmax(pool_post_rmse))   if pool_post_rmse.size  else "nan, nan",
    "R²_post (min,max)"  : fmt_min_max(np.nanmin(pool_post_r2),    np.nanmax(pool_post_r2))     if pool_post_r2.size    else "nan, nan",
    "RMSE_best (min,max)": fmt_min_max(np.nanmin(pool_best_rmse),  np.nanmax(pool_best_rmse))   if pool_best_rmse.size  else "nan, nan",
    "R²_best (min,max)"  : fmt_min_max(np.nanmin(pool_best_r2),    np.nanmax(pool_best_r2))     if pool_best_r2.size    else "nan, nan",
    "RMSE_prior (min,max)": fmt_min_max(np.nanmin(pool_prior_rmse), np.nanmax(pool_prior_rmse)) if pool_prior_rmse.size else "nan, nan",
    "R²_prior (min,max)"  : fmt_min_max(np.nanmin(pool_prior_r2),   np.nanmax(pool_prior_r2))   if pool_prior_r2.size   else "nan, nan",
    "RMSE_exp-others (min,max)": fmt_min_max(np.nanmin(pool_oo_rmse), np.nanmax(pool_oo_rmse)) if pool_oo_rmse.size else "nan, nan",
    "R²_exp-others (min,max)"  : fmt_min_max(np.nanmin(pool_oo_r2),   np.nanmax(pool_oo_r2))   if pool_oo_r2.size   else "nan, nan",
}

summary_table = pd.concat(
    [summary_table,
     pd.DataFrame([row_pooled]),
     pd.DataFrame([row_pooled_median]),
     pd.DataFrame([row_pooled_minmax])],
    ignore_index=True
)

# ===========================
# (Optional) keep your “AVERAGE” row (avg of per-condition means/SDs)
# ===========================
post_rmse_mu_avg = float(np.nanmean(summary_table.loc[summary_table["muscle_pair"].str.startswith("—") == False, "_rmse_post_mu"]))
post_rmse_sd_avg = float(np.nanmean(summary_table.loc[summary_table["muscle_pair"].str.startswith("—") == False, "_rmse_post_sd"]))
post_r2_mu_avg   = float(np.nanmean(summary_table.loc[summary_table["muscle_pair"].str.startswith("—") == False, "_r2_post_mu"]))
post_r2_sd_avg   = float(np.nanmean(summary_table.loc[summary_table["muscle_pair"].str.startswith("—") == False, "_r2_post_sd"]))

best_rmse_mu = float(np.nanmean(summary_table.loc[summary_table["muscle_pair"].str.startswith("—") == False, "_rmse_best"]))
best_rmse_sd = float(np.nanstd(summary_table.loc[summary_table["muscle_pair"].str.startswith("—") == False, "_rmse_best"], ddof=0))
best_r2_mu   = float(np.nanmean(summary_table.loc[summary_table["muscle_pair"].str.startswith("—") == False, "_r2_best"]))
best_r2_sd   = float(np.nanstd(summary_table.loc[summary_table["muscle_pair"].str.startswith("—") == False, "_r2_best"], ddof=0))

prior_rmse_mu_avg = float(np.nanmean(summary_table.loc[summary_table["muscle_pair"].str.startswith("—") == False, "_rmse_prior_mu"]))
prior_rmse_sd_avg = float(np.nanmean(summary_table.loc[summary_table["muscle_pair"].str.startswith("—") == False, "_rmse_prior_sd"]))
prior_r2_mu_avg   = float(np.nanmean(summary_table.loc[summary_table["muscle_pair"].str.startswith("—") == False, "_r2_prior_mu"]))
prior_r2_sd_avg   = float(np.nanmean(summary_table.loc[summary_table["muscle_pair"].str.startswith("—") == False, "_r2_prior_sd"]))

oo_rmse_mu_avg = float(np.nanmean(summary_table.loc[summary_table["muscle_pair"].str.startswith("—") == False, "_rmse_oo_mu"]))
oo_rmse_sd_avg = float(np.nanmean(summary_table.loc[summary_table["muscle_pair"].str.startswith("—") == False, "_rmse_oo_sd"]))
oo_r2_mu_avg   = float(np.nanmean(summary_table.loc[summary_table["muscle_pair"].str.startswith("—") == False, "_r2_oo_mu"]))
oo_r2_sd_avg   = float(np.nanmean(summary_table.loc[summary_table["muscle_pair"].str.startswith("—") == False, "_r2_oo_sd"]))

bottom_avg = {
    "muscle_pair": "— AVERAGE (per-condition means) —",
    "intensity": "",
    "RMSE_post (μ±σ)": fmt(post_rmse_mu_avg, post_rmse_sd_avg),
    "R²_post (μ±σ)"  : fmt(post_r2_mu_avg,   post_r2_sd_avg),
    "RMSE_best"      : fmt(best_rmse_mu, best_rmse_sd),
    "R²_best"        : fmt(best_r2_mu,   best_r2_sd),
    "RMSE_prior (μ±σ)": fmt(prior_rmse_mu_avg, prior_rmse_sd_avg),
    "R²_prior (μ±σ)"  : fmt(prior_r2_mu_avg,   prior_r2_sd_avg),
    "RMSE_exp-others (μ±σ)": fmt(oo_rmse_mu_avg, oo_rmse_sd_avg),
    "R²_exp-others (μ±σ)"  : fmt(oo_r2_mu_avg,  oo_r2_sd_avg),
}

summary_table = pd.concat([summary_table, pd.DataFrame([bottom_avg])], ignore_index=True)

# Drop internal raw columns
summary_table = summary_table[[c for c in summary_table.columns if not c.startswith("_")]]

summary_table.to_csv(f"{path_to_save_into}\\summary_table_rmse_r2.csv", index=False)


In [None]:
summary_table # check previous cell's result

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import gaussian_kde
from tqdm.auto import tqdm
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D

# -------------------------------------------------
# Config
# -------------------------------------------------
# EXCLUDE_PAIRS = []  # e.g. ["VL<->VL", "GM<->SOL"] # Already defined above
INTENSITIES   = sorted(df_summary_obs_experiment["intensity"].unique().tolist())

# Appearance
KDE_GRID_RES     = 200     # 2D KDE grid resolution
PAD_FRAC         = 0.1    # padding around global x/y range
LINE_ALPHA_MARG  = 0.85    # opacity for 1D marginal lines
LINE_WIDTH_MARG  = 1.6
CONTOUR_FILL_A   = 0.08    # area alpha at 90% mass
MARKER_EXP_SIZE  = 80
MARKER_HL_SIZE   = 58
KDE_BW_SCALE       = 0.5   # >1 = smoother 2D contours, <1 = sharper
MARGINAL_BW_SCALE  = 0.5   # >1 = smoother 1D curves, <1 = sharper

# --- Config ---
CONTOUR_PROBS   = (0.6, 0.3) # (0.90, 0.60, 0.30)   # e.g. (0.5,) or (0.8, 0.5, 0.2)
CONTOUR_LINE_WS = (0.3, 0.6) # (0.8, 0.55, 0.35)    # alphas for the lines; will be broadcast to len(CONTOUR_PROBS)

def _broadcast(vals, n):
    # make len(vals) == n (repeat last if shorter; trim if longer; scalars allowed)
    try:
        v = list(vals)
    except TypeError:
        v = [vals]
    if len(v) < n:
        v = v + [v[-1]] * (n - len(v))
    return v[:n]

def kde_levels_for_mass(Z, dx, dy, probs):
    """
    Return density thresholds t for which the highest-density region {Z >= t}
    captures a probability mass 'p' for each p in probs.
    """
    probs = np.atleast_1d(np.clip(probs, 0.0, 1.0))
    z = Z.ravel()
    order = np.argsort(z)[::-1]            # densities high -> low
    z_sorted = z[order]
    mass = np.cumsum(z_sorted) * dx * dy   # cumulative probability mass
    levels = []
    for p in probs:
        k = np.searchsorted(mass, p, side="left")
        levels.append(z_sorted[min(k, z_sorted.size - 1)])
    return np.array(levels, dtype=float)


# Colors for muscle pairs
colors_dict = {
  "VL<->VL": "#D62728",
  "VL<->VM": "#FF9201",
  "VM<->VL": "#FF9201",
  "VM<->VM": "#FFC400",
  "TA<->TA": "#00C71B",
  "FDI<->FDI": "#14BFA8",
  "GM<->GM": "#2489DC",
  "GM<->SOL": "#7D74EC",
  "SOL<->GM": "#7D74EC",
  "SOL<->SOL": "#BB86ED",
}

# --- Metric helpers (full 16-D space; features are z-scored to prior) ---
def rmse(a, b):
    a = np.asarray(a, dtype=float); b = np.asarray(b, dtype=float)
    return float(np.sqrt(np.mean((a - b)**2)))

def r2(pred, true):
    pred = np.asarray(pred, dtype=float); true = np.asarray(true, dtype=float)
    ss_res = np.sum((pred - true)**2)
    ss_tot = np.sum((true - true.mean())**2)
    return float(1.0 - ss_res/ss_tot) if ss_tot > 0 else np.nan

# --- Build per-intensity dicts: dict[muscle_pair] -> np.ndarray[n_rows, 16] ---
def dict_by_pair(df, intensity, exclude_pairs=None):
    dfi = df[df["intensity"] == intensity]
    dfi = dfi.dropna(subset=feat_cols)  # ensure numeric
    out = {}
    for pair, block in dfi.groupby("muscle_pair"):
        if exclude_pairs and pair in exclude_pairs:
            continue
        out[pair] = block[feat_cols].to_numpy(dtype=float)
    return out

# --- 2D KDE helpers with probability-mass contours ---
def kde2d_on_grid(points_2d, xgrid, ygrid):
    """Fit KDE on Nx2 points, evaluate on provided uniform (xgrid,ygrid)."""
    pts = np.asarray(points_2d, dtype=float)
    kde = gaussian_kde(pts.T, bw_method=KDE_BW_SCALE)  # Scott's rule
    Xg, Yg = np.meshgrid(xgrid, ygrid)
    Z = kde(np.vstack([Xg.ravel(), Yg.ravel()])).reshape(len(ygrid), len(xgrid))
    # grid spacings
    dx = (xgrid[-1] - xgrid[0]) / (len(xgrid) - 1)
    dy = (ygrid[-1] - ygrid[0]) / (len(ygrid) - 1)
    return Xg, Yg, Z, dx, dy

def kde_levels_for_mass(Z, dx, dy, probs=(0.90, 0.60, 0.30)):
    """Density thresholds for superlevel sets capturing given masses."""
    flat = Z.ravel()
    order = np.argsort(flat)[::-1]
    z_sorted = flat[order]
    mass = np.cumsum(z_sorted) * dx * dy
    levels = []
    for p in probs:
        idx = np.searchsorted(mass, p, side="left")
        idx = min(max(idx, 0), z_sorted.size - 1)
        levels.append(z_sorted[idx])
    return levels  # same order as probs

# -------------------------------------------------
# Plot per intensity with marginals
# -------------------------------------------------
for inten in INTENSITIES:
    # Slice dicts (drop excluded pairs)
    post_by_pair = dict_by_pair(df_summary_obs_simulated_from_posterior, inten, EXCLUDE_PAIRS)
    exp_by_pair  = dict_by_pair(df_summary_obs_experiment,                   inten, EXCLUDE_PAIRS)
    hl_by_pair   = dict_by_pair(df_summary_obs_simulated_highest_likelihood, inten, EXCLUDE_PAIRS)

    present_pairs = set(post_by_pair.keys()) | set(exp_by_pair.keys()) | set(hl_by_pair.keys())
    if not present_pairs:
        print(f"[info] intensity={inten}: nothing to plot after exclusions.")
        continue

    # ---------- Project to 2D (PC1,PC2) and gather global extents ----------
    proj_post = {}
    proj_exp  = {}
    proj_hl   = {}
    xs, ys = [], []

    for pair in present_pairs:
        if pair in post_by_pair:
            Z_post = pca.transform(post_by_pair[pair])[:, :2]
            proj_post[pair] = Z_post
            xs.append(Z_post[:, 0]); ys.append(Z_post[:, 1])
        if pair in exp_by_pair and exp_by_pair[pair].shape[0] >= 1:
            z_exp = pca.transform(exp_by_pair[pair])[:, :2].reshape(-1, 2)[0]
            proj_exp[pair] = z_exp
            xs.append([z_exp[0]]); ys.append([z_exp[1]])
        if pair in hl_by_pair and hl_by_pair[pair].shape[0] >= 1:
            z_hl = pca.transform(hl_by_pair[pair])[:, :2].reshape(-1, 2)[0]
            proj_hl[pair] = z_hl
            xs.append([z_hl[0]]); ys.append([z_hl[1]])

    # Global bounds + padding shared by main & marginals
    if xs:
        xmin, xmax = np.min(np.concatenate(xs)), np.max(np.concatenate(xs))
        ymin, ymax = np.min(np.concatenate(ys)), np.max(np.concatenate(ys))
    else:
        xmin = xmax = ymin = ymax = 0.0
    dx = xmax - xmin; dy = ymax - ymin
    if dx <= 0: dx = 1.0
    if dy <= 0: dy = 1.0
    xmin -= PAD_FRAC * dx; xmax += PAD_FRAC * dx
    ymin -= PAD_FRAC * dy; ymax += PAD_FRAC * dy

    # Shared grids
    xgrid = np.linspace(xmin, xmax, KDE_GRID_RES)
    ygrid = np.linspace(ymin, ymax, KDE_GRID_RES)

    # ---------- Figure with GridSpec: left marginal | main ; bottom marginal under main ----------
    fig = plt.figure(figsize=(8.6, 7.1))
    gs  = GridSpec(nrows=2, ncols=2, width_ratios=[1.4, 6.6], height_ratios=[6.6, 1.4],
                   wspace=0.05, hspace=0.05)

    ax_left   = fig.add_subplot(gs[0, 0])              # y-marginal (density vs y), horizontal lines
    ax_main   = fig.add_subplot(gs[0, 1])              # main 2-D KDE + points
    ax_bottom = fig.add_subplot(gs[1, 1], sharex=ax_main)  # x-marginal (density vs x)

    # empty bottom-left cell
    ax_blank = fig.add_subplot(gs[1, 0]); ax_blank.axis('off')

    # ---------- Plot 2D KDEs and points on main ----------
    rmse_means, rmse_sds, r2_means, r2_sds = [], [], [], []
    rmse_hl, r2_hl = [], []

    for pair in tqdm(sorted(present_pairs), desc=f"Intensity {inten}: 2D KDEs", leave=False):
        col = colors_dict.get(pair, "#444444")

        if pair in proj_post and proj_post[pair].shape[0] >= 5:
            Z_post = proj_post[pair]
            Xg, Yg, Zden, dxg, dyg = kde2d_on_grid(Z_post, xgrid, ygrid)

            levels = kde_levels_for_mass(Zden, dxg, dyg, probs=CONTOUR_PROBS)
            alphas = _broadcast(CONTOUR_LINE_WS, len(levels))

            # Fill using the FIRST prob in CONTOUR_PROBS (e.g., 0.90 if you keep the default)
            if len(levels):
                ax_main.contourf(
                    Xg, Yg, Zden,
                    levels=[levels[0], Zden.max()],
                    colors=[col],
                    alpha=CONTOUR_FILL_A
                )

        # Draw one contour line per requested probability (same order as CONTOUR_PROBS)
        for lvl, a in zip(levels, alphas):
            ax_main.contour(Xg, Yg, Zden, levels=[lvl], colors=[col], linewidths=1.1, alpha=a)


        # Experimental and highest-likelihood points
        if pair in proj_exp:
            xe, ye = proj_exp[pair]
            ax_main.scatter(xe, ye, marker='X', s=MARKER_EXP_SIZE,
                            facecolor=col, edgecolor='black', linewidths=0.9)
        if pair in proj_hl:
            xh, yh = proj_hl[pair]
            ax_main.scatter(xh, yh, marker='o', s=MARKER_HL_SIZE,
                            facecolor=col, edgecolor='black', linewidths=0.9)

        # 16-D metrics vs EXP (posterior cloud)
        if pair in post_by_pair and pair in exp_by_pair and exp_by_pair[pair].shape[0] >= 1:
            X_post = post_by_pair[pair]
            exp_vec = exp_by_pair[pair][0, :].astype(float, copy=False)
            errs = np.sqrt(np.mean((X_post - exp_vec[None, :])**2, axis=1))
            r2s  = [r2(X_post[i, :], exp_vec) for i in range(X_post.shape[0])]
            rmse_means.append(float(np.mean(errs)))
            rmse_sds.append(float(np.std(errs, ddof=0)))
            r2_means.append(float(np.mean(r2s)))
            r2_sds.append(float(np.std(r2s, ddof=0)))

        # 16-D metrics vs EXP (highest-likelihood)
        if (pair in hl_by_pair and pair in exp_by_pair and
            hl_by_pair[pair].shape[0] >= 1 and exp_by_pair[pair].shape[0] >= 1):
            hl_vec  = hl_by_pair[pair][0, :].astype(float, copy=False)
            exp_vec = exp_by_pair[pair][0, :].astype(float, copy=False)
            rmse_hl.append(rmse(hl_vec, exp_vec))
            r2_hl.append(r2(hl_vec, exp_vec))

    # Main axes cosmetics
    ax_main.set_xlim(xmin, xmax)
    ax_main.set_ylim(ymin, ymax)
    ax_main.set_xlabel("PC1")
    ax_main.set_ylabel("PC2")
    ax_main.grid(alpha=0.25)

    # ---------- Marginal 1-D KDEs ----------
    # Bottom (PC1): density vs x
    bottom_max = 0.0
    for pair in sorted(present_pairs):
        if pair not in proj_post or proj_post[pair].shape[0] < 3:
            continue
        col = colors_dict.get(pair, "#444444")
        x_vals = proj_post[pair][:, 0]
        kde_x  = gaussian_kde(x_vals, bw_method=MARGINAL_BW_SCALE)
        y_dens = kde_x(xgrid)
        bottom_max = max(bottom_max, float(np.max(y_dens)))
        ax_bottom.plot(xgrid, y_dens, color=col, alpha=LINE_ALPHA_MARG, lw=LINE_WIDTH_MARG)
        # Experimental vertical line
        if pair in proj_exp:
            xe = proj_exp[pair][0]
            ax_bottom.axvline(xe, color=col, alpha=0.9, lw=1.6)

    ax_bottom.set_xlim(xmin, xmax)
    ax_bottom.set_ylim(0, bottom_max * 1.05 if bottom_max > 0 else 1.0)
    ax_bottom.set_yticks([])
    ax_bottom.grid(alpha=0.15)

    # Left (PC2): density vs y (plotted horizontally)
    left_max = 0.0
    for pair in sorted(present_pairs):
        if pair not in proj_post or proj_post[pair].shape[0] < 3:
            continue
        col = colors_dict.get(pair, "#444444")
        y_vals = proj_post[pair][:, 1]
        kde_y  = gaussian_kde(y_vals, bw_method=MARGINAL_BW_SCALE)
        x_dens = kde_y(ygrid)
        left_max = max(left_max, float(np.max(x_dens)))
        ax_left.plot(x_dens, ygrid, color=col, alpha=LINE_ALPHA_MARG, lw=LINE_WIDTH_MARG)
        # Experimental horizontal line
        if pair in proj_exp:
            ye = proj_exp[pair][1]
            ax_left.axhline(ye, color=col, alpha=0.9, lw=1.6)

    ax_left.set_ylim(ymin, ymax)
    ax_left.set_xlim(0, left_max * 1.05 if left_max > 0 else 1.0)
    ax_left.set_xticks([])
    ax_left.invert_xaxis()  # densities grow into the left margin
    ax_left.grid(alpha=0.15)

    # ---------- Title + legend ----------
    def μσ(a):
        return (np.mean(a), np.std(a, ddof=0)) if len(a) else (np.nan, np.nan)

    rmseμ, rmseσ = μσ(rmse_means)
    r2μ,   r2σ   = μσ(r2_means)
    rmse_hlμ, rmse_hlσ = μσ(rmse_hl)
    r2_hlμ,   r2_hlσ   = μσ(r2_hl)

    ax_main.set_title(
        f"Posterior predictive (intensity={inten})  |  "
        f"RMSE_post: {rmseμ:.2f}±{rmseσ:.2f}   R²_post: {r2μ:.2f}±{r2σ:.2f}   |   "
        f"RMSE_HL: {rmse_hlμ:.2f}   R²_HL: {r2_hlμ:.2f}"
    )

    handles = [
        Line2D([0], [0], marker='X', color='black', lw=0, markerfacecolor='white',
               label='Experiment (filled cross)', markersize=8),
        Line2D([0], [0], marker='o', color='black', lw=0, markerfacecolor='white',
               label='Highest-likelihood (circle)', markersize=7),
    ]
    for pair in sorted(present_pairs):
        handles.append(Line2D([0], [0], color=colors_dict.get(pair, "#444444"),
                              lw=6, alpha=0.9, label=pair))
    ax_main.legend(handles=handles, bbox_to_anchor=(1.02, 1.0), loc='upper left',
                   frameon=False, fontsize=8)

    plt.tight_layout()
    plt.show()
