In [None]:
import numpy as np
import pandas as pd
import torch
from torch.distributions import constraints, MultivariateNormal
import random
import os
import inspect
import pickle
import re
from sklearn.neighbors import NearestNeighbors
import pathlib
import json
import ast
from tqdm.auto import tqdm

# For plotting
import matplotlib.pyplot as plt
import seaborn as sns
import cmasher as cmr
from scipy.spatial import ConvexHull
from scipy.stats import gaussian_kde
import matplotlib.lines as mlines
from matplotlib.colors import to_rgb
import matplotlib.patches as mpatches
from matplotlib.patches import Ellipse
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from scipy.stats import chi2

# sbi imports
from sbi import utils as sbi_utils
from sbi import inference as sbi_inference
from sbi.analysis import plot_summary, pairplot, conditional_pairplot
from sbi.inference import NPE, ImportanceSamplingPosterior
from sbi.utils import RestrictedPrior, get_density_thresholder

# for summary‐statistics
from scipy.stats import iqr as scipy_iqr

# for embeddings
from sklearn.decomposition import PCA
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, Normalizer, normalize
from sklearn.metrics import silhouette_score
from itertools import product

# for posterior predictive checks comparisons
from scipy.stats import wilcoxon, binomtest


In [None]:
# PATHS: to be replaced with the folders in which the data is stored on your computer - the script does not need the raw output files but the summary CSVs.
simulated_data_path = "C:\\Users\\franc\\Documents\\Mapping_Recurrent_Inhibition_minimal_data_to_run_scripts\\Files_to_run_scripts\\simulated_data_SBI_training_single_muscles" # for single muscle
# simulated_data_path = "C:\\Users\\franc\\Documents\\Mapping_Recurrent_Inhibition_minimal_data_to_run_scripts\\Files_to_run_scripts\\simulated_data_SBI_training_between_muscles" # for muscle pairs
experimental_data_path = "C:\\Users\\franc\\Documents\\Mapping_Recurrent_Inhibition_minimal_data_to_run_scripts\\Files_to_run_scripts\\experimental_data"
experimental_dataframe_to_load = "experimental_data_dataframe_reorganized_and_filtered_dir_inhibited_persp_other_MUs_as_ref.csv"
# ^ Make sure to plug-in the right csv according to choice of "perspective_to_use" and "direction_to_use"

path_to_save = f"{simulated_data_path}\\$_SBI_inference"
# If a neural estimator was already trained and saved, specify the folder in which it has been saved
# # SINGLE MUSCLES CONFIG - if loading trained network and posterior estimates
path_to_load = f"C:\\Users\\franc\\Documents\\GitHub\\Mapping_Recurrent_Inhibition\\Simulation_based_inference\\saved_posterior_density_estimators\\single_muscles_density_estimator" # Load the trained neural density estimator reported in the paper (single muscle case), available in the repository
# # BETWEEN MUSCLES CONFIG - if loading trained network and posterior estimates
# path_to_load = f"C:\\Users\\franc\\Documents\\GitHub\\Mapping_Recurrent_Inhibition\\Simulation_based_inference\\saved_posterior_density_estimators\\paired_muscles_density_estimator" # Load the trained neural density estimator reported in the paper (single muscle case), available in the repository
# NOTE that an existing density estimator may perform very poorly if it has been trained on a set of training data that doesn't reflect the testing data
rerun_network_training_and_sampling = False # False # True # If false, will load the posterior-infering network pickle file instead, and the posterior estimates (samples) as a pickle file too
rerun_network_training_for_heldout_data = False # False # True # If false, will load the posterior-infering network (trained on 90% of the full training dataset) pickle file instead of training a new one. 
# It will look for the pickle file 'sbi_check_heldout_sims.pkl' inside the path_to_load folder. 
# ^ if "False", will look for the pickle file in path_to_loads
# The network training is fast but if the sampling process for the posterior estimate is very long and appears to run indefinitely,
# it means that some experimental observations are out-of-distribution relative to the training data (sbi usually returns a warning)

# ### #
# /!\ POSTERIOR PREDICTIVE CHECKS PARAMETERS (e.g., number of parameter sets to sample from posterior and to simulate from) ARE AT THE END OF THE NOTEBOOK
# ### #

# Creating folder if necessary
if not rerun_network_training_and_sampling and not os.path.exists(path_to_load):
    raise ValueError(f"Please specify a valid, existing 'path_to_load'")
elif rerun_network_training_and_sampling:
    if not os.path.exists(path_to_save):
        os.mkdir(path_to_save)
    inference_dirs = [d for d in os.listdir(f"{path_to_save}")
        if re.fullmatch(r'inference_model_(\d+)', d)]
    if inference_dirs:
        existing_idxs = [int(re.fullmatch(r'inference_model_(\d+)', d).group(1))
                        for d in inference_dirs]
        inference_idx = max(existing_idxs) + 1
    else:
        inference_idx = 0
    path_to_save = f"{path_to_save}\\inference_model_{inference_idx}"
    os.makedirs(path_to_save, exist_ok=True)
else:
    path_to_save = path_to_load

# ── USER‐SPECIFIED HYPERPARAMETERS ──
perspective_to_use = 'other_MUs_as_ref' # 'other_MUs_as_ref' # choose from {'other_MUs_as_ref','MU_as_ref','combined','most_spikes'}
direction_to_use   = 'inhibited' # choose from {'inhibited','inhibiting'}
raw_or_corrected   = 'raw' # choose from {'raw', 'corrected'}
normalize_scale_of_features_for_inference = True

# Map direction → column name in the CSV (for simulated data, different column names for the )
if direction_to_use == 'inhibited':
    inhib_connectivity_colname = f'inhibited_by_estimation_{raw_or_corrected}'
elif direction_to_use == 'inhibiting':
    inhib_connectivity_colname = f'inhibiting_estimation_{raw_or_corrected}'
else:
    raise ValueError("direction_to_use must be 'inhibited' or 'inhibiting'")
asym_colname = f'asymmetry_diff_{raw_or_corrected}'

# Filtering
min_r2_for_baseline_curve_fit = {"simulation":0.1,
                                 "experiment": 0.1}
min_r2_for_overall_curve_fit = {"simulation":0.75,
                                 "experiment": 0.75}
min_nb_spikes = {"simulation":10_000,
                "experiment": 5_000}

# Simulation input parameters
version_of_sim_with_only_single_pool_input_vals = False # False # if True, older simulation version
within_or_between_pools_sbi = 'within' # 'within' # 'between' #'within' 'between'
nb_pools = 1 # 1 for SINGLE MUSCLE CONFIG (1 pool simulated for training dataset) # 2 for BETWEEN MUSCLES CONFIG (2 pools simulated for training dataset)
# ^ If within, will consider only a single pool and use excitatory_input_baseline[0], disynpatic_inhib_connections_desired_MN_MN[0,0] as well as frequency_range_of_common_input[0]
# ^ If between, will consider that there are two pools and will perform sbi only on the following parameters: 
#       - excitatory_input_baseline[0] and excitatory_input_baseline[1]
#       - disynpatic_inhib_connections_desired_MN_MN[0,1] and disynpatic_inhib_connections_desired_MN_MN[1,0]
#       - between_pool_excitatory_input_correlation
# !!! Please adapt input_sim_parameters !!! #
if within_or_between_pools_sbi == 'within':
    use_only_same_muscle_pair = True
    use_only_different_muscle_pair = False
    filter_intensity = [10, 40]
    filter_muscles = [] # Leaving empty keeps all muscles
elif within_or_between_pools_sbi == 'between':
    use_only_different_muscle_pair = True
    use_only_same_muscle_pair = False
    # The between-pool simulations have their priors coming from the posterior obtained from the within-pool inference from the muscle pair and from the intensity of interest.
    # So the mapping of the experimental data should apply only to the muscles and intensity from which the priors have been taken from
    filter_intensity = [10, 40] # [10] #[10]
    filter_muscles = [] # ["SOL","GM"] # ["VM","VL"] # Will consider only those muscles for experimental data

# Input param to be inferred #########################################
# # BETWEEN MUSCLES CONFIG 
# input_sim_parameters_to_infer = ["excitatory_input_baseline_self",
#     "disynpatic_inhib_connections_desired_MN_MN_other_pool",
#     "between_pool_excitatory_input_correlation"]
# # SINGLE MUSCLES CONFIG
input_sim_parameters_to_infer = ["excitatory_input_baseline",
    "disynpatic_inhib_connections_desired_MN_MN",
    "common_input_high_freq_middle_of_range", # "common_input_characteristics.Frequency_middle_of_range.pool_0.input_1"
    "common_input_high_freq_half_width_range", # "common_input_characteristics.Frequency_half_width_of_range.pool_0.input_1"
    "common_input_std"] # "common_input_std[0][1]"

# Input parameters not fixed, but still NOT to be inferred #########################################
# coming from previously-inferred posterior and used as inference features for training (ground-truth) and inference on experimental data (repeated sampling of the single-muscle posterior)
# # BETWEEN MUSCLES CONFIG 
# input_sim_parameters_as_features = [ # Each row comes from a single 'MN to pool', so always entierly defined by "self" VS "other"
#     "disynpatic_inhib_connections_desired_MN_MN_self",
#     "common_input_high_freq_middle_of_range_self",
#     "common_input_high_freq_half_width_range_self",
#     "common_input_std_self"
# ]
# # SINGLE MUSCLES CONFIG
input_sim_parameters_as_features = []

# Below: only used if len(input_sim_parameters_as_features) >= 1
previously_estimated_posterior_samples_for_experimental_data_SBI = 100 # duplicate each experimental data by this value, and assign to each duplicated row a random sample from the posterior
previously_estimated_posterior_results_path = "C:\\Users\\franc\\Documents\\GitHub\\SBI_motor_neuron_behavior\\$$$_Simulation_batch_single_muscle\\$_SBI_inference\\inference_model_0"
previously_estimated_posterior_each_subject_csv = "posterior_samples_each_subject_df.csv"
previously_estimated_posterior_subjects_grouped_csv = "posterior_samples_subjects_grouped_df.csv"
map_strings_of_posterior_estimated_parameters_to_param_used_as_features = {
    "disynpatic_inhib_connections_desired_MN_MN": "disynpatic_inhib_connections_desired_MN_MN_self",
    "common_input_high_freq_middle_of_range": "common_input_high_freq_middle_of_range_self",
    "common_input_high_freq_half_width_range": "common_input_high_freq_half_width_range_self",
    "common_input_std": "common_input_std_self"
}
# ^ The keys need to be the same as in input_sim_parameters_as_features


# The param names below need to exist in 'input_sim_parameters_to_infer' - they are used for plotting
# # BETWEEN MUSCLES CONFIG 
# specific_input_parameters_of_interest = ["disynpatic_inhib_connections_desired_MN_MN_other_pool",
#     "between_pool_excitatory_input_correlation"] # "common_input_std[0][1]"
# # SINGLE MUSCLES CONFIG
specific_input_parameters_of_interest = ["disynpatic_inhib_connections_desired_MN_MN",
    "common_input_high_freq_middle_of_range", # "common_input_characteristics.Frequency_middle_of_range.pool_0.input_1"
    "common_input_high_freq_half_width_range", # "common_input_characteristics.Frequency_half_width_of_range.pool_0.input_1"
    "common_input_std"] # "common_input_std[0][1]"
specific_input_parameters_of_interest_corresponding_indices = [input_sim_parameters_to_infer.index(p) 
    for p in specific_input_parameters_of_interest]
# ^ This is just to make the index correspondance between specific_input_parameters_of_interest_corresponding_indices and specific_input_parameters_of_interest

# Change the pool(s) (and thus the strings) when several pools are considered
mapping_from_common_input_characteristics_to_input_param_names = {
    "Frequency_middle_of_range.pool_0.input_1": "common_input_high_freq_middle_of_range",
    "Frequency_half_width_of_range.pool_0.input_1": "common_input_high_freq_half_width_range",
}

# PRIORS
# # SINGLE MUSCLES CONFIG
priors_per_parameters_to_infer = { # "param_name": [low, high]
    "excitatory_input_baseline": [20*1e3, 70*1e3], 
    "disynpatic_inhib_connections_desired_MN_MN": [0, 3],
    "common_input_high_freq_middle_of_range": [2.5, 75],
    "common_input_high_freq_half_width_range": [2.5, 25],
    "common_input_std": [0, 7.0*1e3]
}
# # BETWEEN MUSCLES CONFIG (pairs)
# priors_per_parameters_to_infer = { # "param_name": [low, high]
#     "excitatory_input_baseline_self": [20*1e3, 90*1e3], # + 20 relative to the single muscle case with only self-inhibition
#     "disynpatic_inhib_connections_desired_MN_MN_other_pool": [0, 3],
#     "between_pool_excitatory_input_correlation": [0, 1]
# }

# Define summary-statistics functions to be used here
summary_funcs = {
    'mean'  : np.nanmean,
    'median': lambda arr: np.nanmedian(arr),
    'sd'    : np.nanstd,
    'iqr'   : lambda arr: scipy_iqr(arr, nan_policy='omit')
}

# Features to use for posterior inference
features_for_inference = [inhib_connectivity_colname,
            'sync_height',
            'IPSP_timing_of_trough',
            'Firing_rates_mean']
# The feature names below need to exist in 'features_for_inference' - they are used for plotting
specific_features_of_interest = [inhib_connectivity_colname,
            'sync_height']
# In the between-muscles case, also using 'input_sim_parameters_as_features' (ground-truth for training, samples from posterior for inference)
features_for_inference += input_sim_parameters_as_features

## Remapping colnames from simulation results
colname_renames = {}
if (direction_to_use == 'inhibited') and (perspective_to_use == 'other_MUs_as_ref'):
    colname_renames = {"delay_forward_IPSP": "IPSP_timing_of_trough"}
elif (direction_to_use == 'inhibited') and (perspective_to_use == 'other_MUs_as_ref'):
    colname_renames = {"delay_backward_IPSP": "IPSP_timing_of_trough"}
elif (direction_to_use == 'inhibited') and (perspective_to_use == 'MU_as_ref'):
    colname_renames = {"delay_backward_IPSP": "IPSP_timing_of_trough"}
elif (direction_to_use == 'inhibiting') and (perspective_to_use == 'MU_as_ref'):
    colname_renames = {"delay_forward_IPSP": "IPSP_timing_of_trough"}

### For SBI
sbi_density_estimator = "maf" # "maf" (masked autoregressive flow) is default; "mdn" (mixture density network) avoid leakages at the priors' boundaries
num_posterior_samples = {"simulation":10_000,
                         "experiment": 10_000,
                         "experiment_with_posterior_estimates_as_features": 100} # 1_000} # Note that the total number of samples per condition will be multiplied by 'previously_estimated_posterior_samples_for_experimental_data_SBI'
best_posterior_estimate_method = "logp" # "logp" or "knn". The 'correct' way is logp and it is the most accurate, but it can be extremely long if the model cannot appropriately reproduce the epxerimental data - otherwise it's very fast
network_training_hyperparameters = {
    "num_atoms": 10, # default is 10
    "training_batch_size": 200, # default is 200
    "learning_rate": 0.0005, # default is 0.0005
    "validation_fraction": 0.1, # default is 10%
    "max_num_epochs": 2000, # Stop training after X epoches (upper bound)
    "stop_after_epochs": 20 # Train for at least X epochs
}
# # Check tutorials online:
# # https://sbi-dev.github.io/sbi/v0.24.0/tutorials/03_density_estimators/
# # https://sbi-dev.github.io/sbi/v0.24.0/tutorials/09_sampler_interface/
# # https://sbi-dev.github.io/sbi/v0.24.0/tutorials/15_importance_sampled_posteriors/
# # https://sbi-dev.github.io/sbi/v0.24.0/tutorials/17_plotting_functionality/
# # https://sbi-dev.github.io/sbi/0.22/tutorial/07_conditional_distributions/ # https://sbi-dev.github.io/sbi/v0.24.0/tutorials/05_conditional_distributions/ 
# sampling_algorithm = "vi" # "direct" (default) ; "mcmc" (Markov Chain Monte Carlo) ; "vi" (variational inference) ; "rejection" ; "is" (importance sampling)
# # a sampling method may be necessary if sampling_algorithm == 'mcmc', 'vi' or 'si'
# # ValueError: You passed `sample_with='rejection' but you did not specify a `proposal` in `rejection_sampling_parameters`. Until sbi v0.22.0, this was interpreted as directly sampling from the posterior. As of sbi v0.23.0, you instead have to use `sample_with='direct'` to do so.

# Mapping from simulation colname to experiment colnames
colname_map = {
    "ISI_cov": "cov_mean",
    "Firing_rates_mean":"firing_rates_mean"
} # The rest of the colnames should have the correct mapping already

muscle_colors_dict = {
    "FDI<->FDI": "#00D6D6",
    "TA<->TA":   "#2CA02C",
    "VL<->VL":   "#D62728",
    "VL<->VM":   "#D65927",
    "VM<->VM":   "#FFC400",
    "VM<->VL":   "#FFA600",
    "GM<->GM":   "#339DFF",
    "GM<->SOL":   "#339DFF",
    "SOL<->SOL": "#9467BD",
    "SOL<->GM": "#9467BD"
}
muscle_colormaps_dict = {
    "FDI<->FDI":    cmr.cosmic,
    "TA<->TA":      cmr.nuclear,
    "VL<->VL":      cmr.ember,
    "VL<->VM":      cmr.ember,
    "VM<->VM":      cmr.amber,
    "VM<->VL":      cmr.amber,
    "GM<->GM":      cmr.bubblegum,
    "GM<->SOL":     cmr.bubblegum,
    "SOL<->SOL":    cmr.freeze,
    "SOL<->GM":     cmr.freeze    
}
subjects_color_dict = {
    'DeFr':"#FFC400",
    'HuFr':"#0091F8",
    'KaPa':"#2CA02C",
    'LeCl':"#D62728",
    'LoTi':"#9467BD",
    'MeJu':"#00D6D6"
}

param_colors_dict = { # Should be the same paramas as in input_sim_parameters
    "disynpatic_inhib_connections_desired_MN_MN": "#1F77B4",
    "common_input_std": "#D62728"
}

# ---- Helper to help with Numpy and PyTorch conversions
def to_np(x):
    """Convert torch.Tensor or array-like to a NumPy array without using tensor.numpy()."""
    if isinstance(x, torch.Tensor):
        # Go through Python lists → NumPy, avoids PyTorch's NumPy bridge entirely
        return np.asarray(x.detach().cpu().tolist(), dtype=float)
    else:
        return np.asarray(x, dtype=float)

In [None]:
# Saving model parameters as JSON #######################
# Only if this is a new inference run
if rerun_network_training_and_sampling:
    # Build a JSON‐safe representation of summary_funcs:
    serializable_summary_funcs = {}
    for name, func in summary_funcs.items():
        try:
            func_name = func.__name__
        except AttributeError:
            # lambdas typically have the name "<lambda>"
            func_name = "<lambda>"
        serializable_summary_funcs[name] = func_name

    # Now collect everything into one dict:
    cfg = {
        "perspective_to_use":                perspective_to_use,
        "direction_to_use":                  direction_to_use,
        "min_r2_for_baseline_curve_fit":     min_r2_for_baseline_curve_fit,
        "min_r2_for_overall_curve_fit":      min_r2_for_overall_curve_fit,
        "min_nb_spikes":                     min_nb_spikes,
        "within_or_between_pools_sbi":       within_or_between_pools_sbi,
        "use_only_same_muscle_pair":         use_only_same_muscle_pair,
        'use_only_different_muscle_pair':    use_only_different_muscle_pair,
        'filter_intensity':                  filter_intensity,
        'filter_muscles':                    filter_muscles,
        "input_sim_parameters_to_infer":     input_sim_parameters_to_infer,
        "summary_funcs":                     serializable_summary_funcs,
        "features_for_inference":            features_for_inference,
        "input_sim_parameters_as_features": input_sim_parameters_as_features,
        "specific_features_of_interest":     specific_features_of_interest,
        "sbi_density_estimator":             sbi_density_estimator,
        "num_posterior_samples":             num_posterior_samples,
        "version_of_sim_with_only_single_pool_input_vals": version_of_sim_with_only_single_pool_input_vals
    }

    # Finally, write it out to disk:
    with open(f"{path_to_save}\\_model_hyperparameters.json", "w") as f:
        json.dump(cfg, f, indent=2)

In [None]:
# Load data frames - simulation
df_simulation = pd.read_csv(f"{simulated_data_path}\\___general_analysis_of_simulations.csv")

In [None]:
# Load data frames - experiment
df_experiment = pd.read_csv(f"{experimental_data_path}\\{experimental_dataframe_to_load}")

In [None]:
# Filter data frame
df_simulation = df_simulation[df_simulation['r2_base']>min_r2_for_baseline_curve_fit['simulation']]
df_simulation = df_simulation[df_simulation['r2_full']>min_r2_for_overall_curve_fit['simulation']]
df_simulation = df_simulation[df_simulation['n_spikes']>min_nb_spikes['simulation']]
df_experiment = df_experiment[df_experiment['r2_base']>min_r2_for_baseline_curve_fit['experiment']]
df_experiment = df_experiment[df_experiment['r2_full']>min_r2_for_overall_curve_fit['experiment']]
df_experiment = df_experiment[df_experiment['n_spikes']>min_nb_spikes['experiment']]
# Split “SOL<->GM” → [“SOL”, “GM”], then check equality
left  = df_experiment["muscle_pair"].str.split("<->").str[0]
right = df_experiment["muscle_pair"].str.split("<->").str[1]
left_sim = df_simulation["pool_pair"].str.split("<->").str[0]
right_sim = df_simulation["pool_pair"].str.split("<->").str[1]
if len(filter_muscles) > 0:
    df_experiment = df_experiment[df_experiment['muscle_of_MU'].isin(filter_muscles)]
if len(filter_intensity) > 0:
    df_experiment = df_experiment[df_experiment['intensity'].isin(filter_intensity)]
if use_only_same_muscle_pair:
    df_experiment = df_experiment[left == right].copy()
if use_only_different_muscle_pair:
    df_experiment = df_experiment[left != right].copy()
    df_simulation = df_simulation[left_sim != right_sim].copy()

# Rename columns to get matching names
# Invert it so we go “experiment → simulation”:
inv_map = { exp_name: sim_name for sim_name, exp_name in colname_map.items() }
# Now rename only those columns in df_experiment:
df_experiment = df_experiment.rename(columns=inv_map)


In [None]:
# Add some new columns to the simulated dataframe to "explode" array parameters into the relevant single float values (one per row)
# new_cols = [
#     "excitatory_input_baseline_self",
#     "disynpatic_inhib_connections_desired_MN_MN_self",
#     "disynpatic_inhib_connections_desired_MN_MN_other_pool",
#     "between_pool_excitatory_input_correlation",
#     "common_input_high_freq_middle_of_range_self",
#     "common_input_high_freq_half_width_range_self",
#     "common_input_std_self"
# ]

# Just for testing (making the data frame much smaller)
# df_simulation_test = df_simulation.copy()
# df_simulation_test = df_simulation_test.iloc[0:100]
# df_simulation_test

# ---------- robust parsers ----------
_num_re = re.compile(r'[-+]?(?:\d*\.\d+|\d+)(?:[eE][-+]?\d+)?')

def parse_vec_len2(s):
    """Return np.ndarray shape (2,) from a string like '[x y]' (no commas), else None."""
    if isinstance(s, np.ndarray) and s.size == 2:
        return s.astype(float)
    if pd.isna(s):
        return None
    nums = _num_re.findall(str(s))
    if len(nums) == 2:
        return np.array([float(v) for v in nums], dtype=float)
    return None

def parse_mat_2x2(s):
    """Return np.ndarray shape (2,2) from a string like '[[a b]\\n [c d]]', else None."""
    if isinstance(s, np.ndarray) and s.size == 4:
        a = np.asarray(s, dtype=float)
        return a.reshape(2, 2)
    if pd.isna(s):
        return None
    nums = _num_re.findall(str(s))
    if len(nums) == 4:
        return np.array([float(v) for v in nums], dtype=float).reshape(2, 2)
    return None

def parse_dict_str(s):
    """Parse your dict string with single quotes into a real dict."""
    if isinstance(s, dict):
        return s
    if pd.isna(s):
        return None
    try:
        return ast.literal_eval(str(s))
    except Exception:
        return None

# --- per-row extractor (row is a Series) ---

def extract_new_fields_row(row, direction_to_use='inhibited'):
    """
    direction_to_use:
      'inhibited'  -> other_pool -> self (receiving inhibition)
      'inhibiting' -> self -> other_pool (delivering inhibition)
    Assumes pools are {0,1}.
    """
    pool_i  = int(row['pool'])
    other_i = 1 - pool_i

    # 1) baseline excitatory input (vector length 2, index by pool)
    ex_base = parse_vec_len2(row.get('excitatory_input_baseline'))
    ex_self = float(ex_base[pool_i]) if ex_base is not None else np.nan

    # 2) disynaptic inhibition matrix (2x2: from_pool, to_pool)
    disyn = parse_mat_2x2(row.get('disynpatic_inhib_connections_desired_MN_MN'))
    if disyn is not None:
        dis_self = float(disyn[pool_i, pool_i])
        dis_other = (float(disyn[other_i, pool_i]) if direction_to_use == 'inhibited'
                     else float(disyn[pool_i, other_i]))
    else:
        dis_self = np.nan
        dis_other = np.nan

    # 3) common_input_characteristics dict
    cic = parse_dict_str(row.get('common_input_characteristics')) or {}
    pk  = f'pool_{pool_i}'
    mid  = cic.get('Frequency_middle_of_range', {}).get(pk, {}).get('input_1', np.nan)
    half = cic.get('Frequency_half_width_of_range', {}).get(pk, {}).get('input_1', np.nan)

    # 4) common_input_std (2x2: pool x input_index) — want input_1 for current pool
    ci_std = parse_mat_2x2(row.get('common_input_std'))
    ci_self = float(ci_std[pool_i, 1]) if ci_std is not None else np.nan

    # 5) scalar passthrough
    between_corr = row.get('between_pool_excitatory_input_correlation', np.nan)

    return pd.Series({
        "excitatory_input_baseline_self": ex_self,
        "disynpatic_inhib_connections_desired_MN_MN_self": dis_self,
        "disynpatic_inhib_connections_desired_MN_MN_other_pool": dis_other,
        "between_pool_excitatory_input_correlation": between_corr,
        "common_input_high_freq_middle_of_range_self": mid,
        "common_input_high_freq_half_width_range_self": half,
        "common_input_std_self": ci_self,
    })

# --- run with a progress bar on your (possibly filtered) frame ---

# pick the frame to process (your 0:100 slice in tests)
df_sim = df_simulation.copy()
df_sim['pool'] = df_sim['pool'].astype(int)

tqdm.pandas(desc="Extracting fields", mininterval=0.5)
new_cols = df_sim.progress_apply(
    lambda r: extract_new_fields_row(r, direction_to_use=direction_to_use),
    axis=1, result_type='expand'
)

# attach to original (aligned by index)
df_simulation = pd.concat([df_simulation, new_cols], axis=1)



In [None]:
df_simulation # checking the result of the previous cell

In [None]:
### Simulation data frame
# Parse the "common_input_characteristics" column here directly, not later - because it may relate to several free parameters (instead of 1 like the other string parameters which need to be parsed)
if "common_input_characteristics" in df_simulation.columns:
    # 1) Parse each string into an actual dict
    dicts = df_simulation['common_input_characteristics'].apply(ast.literal_eval)
    # 2) Use json_normalize to flatten, joining nested keys with dots
    flat = pd.json_normalize(dicts, sep='.')
    cols_to_take = [col for col in mapping_from_common_input_characteristics_to_input_param_names if col in flat.columns]
    selected_cols = flat[cols_to_take].rename(columns=mapping_from_common_input_characteristics_to_input_param_names)
    df_simulation = pd.concat([df_simulation, selected_cols], axis=1)
###
# Add the renamed IPSP delay/timing column
for col_previous_name, col_to_add in colname_renames.items():
    if col_previous_name in df_simulation.columns:
        df_simulation[col_to_add] = df_simulation[col_previous_name]
# Make sure 'sim_name', 'perspective', etc. exist for the simulation data frame
required_columns = ['sim_name',
    'perspective'] + input_sim_parameters_to_infer + features_for_inference + input_sim_parameters_as_features
for col in required_columns:
    if col not in df_simulation.columns:
        raise KeyError(f"Required column '{col}' not found in CSV.")

# Filter to just the chosen perspective
# for simulation
df_simulation = df_simulation[df_simulation['perspective'] == perspective_to_use].copy()
# for experiment = a bit trickier when using different muscles as pairs, because of the choice to use the "pair direction" that maximizes the number of spikes (for example VM<->VL instead of both VM<->VL and VL<->VM)
df_experiment = df_experiment[df_experiment['perspective'] == perspective_to_use].copy()

# Choose the direction for experimental data (only experimental data, because stored as separate columns for the simulated data)
df_experiment = df_experiment[df_experiment['direction']==direction_to_use]
if raw_or_corrected == 'raw':
    df_experiment[inhib_connectivity_colname] = df_experiment['raw_area']
elif raw_or_corrected == 'corrected':
    df_experiment[inhib_connectivity_colname] = df_experiment['corrected_area']

# If there are any NaNs in output‐column, you might want to drop them:
df_simulation = df_simulation[~df_simulation[inhib_connectivity_colname].isna() & ~df_simulation['sync_height'].isna()].copy()
df_experiment = df_experiment[~df_experiment[inhib_connectivity_colname].isna() & ~df_experiment['sync_height'].isna()].copy()

# ── SCALE THE TWO COLUMNS BEFORE COMPUTING SUMMARIES ──
df_simulation[inhib_connectivity_colname] = df_simulation[inhib_connectivity_colname] * (-100)  # make inhibition positive % 
df_experiment[inhib_connectivity_colname] = df_experiment[inhib_connectivity_colname] * (-100)  # make inhibition positive % 

df_simulation['sync_height']  = df_simulation['sync_height']  * 100     # make sync_height in % units
df_experiment['sync_height']  = df_experiment['sync_height']  * 100     # make sync_height in % units

df_simulation['IPSP_timing_of_trough'] = df_simulation['IPSP_timing_of_trough'] * 1000 # convert to ms
df_experiment['IPSP_timing_of_trough'] = df_experiment['IPSP_timing_of_trough'] * 1000 # convert to ms

df_simulation[f'asymmetry_estimation_diff_{raw_or_corrected}'] = df_simulation[f'asymmetry_estimation_diff_{raw_or_corrected}'] * 100
df_experiment[f'asymmetry_diff_{raw_or_corrected}'] = df_experiment[f'asymmetry_diff_{raw_or_corrected}'] / 100 # seems to be already multiplied by 100 twice in this data frame, so dividing by 100


In [None]:
df_simulation # checking the result of the previous cell

In [None]:
df_experiment # checking the result of the previous cell

In [None]:
# 1) A helper to parse ANY JSON‐style array string into a np.array
def parse_array(s: str, shape):
    """
    Strip out brackets, commas, newlines, then read all the numbers
    and reshape to `shape`.  If it doesn’t fit, return an array of NaNs.
    """
    # replace any bracket or comma with a space
    flat = np.fromstring(re.sub(r"[\[\],]", " ", s), sep=" ")
    try:
        return flat.reshape(shape)
    except Exception:
        return np.full(shape, np.nan)

df = df_simulation  # your DataFrame

# 2)  parse the raw array columns with the shapes we expect
nb_of_common_inputs_in_simulations = 2 # There is actually a unique common input, but there are two "spaces" for common input in the parameters
if version_of_sim_with_only_single_pool_input_vals:
    df['excitatory_input_baseline_array'] = (
        df['excitatory_input_baseline']
          .map(lambda v: np.full(nb_pools, v, dtype=float))
    )
    df['common_input_std_array'] = (
        df['common_input_std']
        .map(lambda s: parse_array(s, (1, nb_of_common_inputs_in_simulations)))
    )
    df['frequency_range_of_common_input_array'] = (
        df['frequency_range_of_common_input']
        .map(lambda s: parse_array(s, (1, nb_of_common_inputs_in_simulations, 2)))
    )
else: # New simulation version, with a new dim per input
    # Here I am hard coding "2" as the number of pools nb_pools (even if only one pool was simulated - this is because of how the simulation parameters are organized)
    df['excitatory_input_baseline_array'] = (
        df['excitatory_input_baseline']
        .map(lambda s: parse_array(s, (nb_pools,))) # .map(lambda s: parse_array(s, (nb_pools,)))
    )
    df['common_input_std_array'] = (
        df['common_input_std']
        .map(lambda s: parse_array(s, (2, nb_of_common_inputs_in_simulations))) # .map(lambda s: parse_array(s, (nb_pools, nb_of_common_inputs_in_simulations)))
    )
    df['frequency_range_of_common_input_array'] = (
        df['frequency_range_of_common_input']
        .map(lambda s: parse_array(s, (2, nb_of_common_inputs_in_simulations, 2))) # .map(lambda s: parse_array(s, (nb_pools, nb_of_common_inputs_in_simulations, 2)))
    )

# disynpatic_inhib_connections_desired_MN_MN_array did not change across simulation versions
df['disynpatic_inhib_connections_desired_MN_MN_array'] = (
    df['disynpatic_inhib_connections_desired_MN_MN']
    .map(lambda s: parse_array(s, (2, 2))) # hard-coding the number of pools here, because 2x2 all the time
)

# 3) Selector for “within” vs “between” for baseline excitatory input
def select_excit_baseline(arr, within_between, pool_pair, idx_between):
    # bail if parsing failed → scalar nan
    if not hasattr(arr, "__getitem__"):
        return float(arr)
    # now safe to index arr[0] or arr[idx]
    if within_between == 'within':
        return float(arr[0])
    else:
        return float(arr[idx_between[0]] if pool_pair=='pool_0<->pool_1'
                     else arr[idx_between[1]])
    
# 3.1) Selector for second (high freq) common input STD
def select_high_freq_input_std(arr, within_between, pool_pair, idx_between):
    # bail if parsing failed → scalar nan
    if not hasattr(arr, "__getitem__"):
        return float(arr)
    # now safe to index arr[0] or arr[idx]
    if within_between == 'within':
        return float(arr[0][1]) # arr[0] for first input, arr[1] for second input
    else:
        return float(arr[idx_between[0]][1] if pool_pair=='pool_0<->pool_1'
                     else arr[idx_between[1]][1])

# 3b) Guarded slicers
def slice_cis(arr):
    """Return arr[:,0] if arr is array, else return arr."""
    if hasattr(arr, "__getitem__"):
        try:
            return arr[:,0]
        except Exception:
            pass
    return arr

def slice_freq(arr):
    """Return arr[:,0,1] if arr is array, else return arr."""
    if hasattr(arr, "__getitem__"):
        try:
            return arr[:,0,1]
        except Exception:
            pass
    return arr

# 4) Build scalar columns, using our guarded slicers

# excitatory_input_baseline (vector → select_simple does its own slicing)
df['excitatory_input_baseline'] = df.apply(
    lambda row: select_excit_baseline(
        row['excitatory_input_baseline_array'],
        within_or_between_pools_sbi,
        row['pool_pair'],
        idx_between=(0,1)
    ), axis=1
)

# common_input_std: first slice out [:,0], *then* feed to select_simple
df['common_input_std'] = df.apply(
    lambda row: select_high_freq_input_std(
        row['common_input_std_array'],
        within_or_between_pools_sbi,
        row['pool_pair'],
        idx_between=(0,1)
    ), axis=1
)

# 5) disynaptic selector
def select_disyn(arr, pool_pair, perspective, direction):
    inhib_self = np.nan
    inhib_according_to_direction = np.nan
    if not hasattr(arr, "shape") or len(arr.shape) != 2:
        print("Incorrect array shape")
    if pool_pair == 'pool_0<->pool_0':
        inhib_self = float(arr[0,0])
        inhib_according_to_direction = inhib_self
    elif pool_pair == 'pool_1<->pool_1':
        inhib_self = float(arr[1,1])
        inhib_according_to_direction = inhib_self
    elif pool_pair == 'pool_0<->pool_1':
        inhib_self = float(arr[0,0])
        if perspective == 'MU_as_ref': # I'm keeping this to not be too confused if I come back to this, but "direction" already encodes the perspective taken, so it's actually the same whether the direction if 'MU_as_ref' or 'other_MUs_as_ref'
            if direction == 'inhibiting':
                inhib_according_to_direction = float(arr[0,1])
            elif direction == 'inhibited':
                inhib_according_to_direction = float(arr[1,0])
            else:
                raise ValueError("direction should be either 'inhibiting' or 'inhibited'")
        elif perspective == 'other_MUs_as_ref':# I'm keeping this to not be too confused if I come back to this, but "direction" already encodes the perspective taken, so it's actually the same whether the direction if 'MU_as_ref' or 'other_MUs_as_ref'
            if direction == 'inhibiting':
                inhib_according_to_direction = float(arr[0,1])
            elif direction == 'inhibited':
                inhib_according_to_direction = float(arr[1,0])
            else:
                raise ValueError("direction should be either 'inhibiting' or 'inhibited'")
        else:
            raise ValueError("perspective should be either 'MU_as_ref' or 'other_MUs_as_ref'")
    elif pool_pair == 'pool_1<->pool_0':
        inhib_self = float(arr[1,1])
        if perspective == 'MU_as_ref': # I'm keeping this to not be too confused if I come back to this, but "direction" already encodes the perspective taken, so it's actually the same whether the direction if 'MU_as_ref' or 'other_MUs_as_ref'
            if direction == 'inhibiting':
                inhib_according_to_direction = float(arr[1,0])
            elif direction == 'inhibited':
                inhib_according_to_direction = float(arr[0,1])
            else:
                raise ValueError("direction should be either 'inhibiting' or 'inhibited'")
        elif perspective == 'other_MUs_as_ref': # I'm keeping this to not be too confused if I come back to this, but "direction" already encodes the perspective taken, so it's actually the same whether the direction if 'MU_as_ref' or 'other_MUs_as_ref'
            if direction == 'inhibiting':
                inhib_according_to_direction = float(arr[1,0])
            elif direction == 'inhibited':
                inhib_according_to_direction = float(arr[0,1])
            else:
                raise ValueError("direction should be either 'inhibiting' or 'inhibited'")
        else:
            raise ValueError("perspective should be either 'MU_as_ref' or 'other_MUs_as_ref'")
    else:
        raise ValueError("Only two pools, pool_0 and pool_1 are accepted as input")
    return inhib_self, inhib_according_to_direction

# define the two new column names
out_cols = [
    'disynpatic_inhib_connections_desired_MN_MN_self',
    'disynpatic_inhib_connections_desired_MN_MN'
]
df[out_cols]  = df.apply(
    lambda row: select_disyn(
        row['disynpatic_inhib_connections_desired_MN_MN_array'],
        row['pool_pair'],
        row['perspective'],
        row['direction']
    ), axis=1,
    result_type='expand'
)

# 6) Now you have both “_array” columns and the scalar versions:
#    excitatory_input_baseline_array, excitatory_input_baseline
#    common_input_std_array, common_input_std
#    frequency_range_of_common_input_array, frequency_range_of_common_input
#    disynpatic_inhib_connections_desired_MN_MN_array, disynpatic_inhib_connections_desired_MN_MN

# Drop rows where any column value in 'specific_input_parameters_of_interest' is NaN:
# df_simulation = df.dropna(subset=specific_input_parameters_of_interest).copy()

# Drop rows where any column value in 'input_sim_parameters_to_infer' is NaN:
df_simulation = df.dropna(subset=input_sim_parameters_to_infer).copy()

# Remove duplicated columns (some messy code somewhere I believe, as some of the input_sim_parameters_to_infer columns can be duplicated)
df_simulation = df_simulation.loc[:,~df_simulation.columns.duplicated()]


In [None]:
df_simulation # check result of prevous cell

In [None]:
# Standardize each feature relative to the training data (simulation)

def apply_standardization(df, stats, cols):
    """
    In-place Z-score: (x - μ)/σ for each col in cols,
    using stats.loc[col,'μ'] and stats.loc[col,'σ'].
    """
    for feat in cols:
        μ, σ = stats.loc[feat, ['μ','σ']]
        df[feat] = (df[feat] - μ) / σ

def destandardize(df, stats, cols, name_map=None): # to be able to reverse the operation
    """
    In-place inverse of Z-score standardization:
        x = z * σ + μ

    df:      DataFrame to modify
    stats:   DataFrame with index as feature names and columns ['μ','σ']
    cols:    columns in df to destandardize (post-renaming)
    name_map: optional dict mapping df column name -> stats index name
              (useful if you renamed columns after computing stats)
    """
    name_map = name_map or {}
    for col in cols:
        key = name_map.get(col, col)
        if key not in stats.index:
            raise KeyError(f"Stats for feature '{key}' not found in stats.index.")
        μ, σ = stats.loc[key, ['μ','σ']]
        if np.isclose(σ, 0.0):
            # nothing was scaled; just shift back by μ if needed
            df[col] = df[col] + μ
        else:
            df[col] = df[col] * σ + μ

if normalize_scale_of_features_for_inference:
    # compute per‐feature mean & std on the simulated data (from prior)
    norm_stats = ( # norm stats will record the original meand and std, to reverse the trasnform using destandardize() when needed
        df_simulation[features_for_inference]
        .agg(['mean','std'])
        .transpose()
        .rename(columns={'mean':'μ','std':'σ'})
    )

    # standardize the simulations:
    apply_standardization(df_simulation, norm_stats, features_for_inference)

    # standardize the real data ##########################
    # Make sure all features exist in the experimental data frame
    def add_placeholder_columns(df: pd.DataFrame, placeholders: dict, *, inplace: bool = True):
        """
        Add columns that do not already exist in df and fill them with the
        provided placeholder values. Existing columns are left untouched.

        placeholders: {col_name: scalar_value}
        """
        target = df if inplace else df.copy()
        to_add = [c for c in placeholders if c not in target.columns]
        for c in to_add:
            target[c] = placeholders[c]
        return target
    # --- build placeholders from μ (means) of the simulation, for your selected features ---
    # norm_stats = df_simulation[features_for_inference].agg(['mean','std']).T.rename(columns={'mean':'μ','std':'σ'})
    # If names match:
    placeholders = {
        feat: float(norm_stats.loc[feat, 'μ'])
        for feat in features_for_inference
        if feat in norm_stats.index
    }
    # --- add only missing columns to df_experiment, filled with μ as placeholders ---
    add_placeholder_columns(df_experiment, placeholders, inplace=True)
    # (optional) quick report
    added = [c for c in placeholders if c in df_experiment.columns]
    print(f"Added/ensured {len(added)} placeholder columns in df_experiment.")
    # Apply standardization to real data
    apply_standardization(df_experiment, norm_stats, features_for_inference)

In [None]:
# Just checking
df_experiment[features_for_inference]
plt.figure()
plt.scatter(x=df_experiment[features_for_inference[0]], y=df_experiment[features_for_inference[1]], alpha=0.1)
plt.title("Experimental data")
plt.xlabel(f"{features_for_inference[0]} (standardized relative to training data)")
plt.ylabel(f"{features_for_inference[1]} (standardized relative to training data)")

df_simulation[features_for_inference]
plt.figure()
plt.scatter(x=df_simulation[features_for_inference[0]], y=df_simulation[features_for_inference[1]], alpha=0.005)
plt.title("Simulated training data")
plt.xlabel(f"{features_for_inference[0]} (standardized relative to training data)")
plt.ylabel(f"{features_for_inference[1]} (standardized relative to training data)")

In [None]:
# ── BUILD SUMMARY STATISTICS FOR SIMULATED DATA = EACH (sim_name, pool_pair) & EACH FEATURE ──

all_series = []

# group by both sim_name and pool_pair
group_cols = ["sim_name", "pool_pair"] # "pool"] # "pool_pair"]

for feat in features_for_inference:
    grp = df_simulation.groupby(group_cols)[feat]

    if feat not in input_sim_parameters_as_features:
        for summary_key, summary_fx in summary_funcs.items():
            series = (
                grp
                .apply(lambda arr: summary_fx(arr))
                # name the series "{feature}_{stat}"
                .rename(f"{feat}_{summary_key}")
            )
            all_series.append(series)
    else:
        all_series.append(grp.mean())

# concat them into one DataFrame; it will have a MultiIndex (sim_name, pool_pair)
df_simulation_summary = (
    pd.concat(all_series, axis=1)
      .reset_index()   # bring sim_name & pool_pair back as columns
)

# ── MERGE BACK INPUT PARAMETERS FOR EACH (sim_name, pool_pair) ──
param_cols = group_cols + input_sim_parameters_to_infer # + ['disynpatic_inhib_connections_desired_MN_MN_self']

param_df = (
    df_simulation[param_cols]
      .drop_duplicates(subset=group_cols)
      .reset_index(drop=True)
)

df_simulation_summary = df_simulation_summary.merge(
    param_df,
    on=group_cols,
    how="inner",
)
# df_simulation_summary has one row per sim_name × pool_pair × feature‐stat

In [None]:
df_simulation_summary # checking result of previous cell

In [None]:
# ── BUILD SUMMARY STATISTICS FOR EXPERIMENTAL DATA = EACH subject/muscle pair/intensity & EACH FEATURE ──
# We will compute “summary_funcs” for each feature in features_for_inference.
group_keys = ["subject", "muscle_pair", "intensity"]

# Prepare a list to collect all the “per‐feature” Series objects:
exp_all_series = []
for feat in features_for_inference:
    # GroupBy object keyed by subject, muscle_pair, and intensity
    grp = df_experiment.groupby(group_keys)[feat]

    if feat not in input_sim_parameters_as_features: 
        for summary_key, summary_fx in summary_funcs.items():
            series = (
                grp
                .apply(lambda arr: summary_fx(arr))
                # name the series "{feature}_{stat}"
                .rename(f"{feat}_{summary_key}")
            )
            exp_all_series.append(series)
    else:
        exp_all_series.append(grp.mean())

# ── Concatenate them into a single DataFrame (multi‐index → columns) ──
df_experiment_summary = pd.concat(exp_all_series, axis=1).reset_index()

In [None]:
df_experiment_summary # checking result of previous cell

In [None]:
# ── BUILD SUMMARY STATISTICS FOR EXPERIMENTAL DATA = EACH muscle pair/intensity & EACH FEATURE ──
# Consider all subjects together
# We will compute “summary_funcs” for each feature in features_for_inference.
group_keys = ["muscle_pair", "intensity"]

# Prepare a list to collect all the “per‐feature” Series objects:
exp_all_series = []
for feat in features_for_inference:
    # GroupBy object keyed by muscle_pair, and intensity
    grp = df_experiment.groupby(group_keys)[feat]

    # For each summary statistic in summary_funcs:
    if feat not in input_sim_parameters_as_features: 
        for summary_key, summary_fx in summary_funcs.items():
            series = (
                grp
                .apply(lambda arr: summary_fx(arr))
                # name the series "{feature}_{stat}"
                .rename(f"{feat}_{summary_key}")
            )
            exp_all_series.append(series)
    else:
        exp_all_series.append(grp.mean())

# ── Concatenate them into a single DataFrame (multi‐index → columns) ──
df_experiment_summary_grouped_subjects = pd.concat(exp_all_series, axis=1).reset_index()

In [None]:
df_experiment_summary_grouped_subjects # checking result of previous cell

In [None]:
# Function to visually check coverage of training data (summaries) relative to experimental observations (summaries)
def plot_pairgrid_experiment_vs_sim(
    df_experiment_summary,
    df_simulation_summary,
    *,
    features: list[str],
    stat: str,
    muscle_colors_dict: dict[str,str],
    savepath: str | None = None
):
    """
    For a given `stat` (e.g. "mean", "sd", "median", "iqr") and a list of `features`,
    build an (N × N) grid (N = len(features)):
      - diagonal (i == j): KDE density plots of df_experiment_summary[f"{feature}_{stat}"],
        one curve per muscle_pair (low‐alpha fill, high‐alpha outline), plus a dashed‐black
        KDE of simulated values for that same feature‐stat.
      - off‐diagonal (i != j): scatter of experimental
            x = df_experiment_summary[f"{features[j]}_{stat}"]
            y = df_experiment_summary[f"{features[i]}_{stat}"]
        colored by muscle_pair; overlay the convex‐hull outline of simulated points:
            x_sim = df_simulation_summary[f"{features[j]}_{stat}"]
            y_sim = df_simulation_summary[f"{features[i]}_{stat}"]
    One combined legend (muscle_pair→color) is drawn only in the top‐left diagonal cell.

    Parameters
    ----------
    df_experiment_summary : pandas.DataFrame
        Must contain columns:
          - "muscle_pair" (string)
          - For each feat in `features`, a column named f"{feat}_{stat}"

    df_simulation_summary : pandas.DataFrame
        Must contain, for each feat in `features`, a column named f"{feat}_{stat}`.

    features : list[str]
        e.g. ["inhibiting_estimation_raw", "sync_height", "ISI_cov", ...]

    stat : str
        One of the summary‐stats, e.g. "mean", "sd", "median", "iqr".

    muscle_colors_dict : dict[str,str]
        Maps each `muscle_pair` string to a matplotlib color (hex or name).

    savepath : str or None
        If non‐None, the figure is saved to `savepath`.

    Returns
    -------
    None
    """
    # remove input_sim_parameters_as_features from features if present
    features = [f for f in features if f not in input_sim_parameters_as_features]
    n_feats = len(features)
    fig, axes = plt.subplots(
        nrows=n_feats,
        ncols=n_feats,
        figsize=(3*n_feats, 3*n_feats),
        squeeze=False,
        sharex=False,
        sharey=False,
    )

    # For each (i,j) in the N×N grid:
    for i_row, feat_i in enumerate(features):
        for j_col, feat_j in enumerate(features):
            ax = axes[i_row][j_col]
            col_i = f"{feat_i}_{stat}"
            col_j = f"{feat_j}_{stat}"

            # Extract experimental arrays (as floats) and muscle labels:
            df_experiment_summary_intensity10 = df_experiment_summary[df_experiment_summary["intensity"] == 10]
            df_experiment_summary_intensity40 = df_experiment_summary[df_experiment_summary["intensity"] == 40]
            exp_i_total = df_experiment_summary[col_i].to_numpy(dtype=float)
            exp_j_total = df_experiment_summary[col_j].to_numpy(dtype=float)
            exp_i_10 = df_experiment_summary_intensity10[col_i].to_numpy(dtype=float)
            exp_j_10 = df_experiment_summary_intensity10[col_j].to_numpy(dtype=float)
            exp_i_40 = df_experiment_summary_intensity40[col_i].to_numpy(dtype=float)
            exp_j_40 = df_experiment_summary_intensity40[col_j].to_numpy(dtype=float)
            muscles = df_experiment_summary["muscle_pair"].astype(str)

            # DIAGONAL: KDE density plots of experiment (one per muscle) + sim KDE
            if i_row == j_col:
                # 1) Gather all non‐NaN experimental and simulated values for this cell:
                exp_vals = exp_i_total[~np.isnan(exp_i_total)]
                sim_vals = df_simulation_summary[col_i].to_numpy(dtype=float)
                sim_vals = sim_vals[~np.isnan(sim_vals)]

                # If there is data at all, define a common x‐grid spanning both:
                all_vals = np.concatenate([exp_vals, sim_vals]) if sim_vals.size>0 else exp_vals
                if all_vals.size>0:
                    vmin, vmax = np.nanmin(all_vals), np.nanmax(all_vals)
                    span = vmax - vmin
                    if span == 0:
                        # Single‐value fallback
                        x_grid = np.linspace(vmin - 1, vmax + 1, 200)
                    else:
                        x_grid = np.linspace(vmin - 0.05*span, vmax + 0.05*span, 300)

                    # 2) Plot one KDE per muscle:
                    for m, color in muscle_colors_dict.items():
                        mask_m = (muscles == m)
                        data_m = exp_i_total[mask_m]
                        data_m = data_m[~np.isnan(data_m)]
                        if data_m.size > 1:
                            try:
                                kde_m = gaussian_kde(data_m)
                                y_m = kde_m(x_grid)
                                ax.fill_between(
                                    x_grid, y_m,
                                    facecolor=color,
                                    alpha=0.3
                                )
                                ax.plot(
                                    x_grid, y_m,
                                    color=color,
                                    linewidth=1.2,
                                    alpha=1.0
                                )
                            except Exception:
                                # If KDE fails (e.g. singular data), skip
                                pass

                    # 3) Overlay simulated KDE in dashed black:
                    if sim_vals.size > 1:
                        try:
                            kde_sim = gaussian_kde(sim_vals)
                            y_sim = kde_sim(x_grid)
                            ax.plot(
                                x_grid, y_sim,
                                "--k",
                                linewidth=1.5,
                                alpha=1.0,
                                label="Sim KDE"
                            )
                        except Exception:
                            pass

                ax.set_xlabel(f"{feat_i} ({stat})")
                ax.set_ylabel("Density")
                ax.set_title(f"{feat_i} ({stat})")
                ax.grid(True, linestyle="--", alpha=0.2)

                # Only add legend once (top‐left diagonal cell)
                if i_row == 0 and j_col == 0:
                    from matplotlib.lines import Line2D
                    legend_handles = []
                    # muscle patches
                    present_pairs = sorted(df_experiment_summary["muscle_pair"].unique())
                    for m, color in muscle_colors_dict.items():
                        if m in present_pairs:
                            legend_handles.append(
                                Line2D(
                                    [0], [0],
                                    marker="s",
                                    color="w",
                                    markerfacecolor=color,
                                    label=m,
                                    markersize=8,
                                    markeredgecolor="k",
                                )
                            )
                    # “Sim KDE” line
                    legend_handles.append(
                        Line2D(
                            [0], [0],
                            linestyle="--",
                            color="k",
                            linewidth=1.5,
                            label="Sim KDE"
                        )
                    )
                    ax.legend(
                        handles=legend_handles,
                        loc="upper right",
                        fontsize="small",
                        framealpha=0.7,
                        title="muscle_pair"
                    )

             # OFF‐DIAGONAL: first draw sim‐density background, then scatter + hull
            else:
                # 1) Gather the sim points (x_sim,y_sim) for this cell
                sim_x = df_simulation_summary[col_j].to_numpy(dtype=float)
                sim_y = df_simulation_summary[col_i].to_numpy(dtype=float)
                valid_sim = (~np.isnan(sim_x)) & (~np.isnan(sim_y))

                # 2) If there are ≥2 valid simulated points, build a 2D KDE:
                if np.count_nonzero(valid_sim) > 1:
                    pts_sim = np.vstack([sim_x[valid_sim], sim_y[valid_sim]]).T

                    try:
                        kde2d = gaussian_kde(pts_sim.T)
                        # Determine the grid limits from the sim‐points themselves:
                        x_min, x_max = np.nanmin(sim_x[valid_sim]), np.nanmax(sim_x[valid_sim])
                        y_min, y_max = np.nanmin(sim_y[valid_sim]), np.nanmax(sim_y[valid_sim])

                        # Expand slightly so we see a margin:
                        x_span = x_max - x_min if (x_max > x_min) else 1.0
                        y_span = y_max - y_min if (y_max > y_min) else 1.0
                        x_min -= 0.03*x_span;  x_max += 0.03*x_span
                        y_min -= 0.03*y_span;  y_max += 0.03*y_span

                        # Build a 100×100 grid over [x_min,x_max]×[y_min,y_max]
                        xi = np.linspace(x_min, x_max, 100)
                        yi = np.linspace(y_min, y_max, 100)
                        xi_mesh, yi_mesh = np.meshgrid(xi, yi)
                        grid_coords = np.vstack([xi_mesh.ravel(), yi_mesh.ravel()])

                        # Evaluate KDE on the grid:
                        zi = kde2d(grid_coords).reshape(xi_mesh.shape)

                        # 3) Normalize zi to [0,1]:
                        zi_min, zi_max = zi.min(), zi.max()
                        if zi_max > zi_min:
                            zi_norm = (zi - zi_min) / (zi_max - zi_min)
                        else:
                            zi_norm = np.zeros_like(zi)

                        # 4) Now display it with imshow, using cmap="gray":
                        ax.imshow(
                            zi_norm,
                            origin="lower",
                            extent=(x_min, x_max, y_min, y_max),
                            cmap="gray",
                            aspect="auto",
                            alpha=0.5,
                            zorder=0
                        )

                        # Set axis limits so scatter/hull match:
                        ax.set_xlim(x_min, x_max)
                        ax.set_ylim(y_min, y_max)

                    except Exception:
                        # If KDE fails, simply skip the heatmap background
                        pass

                # 5) Now plot the experimental scatter on top:
                # for intensity = 10%
                x_exp = exp_j_10
                y_exp = exp_i_10
                colors = df_experiment_summary_intensity10["muscle_pair"].map(muscle_colors_dict).fillna("gray")
                mask_xy = (~np.isnan(x_exp)) & (~np.isnan(y_exp))
                if np.any(mask_xy):
                    ax.scatter(
                        x_exp[mask_xy],
                        y_exp[mask_xy],
                        c=colors[mask_xy],
                        marker="o",
                        edgecolor="k",
                        alpha=0.7,
                        s=30,
                        zorder=2
                    )
                # for intensity = 40%
                x_exp = exp_j_40
                y_exp = exp_i_40
                colors = df_experiment_summary_intensity40["muscle_pair"].map(muscle_colors_dict).fillna("gray")
                mask_xy = (~np.isnan(x_exp)) & (~np.isnan(y_exp))
                if np.any(mask_xy):
                    ax.scatter(
                        x_exp[mask_xy],
                        y_exp[mask_xy],
                        c=colors[mask_xy],
                        marker="^",
                        edgecolor="k",
                        alpha=0.7,
                        s=30,
                        zorder=2
                    )

                # 6) Overlay the convex‐hull of simulated points (as black dashed outline)
                valid_sim2 = valid_sim  # same mask
                if np.count_nonzero(valid_sim2) >= 3:
                    pts_sim2 = np.vstack([sim_x[valid_sim2], sim_y[valid_sim2]]).T
                    try:
                        hull = ConvexHull(pts_sim2)
                        hull_pts = pts_sim2[hull.vertices]
                        hull_loop = np.vstack([hull_pts, hull_pts[0]])
                        ax.plot(
                            hull_loop[:,0],
                            hull_loop[:,1],
                            "--k",
                            lw=1.2,
                            zorder=3
                        )
                    except Exception:
                        pass

                ax.set_xlabel(f"{feat_j} ({stat})")
                ax.set_ylabel(f"{feat_i} ({stat})")
                # ax.grid(True, linestyle="--", alpha=0.2)

    if normalize_scale_of_features_for_inference:
        fig.suptitle("Standardized values relative to the chosen features of the experimental data")
    fig.tight_layout()

    if savepath is not None:
        plt.savefig(savepath, dpi=150, bbox_inches="tight")

    plt.show()


In [None]:
summary_stats_names = summary_funcs.keys()
for stat in summary_stats_names:
    plot_pairgrid_experiment_vs_sim(
        df_experiment_summary,
        df_simulation_summary,
        features=features_for_inference,
        stat=stat,
        muscle_colors_dict=muscle_colors_dict,
        savepath=f"{path_to_save}\\pairgrid_{stat}.png"
    )

# Evaluating inference on held-out simulated data (posterior estimates VS ground truth)

### Training model (90% of the training data; 10% used as test set)

In [None]:
# ── Build Torch tensors & prior for SBI ──
# Collect the “feature‐summary” column names in exactly the same order:
summary_colnames = []
for feat in features_for_inference:
    if feat in input_sim_parameters_as_features:
        summary_colnames += [f"{feat}"]
    else:
        for summary_stat in summary_funcs.keys():
            summary_colnames +=  [f"{feat}_{summary_stat}"]
theta_colnames = input_sim_parameters_to_infer
# The commented lines cause error with recent numpy versions
# sim_obs     = torch.from_numpy(df_simulation_summary[summary_colnames].to_numpy(dtype=float)).float()
# theta_raw = torch.from_numpy(df_simulation_summary[theta_colnames].to_numpy(dtype=float)).float()
sim_obs = torch.tensor(
    df_simulation_summary[summary_colnames].values.tolist(),
    dtype=torch.float32,
)
theta_raw = torch.tensor(
    df_simulation_summary[theta_colnames].values.tolist(),
    dtype=torch.float32,
)

# BOUNDARIES OF PRIORS
low_original = []
high_original = []
low_unit = []
high_unit = []
# ── Build low/high tensors from priors_per_parameters dict ──
low_original  = torch.tensor([priors_per_parameters_to_infer[name][0] for name in theta_colnames],
                     dtype=torch.float32)
high_original = torch.tensor([priors_per_parameters_to_infer[name][1] for name in theta_colnames],
                     dtype=torch.float32)
# ── Define normalization ↔︎ denormalization helpers ──
def theta_to_unit(theta):
    """Map from original θ-space into [0,1]^d."""
    return (theta - low_original) / (high_original - low_original)
def unit_to_theta(theta_unit):
    """Map from [0,1]^d back to original θ-space."""
    return theta_unit * (high_original - low_original) + low_original
# ── Define a uniform prior on [0,1]^d ──
low_unit = theta_to_unit(low_original) # If everything works well, set all of them to 0 (because 0 is the normalized min value)
high_unit = theta_to_unit(high_original) # If everything works well, set all of them to 1 (because 1 is the normalized max value)
# ── Normalize your training θ’s ──
theta_unit = theta_to_unit(theta_raw)
prior = sbi_utils.BoxUniform(low=torch.zeros_like(low_unit),
                             high=torch.ones_like(high_unit))

In [None]:
if rerun_network_training_for_heldout_data:
    # PARAMETERS AND SELECTION OF HELD OUT SAMPLES
    held_out_proportion = 0.1 # 12_000 examples in the dataset, so training on 10_800 and testing on 1_200
    sim_samples_nb = df_simulation_summary.shape[0]
    rng = np.random.default_rng(seed=random.randint(0,int(1e5))) # random selector
    hold_out_idx = rng.choice(sim_samples_nb,
                                size=np.round(held_out_proportion*sim_samples_nb).astype(int),
                                replace=False)
    hold_out_idx = np.sort(hold_out_idx)
    training_idx = np.setdiff1d(np.arange(sim_samples_nb), hold_out_idx)

    # Converts numpy int to Python list (this makes Pytorch happy)
    hold_out_idx = hold_out_idx.tolist()
    training_idx = training_idx.tolist()

    # Separate held-out (testing) and training data
    theta_ground_truth_hold_out = theta_unit[hold_out_idx]
    theta_ground_truth_training = theta_unit[training_idx]
    obs_hold_out = sim_obs[hold_out_idx]
    obs_training = sim_obs[training_idx]

    # Train the SBI network - SNPE
    inference_net_held_out_data = sbi_inference.SNPE(prior=prior,
                                                    density_estimator=sbi_density_estimator,
                                                    show_progress_bars=True)
    inference_net_held_out_data.append_simulations(theta_ground_truth_training, obs_training)
    inference_net_held_out_data.train(
            num_atoms                  = network_training_hyperparameters['num_atoms'],       # default is 10
            force_first_round_loss     = True,    # start fresh
            training_batch_size        = network_training_hyperparameters['training_batch_size'],     # default is 200
            learning_rate              = network_training_hyperparameters['learning_rate'],  # default is 0.0005
            validation_fraction        = network_training_hyperparameters['validation_fraction'],      # default is 10%
            max_num_epochs             = network_training_hyperparameters['max_num_epochs'],    # train up to 1000 epochs
            stop_after_epochs          = network_training_hyperparameters['stop_after_epochs'],      # but at least train 20
            show_train_summary         = True
        )
    
    # Saving in a dictionary the relevant data for checking the performance of the network trained on heldout data
    heldout_data_sbi_check = {
        "neural_density_estimator_network": inference_net_held_out_data,
        "theta_ground_truth_hold_out": theta_ground_truth_hold_out,
        "obs_hold_out": obs_hold_out,
        "theta_ground_truth_training": theta_ground_truth_training,
        "obs_training": obs_training,
        "hold_out_idx": hold_out_idx,
        "training_idx": training_idx,
        "held_out_proportion": held_out_proportion
    }
    sbi_check_heldout_sims_pickle_path = f"{path_to_save}\\sbi_check_heldout_sims.pkl"
    with open(sbi_check_heldout_sims_pickle_path, "wb") as f:
        pickle.dump(heldout_data_sbi_check, f)
    print(f"✅ Neural density estimator trained on heldout data saved to '{sbi_check_heldout_sims_pickle_path}'")
else:
    # Load previously trained network
    sbi_check_heldout_sims_pickle_path = f"{path_to_load}\\sbi_check_heldout_sims.pkl"
    with open(sbi_check_heldout_sims_pickle_path, "rb") as f:
        heldout_data_sbi_check = pickle.load(f)
    print(f"✅  Neural density estimator trained on heldout data loaded from '{sbi_check_heldout_sims_pickle_path}'")
    inference_net_held_out_data = heldout_data_sbi_check["neural_density_estimator_network"]
    theta_ground_truth_hold_out = heldout_data_sbi_check["theta_ground_truth_hold_out"]
    obs_hold_out = heldout_data_sbi_check["obs_hold_out"] # re-assigning, as the normalization of units may not have been performed in the loaded data
    theta_ground_truth_training = heldout_data_sbi_check["theta_ground_truth_training"]
    obs_training = heldout_data_sbi_check["obs_training"]
    hold_out_idx_raw = heldout_data_sbi_check["hold_out_idx"]
    training_idx_raw = heldout_data_sbi_check["training_idx"]
    held_out_proportion = heldout_data_sbi_check["held_out_proportion"] if "held_out_proportion" in heldout_data_sbi_check else len(hold_out_idx_raw) / (len(hold_out_idx_raw) + len(training_idx_raw))
    # Normalize any np.array / np.int64 into plain Python int lists
    hold_out_idx   = [int(i) for i in (hold_out_idx_raw.tolist()   if hasattr(hold_out_idx_raw, "tolist")   else hold_out_idx_raw)]
    training_idx   = [int(i) for i in (training_idx_raw.tolist()   if hasattr(training_idx_raw, "tolist")   else training_idx_raw)]


In [None]:
# --- Consistency checks and auto-alignment of sim_obs to saved obs_* ---

# 1) Basic checks: do θ and obs match the current tensors?
theta_match = torch.allclose(
    theta_ground_truth_hold_out,
    theta_unit[hold_out_idx]
)

obs_match = torch.allclose(
    obs_hold_out,
    sim_obs[hold_out_idx]
)

print("theta match:", theta_match)
print("obs match :", obs_match)

# 3) If obs don't match, infer the training-time transform and align sim_obs.
#    We assume an affine, feature-wise transform of the form:
#        Y ≈ (X - shift) / scale
#    where:
#        X = current sim_obs[training_idx_py]
#        Y = saved obs_training (the canonical training space)
if not obs_match:
    print(
        "→ obs_hold_out and sim_obs[hold_out_idx] differ in scale; "
        "attempting to align sim_obs to the saved training/held-out scale."
    )

    def describe_stats(name: str, X: torch.Tensor):
        """Print rough summary stats to see who looks standardized."""
        mean = X.mean(dim=0)
        std  = X.std(dim=0, unbiased=False)
        mean_abs = float(mean.abs().mean())
        std_mean = float(std.mean())
        print(f"{name}: mean(|μ|) ≈ {mean_abs:.3f}, mean(σ) ≈ {std_mean:.3f}")
        return mean, std

    # Use the TRAINING subset to infer the transform that was used originally.
    # X_train: current representation (raw or differently scaled)
    # Y_train: saved representation (what SNPE actually saw during training)
    X_train = sim_obs[training_idx]   # shape [N_train, D]
    Y_train = obs_training               # shape [N_train, D]

    X_mean, X_std = describe_stats("sim_obs[training_idx]", X_train)
    Y_mean, Y_std = describe_stats("obs_training",          Y_train)

    # Infer per-feature scale and shift assuming:
    #   Y ≈ (X - shift) / scale  ⇒  X ≈ scale * Y + shift
    eps   = 1e-12
    scale = X_std / (Y_std + eps)              # element-wise [D]
    shift = X_mean - scale * Y_mean           # element-wise [D]

    # Apply the same transform to ALL rows in sim_obs so that:
    #   sim_obs_aligned ≈ Y when restricted to training/held-out indices
    sim_obs = (sim_obs - shift) / (scale + eps)

    # Re-check match specifically on the held-out set
    obs_match_aligned = torch.allclose(
        obs_hold_out,
        sim_obs[hold_out_idx],
        atol=1e-5,
        rtol=1e-4,
    )
    print("obs match after alignment:", obs_match_aligned)

# At this point:
# - θ space is already consistent via theta_to_unit / unit_to_theta.
# - sim_obs has been rescaled (if needed) so that its held-out rows match
#   the obs_hold_out that the loaded SNPE model expects.
# - All later uses of sim_obs with this loaded model will now be in the
#   correct observation space.


In [None]:
# Plot the training results
plt.plot(figsize=(8,6))
plt.plot(inference_net_held_out_data.summary['training_loss'],
            label="Training loss", linewidth=3, color='blue', alpha=0.5)
plt.plot(inference_net_held_out_data.summary['validation_loss'],
            label="Validation loss", linewidth=3, color='red', alpha=0.5)
plt.xlabel("epochs_trained")
plt.ylabel(f"Loss - negative log-probability\n(neg-log-prob of observation,\ngiven inferred generative parameters)")
plt.title(f"Training results on SBI neural net\nwith {(1-held_out_proportion)*100}% of the simulated observations")
plt.legend()
plt.tight_layout()
plt.savefig(f"{path_to_save}\\network_training_loss_inference_check.png")
plt.show()

In [None]:
### DIFFERENCE & DISTANCE FROM POSTERIOR (VECTORS SAMPLED FROM POSTERIOR) TO GROUND TRUTH

posterior_net_held_out_data = inference_net_held_out_data.build_posterior()
# Infer the posterior probability distribution over theta for each held-out sim
# -> Get the probability distribution by sampling the posterior
# Compute, for each sample, the distance (for each theta, and in theta-space) between the sample and the ground-truth.
# Then, get the mean and SD of distance (again, for each theta and for the overall theta-space)
n_held_out_sims = len(hold_out_idx)
n_params_to_infer = theta_unit.shape[-1]
n_posterior_draws = num_posterior_samples["simulation"]

estimate_to_ground_truth_diff_for_each_theta_mean = torch.zeros(n_held_out_sims, n_params_to_infer)
estimate_to_ground_truth_diff_for_each_theta_std = torch.zeros(n_held_out_sims, n_params_to_infer)
estimate_to_ground_truth_abs_diff_for_each_theta_mean = torch.zeros(n_held_out_sims, n_params_to_infer)
estimate_to_ground_truth_abs_diff_for_each_theta_std = torch.zeros(n_held_out_sims, n_params_to_infer)
estimate_to_ground_truth_euc_dist_in_theta_space_mean = torch.zeros(n_held_out_sims)
estimate_to_ground_truth_euc_dist_in_theta_space_std = torch.zeros(n_held_out_sims)

num_posterior_samples["simulation"] = 1_000 # 10_000
print(f"Posterior estimation of simulated samples that were kept out of the training dataset ({len(hold_out_idx)} samples)\nDrawing {num_posterior_samples["simulation"]} samples from posterior for each simulation")
posterior_samples_thetas_for_held_out_sims = {}
posterior_log_probs_for_held_out_sims = {}
ground_truth_thetas_for_held_out_sims = {}

for iter in range(n_held_out_sims):
    sim_id = hold_out_idx[iter]  # original global index, for dict keys

    # # 0) Make fewer iterations for testing purpose
    # if iter > 100:
    #     break
    # ^ CAREFUL !!!! IF this is not uncommented, the average distance will be much smaller because the arrays are initialized with zeros !!!
    if iter % 20 == 0:
        print(f"    Iteration {iter}/{len(hold_out_idx)}")

    # 1) Grab ground truth and observation *from saved tensors*
    theta_ground_truth = theta_ground_truth_hold_out[iter].unsqueeze(0)    # [1, D]
    obs_hold_out_current_sim = obs_hold_out[iter].unsqueeze(0)            # [1, Dx]

    # 2) Draw posterior samples
    theta_samples = posterior_net_held_out_data.sample(
        (num_posterior_samples["simulation"],),
        x=obs_hold_out_current_sim,
        show_progress_bars=False,
    )  # [N, D]

    logp_sample = posterior_net_held_out_data.log_prob(
        theta_samples, x=obs_hold_out_current_sim
    )  # [N]

    posterior_samples_thetas_for_held_out_sims[sim_id] = theta_samples
    posterior_log_probs_for_held_out_sims[sim_id]      = logp_sample
    ground_truth_thetas_for_held_out_sims[sim_id]      = theta_ground_truth

    # 3) Per-parameter diffs
    diffs = theta_samples - theta_ground_truth               # [N, D]
    abs_diffs = diffs.abs()

    estimate_to_ground_truth_diff_for_each_theta_mean[iter, :]     = diffs.mean(dim=0)
    estimate_to_ground_truth_diff_for_each_theta_std[iter, :]      = diffs.std(dim=0)
    estimate_to_ground_truth_abs_diff_for_each_theta_mean[iter, :] = abs_diffs.mean(dim=0)
    estimate_to_ground_truth_abs_diff_for_each_theta_std[iter, :]  = abs_diffs.std(dim=0)

    # 4) Euclidean distance in θ-space
    euc_dists = torch.norm(diffs, dim=1)   # [N]
    estimate_to_ground_truth_euc_dist_in_theta_space_mean[iter] = euc_dists.mean()
    estimate_to_ground_truth_euc_dist_in_theta_space_std[iter]  = euc_dists.std()


In [None]:
### FIGURE CHECK # 1
# Comparing posterior distribution (right) centered on ground truth (difference of 0 = ground truth) for each parameter to infer, and compared to prior distribution (left; flat)

# ===========================================
# Centered prior/posterior overlays with:
# - progress bars
# - optional subsampling of held-out sims
# - global y-scale across all panels
# - configurable x-limits
# - aggregate mean ± SD bands
# - info gain (KL) and shrinkage summaries; per-parameter info gain (1D KL to Uniform[0,1]
# - RED curve: best-guess offsets KDE/HIST across sims
# ===========================================

from tqdm.auto import tqdm

# -------- CONFIG --------
PLOT_KIND = 'kde'           # 'hist' or 'kde'
SUBSAMPLE_SIMS = 100 # 100       # e.g., 300 (or None for all held-out sims)
SUBSAMPLE_SEED = 0
XLIM = (-1, 1) # (-0.6, 0.6)          # shared x-range for all panels
FIGSIZE = (12, 20)
ALPHA_LINE = 0.05  # line alpha
NBINS = 101                 # hist bins
KDE_POINTS = 401            # points for x-grid (kde)
KDE_BW = 0.15               # KDE bandwidth on centered axis # (lower = more precise, less smooth)
LINEWIDTH_THIN = 2.0
LINEWIDTH_THICK = 2.0
AGGREGATE_ON_ALL = True     # mean±SD bands computed from ALL held-out sims (not just displayed)
SHOW_BEST_GUESS_CURVE = True
SHOW_KL_TEXT = True
SAVEFIG_PATH = path_to_save

# ---- Names & IDs ----
param_names = input_sim_parameters_to_infer if 'input_sim_parameters_to_infer' in globals() \
             else [f"θ[{j}]" for j in range(theta_unit.shape[-1])]

# normalize keys to plain int
all_ids = sorted(int(k) for k in posterior_samples_thetas_for_held_out_sims.keys())
assert len(all_ids) > 0, "No held-out sims found."

theta_true_by_sim_display = {}
theta_true_by_sim_all = {}

# choose which sims to DRAW (display subset)
if SUBSAMPLE_SIMS is not None and SUBSAMPLE_SIMS < len(all_ids):
    rng = np.random.default_rng(SUBSAMPLE_SEED)
    sel_ids = list(rng.choice(all_ids, size=SUBSAMPLE_SIMS, replace=False))
else:
    sel_ids = all_ids

D = len(param_names)
first_key = all_ids[0]
n_draws_per_sim = int(posterior_samples_thetas_for_held_out_sims[first_key].shape[0])

def sample_prior(n, d):
    """Sample from the prior as a NumPy array of shape (n, d)."""
    if 'prior' in globals():
        with torch.no_grad():
            samples = prior.sample((n,))   # torch tensor [n, d]
        return to_np(samples)              # safe conversion to NumPy
    # Fallback: uniform on [0, 1]^d
    rng = np.random.default_rng(0)
    return rng.uniform(0.0, 1.0, size=(n, d))

def kde_curve(samples, xgrid, bw):
    x = np.clip(samples, XLIM[0], XLIM[1])
    if x.size == 0:
        return np.zeros_like(xgrid)
    inv = 1.0 / np.sqrt(2*np.pi*bw*bw)
    diffs = xgrid[None, :] - x[:, None]            # [n, M]
    y = inv * np.exp(-0.5 * (diffs / bw)**2)       # [n, M]
    return y.mean(axis=0)

def hist_curve(samples, edges):
    x = np.clip(samples, XLIM[0], XLIM[1])
    counts, _ = np.histogram(x, bins=edges, density=True)
    return counts  # length = len(edges)-1

def kl_to_uniform_1d(samples, M=256, eps=1e-12):
    """
    Approximate KL(p || Uniform[0,1]) for 1D posterior samples in [0,1].
    Uses discrete bins with mass p_k (sum p_k = 1), prior mass q_k = 1/M.
    KL = sum_k p_k * log(p_k / q_k) = sum p_k * log p_k + log M.
    """
    s = np.clip(samples, 0.0, 1.0)
    counts, _ = np.histogram(s, bins=M, range=(0.0, 1.0))
    p = counts.astype(np.float64)
    Z = p.sum()
    if Z <= 0:
        return 0.0
    p /= Z
    p = np.clip(p, eps, 1.0)
    return float(np.sum(p * np.log(p)) + np.log(M))

def flat_prior_curve_centered(theta_true_j, xgrid, edges=None):
    """
    Prior is Uniform[0,1]. Centering by θ_true makes support [a,b]=[-θ_true, 1-θ_true] on the centered axis.
    We return a flat line at density=1 on [a,b], and NaN outside (so nanmean/nanstd work).
    If edges is given (hist mode), xgrid are the LEFT edges and we assign 1.0 to bins that overlap [a,b).
    """
    a = -float(theta_true_j)
    b = 1.0 - float(theta_true_j)

    # clip to plotting window
    a = max(a, XLIM[0]); b = min(b, XLIM[1])
    if a >= b:
        # no overlap with visible window
        return np.full_like(xgrid, np.nan, dtype=float)

    if edges is None:
        # KDE/line mode: xgrid are points
        y = np.full_like(xgrid, np.nan, dtype=float)
        mask = (xgrid >= a) & (xgrid <= b)
        y[mask] = 1.0
        return y
    else:
        # HIST mode: xgrid = edges[:-1] (left edges); mark bins that overlap any part of [a,b)
        left = edges[:-1]
        right = edges[1:]
        overlap = (left < b) & (right > a)
        y = np.full(left.shape, np.nan, dtype=float)
        y[overlap] = 1.0
        return y


# ---- Collect centered arrays (DISPLAY subset) ----
posterior_centered_disp = [[] for _ in range(D)]
prior_centered_disp     = [[] for _ in range(D)]

# For bands & best-guess & KL we’ll use ALL sims (robust)
posterior_centered_all = [[] for _ in range(D)] if AGGREGATE_ON_ALL else posterior_centered_disp
prior_centered_all     = [[] for _ in range(D)] if AGGREGATE_ON_ALL else prior_centered_disp

# best-guess offsets per parameter (across ALL sims)
best_guess_offsets_by_param = [[] for _ in range(D)]
# per-parameter KL list across sims
kl_list_by_param = [[] for _ in range(D)]

print(f"Collecting centered samples for display (sims={len(sel_ids)}) "
      f"{'(aggregate on ALL held-out sims)' if AGGREGATE_ON_ALL else '(aggregate on displayed sims)'}")

# display subset
for sim_id in tqdm(sel_ids, desc="Centering (display subset)"):
    theta_true = to_np(ground_truth_thetas_for_held_out_sims[sim_id]).reshape(-1)
    theta_true_by_sim_display[sim_id] = theta_true
    post = to_np(posterior_samples_thetas_for_held_out_sims[sim_id])              # [n_draws, D]
    prior_s = sample_prior(n_draws_per_sim, post.shape[1])
    centered_post = post - theta_true[None, :]
    centered_prior = prior_s - theta_true[None, :]
    for j in range(D):
        posterior_centered_disp[j].append(centered_post[:, j])
        prior_centered_disp[j].append(centered_prior[:, j])

# aggregate scope (ALL sims) + best-guess offsets + 1D KLs
for sim_id in tqdm(all_ids, desc="Aggregate scope (bands, best-guess, KL)"):
    theta_true = to_np(ground_truth_thetas_for_held_out_sims[sim_id]).reshape(-1)
    theta_true_by_sim_all[sim_id] = theta_true
    post = to_np(posterior_samples_thetas_for_held_out_sims[sim_id])              # [n_draws, D]
    prior_s = sample_prior(n_draws_per_sim, post.shape[1])

    # centered arrays
    cp = post - theta_true[None, :]
    cpr = prior_s - theta_true[None, :]
    if AGGREGATE_ON_ALL:
        for j in range(D):
            posterior_centered_all[j].append(cp[:, j])
            prior_centered_all[j].append(cpr[:, j])

    # best-guess (argmax logp) and offsets
    if sim_id in posterior_log_probs_for_held_out_sims:
        logp = to_np(posterior_log_probs_for_held_out_sims[sim_id]).reshape(-1)
        idx_best = int(np.argmax(logp))
        best = post[idx_best, :]
    else:
        best = post.mean(axis=0)
    offset = best - theta_true                             # [D]
    for j in range(D):
        best_guess_offsets_by_param[j].append(float(offset[j]))

    # 1D KL per parameter (posterior vs Uniform[0,1])
    for j in range(D):
        kl_list_by_param[j].append(kl_to_uniform_1d(post[:, j], M=256))

# convert lists to arrays
for j in range(D):
    best_guess_offsets_by_param[j] = np.asarray(best_guess_offsets_by_param[j], dtype=float)
    kl_list_by_param[j] = np.asarray(kl_list_by_param[j], dtype=float)

# ---- Shrinkage metrics per parameter (computed across ALL held-out sims) ----
# Prior on [0,1]: prior variance = 1/12, prior 90% ETI width = 0.90 (equal-tailed)
PRIOR_VAR = 1.0 / 12.0
PRIOR_CI90_WIDTH = 0.90
EPS = 1e-12

ci90_shrink_mean = np.zeros(D, dtype=float)
ci90_shrink_sd   = np.zeros(D, dtype=float)
var_shrink_mean  = np.zeros(D, dtype=float)
var_shrink_sd    = np.zeros(D, dtype=float)

# also convert KL nats -> bits for display
kl_bits_list_by_param = [ arr / np.log(2.0) for arr in kl_list_by_param ]

for j in range(D):
    # gather posterior samples per sim for param j (ALL sims for robustness)
    widths = []
    vsh    = []
    for sid in all_ids:
        post = to_np(posterior_samples_thetas_for_held_out_sims[sid])[:, j]  # [n_draws]
        # 90% equal-tailed credible interval width
        lo, hi = np.quantile(post, [0.05, 0.95])
        width = max(hi - lo, EPS)
        widths.append(width)
        # variance shrinkage
        v = float(np.var(post, ddof=0))
        v = max(v, EPS)
        vsh.append(PRIOR_VAR / v)

    widths = np.asarray(widths, dtype=float)
    vsh    = np.asarray(vsh, dtype=float)

    # "shrinkage" = prior width / posterior width (prior width = 0.90 on [0,1])
    ci_sh = PRIOR_CI90_WIDTH / widths
    ci90_shrink_mean[j] = float(np.mean(ci_sh))
    ci90_shrink_sd[j]   = float(np.std(ci_sh, ddof=0))

    var_shrink_mean[j]  = float(np.mean(vsh))
    var_shrink_sd[j]    = float(np.std(vsh, ddof=0))


# ---- Build common x-grids for plotting & precompute curves with progress bars ----
if PLOT_KIND == 'hist':
    edges = np.linspace(XLIM[0], XLIM[1], NBINS + 1)
    xgrid = edges[:-1]
else:
    edges = None
    xgrid = np.linspace(XLIM[0], XLIM[1], KDE_POINTS)

# Caches of curves for display subset
curves_prior  = [[] for _ in range(D)]   # each: list of y arrays (one per sim)
curves_post   = [[] for _ in range(D)]
# Aggregate bands (mean ± SD across sims' curves)
band_prior    = [None for _ in range(D)] # dict with 'mean','std'
band_post     = [None for _ in range(D)]
# Best-guess density curve (one per param)
best_guess_curve = [None for _ in range(D)]

print("Computing density curves (display subset)…")
for j in range(D):
    for sid in tqdm(sel_ids, desc=f"[{param_names[j]}] prior curves", leave=False):
        theta_true = theta_true_by_sim_display[sid]
        if PLOT_KIND == 'hist':
            y = flat_prior_curve_centered(theta_true[j], xgrid, edges)   # step-ready
        else:
            y = flat_prior_curve_centered(theta_true[j], xgrid, None)    # line-ready
        curves_prior[j].append(y)
    for arr in tqdm(posterior_centered_disp[j], desc=f"[{param_names[j]}] post curves", leave=False):
        y = hist_curve(arr, edges) if PLOT_KIND == 'hist' else kde_curve(arr, xgrid, KDE_BW)
        curves_post[j].append(y)

print("Computing aggregate mean ± SD bands & best-guess curves…")
for j in range(D):

    if AGGREGATE_ON_ALL:
        sim_ids_for_band = all_ids
    else:
        sim_ids_for_band = sel_ids

    Ys_prior = []
    for sid in tqdm(sim_ids_for_band, desc=f"[{param_names[j]}] band prior", leave=False):
        theta_true = (theta_true_by_sim_all if AGGREGATE_ON_ALL else theta_true_by_sim_display)[sid]
        if PLOT_KIND == 'hist':
            y = flat_prior_curve_centered(theta_true[j], xgrid, edges)
        else:
            y = flat_prior_curve_centered(theta_true[j], xgrid, None)
        Ys_prior.append(y)

    Ys_post = []
    for arr in tqdm((posterior_centered_all[j] if AGGREGATE_ON_ALL else posterior_centered_disp[j]),
                    desc=f"[{param_names[j]}] band post", leave=False):
        Ys_post.append(hist_curve(arr, edges) if PLOT_KIND == 'hist' else kde_curve(arr, xgrid, KDE_BW))

    # Use nan-aggregates for prior band (since we inserted NaNs outside support)
    if len(Ys_prior) > 0:
        YP = np.vstack(Ys_prior)
        band_prior[j] = dict(mean=np.nanmean(YP, axis=0), std=np.nanstd(YP, axis=0))
    else:
        band_prior[j] = dict(mean=None, std=None)

    if len(Ys_post) > 0:
        YQ = np.vstack(Ys_post)
        mu_post = YQ.mean(axis=0)
        sd_post = YQ.std(axis=0, ddof=0)
        band_post[j] = dict(mean=mu_post, std=sd_post)
    else:
        band_post[j] = dict(mean=None, std=None)

    # best-guess offsets density (red curve)
    if SHOW_BEST_GUESS_CURVE and best_guess_offsets_by_param[j].size > 0:
        offsets = np.clip(best_guess_offsets_by_param[j], XLIM[0], XLIM[1])
        if PLOT_KIND == 'hist':
            best_guess_curve[j] = hist_curve(offsets, edges)
        else:
            best_guess_curve[j] = kde_curve(offsets, xgrid, KDE_BW)

# ---- Global y-limits across ALL panels (include red curves & bands) ----
ymax = 0.0
for j in range(D):
    for y in curves_prior[j]:  ymax = max(ymax, float(np.nanmax(y)))
    for y in curves_post[j]:   ymax = max(ymax, float(np.nanmax(y)))
    if band_prior[j]['mean'] is not None:
        ymax = max(ymax, float(np.nanmax(band_prior[j]['mean'] + band_prior[j]['std'])))
    if band_post[j]['mean'] is not None:
        ymax = max(ymax, float(np.nanmax(band_post[j]['mean'] + band_post[j]['std'])))
    if best_guess_curve[j] is not None:
        ymax = max(ymax, float(np.nanmax(best_guess_curve[j])))
ymax *= 1.05
ylims = (0.0, ymax if np.isfinite(ymax) and ymax > 0 else 1.0)

# ---- Plot ----
fig, axes = plt.subplots(D, 2, figsize=FIGSIZE, sharex=True, sharey=True)
if D == 1: axes = np.array([axes])

for j in range(D):
    # LEFT: PRIOR
    axL = axes[j, 0]
    for y in curves_prior[j]:
        if PLOT_KIND == 'hist' or PLOT_KIND == 'kde': # always hist for prior
            axL.step(xgrid, y, where='post', alpha=ALPHA_LINE, lw=LINEWIDTH_THIN, color='grey')
        else:
            axL.plot(xgrid, y, alpha=ALPHA_LINE, lw=LINEWIDTH_THIN, color='grey')
    if band_prior[j]['mean'] is not None:
        mu = band_prior[j]['mean']; sd = band_prior[j]['std']
        if PLOT_KIND == 'hist' or PLOT_KIND == 'kde': # always hist for prior
            axL.step(xgrid, mu, where='post', lw=LINEWIDTH_THICK, color='grey', alpha=0.8, label='Mean density')
            axL.fill_between(xgrid, np.clip(mu - sd, 0, None), mu + sd, step='post',
                             color='grey', alpha=0.12, label='±1 SD (across sims)')
        else:
            axL.plot(xgrid, mu, lw=LINEWIDTH_THICK, color='grey', alpha=0.8, label='Mean density')
            axL.fill_between(xgrid, np.clip(mu - sd, 0, None), mu + sd,
                             color='grey', alpha=0.12, label='±1 SD (across sims)')
    axL.axvline(0.0, ls=':', lw=1.2, color='k')
    axL.set_ylabel(param_names[j])
    axL.set_xlim(*XLIM); axL.set_ylim(*ylims)
    if j == 0:
        axL.set_title("Prior (centered)")
        axL.legend(loc='upper right', fontsize=8, frameon=False)

    # RIGHT: POSTERIOR
    axR = axes[j, 1]
    for y in curves_post[j]:
        if PLOT_KIND == 'hist':
            axR.step(xgrid, y, where='post', alpha=ALPHA_LINE, lw=LINEWIDTH_THIN, color='purple')
        else:
            axR.plot(xgrid, y, alpha=ALPHA_LINE, lw=LINEWIDTH_THIN, color='purple')
    if band_post[j]['mean'] is not None:
        mu = band_post[j]['mean']; sd = band_post[j]['std']
        if PLOT_KIND == 'hist':
            axR.step(xgrid, mu, where='post', lw=LINEWIDTH_THICK, color='purple', alpha=0.9, label='Mean density')
            axR.fill_between(xgrid, np.clip(mu - sd, 0, None), mu + sd, step='post',
                             color='purple', alpha=0.2, label='±1 SD (across sims)')
        else:
            axR.plot(xgrid, mu, lw=LINEWIDTH_THICK, color='purple', alpha=0.9, label='Mean density')
            axR.fill_between(xgrid, np.clip(mu - sd, 0, None), mu + sd,
                             color='purple', alpha=0.2, label='±1 SD (across sims)')

    # RED: best-guess offsets density
    if SHOW_BEST_GUESS_CURVE and best_guess_curve[j] is not None:
        if PLOT_KIND == 'hist':
            axR.step(xgrid, best_guess_curve[j], where='post', lw=LINEWIDTH_THICK, color='red', alpha=0.6, label='Best guess (distribution of difference relative to ground truth)')
        else:
            axR.plot(xgrid, best_guess_curve[j], lw=LINEWIDTH_THICK, color='red', alpha=0.6, label=f'Best guess\n(distribution of difference\nrelative to ground truth)')

    axR.axvline(0.0, ls=':', lw=1.2, color='k')
    axR.set_xlim(*XLIM); axR.set_ylim(*ylims)
    if j == 0:
        axR.set_title("Posterior (centered)")
        axR.legend(loc='upper right', fontsize=8, frameon=False)

    if SHOW_KL_TEXT:
        # KL: use bits for display
        kl_bits = kl_bits_list_by_param[j]
        kl_mean_bits = float(kl_bits.mean()) if kl_bits.size else np.nan
        kl_sd_bits   = float(kl_bits.std(ddof=0)) if kl_bits.size else np.nan

        # Add shrinkage summaries (mean ± SD across sims)
        ci_mu, ci_sd = ci90_shrink_mean[j], ci90_shrink_sd[j]
        vs_mu, vs_sd = var_shrink_mean[j],  var_shrink_sd[j]

        axR.text(0.02, 0.95,
                "Info gain (from uniform to marginal posterior):\n"
                f"  KL divergence between uniform[0,1] and posterior\n  (entropy of posterior)\n  = {kl_mean_bits:.3f} ± {kl_sd_bits:.3f} bits\n"
                f" \nShrinkage (X-fold decrease from prior to posterior):\n"
                f"  90% confidence interval  = {ci_mu:.2f} ×  ± {ci_sd:.2f}\n"
                f"  Variance   = {vs_mu:.2f} ×  ± {vs_sd:.2f}",
                transform=axR.transAxes, va='top', ha='left',
                color='black', fontsize=8,
                bbox=dict(facecolor='white', alpha=0.6, edgecolor=None))


for ax in axes[-1, :]:
    ax.set_xlabel("Centered value (θ − θ_true)")

plt.tight_layout()
if SAVEFIG_PATH:
    plt.savefig(f"{SAVEFIG_PATH}\\ Accuracy_posterior", dpi=300)
plt.show()


In [None]:
### FIGURE CHECK # 2.1
# Comparing posterior distribution to ground truth, this time for each held out simulation (distribution density as color scale on the y axis, ground-truth as black line)

# ============================================================
# Posterior density maps (light theme + custom white→color cmap)
#  - Per-parameter 2D density (columns = held-out sims, y∈[0,1])
#  - Ground-truth line overlay
#  - Best-guess accuracy metrics (MAE/RMSE/R²) + expected baselines (Uniform[0,1])
#  - Option to SHARE the density color scale across all subplots
# ============================================================

import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, to_rgb
from tqdm.auto import tqdm
from scipy.stats import pearsonr

# ---------------- CONFIG ----------------
NBINS_Y         = 100              # vertical bins (resolution in y)
BASE_COLOR      = "purple" # "#2a6f97"        # choose any color; 0-density is white → high density = this color
SHARE_VSCALE    = True             # if True, all subplots share the same color scale (comparable intensities)
VSCALE_QUANTILE = 0.999            # when sharing, take this global quantile as vmax (robust to outliers); 1.0 = max
FIG_W           = 10
FIG_H_PER_ROW   = 4
COLORBAR_MODE   = "shared"         # "shared" (one colorbar for all) or "per-axes"
SAVEFIG         = path_to_save  # or None

# Pick held-out index
HELDOUT_PICK = 19 # Choose the held out simulation to highlight
heldout_idx = int(hold_out_idx[HELDOUT_PICK] if isinstance(HELDOUT_PICK, int) else HELDOUT_PICK)

# ----------------------------------------

def _to_np(x):
    """Convert torch.Tensor or array-like to a NumPy array without using tensor.numpy()."""
    if isinstance(x, torch.Tensor):
        # Go through Python lists → NumPy, avoids PyTorch's NumPy bridge entirely
        return np.asarray(x.detach().cpu().tolist(), dtype=float)
    else:
        return np.asarray(x, dtype=float)

# Light-theme matplotlib defaults
plt.rcParams.update({
    "axes.facecolor": "white",
    "figure.facecolor": "white",
    "savefig.facecolor": "white",
    "text.color": "black",
    "axes.edgecolor": "black",
    "axes.labelcolor": "black",
    "xtick.color": "black",
    "ytick.color": "black",
})

# Custom colormap: white → BASE_COLOR (linear)
def white_to_color_cmap(hex_color="#2a6f97", steps=256):
    rgb = to_rgb(hex_color)
    return LinearSegmentedColormap.from_list("white_to_color", [(1,1,1), rgb], N=steps)

CMAP = white_to_color_cmap(BASE_COLOR, 256)

# Resolve names and held-out IDs
param_names = input_sim_parameters_to_infer if 'input_sim_parameters_to_infer' in globals() else \
              [f"θ[{j}]" for j in range(next(iter(posterior_samples_thetas_for_held_out_sims.values())).shape[1])]
heldout_ids = sorted(int(k) for k in posterior_samples_thetas_for_held_out_sims.keys())
N = len(heldout_ids)
D = len(param_names)
assert N > 0 and D > 0, "No held-out sims or parameters found."

# Collect GT, best-guess, and posterior samples
GT = np.zeros((N, D), dtype=np.float32)
BEST_GUESS = np.zeros((N, D), dtype=np.float32)
SAMPLES_PER_SIM = []

for i, sid in enumerate(heldout_ids):
    theta_true = _to_np(ground_truth_thetas_for_held_out_sims[sid]).reshape(-1)
    post = _to_np(posterior_samples_thetas_for_held_out_sims[sid])  # [n_draws, D]
    GT[i, :] = theta_true
    if sid in posterior_log_probs_for_held_out_sims:
        logp = _to_np(posterior_log_probs_for_held_out_sims[sid]).reshape(-1)
        idx  = int(np.argmax(logp))
        BEST_GUESS[i, :] = post[idx, :]
    else:
        BEST_GUESS[i, :] = post.mean(axis=0)
    SAMPLES_PER_SIM.append(post)

# Metric helpers
def mae(a, b):  return float(np.mean(np.abs(a - b)))
def rmse(a, b): return float(np.sqrt(np.mean((a - b)**2)))
def r2(pred, true):
    ss_res = np.sum((pred - true)**2)
    ss_tot = np.sum((true - true.mean())**2)
    return float(1.0 - ss_res/ss_tot) if ss_tot > 0 else np.nan

# Precompute images (so we can share color scale if requested)
y_edges = np.linspace(0.0, 1.0, NBINS_Y + 1)
images_per_param = []   # list of [NBINS_Y, N]
gt_sorted_per_param = []  # x alignments

for j, pname in enumerate(param_names):
    gt_j = GT[:, j]
    order = np.argsort(gt_j)
    gt_sorted = gt_j[order]
    img = np.zeros((NBINS_Y, N), dtype=np.float32)
    for col, idx in enumerate(tqdm(order, desc=f"Building density for {pname}", leave=False)):
        samples = SAMPLES_PER_SIM[idx][:, j]
        samples = np.clip(samples, 0.0, 1.0)
        counts, _ = np.histogram(samples, bins=y_edges)
        col_pdf = counts.astype(np.float32)
        Z = col_pdf.sum()
        if Z > 0:
            col_pdf /= Z  # column-wise PDF
        img[:, col] = col_pdf
    images_per_param.append(img)
    gt_sorted_per_param.append(gt_sorted)

# Choose shared vmin/vmax
if SHARE_VSCALE:
    all_vals = np.concatenate([img.ravel() for img in images_per_param])
    vmin = 0.0
    vmax = float(np.quantile(all_vals, VSCALE_QUANTILE)) if VSCALE_QUANTILE < 1.0 else float(all_vals.max())
    if vmax <= 0:
        vmax = 1.0
else:
    vmin, vmax = 0.0, None  # autoscale per-axes


In [None]:
### FIGURE CHECK # 2.2
# Comparing posterior distribution to ground truth, this time for each held out simulation (distribution density as color scale on the y axis, ground-truth as black line)

# ---- Plot (with markers for HELDOUT_PICK) ----
fig, axes = plt.subplots(D, 1, figsize=(FIG_W, max(3.0, D*FIG_H_PER_ROW)), squeeze=False)
mappables = []

# map heldout_idx (sim id) -> its row index i_pick in arrays
try:
    i_pick = heldout_ids.index(heldout_idx)
except ValueError:
    i_pick = None
    print(f"Warning: heldout_idx={heldout_idx} not found in heldout_ids; markers will be skipped.")

for j, pname in enumerate(param_names):
    ax = axes[j, 0]
    img = images_per_param[j]

    im = ax.imshow(
        img, aspect='auto', origin='lower',
        extent=[0, N-1, 0.0, 1.0],
        cmap=CMAP, vmin=vmin, vmax=vmax,
        interpolation=None,
    )
    mappables.append(im)

    # Ground-truth line (y = sorted GT)
    gt_j = GT[:, j]
    order_j = np.argsort(gt_j)
    gt_sorted = gt_j[order_j]
    ax.plot(np.arange(N), gt_sorted, color='black', lw=1.4, alpha=0.9, label="Ground truth")

    # Marker for the chosen held-out sim (column depends on sorting for THIS parameter)
    if i_pick is not None:
        # column where this sim lands after sorting by parameter j
        col_pos = int(np.where(order_j == i_pick)[0][0])
        y_val = gt_sorted[col_pos]
        ax.scatter(
            [col_pos], [y_val],
            s=65, marker='D',
            facecolor='none', edgecolor='black', linewidth=1.2, zorder=5,
            label=None
        )

    ax.set_xlim(0, N-1); ax.set_ylim(0.0, 1.0)
    ax.set_ylabel(pname)
    if j == 0:
        ax.set_title("Posterior density across held-out simulations (light theme; white→low, color→high)")

    # Best-guess metrics vs unsorted GT (metrics independent of sort)
    mae_best  = float(np.mean(np.abs(BEST_GUESS[:, j] - GT[:, j])))
    rmse_best = float(np.sqrt(np.mean((BEST_GUESS[:, j] - GT[:, j])**2)))
    ss_res = np.sum((BEST_GUESS[:, j] - GT[:, j])**2)
    ss_tot = np.sum((GT[:, j] - GT[:, j].mean())**2)
    r2_best  = float(1.0 - ss_res/ss_tot) if ss_tot > 0 else np.nan
    # Pearson correlation between posterior mode and ground truth
    if GT[:, j].std(ddof=0) > 0 and BEST_GUESS[:, j].std(ddof=0) > 0 and GT.shape[0] > 1:
        pear_r, _ = pearsonr(BEST_GUESS[:, j], GT[:, j])
        pear_r = float(pear_r)
    else:
        pear_r = np.nan


    # Expected baselines under Uniform[0,1]
    a = GT[:, j]
    E_abs = a**2 - a + 0.5                         # E|U-a|
    E_sq  = (0.5 - a)**2 + (1.0/12.0)              # E[(U-a)^2]
    exp_mae_prior  = float(np.mean(E_abs))
    exp_rmse_prior = float(np.sqrt(np.mean(E_sq)))
    muY, varY = float(a.mean()), float(a.var(ddof=0))
    exp_r2_prior = (1.0 - (varY + 1.0/12.0 + (muY - 0.5)**2) / varY) if varY > 0 else np.nan

    txt = (f"Best-guess vs Ground Truth:\n"
           f"  Mean Abs. Error  = {mae_best:.3f}\n"
           f"  RMSE = {rmse_best:.3f}\n"
           f"  R²   = {r2_best:.3f}\n"
           f"  Pearson r        = {pear_r:.3f}\n"
           f"Baseline (Uniform prior, expected):\n"
           f"  E[MAE]  = {exp_mae_prior:.3f}\n"
           f"  E[RMSE] = {exp_rmse_prior:.3f}\n"
           f"  E[R²]   = {exp_r2_prior:.3f}")
    ax.text(0.01, 0.99, txt, transform=ax.transAxes, va='top', ha='left',
            fontsize=8, color='black',
            bbox=dict(facecolor='white', alpha=0.75, edgecolor='none'))

    if j == 0:
        ax.legend(loc='lower right', fontsize=8, frameon=False)

ax.set_xlabel("Held-out simulations (sorted by ground truth per row)")

# Colorbar(s)
if COLORBAR_MODE == "shared":
    cbar = fig.colorbar(mappables[-1], ax=axes.ravel().tolist(), fraction=0.035, pad=0.02)
    cbar.set_label("Column PDF (density)")
else:
    for ax, im in zip(axes.ravel(), mappables):
        cbar = fig.colorbar(im, ax=ax, fraction=0.035, pad=0.02)
        cbar.set_label("Column PDF")

if SAVEFIG:
    plt.savefig(f"{SAVEFIG}\\Accuracy_posterior_density_light.png", dpi=300)
    plt.savefig(f"{SAVEFIG}\\Accuracy_posterior_density_light.svg")
plt.show()


In [None]:
### FIGURE CHECK # 2.3
# Comparing posterior distribution to ground truth, specifically at the highlighted heldout idx (corresponds to a "slice" of the plots generated in FOGURE CHECK # 2.2)

# ===========================================
# “Slice” figure for HELDOUT_PICK
#  - One panel per parameter
#  - KDE of posterior samples for that sim
#  - Vertical lines: GT (black solid), Best-guess (colored dashed)
# ===========================================

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgb

SLICE_BASE_COLOR = BASE_COLOR    # reuse color from above for consistency
SLICE_BW         = 0.04          # KDE bandwidth in parameter units [0,1]
SLICE_POINTS     = 401           # grid resolution
SLICE_XLIM       = (0.0, 1.0)    # parameters are normalized to [0,1]
FIGSIZE_SLICE    = (10, max(3.0, 2.2*D))

def kde_1d(samples, xgrid, bw):
    x = np.clip(np.asarray(samples, float), SLICE_XLIM[0], SLICE_XLIM[1])
    if x.size == 0:
        return np.zeros_like(xgrid)
    inv = 1.0 / np.sqrt(2*np.pi*bw*bw)
    dif = xgrid[None, :] - x[:, None]
    y = inv * np.exp(-0.5 * (dif / bw)**2)
    # normalize to integrate to ~1 over the grid
    y_mean = y.mean(axis=0)
    dx = np.diff(xgrid).mean()
    area = (y_mean * dx).sum()
    return (y_mean / area) if area > 0 else y_mean

# Map heldout_idx to its array row index i_pick (if not done already)
try:
    i_pick = heldout_ids.index(heldout_idx)
except ValueError:
    raise RuntimeError(f"heldout_idx={heldout_idx} not found in heldout_ids.")

# Build figure
xgrid = np.linspace(SLICE_XLIM[0], SLICE_XLIM[1], SLICE_POINTS)
fig, axes = plt.subplots(D, 1, figsize=FIGSIZE_SLICE, sharex=True)

if D == 1:
    axes = np.array([axes])

curve_color = SLICE_BASE_COLOR
for j, pname in enumerate(param_names):
    ax = axes[j]

    # Posterior samples for this sim & parameter
    samples_j = SAMPLES_PER_SIM[i_pick][:, j]
    dens = kde_1d(samples_j, xgrid, SLICE_BW)

    # Lines
    ax.plot(xgrid, dens, color=curve_color, lw=2.0, alpha=0.95, label="Posterior (KDE)")
    ax.axvline(GT[i_pick, j], color='black', lw=1.5, label="Ground truth")
    ax.axvline(BEST_GUESS[i_pick, j], color=curve_color, lw=1.5, ls='--', label="Best guess")

    ax.set_ylabel(pname)
    ax.set_xlim(*SLICE_XLIM)
    # pretty y-limit with a bit of headroom
    ymax = float(dens.max()) if np.isfinite(dens).all() else 1.0
    ax.set_ylim(0, ymax * 1.08)

axes[-1].set_xlabel("Parameter value (normalized)")

# Top legend (single, clean)
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper right", frameon=False)

plt.tight_layout()
if SAVEFIG:
    plt.savefig(f"{SAVEFIG}\\Accuracy_slice_KDE.png", dpi=300)
    plt.savefig(f"{SAVEFIG}\\Accuracy_slice_KDE.svg")
plt.show()


In [None]:
### FIGURE CHECK # 3
# Comparing posterior distribution to ground truth in feature-space (and comparing it to prior), with different projection styles

# %% SBI held-out check — projection (with full D axes) + 2D scatter (fixed contours)

# ------------------- TUNING KNOBS -------------------
HELDOUT_PICK = 19 # Held out idx to look at
# Normalization for projection & scatter: "original", "unit", or "zscore"
proj_normalization    = "unit"
scatter_normalization = "unit"

prior_color = "#888888"
heldout_condition_color = "#C2388D"

# --- Figure 1 (D->2 arrow projection, show ALL basis axes) ---
PROJ_MODE    = "manual"                 # "mix2" | "random" | "axes" | "manual"
# Plane definition if PROJ_MODE == "mix2":
# u_raw = cos(θ) e_i + sin(θ) e_j  ; v_raw = cos(φ) e_k + sin(φ) e_l
MIX2_U = (0, 1, 35.0)                 # (i, j, angle_deg)
MIX2_V = (0, 2, -20.0)                # (k, l, angle_deg)
# Fallbacks for other modes
ORIENT_SEED    = 7
ORIENT_AXES    = (0, 1)               # for PROJ_MODE=="axes"
# ORIENT_B_MANUAL = None                # np.ndarray shape (D,2) if PROJ_MODE=="manual"
ORIENT_B_MANUAL = np.array([
    [ 1.00,  0.00],   # θ1
    [ 0.31,  0.95],   # θ2
    [-0.81,  0.59],   # θ3
    [ 0.62, -0.78],   # θ4
    [-0.96,  0.28],   # θ5
])

# How many arrows to show (subset to avoid clutter)
n_post_proj  = 20
n_prior_proj = 50

# Arrow styling
arrow_lw_gt       = 3.0
arrow_lw_post     = 2
arrow_lw_prior    = 2
arrow_alpha_gt    = 1.0
arrow_alpha_post  = 0.2
arrow_alpha_prior = 0.2
# Small arrowheads (relative — the plane is unit-ish)
arrow_head_width  = 0.02
arrow_head_length = 0.03

# Basis (parameter) axis arrows (projected) — black
basis_axis_color = "#000000"
basis_axis_lw    = 1.4
basis_axis_alpha = 0.8
basis_head_w     = 0.018
basis_head_l     = 0.035

# Limits padding for projection figure
proj_axis_pad_frac = 0.06
label_params_on_axes = True           # show θ_i labels near arrow tips

# --- Figure 2 (2D scatter) ---
param_x = "common_input_std"
param_y = "disynpatic_inhib_connections_desired_MN_MN"
n_post_scatter  = 2500
n_prior_scatter = 11500
dot_size_prior  = 30
dot_alpha_prior = 0.05
dot_size_post   = 30
dot_alpha_post  = 0.1
cross_size      = 130
cross_edge_lw   = 1.2
scatter_axis_pad_frac = 0.08

# 2D KDE (posterior) — mass-contour lines 30/60/90%
kde_levels_mass = (0.3, 0.6, 0.9)
kde_grid        = 200
kde_lw          = 1.4
kde_color       = "#000000"
kde_bw_scale    = None                # None→Scott; or float

# Save paths
path_to_save = globals().get("path_to_save", ".")
os.makedirs(path_to_save, exist_ok=True)
FIG1_PATH = os.path.join(path_to_save, "SBI_check_projection_theta.svg")
FIG2_PATH = os.path.join(path_to_save, "SBI_check_scatter_theta.svg")

# Helper to avoid numpy arrays (not supported by new Pytorch versions)
def to_np(x):
    """Convert torch.Tensor or array-like to a NumPy array without using tensor.numpy()."""
    if isinstance(x, torch.Tensor):
        return np.asarray(x.detach().cpu().tolist(), dtype=float)
    else:
        return np.asarray(x, dtype=float)

# ----------------------------------------------------
# Context (same as before)
theta_colnames = list(globals()["theta_colnames"])
D = len(theta_colnames)
if "theta_raw" not in globals():
    theta_raw = torch.tensor(
        df_simulation_summary[theta_colnames].to_numpy(dtype=float).tolist(),
        dtype=torch.float32,
    )

# Prior bounds (original scale)
if "low_original" not in globals() or "high_original" not in globals():
    low_original  = torch.tensor([priors_per_parameters_to_infer[name][0] for name in theta_colnames], dtype=torch.float32)
    high_original = torch.tensor([priors_per_parameters_to_infer[name][1] for name in theta_colnames], dtype=torch.float32)

def theta_to_unit(theta):
    return (theta - low_original) / (high_original - low_original)

def unit_to_theta(theta_unit):
    return theta_unit * (high_original - low_original) + low_original

theta_mu = theta_raw.mean(dim=0)
theta_sd = theta_raw.std(dim=0, unbiased=False).clamp_min(1e-12)

# Pick held-out index
heldout_idx = int(hold_out_idx[HELDOUT_PICK] if isinstance(HELDOUT_PICK, int) else HELDOUT_PICK)

# Observation vector for this held-out
x_obs = obs_hold_out[HELDOUT_PICK] # sim_obs[heldout_idx]

# Ground truth θ
theta_gt_original = theta_raw[heldout_idx]
theta_gt_unit     = theta_to_unit(theta_gt_original)

# Posterior samples (unit)
N_POST = int(max(n_post_proj, n_post_scatter))
with torch.no_grad():
    post_samples_unit = posterior_net_held_out_data.sample((N_POST,), x=x_obs)

# Subsets
rng = np.random.default_rng(123)
idx_post_proj   = rng.choice(post_samples_unit.shape[0], size=min(n_post_proj, post_samples_unit.shape[0]), replace=False)
idx_post_scatt  = rng.choice(post_samples_unit.shape[0], size=min(n_post_scatter, post_samples_unit.shape[0]), replace=False)

N_PRIOR_PROJ   = min(n_prior_proj,   len(df_simulation_summary))
N_PRIOR_SCATT  = min(n_prior_scatter, len(df_simulation_summary))
idx_prior_proj  = rng.choice(len(df_simulation_summary), size=N_PRIOR_PROJ,  replace=False)
idx_prior_scatt = rng.choice(len(df_simulation_summary), size=N_PRIOR_SCATT, replace=False)

# make them safe for PyTorch indexing
idx_post_proj   = idx_post_proj.tolist()
idx_post_scatt  = idx_post_scatt.tolist()
idx_prior_proj  = idx_prior_proj.tolist()
idx_prior_scatt = idx_prior_scatt.tolist()

prior_subset_original_proj  = theta_raw[idx_prior_proj]
prior_subset_original_scatt = theta_raw[idx_prior_scatt]

# ---------- helpers ----------
def make_plane_mix2(D, mix_u, mix_v):
    # mix = (i, j, angle_deg)
    def vec_from_mix(i, j, ang_deg):
        ang = np.deg2rad(float(ang_deg))
        v = np.zeros((D,), dtype=float)
        v[i] = np.cos(ang); v[j] = np.sin(ang)
        return v
    u_raw = vec_from_mix(*mix_u)
    v_raw = vec_from_mix(*mix_v)
    u = u_raw / (np.linalg.norm(u_raw) + 1e-12)
    v = v_raw - (u @ v_raw) * u
    nv = np.linalg.norm(v)
    if nv < 1e-8:
        # fallback: pick next canonical axis not collinear
        for k in range(D):
            cand = np.zeros(D); cand[k] = 1.0
            cand = cand - (u @ cand) * u
            nv2 = np.linalg.norm(cand)
            if nv2 > 1e-8:
                v = cand / nv2
                break
    else:
        v = v / nv
    return np.column_stack([u, v])  # (D,2)

def make_projection_basis(D, mode="mix2", seed=7, axes=(0,1), B_manual=None, mix_u=None, mix_v=None):
    if mode == "mix2":
        if mix_u is None or mix_v is None:
            raise ValueError("Provide MIX2_U and MIX2_V when PROJ_MODE='mix2'.")
        return make_plane_mix2(D, mix_u, mix_v)
    elif mode == "axes":
        u = np.zeros((D,)); v = np.zeros((D,))
        u[axes[0]] = 1.0; v[axes[1]] = 1.0
        return np.column_stack([u, v])
    elif mode == "manual":
        B = np.asarray(B_manual, float)
        if B.shape != (D,2): raise ValueError("ORIENT_B_MANUAL must be (D,2)")
        Q, _ = np.linalg.qr(B)
        return Q[:, :2]
    else:
        A = np.random.default_rng(seed).normal(size=(D,2))
        Q, _ = np.linalg.qr(A)
        return Q[:, :2]

def project(V, B):  # V: (..., D), B: (D,2)
    return V @ B

def axis_limits_with_padding(xy, pad_frac=0.06):
    xmin, xmax = np.nanmin(xy[:,0]), np.nanmax(xy[:,0])
    ymin, ymax = np.nanmin(xy[:,1]), np.nanmax(xy[:,1])
    dx, dy = xmax - xmin, ymax - ymin
    xmin -= dx*pad_frac; xmax += dx*pad_frac
    ymin -= dy*pad_frac; ymax += dy*pad_frac
    if dx <= 0: xmin -= 1; xmax += 1
    if dy <= 0: ymin -= 1; ymax += 1
    return (xmin, xmax, ymin, ymax)

def draw_arrow(ax, start, end, lw=1.5, color="#000", alpha=1.0,
               head_w=0.02, head_len=0.04, z=1):
    x0, y0 = start; x1, y1 = end
    dx, dy = (x1 - x0), (y1 - y0)
    ax.arrow(x0, y0, dx, dy,
             head_width=head_w, head_length=head_len,
             length_includes_head=True, linewidth=lw,
             facecolor=color, edgecolor=color, alpha=alpha, zorder=z)

def to_mode(arr_torch, mode):
    if mode == "original":
        return to_np(arr_torch)
    if mode == "unit":
        return to_np(theta_to_unit(arr_torch))
    if mode == "zscore":
        return to_np((arr_torch - theta_mu) / theta_sd)
    raise ValueError("mode must be 'original'|'unit'|'zscore'")


# 2D KDE + mass levels
def kde2d_grid(x, y, xlim, ylim, grid=200, bw_scale=None):
    x = np.asarray(x, float); y = np.asarray(y, float)
    x = x[np.isfinite(x)]; y = y[np.isfinite(y)]
    if x.size < 2:
        gx = np.linspace(*xlim, grid); gy = np.linspace(*ylim, grid)
        XX, YY = np.meshgrid(gx, gy); return XX, YY, np.zeros_like(XX)
    gx = np.linspace(*xlim, grid); gy = np.linspace(*ylim, grid)
    XX, YY = np.meshgrid(gx, gy)
    sx = np.std(x, ddof=1); sy = np.std(y, ddof=1)
    if sx <= 0 or sy <= 0:
        sx = (xlim[1]-xlim[0]) * 1e-3 or 1e-6
        sy = (ylim[1]-ylim[0]) * 1e-3 or 1e-6
    n = x.size; h_scott = n ** (-1/6)
    bx = sx * (h_scott if bw_scale is None else float(bw_scale))
    by = sy * (h_scott if bw_scale is None else float(bw_scale))
    dens = np.zeros_like(XX, dtype=float)
    chunk = max(1, int(2e5 // XX.size))
    for s in range(0, n, chunk):
        e = min(n, s+chunk)
        dx = (XX[...,None] - x[None,None,s:e]) / bx
        dy = (YY[...,None] - y[None,None,s:e]) / by
        dens += np.exp(-0.5*(dx*dx + dy*dy)).sum(axis=2)
    dens /= (n * (2*np.pi*bx*by))
    return XX, YY, dens

def mass_contour_levels(dens, levels_mass=(0.3,0.6,0.9), XX=None, YY=None):
    flat = dens.ravel()
    dx = (XX[0,1] - XX[0,0]); dy = (YY[1,0] - YY[0,0])
    area = dx*dy
    order = np.argsort(flat)[::-1]
    flat_sorted = flat[order]
    mass_cum = np.cumsum(flat_sorted) * area
    thr = []
    for m in levels_mass:
        idx = np.searchsorted(mass_cum, m, side="left")
        idx = np.clip(idx, 0, flat_sorted.size-1)
        thr.append(float(flat_sorted[idx]))
    # Matplotlib requires strictly increasing levels
    thr = np.unique(np.sort(thr))
    return thr.tolist()

# ------------- Build datasets in chosen normalization -------------
# Projection data
gt_vec = {
    "original": theta_gt_original,
    "unit":     theta_gt_unit,
    "zscore":   (theta_gt_original - theta_mu) / theta_sd,
}[proj_normalization]
gt_vec = to_np(gt_vec)

post_proj = {
    "original": unit_to_theta(post_samples_unit[idx_post_proj]),
    "unit":     post_samples_unit[idx_post_proj],
    "zscore":   (unit_to_theta(post_samples_unit[idx_post_proj]) - theta_mu)/theta_sd
}[proj_normalization]
prior_proj = {
    "original": prior_subset_original_proj,
    "unit":     theta_to_unit(prior_subset_original_proj),
    "zscore":   (prior_subset_original_proj - theta_mu)/theta_sd
}[proj_normalization]
post_proj_np  = to_np(post_proj)
prior_proj_np = to_np(prior_proj)

# Scatter data
def extract_xy_any(obj, xname, yname, mode="original"):
    ix = theta_colnames.index(xname)
    iy = theta_colnames.index(yname)
    if isinstance(obj, torch.Tensor):
        if mode == "original":
            X = to_np(obj[:, ix])
            Y = to_np(obj[:, iy])
        elif mode == "unit":
            U = theta_to_unit(obj)
            X = to_np(U[:, ix])
            Y = to_np(U[:, iy])
        elif mode == "zscore":
            X = to_np((obj[:, ix] - theta_mu[ix]) / theta_sd[ix])
            Y = to_np((obj[:, iy] - theta_mu[iy]) / theta_sd[iy])
        else:
            raise ValueError("mode must be 'original'|'unit'|'zscore'")
        return X, Y
    arr = np.asarray(obj, dtype=float)
    return arr[:, ix], arr[:, iy]

post_scatt_original = unit_to_theta(post_samples_unit[idx_post_scatt])
prior_scatt_original= prior_subset_original_scatt

def to_xy(obj_torch, mode):
    return extract_xy_any(obj_torch, param_x, param_y, mode=mode)

px_post, py_post = to_xy(post_scatt_original,  scatter_normalization)
px_prior,py_prior= to_xy(prior_scatt_original, scatter_normalization)
gt_point_original = theta_gt_original.unsqueeze(0)
gx, gy = to_xy(gt_point_original, scatter_normalization)
gx, gy = float(gx[0]), float(gy[0])

# ----- FIGURE 1: D→2 projection with basis axes -----
B = make_projection_basis(
    D, mode=PROJ_MODE, seed=ORIENT_SEED, axes=ORIENT_AXES,
    B_manual=ORIENT_B_MANUAL, mix_u=MIX2_U, mix_v=MIX2_V
)  # (D,2)

gt_xy     = project(gt_vec,           B).reshape(1,2)
post_xy   = project(post_proj_np,     B)
prior_xy  = project(prior_proj_np,    B)

# Basis axes: project each canonical basis e_i
E = np.eye(D)
E_proj = E @ B  # (D,2)

stack_all = np.vstack([gt_xy, post_xy, prior_xy, E_proj])
xmin,xmax,ymin,ymax = axis_limits_with_padding(stack_all, pad_frac=proj_axis_pad_frac)

fig1, ax1 = plt.subplots(figsize=(7.4, 6.2))
# Draw basis axes (black), from origin
for i in range(D):
    end = E_proj[i]
    draw_arrow(ax1, (0,0), (end[0], end[1]),
               lw=basis_axis_lw, color=basis_axis_color, alpha=basis_axis_alpha,
               head_w=basis_head_w, head_len=basis_head_l, z=3)
    if label_params_on_axes:
        ax1.text(end[0]*1.05, end[1]*1.05, f"θ:{theta_colnames[i]}",
                 fontsize=9, color=basis_axis_color, alpha=basis_axis_alpha)

# Prior arrows
for p in prior_xy:
    draw_arrow(ax1, (0,0), (p[0], p[1]), lw=arrow_lw_prior, color=prior_color,
               alpha=arrow_alpha_prior, head_w=arrow_head_width, head_len=arrow_head_length, z=1)
# Posterior arrows
for p in post_xy:
    draw_arrow(ax1, (0,0), (p[0], p[1]), lw=arrow_lw_post, color=heldout_condition_color,
               alpha=arrow_alpha_post, head_w=arrow_head_width, head_len=arrow_head_length, z=2)
# Ground-truth arrow
draw_arrow(ax1, (0,0), (gt_xy[0,0], gt_xy[0,1]), lw=arrow_lw_gt, color=heldout_condition_color,
           alpha=arrow_alpha_gt, head_w=arrow_head_width, head_len=arrow_head_length, z=5)

# ax1.axhline(0, color="0.85", lw=1)
# ax1.axvline(0, color="0.85", lw=1)
ax1.set_xlim(xmin, xmax); ax1.set_ylim(ymin, ymax)
ax1.set_aspect("equal", adjustable="box")
ax1.set_title(f"Held-out θ projection ({proj_normalization} space) • plane={PROJ_MODE}")
ax1.set_xlabel("proj u"); ax1.set_ylabel("proj v")
ax1.set_box_aspect(1)
# plt.tight_layout()
plt.savefig(FIG1_PATH, dpi=180)
plt.savefig(FIG1_PATH.replace(".svg",".png"), dpi=300)
plt.show()

# ----- FIGURE 2: 2D scatter + fixed mass-contours (sorted levels) -----
def axis_limits_with_padding_xy(px_prior, py_prior, px_post, py_post, gx, gy, pad_frac=0.08):
    xy_all = np.vstack([
        np.column_stack([px_prior, py_prior]),
        np.column_stack([px_post,  py_post]),
        np.array([[gx, gy]])
    ])
    return axis_limits_with_padding(xy_all, pad_frac=pad_frac)

xmin,xmax,ymin,ymax = axis_limits_with_padding_xy(px_prior, py_prior, px_post, py_post, gx, gy, pad_frac=scatter_axis_pad_frac)

fig2, ax2 = plt.subplots(figsize=(7.2, 6.2))
ax2.scatter(px_prior, py_prior, s=dot_size_prior, alpha=dot_alpha_prior,
            color=prior_color, edgecolors='none', label="prior")
ax2.scatter(px_post, py_post, s=dot_size_post, alpha=dot_alpha_post,
            color=heldout_condition_color, edgecolors='none', label="posterior")
ax2.scatter([gx], [gy], s=cross_size, marker='X', color=heldout_condition_color,
            edgecolor='k', linewidth=cross_edge_lw, zorder=8, label="ground truth")

# Posterior mass-contours (30/60/90%), with increasing levels (fix)
XX, YY, dens = kde2d_grid(px_post, py_post, (xmin, xmax), (ymin, ymax), grid=kde_grid, bw_scale=kde_bw_scale)
thr = mass_contour_levels(dens, kde_levels_mass, XX=XX, YY=YY)  # sorted & unique
if len(thr) >= 1 and np.isfinite(thr).all():
    ax2.contour(XX, YY, dens, levels=thr, colors=kde_color, linewidths=kde_lw)

ax2.set_xlim(xmin, xmax); ax2.set_ylim(ymin, ymax)
ax2.set_xlabel(param_x + f" ({scatter_normalization})")
ax2.set_ylabel(param_y + f" ({scatter_normalization})")
ax2.set_aspect("equal", adjustable="box")
ax2.legend(frameon=False, loc="best")
ax2.set_title("Posterior vs prior (2D slice) with posterior mass-contours")
ax2.set_box_aspect(1)
# plt.tight_layout()
plt.savefig(FIG2_PATH, dpi=180)
plt.savefig(FIG2_PATH.replace(".svg",".png"), dpi=300)
plt.show()

print("Saved:")
print("  ", FIG1_PATH)
print("  ", FIG2_PATH)


In [None]:
### FIGURE CHECK # 4
# Posterior density estimator calibration
# 4.1 = calculating calibration

# === Alpha-star step functions (HPD via log_prob), plus marginal alpha* per parameter ===
# Needs (for JOINT):
# - posterior_log_probs_for_held_out_sims: {sim_id: torch.Tensor [n_draws]}   (log p(theta_samples | x))
# - ground_truth_thetas_for_held_out_sims: {sim_id: torch.Tensor [1,D] or [D]}
# Optional (faster):
# - posterior_true_log_prob_for_held_out_sims: {sim_id: float or 1D tensor}
# Fallback if recompute needed:
# - posterior_net_held_out_data  (sbi posterior)
# - sim_obs  (indexable by sim_id)
#
# Needs (for MARGINAL α*):
# - posterior_samples_thetas_for_held_out_sims: {sim_id: torch.Tensor [n_draws, D]}
#   If missing, we will attempt to resample them (requires posterior_net_held_out_data, sim_obs).

def to_np(x): # Helper to avoid pytorch error when using numpy variables
    """Convert torch.Tensor or array-like to a NumPy array without using tensor.numpy()."""
    if isinstance(x, torch.Tensor):
        return np.asarray(x.detach().cpu().tolist(), dtype=float)
    else:
        return np.asarray(x, dtype=float)

def _to1d_cpu(x: torch.Tensor) -> torch.Tensor:
    x = x.detach()
    if x.ndim > 1:
        x = x.squeeze()
    return x.cpu()

def _to2d_cpu(x: torch.Tensor) -> torch.Tensor:
    x = x.detach()
    if x.ndim == 1:
        x = x.unsqueeze(0)
    return x.cpu()

def _alpha_star_joint_from_logps(logp_samples_1d: torch.Tensor, logp_true_scalar: float) -> float:
    """alpha* (joint) = fraction of posterior mass with density >= density at theta_true."""
    s = to_np(_to1d_cpu(logp_samples_1d))          # <-- no .numpy()
    return float(np.mean(s >= logp_true_scalar))

def _alpha_star_marginal_central(samples_1d: torch.Tensor, theta_true_scalar: float) -> float:
    """
    Marginal alpha* using CENTRAL intervals (nested, cheap).
    Let u = empirical CDF at theta_true (via rank). The smallest central CI containing theta_true has
      alpha* = 1 - 2*min(u, 1-u) = 2*|u - 0.5|.
    """
    x = to_np(_to1d_cpu(samples_1d))              # <-- no .numpy()
    N = x.shape[0]
    # rank and normalized CDF u \in (0,1) using the standard SBC tie-handling
    less = np.sum(x < theta_true_scalar)
    equal = np.sum(x == theta_true_scalar)
    # break ties uniformly
    rank = less + (np.random.randint(0, equal + 1) if equal > 0 else 0)
    u = (rank + 0.5) / (N + 1.0)
    return float(2.0 * abs(u - 0.5))

# ---- Collect held-out ids & basic info
heldout_ids = list(posterior_log_probs_for_held_out_sims.keys())
assert len(heldout_ids) > 0, "posterior_log_probs_for_held_out_sims is empty."

n_draws = int(_to1d_cpu(posterior_log_probs_for_held_out_sims[heldout_ids[0]]).shape[0])

# Param names (optional, for later plotting titles)
has_param_names = 'input_sim_parameters_to_infer' in globals()
if has_param_names:
    param_names = input_sim_parameters_to_infer
else:
    # Try to infer D from a ground-truth vector
    any_gt = ground_truth_thetas_for_held_out_sims[heldout_ids[0]]
    D = int(_to2d_cpu(any_gt).shape[1])
    param_names = [f"theta[{j}]" for j in range(D)]

# Determine dimensionality D
first_gt = _to2d_cpu(ground_truth_thetas_for_held_out_sims[heldout_ids[0]])
D = int(first_gt.shape[1])

# Optional caches present?
has_true_lp_cache = 'posterior_true_log_prob_for_held_out_sims' in globals()
has_sample_cache  = 'posterior_samples_thetas_for_held_out_sims' in globals()

# Prepare outputs
alpha_stars_joint = []                     # shape [n_sims]
alpha_stars_marginal = np.zeros((D, len(heldout_ids)), dtype=float)  # [D, n_sims]

# Loop
for col, sim_id in enumerate(heldout_ids):
    if col % 50 == 0:
        print(f"Processing held-out sim_id={sim_id} ({col+1}/{len(heldout_ids)})...")
    # --- JOINT alpha* from log_prob thresholds
    logp_samps = _to1d_cpu(posterior_log_probs_for_held_out_sims[sim_id])

    if has_true_lp_cache and sim_id in posterior_true_log_prob_for_held_out_sims:
        lp_true = float(_to1d_cpu(posterior_true_log_prob_for_held_out_sims[sim_id]).item())
    else:
        assert 'posterior_net_held_out_data' in globals(), "Need posterior_net_held_out_data to compute log_prob(theta_true|x)."
        assert 'sim_obs' in globals(), "Need sim_obs[sim_id] to compute log_prob(theta_true|x)."
        th_true_vec = ground_truth_thetas_for_held_out_sims[sim_id]
        th_true_vec = th_true_vec if th_true_vec.ndim == 2 else th_true_vec.unsqueeze(0)  # [1, D]
        x_obs = sim_obs[sim_id].unsqueeze(0)
        lp_true = float(posterior_net_held_out_data.log_prob(th_true_vec, x=x_obs).detach().cpu().item())

    alpha_stars_joint.append(_alpha_star_joint_from_logps(logp_samps, lp_true))

    # --- MARGINAL alpha* (per parameter) using posterior samples (central intervals)
    if has_sample_cache and sim_id in posterior_samples_thetas_for_held_out_sims:
        samps = _to2d_cpu(posterior_samples_thetas_for_held_out_sims[sim_id])  # [N, D]
    else:
        # resample to match n_draws
        assert 'posterior_net_held_out_data' in globals() and 'sim_obs' in globals(), \
            "Need posterior samples or the ability to sample them (posterior_net_held_out_data + sim_obs)."
        x_obs = sim_obs[sim_id].unsqueeze(0)
        samps = posterior_net_held_out_data.sample((n_draws,), x_obs, show_progress_bars=False).detach().cpu()  # [N, D]

    th_true = to_np(
        _to2d_cpu(ground_truth_thetas_for_held_out_sims[sim_id]).squeeze(0))  # [D] as NumPy array

    for j in range(D):
        alpha_stars_marginal[j, col] = _alpha_star_marginal_central(samps[:, j], th_true[j])

alpha_stars_joint = np.asarray(alpha_stars_joint)  # [n_sims]
# At this point you have:
# - alpha_stars_joint: shape [n_sims]
# - alpha_stars_marginal: shape [D, n_sims]
# - param_names: list of D labels


In [None]:
### FIGURE CHECK # 4
# Posterior density estimator calibration
# 4.2 = displaying calibration curves

from math import ceil

# ----- JOINT plot (same style as before) -----
alpha_grid = np.linspace(0.0, 1.0, 1001)
steps_joint = (alpha_grid[None, :] >= alpha_stars_joint[:, None]).astype(float)  # [n_sims, 1001]
mean_joint = steps_joint.mean(axis=0)
std_joint  = steps_joint.std(axis=0, ddof=0)

fig, ax = plt.subplots(figsize=(7.0, 5.2))
ax.plot(alpha_grid, mean_joint, lw=2.0, color="C0", label="Mean step curve")
ax.plot(alpha_grid, alpha_grid, ls='--', lw=1.2, color="black", label="Ideal")

ax.set_xlim(0, 1); ax.set_ylim(0, 1)
ax.set_xlabel("Nominal coverage (α)")
ax.set_ylabel("Proportion covered (mean of steps)")
ax.set_title(f"JOINT calibration via α* step functions | sims={len(alpha_stars_joint)}")

plt.tight_layout()
figname_to_save = f"{path_to_save}\\CALIBRATION_test_full_posterior"
plt.savefig(f"{figname_to_save}.png", dpi=300)
plt.savefig(f"{figname_to_save}.svg")
plt.show()

# Optional: ECE and histogram of alpha* (joint)
ece_joint = float(np.mean(np.abs(mean_joint - alpha_grid)))
print(f"[JOINT] Expected Calibration Error (ECE): {ece_joint:.4f}")

plt.figure(figsize=(6.0, 3.5))
plt.hist(alpha_stars_joint, bins=30, density=True, edgecolor='black', linewidth=0.5)
plt.axhline(1.0, ls='--', lw=1.2)  # Uniform(0,1) density
plt.xlim(0, 1)
plt.xlabel("α (joint)")
plt.ylabel("Density")
plt.title("Distribution of α* (joint) across held-out sims")
plt.tight_layout()
plt.show()

# ----- MARGINAL plots (one small panel per parameter) -----
D = alpha_stars_marginal.shape[0]
cols = min(4, D)
rows = ceil(D / cols)

fig, axes = plt.subplots(rows, cols, figsize=(4.5*cols, 3.6*rows), squeeze=False)

for j in range(D):
    r = j // cols
    c = j % cols
    ax = axes[r, c]

    steps_j = (alpha_grid[None, :] >= alpha_stars_marginal[j, :, None]).astype(float)  # [n_sims, 1001]
    mean_j  = steps_j.mean(axis=0)
    std_j   = steps_j.std(axis=0, ddof=0)

    # stack of step functions (transparent)
    # for i in range(steps_j.shape[0]):
    #     ax.plot(alpha_grid, steps_j[i], lw=0.6, color="C1", alpha=0.05)

    lo = np.clip(mean_j - std_j, 0, 1)
    hi = np.clip(mean_j + std_j, 0, 1)
    # ax.fill_between(alpha_grid, lo, hi, alpha=0.18, color="C1", label="Mean ± 1 SD")
    ax.plot(alpha_grid, mean_j, lw=1.8, color="C1", label="Mean step curve")
    ax.plot(alpha_grid, alpha_grid, ls='--', lw=1.0, color="black", label="Ideal")

    ax.set_xlim(0, 1); ax.set_ylim(0, 1)
    ax.set_xlabel("α")
    ax.set_ylabel("Proportion covered")
    ax.set_title(f"{param_names[j]} — marginal α* (central)")

    # if j == 0:
    #     ax.legend(loc="lower right", fontsize=8)

# Hide any unused axes
for j in range(D, rows*cols):
    r = j // cols; c = j % cols
    axes[r, c].axis('off')

plt.tight_layout()
figname_to_save = f"{path_to_save}\\CALIBRATION_test_per_param"
plt.savefig(f"{figname_to_save}.png", dpi=300)
plt.savefig(f"{figname_to_save}.svg")
plt.show()

# Optional: per-parameter summaries
for j in range(D):
    mean_curve_j = (alpha_grid[None, :] >= alpha_stars_marginal[j, :, None]).mean(axis=0)
    ece_j = float(np.mean(np.abs(mean_curve_j - alpha_grid)))
    print(f"[MARGINAL] {param_names[j]} — ECE: {ece_j:.4f}")

# Optional: histograms of marginal alpha*
fig, axes = plt.subplots(rows, cols, figsize=(4.5*cols, 3.2*rows), squeeze=False)
for j in range(D):
    r = j // cols; c = j % cols
    ax = axes[r, c]
    ax.hist(alpha_stars_marginal[j], bins=30, density=True, edgecolor='black', linewidth=0.5)
    ax.axhline(1.0, ls='--', lw=1.0)  # Uniform(0,1) density reference
    ax.set_xlim(0, 1)
    ax.set_title(f"{param_names[j]} — dist. of α* (marginal)")
    ax.set_xlabel("α"); ax.set_ylabel("Density")
# Hide unused axes
for j in range(D, rows*cols):
    r = j // cols; c = j % cols
    axes[r, c].axis('off')

plt.tight_layout()
plt.show()


In [None]:
### FIGURE CHECK # 5
# Posterior density estimator "resolution" = what is the average difference in a parameter (ground-truth) so that the estimated posteriors are different with probability > 0.9?
# 5.1 = Resolution calculations

# Inputs expected:
# - posterior_samples_thetas_for_held_out_sims: {sim_id: torch.Tensor [n_draws, D]}
# - ground_truth_thetas_for_held_out_sims:     {sim_id: torch.Tensor [1,D] or [D]}
# - input_sim_parameters_to_infer:             list[str]

# Config
NUM_BINS   = 256     # histogram bins for PDF/CDF
BIN_MARGIN = 0.01    # extend min/max by this fraction to avoid edge clipping

# Stable ordering of sims
heldout_ids = sorted(posterior_samples_thetas_for_held_out_sims.keys())
N = len(heldout_ids)
assert N > 1, "Need at least two held-out sims."

# Infer D and param names
first = posterior_samples_thetas_for_held_out_sims[heldout_ids[0]]
D = int(first.shape[1])
param_names = input_sim_parameters_to_infer if 'input_sim_parameters_to_infer' in globals() else [f"theta[{j}]" for j in range(D)]

def to_np(x): # Helper to avoid numpy and pytorch compatibility errors
    """Convert torch.Tensor or array-like to a NumPy array without using tensor.numpy()."""
    if isinstance(x, torch.Tensor):
        return np.asarray(x.detach().cpu().tolist(), dtype=float)
    else:
        return np.asarray(x, dtype=float)

# Stack ground-truth into [N, D]
gt = np.zeros((N, D), dtype=np.float64)
for i, sid in enumerate(heldout_ids):
    th = ground_truth_thetas_for_held_out_sims[sid]   # torch tensor [1,D] or [D]
    th_np = to_np(th)                                 # -> NumPy array
    if th_np.ndim == 2:
        th_np = th_np[0]
    gt[i, :] = th_np

flattened_by_param = []  # for plotting cell; list of dicts {name, delta, pdelta}

for j in tqdm(range(D), desc="Parameters"):
    pname = param_names[j]

    # ---- 1) Global range for this parameter across all sims' posterior samples
    global_min = +np.inf
    global_max = -np.inf
    for sid in heldout_ids:
        x = posterior_samples_thetas_for_held_out_sims[sid][:, j]
        x = to_np(x)   # -> NumPy array
        if x.size == 0: continue
        xmin, xmax = x.min(), x.max()
        if xmin < global_min: global_min = xmin
        if xmax > global_max: global_max = xmax
    xrng = global_max - global_min
    if not np.isfinite(xrng) or xrng <= 0:
        # degenerate safety
        global_min -= 0.5; global_max += 0.5; xrng = global_max - global_min
    edges = np.linspace(global_min - BIN_MARGIN*xrng, global_max + BIN_MARGIN*xrng, NUM_BINS + 1)

    # ---- 2) Build PDFs and CDFs on the common grid (with a progress bar)
    pdf = np.zeros((N, NUM_BINS), dtype=np.float32)
    for i, sid in enumerate(tqdm(heldout_ids, desc=f"[{pname}] histograms", leave=False)):
        x = posterior_samples_thetas_for_held_out_sims[sid][:, j]
        x = to_np(x)   # -> NumPy array
        counts, _ = np.histogram(x, bins=edges)
        pdf[i, :] = counts.astype(np.float32) / max(1, x.shape[0])
    cdf = np.cumsum(pdf, axis=1, dtype=np.float32)  # [N, M]

    # ---- 3) All-pairs p_delta via matrix multiply: p = CDF_A @ PDF_B^T
    pdelta = (cdf @ pdf.T).astype(np.float32)  # [N, N]

    # ---- 4) Ground-truth Δθ (B − A)
    gt_j = gt[:, j]
    dtheta = gt_j[None, :] - gt_j[:, None]      # [N, N]

    # ---- 5) Extract upper-triangle unique pairs and keep only Δθ > 0 (no pΔ filter)
    iu, ju = np.triu_indices(N, k=1)
    x = dtheta[iu, ju].astype(np.float64)
    y = pdelta[iu, ju].astype(np.float64)

    keep = x > 0.0
    x = x[keep]
    y = y[keep]

    flattened_by_param.append({'name': pname, 'delta': x, 'pdelta': y})
    print(f"[{pname}] kept {x.size:,} pairs out of {N*(N-1)//2:,} unique pairs; filter: Δθ>0 only.")


In [None]:
### FIGURE CHECK # 5
# Posterior density estimator "resolution" = what is the average difference in a parameter (ground-truth) so that the estimated posteriors are different with probability > 0.9?
# 5.2 = Resolution figures display

from math import ceil

# Config
p_threshold = 0.90   # your decision threshold for pΔ crossing
moving_average_fraction = 0.03

def moving_average_by_width(x, y, width_fraction=0.10, step_fraction=0.02, min_pts=50):
    """
    Sliding-window mean over x with window = width_fraction * (x_max - x_min).
    Returns centers and means (NaN where insufficient points).
    """
    x = np.asarray(x); y = np.asarray(y)
    xmin, xmax = np.min(x), np.max(x)
    x_rng = xmax - xmin
    if x_rng <= 0:
        return np.array([xmin]), np.array([np.nanmean(y)])
    width = width_fraction * x_rng
    step  = max(width * step_fraction, 1e-12)
    centers = np.arange(xmin + 0.5*width, xmax - 0.5*width + step, step)
    means = np.full(centers.size, np.nan, dtype=float)
    for i, c in enumerate(centers):
        lo, hi = c - 0.5*width, c + 0.5*width
        m = (x >= lo) & (x <= hi)
        if m.sum() >= min_pts:
            means[i] = y[m].mean()
    return centers, means

def first_crossing_x(cx, my, y0):
    """
    Return the leftmost x where the curve my crosses y0 (>=), using linear interpolation.
    Returns None if no crossing.
    """
    if cx.size < 2 or not np.any(np.isfinite(my)):
        return None
    mask = np.isfinite(my)
    cx = cx[mask]; my = my[mask]
    above = my >= y0
    if not np.any(above):
        return None
    idx = np.argmax(above)  # first True
    if idx == 0:
        return float(cx[0])
    # linear interp between (idx-1, idx)
    x0, yA = cx[idx-1], my[idx-1]
    x1, yB = cx[idx],   my[idx]
    if yB == yA:
        return float(cx[idx])
    t = (y0 - yA) / (yB - yA)
    t = np.clip(t, 0.0, 1.0)
    return float(x0 + t*(x1 - x0))

# Plot settings
max_points_to_scatter = 20_000
alpha_scatter = 0.05
s_scatter = 5.0

D = len(flattened_by_param)
cols = min(3, D)
rows = ceil(D / cols)

fig, axes = plt.subplots(rows, cols, figsize=(5.6*cols, 4.8*rows), squeeze=False)

for k, item in enumerate(flattened_by_param):
    r = k // cols
    c = k % cols
    ax = axes[r, c]
    x = item['delta']
    y = item['pdelta']

    # Optional downsample for plot responsiveness
    if x.size > max_points_to_scatter:
        rng = np.random.default_rng(0)
        idx = rng.choice(x.size, size=max_points_to_scatter, replace=False)
        xs, ys = x[idx], y[idx]
    else:
        xs, ys = x, y

    # Scatter
    ax.scatter(xs, ys, s=s_scatter, color='C0', alpha=alpha_scatter)

    # Moving average (10% window)
    cx, my = moving_average_by_width(x, y, width_fraction=moving_average_fraction, step_fraction=moving_average_fraction/3, min_pts=50)
    ax.plot(cx, my, lw=2.0, color='blue', label=f'Moving avg ({moving_average_fraction*100}% width)')

    # Reference lines
    ax.axhline(0.5, ls='--', lw=2.0, color='black', alpha=0.6)
    ax.axhline(p_threshold, ls='--', lw=2.0, color='red', alpha=0.8)

    # Crossing Δ where moving-average hits p_threshold
    x_cross = first_crossing_x(cx, my, p_threshold)
    if x_cross is not None:
        ax.axvline(x_cross, lw=2.0, color='red', alpha=0.3)
        ax.text(x_cross+0.05, p_threshold-0.05 if p_threshold < 0.95 else (p_threshold-0.04),
                f"Δ≈{x_cross:.3g}", va='top', ha='left', fontsize=10, weight='bold', color='black') # rotation=90, va='top', ha='right', fontsize=9)

    ax.set_xlabel(r"$\Delta \theta$ (ground truth)")
    ax.set_ylabel(r"$p_\Delta = \Pr(\theta_B > \theta_A)$")
    ax.set_title(item['name'])
    ax.set_ylim(-0.03, 1.03)
    ax.legend(loc='lower right', fontsize=8)

# Hide any empty axes
for k in range(D, rows*cols):
    r = k // cols; c = k % cols
    axes[r, c].axis('off')

plt.tight_layout()
figname_to_save = f"{path_to_save}\\SENSITIVITY_test"
plt.savefig(f"{figname_to_save}.png", dpi=300)
plt.savefig(f"{figname_to_save}.svg")
plt.show()

# Also print a clean numeric summary of the crossing per parameter
for item in flattened_by_param:
    cx, my = moving_average_by_width(item['delta'], item['pdelta'], width_fraction=0.10, step_fraction=0.01, min_pts=50)
    xc = first_crossing_x(cx, my, p_threshold)
    if xc is None:
        print(f"{item['name']}: moving average never reaches pΔ ≥ {p_threshold:.2f}.")
    else:
        print(f"{item['name']}: detectable Δθ at pΔ ≥ {p_threshold:.2f} ≈ {xc:.4g}")


# SBI ON EXPERIMENTAL DATA
### If you re-run this, even with the same training data (downloaded from repository), the posteriors may be different from reported in the paper (there is variability between estimators based on initial seed)

In [None]:
# Re-train a network, this time on 100% of the training data

# ── Build Torch tensors & prior for SBI ──
# Collect the “feature‐summary” column names in exactly the same order:
summary_colnames = []
for feat in features_for_inference:
    if feat in input_sim_parameters_as_features:
        summary_colnames += [f"{feat}"]
    else:
        for summary_stat in summary_funcs.keys():
            summary_colnames += [f"{feat}_{summary_stat}"]

theta_colnames = input_sim_parameters_to_infer

# Python lists → robust to NumPy bridge issues
x_sim = torch.tensor(
    df_simulation_summary[summary_colnames].to_numpy(dtype=float).tolist(),
    dtype=torch.float32,
)
theta_raw = torch.tensor(
    df_simulation_summary[theta_colnames].to_numpy(dtype=float).tolist(),
    dtype=torch.float32,
)

theta_unit = theta_to_unit(theta_raw)
x_o = torch.randn((1, len(summary_colnames)))  # default observation

low_original = []
high_original = []
low_unit = []
high_unit = []

# ── Build low/high tensors from priors_per_parameters dict ──
low_original  = torch.tensor(
    [priors_per_parameters_to_infer[name][0] for name in theta_colnames],
    dtype=torch.float32,
)
high_original = torch.tensor(
    [priors_per_parameters_to_infer[name][1] for name in theta_colnames],
    dtype=torch.float32,
)

# ── Define a uniform prior on [0,1]^d ──
low_unit  = theta_to_unit(low_original)   # should be ~0
high_unit = theta_to_unit(high_original)  # should be ~1

# ── Normalize your training θ’s ──
theta_unit = theta_to_unit(theta_raw)

prior = sbi_utils.BoxUniform(
    low=torch.zeros_like(low_unit),
    high=torch.ones_like(high_unit),
)

# At this point:
#   • `x_sim`   : (n_sims, n_features)
#   • `theta_unit`: (n_sims, n_theta)
#   • `prior`   : BoxUniform over [0,1]^d

In [None]:
# Train the network on the entire simulated data summary data frame ###############
if rerun_network_training_and_sampling:
    N = df_simulation_summary.shape[0]

    # Train SNPE on (N−nb_samples_to_check) points
    inference = sbi_inference.SNPE(prior=prior, density_estimator=sbi_density_estimator)
    inference.append_simulations(theta_unit, x_sim)
    inference.train(
            num_atoms                  = network_training_hyperparameters['num_atoms'],       # default is 10
            force_first_round_loss     = True,    # start fresh
            training_batch_size        = network_training_hyperparameters['training_batch_size'],     # default is 200
            learning_rate              = network_training_hyperparameters['learning_rate'],  # default is 0.0005
            validation_fraction        = network_training_hyperparameters['validation_fraction'],      # default is 10%
            max_num_epochs             = network_training_hyperparameters['max_num_epochs'],    # train up to 1000 epochs
            stop_after_epochs          = network_training_hyperparameters['stop_after_epochs'],      # but at least train 20
            show_train_summary         = True
        )
    signature = inspect.signature(inference.train)
    print(f"Inference method signature = {signature}")

    # Plot the training results
    plt.plot(figsize=(8,6))
    plt.plot(inference.summary['training_loss'],
                label="Training loss", linewidth=3, color='blue', alpha=0.5)
    plt.plot(inference.summary['validation_loss'],
                label="Validation loss", linewidth=3, color='red', alpha=0.5)
    plt.xlabel("epochs_trained")
    plt.ylabel(f"Loss - negative log-probability\n(neg-log-prob of observation,\ngiven inferred generative parameters)")
    plt.title(f"Training results on SBI neural net\nwith 100% of the simulated observations")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{path_to_save}\\network_training_loss_no_held_out_sims_for_training.png")
    plt.show()

In [None]:
# SAVE OR LOAD POSTERIOR
if rerun_network_training_and_sampling:
    posterior_network_pickle_path = f"{path_to_save}\\posterior_infering_network.pkl"
    # ── build & save the posterior object ──────────────────────────────────────
    posterior = inference.build_posterior()
    # a sampling method may be necessary if sampling_algorithm == 'mcmc', 'vi' or 'si'
    # vi_method = "rKL"  # or fKL
    # posterior = inference.build_posterior(sample_with=sampling_algorithm,
    #                                      vi_method=vi_method)
    # Save posterior
    with open(posterior_network_pickle_path, "wb") as f:
        pickle.dump(posterior, f)
    print(f"✅ Posterior-infering network saved to '{posterior_network_pickle_path}'")
else:
    # Load the existing network
    posterior_network_pickle_path = f"{path_to_load}\\posterior_infering_network.pkl"
    with open(posterior_network_pickle_path, 'rb') as f:
        posterior = pickle.load(f)
    print(f"✅ Posterior-infering network loaded from '{posterior_network_pickle_path}'")

## POSTERIOR ESTIMATE PER PARTICIPANT
### (not reported in paper)
### One posterior distribution per [Muscle_pair × Intensity × Participant]

In [None]:
# if using single-muscle posteriors as features, (if len(input_sim_parameters_as_features) >= 1, right now they should all be filled with zeros as placeholder values)
# augment df_experiment_summary by duplicating each row N times and assigning N samples of posterior-estimated parameters used for inference to those rows
# Thus, the neural network will infer N posterior per condition
# N = previously_estimated_posterior_samples_for_experimental_data_SBI (already defined)
match_by = ['subject', 'muscle', 'intensity']
# map_strings_of_posterior_estimated_parameters_to_param_used_as_features = {
#     "disynpatic_inhib_connections_desired_MN_MN": "disynpatic_inhib_connections_desired_MN_MN_self",
#     "common_input_high_freq_middle_of_range": "common_input_high_freq_middle_of_range_self",
#     "common_input_high_freq_half_width_range": "common_input_high_freq_half_width_range_self",
#     "common_input_std": "common_input_std_self"
# }
df_experiment_summary_augmented = df_experiment_summary.copy()
if len(input_sim_parameters_as_features) >= 1:
    df_experiment_summary_augmented['muscle'] = df_experiment_summary['muscle_pair'].str.split("<->").str[0]
    # Load csv with reviously estimated posterior
    df_previous_posterior_each_subject = pd.read_csv(f"{previously_estimated_posterior_results_path}{previously_estimated_posterior_each_subject_csv}")
    df_previous_posterior_each_subject['muscle'] = df_previous_posterior_each_subject['muscle_pair'].str.split("<->").str[0]
    df_previous_posterior_each_subject.rename(columns=map_strings_of_posterior_estimated_parameters_to_param_used_as_features, inplace=True)
    # Apply standardization (determined earlier, ctrl+F "apply_standardization()") to the posterior samples of features to be used as features
    apply_standardization(df_previous_posterior_each_subject, norm_stats, input_sim_parameters_as_features)

    # Build a lookup dict of posterior pools per (subject, muscle, intensity)
    posterior_pool = {}
    for gkey, gdf in df_previous_posterior_each_subject.groupby(match_by, dropna=False):
        posterior_pool[gkey] = gdf.reset_index(drop=True)

    # Duplicate each experimental row N times and fill with matched posterior draws
    aug_rows = []
    rng = np.random.default_rng()  # set a seed here if you want reproducibility, e.g., np.random.default_rng(123)

    for idx, row in df_experiment_summary_augmented.iterrows():
        gkey = tuple(row[k] for k in match_by)
        pool = posterior_pool.get(gkey, None)

        # Make N copies of the experimental row
        block = pd.DataFrame([row.values] * previously_estimated_posterior_samples_for_experimental_data_SBI,
                             columns=df_experiment_summary_augmented.columns)
        block['posterior_draw'] = np.arange(previously_estimated_posterior_samples_for_experimental_data_SBI,
                                            dtype=int)  # useful to track Monte Carlo runs

        if (pool is None) or pool.empty:
            # No matching posterior samples: keep placeholders (zeros/NaNs) and warn
            for c in input_sim_parameters_as_features:
                # keep existing placeholders in df_experiment_summary_augmented or explicitly set NaN
                if c not in block.columns:
                    block[c] = np.nan
            # (optional) print/log a warning
            # print(f"[WARN] No posterior samples for key={gkey}; leaving {input_sim_parameters_as_features} as-is.")
        else:
            # Sample N posterior rows with replacement for these features
            sampled = pool.sample(n=previously_estimated_posterior_samples_for_experimental_data_SBI,
                                  replace=True, random_state=rng.integers(0, 1_000_000))[input_sim_parameters_as_features].reset_index(drop=True)

            # Ensure destination columns exist, then assign
            for c in input_sim_parameters_as_features:
                if c not in block.columns:
                    block[c] = np.nan
            block.loc[:, input_sim_parameters_as_features] = sampled.values

        aug_rows.append(block)

    df_experiment_summary_augmented = pd.concat(aug_rows, ignore_index=True)

In [None]:
df_experiment_summary_augmented # Check result of previous cell

In [None]:
# --- Choose which dataframe to use ---
df_for_inference = (
    df_experiment_summary_augmented
    if len(input_sim_parameters_as_features) >= 1
    else df_experiment_summary
)

# --- Config ---
keys_cols    = ['subject', 'muscle_pair', 'intensity']   # condition identity
feature_cols = summary_colnames                          # observed + posterior-as-features cols
num_samples  = (num_posterior_samples['experiment_with_posterior_estimates_as_features']
                if len(input_sim_parameters_as_features) >= 1
                else num_posterior_samples['experiment'])
best_method  = best_posterior_estimate_method            # "logp" or "knn"
device       = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def to_np(x): # Helper to avoid incompatibility between numpy and pytorch
    """Convert torch.Tensor or array-like to a NumPy array without using tensor.numpy()."""
    if isinstance(x, torch.Tensor):
        return np.asarray(x.detach().cpu().tolist(), dtype=float)
    else:
        return np.asarray(x, dtype=float)

def _densest_by_knn(samples_np, k=5):
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(samples_np)
    dist, _ = nbrs.kneighbors(samples_np)
    kth = dist[:, k]   # distance to k-th non-self neighbor
    return int(np.argmin(kth))

def infer_posteriors_by_condition_looped(
    df_obs,
    posterior,
    feature_cols,
    keys_cols,
    num_samples=1000,
    best_method="logp",          # "logp" or "knn"
    show_progress=True,
    transform_unit_to_theta=True,
):
    """
    For each (subject, muscle_pair, intensity), loop over each duplicated row
    and sample from p(theta | x_row). Concatenate samples to create a mixture
    across the duplicates for that condition.
    """
    posterior_samples_dict = {}
    posterior_logp_dict    = {}
    best_samples_dict      = {}

    device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    grouped = df_obs.groupby(keys_cols, dropna=False)

    outer_iter = grouped
    total_groups = grouped.ngroups if hasattr(grouped, "ngroups") else None
    if show_progress:
        outer_iter = tqdm(grouped, total=total_groups, desc="Conditions", leave=True)

    for gkey, gdf in outer_iter:
        # Print condition + number of duplicate rows
        if show_progress:
            outer_iter.write(f"Condition {gkey}: {len(gdf)} rows, sampling {num_samples} / row")

        all_samps = []
        all_logp  = []   # only used if best_method == "logp"

        inner_iter = gdf.iterrows()
        if show_progress:
            inner_iter = tqdm(inner_iter, total=len(gdf), desc=f"{gkey}", leave=False)

        for _, row in inner_iter:
            # x row as Python list → torch tensor (no from_numpy)
            x_vec = row[feature_cols].to_numpy(dtype=float).tolist()  # length Dx
            x_t   = torch.tensor([x_vec], dtype=torch.float32, device=device)  # shape (1, Dx)

            with torch.no_grad():
                # sample for this single row
                s = posterior.sample((num_samples,), x=x_t, show_progress_bars=False)  # (N, Dθ)
                if best_method == "logp":
                    lp = posterior.log_prob(s, x=x_t)  # (N,)

            if transform_unit_to_theta:
                s = unit_to_theta(s)  # your helper

            # Store as NumPy via safe helper
            all_samps.append(to_np(s))   # (N, Dθ)
            if best_method == "logp":
                all_logp.append(to_np(lp))   # (N,)

        samples_flat = np.vstack(all_samps)      # (N * #rows, Dθ)

        # Decide best index and logp storage depending on method
        if best_method == "logp":
            logp_flat = np.concatenate(all_logp)  # (N * #rows,)
            posterior_logp_dict[gkey] = logp_flat
            best_idx = int(np.argmax(logp_flat))
        elif best_method == "knn":
            posterior_logp_dict[gkey] = np.array([])
            best_idx = _densest_by_knn(samples_flat, k=5)
        else:
            raise ValueError(f"Unknown best_method='{best_method}'")

        posterior_samples_dict[gkey] = samples_flat
        best_samples_dict[gkey]      = samples_flat[best_idx]

    return posterior_samples_dict, posterior_logp_dict, best_samples_dict


# --- Run inference ---
posterior_samples_dict, posterior_logp_dict, best_samples_dict = infer_posteriors_by_condition_looped(
    df_obs=df_for_inference,
    posterior=posterior,
    feature_cols=feature_cols,
    keys_cols=keys_cols,
    num_samples=100, # 100 # num_samples # 100 to make things faster
    best_method="knn", # "knn" # best_method # "knn" to make things faster
    show_progress=True,
    transform_unit_to_theta=True,   # set False if you prefer to transform later
)
# ^ This can be extremely slow if some experimental observations are outside the coverage of the density estimator
# Replaced by num_samples=100 and best_method="knn"

# (Optional) If you prefer to delay unit->theta transform, you can run it here:
# for k, arr in posterior_samples_dict.items():
#     arr_torch = torch.tensor(arr.tolist(), dtype=torch.float32)
#     posterior_samples_dict[k] = to_np(unit_to_theta(arr_torch))
#
# for k, arr in best_samples_dict.items():
#     arr_torch = torch.tensor(arr.tolist(), dtype=torch.float32)
#     best_samples_dict[k] = to_np(unit_to_theta(arr_torch))


In [None]:
# Saving the posteriors (per-participant)

if rerun_network_training_and_sampling:
    posterior_estimates_pickle_path = f"{path_to_save}\\posterior_estimates_each_subject.pkl"
    # Save the samples, the log proabilities, and the best samples (highest logp)
    posterior_estimates = {
        "posterior_samples": posterior_samples_dict,
        "posterior_logp": posterior_logp_dict,
        "best_samples": best_samples_dict,
    }
    with open(posterior_estimates_pickle_path, "wb") as f:
        pickle.dump(posterior_estimates, f)
    print(f"✅ Posterior estimate samples saved to '{posterior_estimates_pickle_path}'")
else:
    posterior_estimates_pickle_path = f"{path_to_load}\\posterior_estimates_each_subject.pkl"
    with open(posterior_estimates_pickle_path, "rb") as f:
        posterior_estimates = pickle.load(f)
    print(f"✅ Posterior estimate samples loaded from '{posterior_estimates_pickle_path}'")

## POSTERIOR ESTIMATE, PARTICIPANTS POOLED
### (results reported in paper = experimental observations of all participants' motor units pooled together)
### One posterior distribution per [Muscle_pair × Intensity]

In [None]:
# # SAMPLING - same as before, but with subjects grouped together now
# DOES NOT SUPPORT 'input_sim_parameters_as_features' YET - CHECK IF IT STILL WORKS FOR MUSCLE PAIRS INFERENCE

# --- Choose which dataframe to use ---
# df_for_inference = (
#     df_experiment_summary_grouped_subjects
#     if len(input_sim_parameters_as_features) >= 1
#     else df_experiment_summary
# )
df_for_inference = df_experiment_summary_grouped_subjects

# --- Config ---
keys_cols    = ['muscle_pair', 'intensity']   # condition identity
feature_cols = summary_colnames               # observed + posterior-as-features cols
# num_samples  = (num_posterior_samples['experiment_with_posterior_estimates_as_features']
#                 if len(input_sim_parameters_as_features) >= 1
#                 else num_posterior_samples['experiment'])
num_samples = num_posterior_samples['experiment']
best_method  = best_posterior_estimate_method            # "logp" or "knn"
device       = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Run inference ---
posterior_samples_dict_subjects_grouped, posterior_logp_dict_subjects_grouped, best_samples_dict_subjects_grouped = infer_posteriors_by_condition_looped(
    df_obs=df_for_inference,
    posterior=posterior,
    feature_cols=feature_cols,
    keys_cols=keys_cols,
    num_samples=num_samples,
    best_method=best_method,
    show_progress=True,
    transform_unit_to_theta=True,   # set False if you prefer to transform later
)

# (Optional) If you prefer to delay unit->theta transform, you can run it here:
# for k, arr in posterior_samples_dict.items():
#     posterior_samples_dict[k] = unit_to_theta(torch.from_numpy(arr)).numpy()
# for k, arr in best_samples_dict.items():
#     best_samples_dict[k] = unit_to_theta(torch.from_numpy(arr)).numpy()


In [None]:
if rerun_network_training_and_sampling:
    posterior_estimates_pickle_path = f"{path_to_save}\\posterior_estimates_subjects_grouped.pkl"
    # Save the samples, the log proabilities, and the best samples (highest logp)
    posterior_estimates_subjects_grouped = {
        "posterior_samples": posterior_samples_dict_subjects_grouped,
        "posterior_logp": posterior_logp_dict_subjects_grouped,
        "best_samples": best_samples_dict_subjects_grouped,
    }
    with open(posterior_estimates_pickle_path, "wb") as f:
        pickle.dump(posterior_estimates_subjects_grouped, f)
    print(f"✅ Posterior estimate samples (for grouped subjects) saved to '{posterior_estimates_pickle_path}'")
else:
    posterior_estimates_pickle_path = f"{path_to_load}\\posterior_estimates_subjects_grouped.pkl"
    with open(posterior_estimates_pickle_path, "rb") as f:
        posterior_estimates_subjects_grouped = pickle.load(f)
    print(f"✅ Posterior estimate samples (for grouped subjects) loaded from '{posterior_estimates_pickle_path}'")

### Posterior estimates - Generate plots
#### Not inline; saves the plots in path_to_save folder
#### Those are just plots to see what's going on; the final plots are generated in another scripts (based on the saved .csv files containing the posterior samples)

In [None]:
# PATCHING sbi plt_hist_1d function used in pairplot() to fix numpy VS Pytorch compatibility issue
import copy
import numpy as np
from scipy.stats import iqr
import sbi.analysis.plot as sbi_plot  # where pairplot & plt_hist_1d live

def safe_plt_hist_1d(ax, samples, limits, diag_kwargs):
    """
    Replacement for sbi.analysis.plot.plt_hist_1d that:
      - uses Freedman–Diaconis per dimension (if requested),
      - converts torch limits -> Python floats,
      - never triggers torch.Tensor.__array__ (so works with NumPy>=2 + PyTorch>=2.5).
    """
    # Original: hist_kwargs = copy.deepcopy(diag_kwargs["mpl_kwargs"])
    hist_kwargs = copy.deepcopy(diag_kwargs["mpl_kwargs"])

    # Convert limits (torch tensor of shape (2,)) -> floats
    lo = float(limits[0])
    hi = float(limits[1])

    # Current bins setting (may or may not be present)
    bins = hist_kwargs.get("bins", None)

    # If no bins specified, apply a FD-style heuristic or fallback
    if bins is None:
        heuristic = diag_kwargs.get("bin_heuristic", "Freedman-Diaconis")

        if heuristic == "Freedman-Diaconis":
            # FD rule on this 1D samples array
            bw = 2.0 * iqr(samples) * (len(samples) ** (-1.0 / 3.0))

            if not np.isfinite(bw) or bw <= 0.0:
                # Fallback: fixed bin count
                n_bins = 50
            else:
                span = hi - lo
                if span <= 0:
                    span = 1.0
                n_bins = int(np.clip(np.ceil(span / bw), 5, 200))
        else:
            # Other heuristics could go here; for now, just fixed
            n_bins = 50

        # Turn bin count into edges
        bins = np.linspace(lo, hi, n_bins + 1)

    # If user gave an integer, turn it into edges (again: with float limits)
    if isinstance(bins, int):
        bins = np.linspace(lo, hi, bins + 1)

    hist_kwargs["bins"] = bins
    ax.hist(samples, **hist_kwargs)


# Monkey-patch sbi
sbi_plot.plt_hist_1d = safe_plt_hist_1d


### Posterior estimates - per participant
(not showing them inline in the jupyter notebook)

In [None]:
path_to_save_obs_pairplots = f"{path_to_save}\\obs_pairplots_each_subject"
os.makedirs(path_to_save_obs_pairplots, exist_ok=True)

n_dim = len(input_sim_parameters_to_infer)

for obs_key in posterior_estimates['posterior_samples'].keys():
    muscle_pair_i = obs_key[1]
    diagonal_density_color = muscle_colors_dict[muscle_pair_i]
    off_diagonal_2d_density_colormap = muscle_colormaps_dict[muscle_pair_i]

    # Ensure posterior_samples is NumPy (good practice, though sbi can handle both)
    posterior_samples = posterior_estimates['posterior_samples'][obs_key]
    posterior_samples = to_np(posterior_samples)  # (N, D)
    best_theta = posterior_estimates['best_samples'][obs_key]

    fig, axes = pairplot(
        posterior_samples,
        limits=[[low_original[i].item(), high_original[i].item()] for i in range(n_dim)],
        diag_kwargs={
            "mpl_kwargs": {
                "color": diagonal_density_color,
                # no 'bins' here; our patched plt_hist_1d will handle it
            },
            # You can keep using the default FD heuristic:
            "bin_heuristic": "Freedman-Diaconis",
        },
        upper_kwargs={"mpl_kwargs": {"cmap": off_diagonal_2d_density_colormap}},
        labels=input_sim_parameters_to_infer,
        figsize=(7, 7),
    )

    # 2) overlay best estimate (best_theta) on every subplot:
    best_linestyle = dict(color="#000000", linestyle="-", linewidth=1.5, alpha=1)
    best_marker    = dict(marker="X", s=50, edgecolor="#000000",
                          facecolor="white", linewidth=1.0)

    added_best_legend = False

    for i in range(n_dim):
        # diagonal: vertical line at best_theta[i]
        ax = axes[i, i]
        if not added_best_legend:
            ax.axvline(best_theta[i], **best_linestyle,
                       label="θ̂ (posterior with highest prob)")
            added_best_legend = True
        else:
            ax.axvline(best_theta[i], **best_linestyle)

        # off-diagonal (upper triangle)
        for j in range(n_dim):
            if i >= j:
                continue
            ax_off = axes[i, j]
            if not added_best_legend:
                ax_off.scatter(best_theta[j], best_theta[i], **best_marker,
                               label="θ̂ (posterior with highest prob)")
                added_best_legend = True
            else:
                ax_off.scatter(best_theta[j], best_theta[i], **best_marker)

    axes[0, 0].legend(loc="upper right", fontsize="small")

    fig.suptitle(f"Posterior estimate for {obs_key}", fontsize=14)
    plt.tight_layout()

    sanitized_filename = re.sub(r'[^0-9A-Za-z._-]+', '', str(obs_key))
    sanitized_filename = sanitized_filename.replace('np.float64', '')
    out_path = os.path.join(path_to_save_obs_pairplots, f"{sanitized_filename}.png")
    plt.savefig(out_path, dpi=300)
    plt.close()


In [None]:
list_dfs_temp = []

for (subject, muscle_pair, intensity), arr in posterior_estimates['posterior_samples'].items():
    # Make sure 'arr' is a plain NumPy array, not a torch.Tensor
    arr_np = to_np(arr)   # or np.asarray(arr.detach().cpu().tolist(), dtype=float) if you prefer inline

    # Turn into DataFrame with the right column names
    df_i = pd.DataFrame(arr_np, columns=input_sim_parameters_to_infer)

    df_i['subject']     = subject
    df_i['muscle_pair'] = muscle_pair
    df_i['intensity']   = intensity
    
    list_dfs_temp.append(df_i)

# concatenate them all into one big long DataFrame
df_kde_posterior_samples = pd.concat(list_dfs_temp, ignore_index=True)

# Save the data frame
df_kde_posterior_samples_path = f"{path_to_save}\\posterior_samples_each_subject_df.csv"
df_kde_posterior_samples.to_csv(df_kde_posterior_samples_path, index=False)

print(f"✅ Saved posterior samples to: {df_kde_posterior_samples_path}")


### Posterior estimates - participants pooled
(not showing them inline in the jupyter notebook)

In [None]:
### PLOTTING THE POSTERIOR ESTIMATES FOR EACH EXPERIMENTAL OBSERVATION, WITH SUBJECTS GROUPED

# Create folder to save experimental observations pairplots
path_to_save_obs_pairplots = f"{path_to_save}\\obs_pairplots_subjects_grouped"
os.makedirs(path_to_save_obs_pairplots, exist_ok=True)

n_dim = len(input_sim_parameters_to_infer)

for obs_key in posterior_estimates_subjects_grouped['posterior_samples'].keys():
    muscle_pair_i = obs_key[0]
    diagonal_density_color = muscle_colors_dict[muscle_pair_i]
    off_diagonal_2d_density_colormap = muscle_colormaps_dict[muscle_pair_i]

    # Ensure NumPy arrays (avoid torch→numpy inside sbi.pairplot)
    posterior_samples = posterior_estimates_subjects_grouped['posterior_samples'][obs_key]
    posterior_samples_np = to_np(posterior_samples)   # <— key change

    # Best-θ as NumPy as well
    best_theta = posterior_estimates_subjects_grouped['best_samples'][obs_key]
    best_theta = to_np(best_theta).reshape(-1)        # <— key change

    # 1) draw the base pairplot (pass NumPy, not torch)
    fig, axes = pairplot(
        posterior_samples_np,
        limits=[[low_original[i].item(), high_original[i].item()] for i in range(n_dim)],
        diag_kwargs={
            "mpl_kwargs": {
                "color": diagonal_density_color,
            },
            "bin_heuristic": "Freedman-Diaconis",
        },
        upper_kwargs={"mpl_kwargs": {"cmap": off_diagonal_2d_density_colormap}},
        labels=input_sim_parameters_to_infer,
        figsize=(7, 7),
    )

    # 2) overlay best estimate (best_theta) on every subplot:
    best_linestyle = dict(color="#000000",  linestyle="-", linewidth=1.5, alpha=1)
    best_marker    = dict(marker="X", s=50, edgecolor="#000000", facecolor="white", linewidth=1)

    added_best_legend = False

    for i in range(n_dim):
        ax = axes[i, i]
        if not added_best_legend:
            ax.axvline(best_theta[i], **best_linestyle, label="θ̂ (posterior with highest prob)")
            added_best_legend = True
        else:
            ax.axvline(best_theta[i], **best_linestyle)

        for j in range(n_dim):
            if i >= j:  # only upper triangle
                continue
            ax_off = axes[i, j]
            if not added_best_legend:
                ax_off.scatter(best_theta[j], best_theta[i], **best_marker,
                               label="θ̂ (posterior with highest prob)")
                added_best_legend = True
            else:
                ax_off.scatter(best_theta[j], best_theta[i], **best_marker)

    axes[0, 0].legend(loc="upper right", fontsize="small")

    fig.suptitle(f"Posterior estimate for {obs_key}", fontsize=14)
    plt.tight_layout()
    sanitized_filename = re.sub(r'[^0-9A-Za-z._-]+', '', str(obs_key))
    sanitized_filename = sanitized_filename.replace('np.float64', '')
    plt.savefig(f"{path_to_save_obs_pairplots}\\{sanitized_filename}.png", dpi=300)
    plt.close()


In [None]:
# Create data frame for kernel density estimates and plotting for the latent parameters of interest (at the level of subjects grouped together)
# Full marginals (= considering the full distribution over the other latent parameters)
list_dfs_temp = []
for (muscle_pair, intensity), arr in posterior_estimates_subjects_grouped['posterior_samples'].items():
    # Ensure NumPy array (avoid torch→numpy issue inside pandas)
    arr_np = to_np(arr)
    # turn the array into a DataFrame with the right column names
    df_i = pd.DataFrame(arr_np, columns=input_sim_parameters_to_infer)

    df_i['muscle_pair'] = muscle_pair
    df_i['intensity']   = intensity
    
    list_dfs_temp.append(df_i)

# concatenate them all into one big long DataFrame
df_kde_posterior_samples_subjects_grouped = pd.concat(list_dfs_temp, ignore_index=True)

# Save the data frame
df_kde_posterior_samples_subjects_grouped_path = f"{path_to_save}\\posterior_samples_subjects_grouped_df.csv"
df_kde_posterior_samples_subjects_grouped.to_csv(df_kde_posterior_samples_subjects_grouped_path, index=False)

In [None]:
# Create data frame for the best (highest probability density = posterior mode) sample, per subject, per muscle pair, per intensity
list_dfs_temp = []
for key in posterior_estimates['best_samples'].keys():
    arr = posterior_estimates['best_samples'][key]
    df_i = pd.DataFrame({
        "subject": [key[0]],
        "muscle_pair": [key[1]],
        "intensity": [key[2]]})
    for i, param_name in enumerate(input_sim_parameters_to_infer):
        df_i[param_name] = arr[i]
    list_dfs_temp.append(df_i)
df_best_sample = pd.concat(list_dfs_temp, axis=0)

##### POSTERIOR FIGURES ARE GENERATED IN A SEPRATAE SCRIPT FROM THE SAVED .csv FILES

# Posterior predictive checks
### CHECK IF ESTIMATED POSTERIOR PARAMETERS, WHEN USED AS SIMULATION PARAMETERS, REPRODUCE THE EXPERIMENTAL DATA
GET A FEW SAMPLES FROM THE POSTERIOR AND SIMULATE FROM IT = check if this reproduces the experimental data it is supposed to match

In [None]:
# Libraries and helper functions
# Import libraries and modules
import sys
from pathlib import Path
# Add the parent directory of this notebook to sys.path
parent_dir = Path().resolve().parent
if str(parent_dir) not in sys.path:
    sys.path.insert(0, str(parent_dir))
from simulator import SimulationParameters, run_simulation

from brian2 import *
import copy
from copy import deepcopy
from dataclasses import fields, dataclass, asdict
# Import for parallelization
import getpass
import psutil
from joblib import Parallel, delayed
from brian2 import prefs, device
import logging
from threading import Thread, Event
import time
from datetime import datetime
from os import listdir
from os.path import isfile, join

In [None]:
# Helper function needed when posterior-based simulation of 2 pools (muscle pairs) and not 1
def expand_params(params: list[str]) -> list[str]:
    out = []
    for s in params:
        if s.startswith('disynpatic_inhib_'):  # note: output fixes the typo -> disynaptic
            if s.endswith('_self'):
                out += [
                    'disynaptic_inhib_self_connectivity_pool0',
                    'disynaptic_inhib_self_connectivity_pool1',
                ]
            elif s.endswith('_other_pool'):
                out += [
                    'disynaptic_inhib_connectivity_pool0_to_pool1',
                    'disynaptic_inhib_connectivity_pool1_to_pool0',
                ]
            else:
                out.append(s)
        elif s.endswith('_self'):
            base = s[:-5]  # strip '_self'
            out += [f'{base}_pool0', f'{base}_pool1']
        else:
            out.append(s)
    return out


In [None]:
sim_parallel_cpus = 16
n_sims_per_condition = 1 # 100 # Determine the number of samples (parameters) to draw from the posterior. All those parameter will be used for posterior-predictive simulations.
muscle_pairs_posterior_predictive_checks = False # Set to True for the BETWEEN MUSCLES CONFIG
# SINGLE MUSCLE CONFIG
if not muscle_pairs_posterior_predictive_checks:
    priors_filename = f"C:\\Users\\franc\\Documents\\GitHub\\Mapping_Recurrent_Inhibition\\Simulation_parameters\\FIG_4_Simulation_based_inference_training_dataset\\single muscles training data\\Simulation_single_muscle_priors_batch0.pkl" # used to extract the fixed parameters used for the simulations
# MUSCLE PAIRS CONFIG
else:
    folder_sims_param_muscle_pairs = f"C:\\Users\\franc\\Documents\\GitHub\\Mapping_Recurrent_Inhibition\\Simulation_parameters\\FIG_4_Simulation_based_inference_training_dataset\\muscle pairs training data"
    muscle_pairs_param_files = [f for f in listdir(folder_sims_param_muscle_pairs) if isfile(join(folder_sims_param_muscle_pairs, f))]
    priors_filepath_each = {}
    for f in muscle_pairs_param_files:
        if 'inference' in f:
            continue
        priors_filepath_each[f] = f"{folder_sims_param_muscle_pairs}//{f}" # Can be any from the paired-muscles simulations, since the muscle-pair specific priors are replaced with input_sim_parameters_as_features

# Set the mapping between the data frame column names (= actual parameter field names from SimulationParameters()) and the parameter names used later in the script to create the list of parameters to be simulated in this batch
simulate_per_subject_or_subjects_grouped = "subjects_grouped" # "per_subject" or "subjects_grouped"

variable_parameters = input_sim_parameters_to_infer + input_sim_parameters_as_features
if muscle_pairs_posterior_predictive_checks: # In the case of simulating 2 pools, the variable parameters have to be expanded
    variable_parameters = expand_params(variable_parameters)

# Create folder to save the new simulations coming from the most likely posteriors
path_to_save_posterior_predictive_check_sims = f"{path_to_save}\\posterior_predictive_checks_{simulate_per_subject_or_subjects_grouped}"
# path_to_save_posterior_predictive_check_sims = f"{simulated_data_path_for_priors_of_posterior_predictive_checks}\\posterior_predictive_checks_{simulate_per_subject_or_subjects_grouped}"
os.makedirs(f"{path_to_save_posterior_predictive_check_sims}", exist_ok=True)

# Get fixed parmeters for the simulation = find the simulation parameters in the first simulation output
# SINGLE MUSCLE CASE
if not muscle_pairs_posterior_predictive_checks:
    # Get fixed parmeters for the simulation = find the simulation parameters in the first simulation output
    with open(priors_filename, "rb") as f:
        batch_sim_priors = pickle.load(f)
    # variable_parameters = [str(k) for k, v in batch_sim_priors['free_parameter_bounds'].items()]
    fixed_parameters = batch_sim_priors['params_prior_list'][0] # The full list or the first instance [0] of SimulationParameters objects that were simulated in the batch used for network training for SBI
    # fixed_parameters = batch_sim_priors['fixed_parameters']
# BETWEEN MUSCLE PAIRS CASE
else:
    batch_sim_priors_each = {}
    for key, filename_path in priors_filepath_each.items():
        # Expand batch_sim_priors_each to have both directions (VM-VL_intensityXX and VL_VM_intensityXX for instance) by looping into filename_path twice, and just reversing all pool0 and pool1 values from the priors_from_posterior dataframe
        with open(filename_path, "rb") as f:
            batch_sim_priors_each[key] = pickle.load(f)
            fixed_parameters = batch_sim_priors_each[key]['params_prior_list'][0] # The baseline fixed parameters are the same for all simulations, so could come from any loaded batch of priors


In [None]:
# MANY other helper functions used in the case of simulating pairs of muscles
from typing import Dict, Tuple, Optional

ALIAS_COLUMNS = {
    "std_of_second_common_input_pool0": "common_input_std_pool0",
    "std_of_second_common_input_pool1": "common_input_std_pool1",
}
POOL_SUFFIXES = ("_pool0", "_pool1")

def split_dir_pair(pair_str: str) -> Tuple[str, str]:
    if "<->" in pair_str:
        a, b = pair_str.split("<->")
    elif "-" in pair_str:
        a, b = pair_str.split("-")
    else:
        raise ValueError(f"Unrecognized pair string: {pair_str}")
    return a.strip(), b.strip()

def undirected_key(a: str, b: str) -> str:
    return f"{a}-{b}"

def batch_key_for_priors(a: str, b: str, intensity: int) -> str:
    return f"{a}-{b}_intensity{int(intensity)}"

def find_priors_entry(batch_sim_priors_each: Dict, a: str, b: str, intensity: int):
    key_ab = batch_key_for_priors(a, b, intensity)
    key_ba = batch_key_for_priors(b, a, intensity)
    if key_ab in batch_sim_priors_each:
        return batch_sim_priors_each[key_ab], "ab"
    elif key_ba in batch_sim_priors_each:
        return batch_sim_priors_each[key_ba], "ba"
    else:
        raise KeyError(f"No priors found for {a}-{b} (or {b}-{a}) @ intensity {intensity}")

def swap_pool_cols_inplace(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    cols = set(df.columns)
    pairs, seen, uniq = [], set(), []
    for c in list(cols):
        if c.endswith("_pool0"):
            other = c[:-6] + "_pool1"
            if other in cols:
                pairs.append((c, other))
        elif c.endswith("_pool1"):
            other = c[:-6] + "_pool0"
            if other in cols:
                pairs.append((other, c))
    for a, b in pairs:
        if (a, b) not in seen and (b, a) not in seen:
            uniq.append((a, b)); seen.add((a, b))
    for a, b in uniq:
        df[a], df[b] = df[b].copy(), df[a].copy()
    return df

def filter_priors_df(df: pd.DataFrame, intensity: int, subject: Optional[str]) -> pd.DataFrame:
    g = df.copy()
    if "intensity" in g.columns:
        g = g[g["intensity"].astype(int) == int(intensity)]
    if subject is not None and "subject" in g.columns:
        g = g[g["subject"] == subject]
    if g.empty:
        raise ValueError(f"priors_from_posterior_df empty after filtering (intensity={intensity}, subject={subject})")
    return g

def build_expanded_between_pool_posteriors(
    posterior_estimates: Dict[str, Dict],
    batch_sim_priors_each: Dict,
    variable_parameters: list[str],
    sample_priors_with_replacement: bool = True,
    seed: Optional[int] = None,
    subject_by_condition: Optional[Dict[Tuple[str, int], str]] = None,
):
    """
    Merge directions (A<->B, B<->A) into undirected (A-B), expand to variable_parameters,
    and fill remaining pool-specific params from priors. If `sample_priors_with_replacement`
    is True, draw priors rows with replacement per sample (N) to introduce variability.

    Returns:
      tensors = {
        'posterior_samples': { (pair, intensity): torch.FloatTensor [N, P] },
        'best_samples':      { (pair, intensity): torch.FloatTensor [P]     },
      }
      df_all : pandas.DataFrame over all conditions
    """
    rng = np.random.default_rng(seed)

    # Collect directional arrays by undirected condition
    dir_samples, dir_best = {}, {}
    for (pair_str, intens), arr in posterior_estimates["posterior_samples"].items():
        a, b = split_dir_pair(pair_str)
        key = (undirected_key(a, b), int(float(intens)))
        dir_samples.setdefault(key, {})["ab"] = arr
    for (pair_str, intens), arr in posterior_estimates["best_samples"].items():
        a, b = split_dir_pair(pair_str)
        key = (undirected_key(a, b), int(float(intens)))
        dir_best.setdefault(key, {})["ab"] = arr
    # Try to attach reverse if present
    for (pair_str, intens), arr in posterior_estimates["posterior_samples"].items():
        b, a = split_dir_pair(pair_str)   # flip parsing to find reverse
        key = (undirected_key(a, b), int(float(intens)))
        if key in dir_samples and "ba" not in dir_samples[key]:
            dir_samples[key]["ba"] = arr
    for (pair_str, intens), arr in posterior_estimates["best_samples"].items():
        b, a = split_dir_pair(pair_str)
        key = (undirected_key(a, b), int(float(intens)))
        if key in dir_best and "ba" not in dir_best[key]:
            dir_best[key]["ba"] = arr

    tensors = {"posterior_samples": {}, "best_samples": {}}
    df_rows = []
    P = len(variable_parameters)
    col_idx = {v: i for i, v in enumerate(variable_parameters)}

    # Posterior-derived names (the rest come from priors)
    POST_NAMES = {
        "excitatory_input_baseline_pool0",
        "excitatory_input_baseline_pool1",
        "disynaptic_inhib_connectivity_pool0_to_pool1",
        "disynaptic_inhib_connectivity_pool1_to_pool0",
        "between_pool_excitatory_input_correlation",
    }

    for (pair, intensity), sides in dir_samples.items():
        a, b = pair.split("-")
        subj = None if subject_by_condition is None else subject_by_condition.get((pair, intensity), None)

        pri_entry, used_order = find_priors_entry(batch_sim_priors_each, a, b, intensity)
        pri_df = pri_entry["priors_from_posterior_df"]
        pri_df = pri_df.rename(columns=ALIAS_COLUMNS)
        pri_df = pri_df if used_order == "ab" else swap_pool_cols_inplace(pri_df)
        pri_df = filter_priors_df(pri_df, intensity=intensity, subject=subj)

        arr_ab = sides.get("ab")
        arr_ba = sides.get("ba")
        if arr_ab is None and arr_ba is None:
            continue
        if arr_ab is None: arr_ab = arr_ba
        if arr_ba is None: arr_ba = arr_ab

        N = min(arr_ab.shape[0], arr_ba.shape[0])
        arr_ab = arr_ab[:N, :]
        arr_ba = arr_ba[:N, :]

        X = np.full((N, P), np.nan, dtype=np.float32)

        # Posterior-derived fills
        if "excitatory_input_baseline_pool0" in col_idx:
            X[:, col_idx["excitatory_input_baseline_pool0"]] = arr_ab[:, 0]
        if "excitatory_input_baseline_pool1" in col_idx:
            X[:, col_idx["excitatory_input_baseline_pool1"]] = arr_ba[:, 0]
        if "disynaptic_inhib_connectivity_pool0_to_pool1" in col_idx:
            X[:, col_idx["disynaptic_inhib_connectivity_pool0_to_pool1"]] = arr_ab[:, 1]
        if "disynaptic_inhib_connectivity_pool1_to_pool0" in col_idx:
            X[:, col_idx["disynaptic_inhib_connectivity_pool1_to_pool0"]] = arr_ba[:, 1]
        if "between_pool_excitatory_input_correlation" in col_idx:
            X[:, col_idx["between_pool_excitatory_input_correlation"]] = 0.5 * (arr_ab[:, 2] + arr_ba[:, 2])

        # Priors-derived fills: sample with replacement per row
        pri_cols_needed = [v for v in variable_parameters if v not in POST_NAMES]
        missing_pri = [v for v in pri_cols_needed if v not in pri_df.columns]
        if missing_pri:
            raise ValueError(
                f"Priors DF missing needed columns for {pair}@{intensity}: {missing_pri}\n"
                f"Available: {sorted(pri_df.columns)}"
            )

        if sample_priors_with_replacement:
            idx = rng.integers(0, len(pri_df), size=N)
            sampled = pri_df.iloc[idx]
        else:
            # deterministic: cycle through without replacement if enough rows, else wrap
            if len(pri_df) == 0:
                raise ValueError("No rows in priors DF after filtering.")
            rep = int(np.ceil(N / len(pri_df)))
            sampled = pd.concat([pri_df]*rep, ignore_index=True).iloc[:N]

        for v in pri_cols_needed:
            X[:, col_idx[v]] = sampled[v].to_numpy(dtype=float)

        # Best sample vector (posterior best + one sampled priors row)
        best_sides = dir_best.get((pair, intensity), {})
        best_ab = best_sides.get("ab")
        best_ba = best_sides.get("ba")
        if best_ab is None and best_ba is None:
            best_vec = X[0].copy()
        else:
            if best_ab is None: best_ab = best_ba
            if best_ba is None: best_ba = best_ab
            best_vec = np.full((P,), np.nan, dtype=np.float32)
            if "excitatory_input_baseline_pool0" in col_idx:
                best_vec[col_idx["excitatory_input_baseline_pool0"]] = best_ab[0]
            if "excitatory_input_baseline_pool1" in col_idx:
                best_vec[col_idx["excitatory_input_baseline_pool1"]] = best_ba[0]
            if "disynaptic_inhib_connectivity_pool0_to_pool1" in col_idx:
                best_vec[col_idx["disynaptic_inhib_connectivity_pool0_to_pool1"]] = best_ab[1]
            if "disynaptic_inhib_connectivity_pool1_to_pool0" in col_idx:
                best_vec[col_idx["disynaptic_inhib_connectivity_pool1_to_pool0"]] = best_ba[1]
            if "between_pool_excitatory_input_correlation" in col_idx:
                best_vec[col_idx["between_pool_excitatory_input_correlation"]] = 0.5 * (best_ab[2] + best_ba[2])

            # sample one priors row for the remaining fields
            row = pri_df.iloc[rng.integers(0, len(pri_df))]
            for v in pri_cols_needed:
                best_vec[col_idx[v]] = float(row[v])

        # Guard
        if np.isnan(X).any():
            miss = [variable_parameters[j] for j in np.where(np.isnan(X).any(axis=0))[0]]
            raise ValueError(f"Incomplete matrix for {pair}@{intensity}: {miss}")
        if np.isnan(best_vec).any():
            miss = [variable_parameters[j] for j in np.where(np.isnan(best_vec))[0]]
            raise ValueError(f"Incomplete best_vec for {pair}@{intensity}: {miss}")

        tensors["posterior_samples"][(pair, intensity)] = torch.tensor(X, dtype=torch.float32)
        tensors["best_samples"][(pair, intensity)] = torch.tensor(best_vec, dtype=torch.float32)

        df_block = pd.DataFrame(X, columns=variable_parameters)
        df_block.insert(0, "pair", pair)
        df_block.insert(1, "intensity", intensity)
        if "subject" in pri_df.columns:
            # carry over sampled subjects (optional, useful for debugging)
            df_block["subject_from_priors"] = sampled["subject"].to_numpy() if sample_priors_with_replacement else np.nan
        df_rows.append(df_block)

    df_all = pd.concat(df_rows, ignore_index=True) if df_rows else pd.DataFrame(columns=["pair","intensity"]+variable_parameters)
    return tensors, df_all



In [None]:
### Find set of priors (torch tensors and data frame) (samples from previously estimated posterior) that will be used later to create "params_prior_list" = list of SimulationParameters variables (dataclasses) that will be used for the simulations

if simulate_per_subject_or_subjects_grouped == "subjects_grouped":
    if not muscle_pairs_posterior_predictive_checks: # single muscle case
        posterior_samples_to_draw_posterior_predictive_checks_sim_param_from = posterior_estimates_subjects_grouped
        posterior_samples_to_draw_posterior_predictive_checks_sim_param_from_df = df_kde_posterior_samples_subjects_grouped
    else: # between muscles case
        posterior_samples_to_draw_posterior_predictive_checks_sim_param_from, posterior_samples_to_draw_posterior_predictive_checks_sim_param_from_df = build_expanded_between_pool_posteriors(
            posterior_estimates=posterior_estimates_subjects_grouped,
            batch_sim_priors_each=batch_sim_priors_each,
            variable_parameters=variable_parameters,
            sample_priors_with_replacement=True,   # <- enables per-row sampling
            seed=42,                               # <- reproducibility
            )
elif simulate_per_subject_or_subjects_grouped == "per_subject":
    if not muscle_pairs_posterior_predictive_checks:
        posterior_samples_to_draw_posterior_predictive_checks_sim_param_from = posterior_estimates
        posterior_samples_to_draw_posterior_predictive_checks_sim_param_from_df = df_kde_posterior_samples
    else: # between muscles case
        posterior_samples_to_draw_posterior_predictive_checks_sim_param_from, posterior_samples_to_draw_posterior_predictive_checks_sim_param_from_df = build_expanded_between_pool_posteriors(
            posterior_estimates=posterior_estimates,
            batch_sim_priors_each=batch_sim_priors_each,
            variable_parameters=variable_parameters,
            sample_priors_with_replacement=True,   # <- enables per-row sampling
            seed=42,                               # <- reproducibility
            )
else:
    print(f'Please select ["subjects_grouped" or "per_subject"] for simulate_per_subject_or_subjects_grouped')


In [None]:
posterior_samples_to_draw_posterior_predictive_checks_sim_param_from_df # Checking results from previous cell

In [None]:
# # Make sure you already have:
# # - variable_parameters = expand_params(input_sim_parameters_to_infer + input_sim_parameters_as_features)
# # - batch_sim_priors_each loaded (your loop with pickle.load)
# # - posterior_estimates_subjects_grouped with keys 'posterior_samples' and 'best_samples'

# tensors_grouped, df_grouped = build_expanded_between_pool_posteriors(
#     posterior_estimates=posterior_estimates_subjects_grouped,
#     batch_sim_priors_each=batch_sim_priors_each,
#     variable_parameters=variable_parameters,
# )

# # Access:
# # tensors_grouped['posterior_samples'][( 'SOL-GM', 10 )]  -> torch.FloatTensor [N, len(variable_parameters)]
# # tensors_grouped['best_samples'][( 'SOL-GM', 10 )]      -> torch.FloatTensor [len(variable_parameters)]
# # df_grouped                                              -> one big table across all pairs & intensities


In [None]:
from brian2.units.fundamentalunits import Quantity, have_same_dimensions
from brian2.units import ms, nA, msiemens

def ensure_unit(v, unit, scale=1.0):
    """
    If v already has units, return it (rescaled if you like).
    If v is unitless (float/ndarray), attach `scale * unit`.
    """
    if isinstance(v, Quantity):
        # Optional: rescale to the target unit (not strictly necessary)
        # return v.in_units(unit)
        # Brian2 accepts any same-dimension quantity, so returning v is fine.
        return v
    else:
        return (v * scale) * unit

# 1) A mapping of input-names → functions that take (value, init_kwargs, default)
#    and mutate init_kwargs appropriately.
_special_setters = {
    # If your posterior values are in the same magnitude as you used before,
    # keep scale=1e3 as you had (your default is 3.0*1e3*nA).
    'Renshaw_to_MN_IPSP': lambda v, kw, default: kw.update({
        'Renshaw_to_MN_IPSP': ensure_unit(v, nA, scale=1e3)
    }),
    'AHP_conductance_delta_after_spiking': lambda v, kw, default: kw.update({
        'AHP_conductance_delta_after_spiking': ensure_unit(v, msiemens)
    }),
    'tau_Renshaw': lambda v, kw, default: kw.update({
        'tau_Renshaw': ensure_unit(v, ms)
    }),
    'synaptic_IPSP_decay_time_constant': lambda v, kw, default: kw.update({
        'synaptic_IPSP_decay_time_constant': ensure_unit(v, ms)
    }),
    'MN_RC_synpatic_delay': lambda v, kw, default: kw.update({
        'MN_RC_synpatic_delay': ensure_unit(v, ms)
    }),
}


def build_posterior_predictive_check_simulation_parameters(fixed_parameters, free_priors, batch_name):
    """
    Build a SimulationParameters with:
      - values taken from 'fixed_parameters' (dataclass instance), EXCEPT fields with init=False,
      - overrides from 'free_priors' (SBI posterior draws),
      - special handling for multi-field params (e.g., common_input_std, frequency ranges, disynaptic matrix),
      - and a custom output folder name.
    """
    default = SimulationParameters()
    init_kwargs = {}

    # Partition dataclass fields into init vs non-init
    init_field_names     = {f.name for f in fields(SimulationParameters) if f.init}
    non_init_field_names = {f.name for f in fields(SimulationParameters) if not f.init}

    # --- Start from dict copies of inputs ---
    fixed_params = asdict(fixed_parameters).copy()  # includes non-init fields!
    free_params  = free_priors.copy()

    # Drop any non-init fields from the fixed side
    for k in list(fixed_params.keys()):
        if k in non_init_field_names:
            fixed_params.pop(k)

    # --- Resolve naming collisions / special cases ---

    # Split common_input_std: keep input0 from fixed side, map posterior to input1
    if 'common_input_std' in fixed_params:
        # keep only pool 0, input 0 (shape is [pool, input])
        first_val = np.asarray(fixed_params.pop('common_input_std'))[0][0]
        fixed_params['common_input_std_input0'] = float(first_val)

    if 'common_input_std' in free_params:
        free_params['common_input_std_input1'] = float(free_params.pop('common_input_std'))

    # If a legacy within-pool RI is present on the free side, map it to pool0 self-connectivity
    if 'disynpatic_inhib_connections_desired_MN_MN' in fixed_params:
        fixed_params.pop('disynpatic_inhib_connections_desired_MN_MN', None)
    if 'disynpatic_inhib_connections_desired_MN_MN' in free_params:
        free_params['disynaptic_inhib_self_connectivity_pool0'] = float(
            free_params.pop('disynpatic_inhib_connections_desired_MN_MN')
        )

    # Merge into a single “prior overrides” dict
    all_priors = {**fixed_params, **free_params}

    # Staging buffers for multi-field assembly
    common_input_std_to_assign = np.array([[0.0, 0.0],
                                           [0.0, 0.0]], dtype=float)  # [pool, input]
    disyn = deepcopy(default.disynpatic_inhib_connections_desired_MN_MN)
    disyn_touched = False

    # Iterate once over all keys
    for name, val in all_priors.items():

        # 1) Special setters (if you have a registry elsewhere)
        if name in _special_setters:
            _special_setters[name](val, init_kwargs, default)
            continue

        # 2) Single-field remappings / list-wrapping
        if name == 'excitatory_input_baseline':
            # ensure it's a list with at least one pool
            init_kwargs['excitatory_input_baseline'] = [float(val)]
            continue

        # 3) Frequency band of common input (middle & half-width → [low, high])
        if name in ("common_input_high_freq_middle_of_range",
                    "common_input_high_freq_half_width_range"):
            # Use whatever is already staged, else default’s array
            freq = deepcopy(init_kwargs.get("frequency_range_of_common_input",
                                            default.frequency_range_of_common_input))
            pool_idx  = 0  # we only use pool 0 for the posterior-predictive check here
            input_idx = 1  # use the second input as “high-frequency” slot
            mid  = float(all_priors.get('common_input_high_freq_middle_of_range',
                                        default.common_input_characteristics["Frequency_middle_of_range"]["pool_0"]["input_1"]))
            half = float(all_priors.get('common_input_high_freq_half_width_range',
                                        default.common_input_characteristics["Frequency_half_width_of_range"]["pool_0"]["input_1"]))
            freq[pool_idx, input_idx, 0] = mid - half
            freq[pool_idx, input_idx, 1] = mid + half
            init_kwargs["frequency_range_of_common_input"] = freq
            continue

        # 4) Common‐input STD per input (we assemble a [pool,input] array)
        if name.startswith("common_input_std_input"):
            pool_idx  = 0
            input_idx = int(name[-1])  # expects ...input0 or ...input1
            common_input_std_to_assign[pool_idx, input_idx] = float(val)
            init_kwargs["common_input_std"] = common_input_std_to_assign
            continue

        # 5) Disynaptic connectivity entries (assemble 2×2)
        if name == 'disynaptic_inhib_self_connectivity_pool0':
            disyn[0, 0] = float(val); disyn_touched = True; continue
        if name == 'disynaptic_inhib_self_connectivity_pool1':
            disyn[1, 1] = float(val); disyn_touched = True; continue
        if name == 'disynaptic_inhib_connectivity_pool0_to_pool1':
            disyn[0, 1] = float(val); disyn_touched = True; continue
        if name == 'disynaptic_inhib_connectivity_pool1_to_pool0':
            disyn[1, 0] = float(val); disyn_touched = True; continue

        # 6) Plain assignment only if it's an init=True field
        if name in init_field_names:
            init_kwargs[name] = val
            continue

        # Otherwise ignore quietly or raise (your choice)
        # raise KeyError(f"Unrecognized or non-init parameter '{name}'")

    # Write back the disynaptic matrix if touched
    if disyn_touched:
        init_kwargs['disynpatic_inhib_connections_desired_MN_MN'] = disyn

    # Always override the output folder
    init_kwargs['output_folder_name'] = batch_name

    # FINAL GUARD: drop any accidental non-init fields before constructing
    init_kwargs = {k: v for k, v in init_kwargs.items() if k in init_field_names}

    # Build the dataclass (runs __post_init__)
    return SimulationParameters(**init_kwargs)

In [None]:
def all_values_filled(d):
    for v in d.values():
        if isinstance(v, dict):
            if not all_values_filled(v):
                return False
        else:
            if v is None:
                return False
    return True

def build_posterior_predictive_check_simulation_parameters_two_pools(
    fixed_parameters,
    free_priors: dict,
    batch_name: str,
    baseline_jitter: float = 0.0,   # like your baseline_excit_input_to_add_to_posterior_sample
):
    """
    Build a SimulationParameters for 2 pools using expanded variable names.

    Accepts keys like:
      - excitatory_input_baseline_pool0 / _pool1
      - common_input_high_freq_middle_of_range_pool0/_pool1
      - common_input_high_freq_half_width_range_pool0/_pool1
      - common_input_std_pool0/_pool1        (assumed to map to input #1 == 'second' input)
      - disynaptic_inhib_*                    (pool0_to_pool1, pool1_to_pool0, self_connectivity_pool*)
      - between_pool_excitatory_input_correlation (direct field if present)

    Also supports legacy aliases if they sneak in:
      - std_of_second_common_input_pool0/_pool1  -> common_input_std_pool*
    """
    default = SimulationParameters()
    init_field_names = {f.name for f in fields(SimulationParameters) if f.init}
    non_init_field_names = {f.name for f in fields(SimulationParameters) if not f.init}

    fixed_params = asdict(fixed_parameters).copy()
    for k in list(fixed_params.keys()):
        if k in non_init_field_names:
            fixed_params.pop(k)

    # normalize aliases on free side (be liberal in what we accept)
    free_params = free_priors.copy()
    alias_map = {
        "std_of_second_common_input_pool0": "common_input_std_pool0",
        "std_of_second_common_input_pool1": "common_input_std_pool1",
    }
    for old, new in alias_map.items():
        if old in free_params and new not in free_params:
            free_params[new] = free_params.pop(old)

    # merge (free overrides fixed)
    all_priors = {**fixed_params, **free_params}

    init_kwargs = {}

    # stage multi-field buffers
    # common_input_std is [pool, input] -> we'll write input #1 ("second") here
    cis = deepcopy(default.common_input_std)
    cis_touched = False

    # 2x2 disynaptic matrix
    disyn = deepcopy(default.disynpatic_inhib_connections_desired_MN_MN)
    disyn_touched = False

    # frequency ranges of common input: we’ll set input#1 band per pool when both middle & half_width are provided
    hf = {
        0: {"middle_of_range": None, "half_width": None},
        1: {"middle_of_range": None, "half_width": None},
    }
    freq_range = deepcopy(default.frequency_range_of_common_input)  # shape [pool, input, 2]

    # iterate once
    for name, val in all_priors.items():
        # 0) unit-aware special setters
        if name in _special_setters:
            _special_setters[name](val, init_kwargs, default)
            continue

        # 1) excitatory baseline per pool
        if name.startswith("excitatory_input_baseline_pool"):
            pool = 0 if name.endswith("pool0") else 1
            arr = deepcopy(init_kwargs.get("excitatory_input_baseline", default.excitatory_input_baseline))
            # add optional small jitter like your training script
            jitter = np.random.uniform(0, baseline_jitter) if baseline_jitter > 0 else 0.0
            arr[pool] = float(val) + jitter
            init_kwargs["excitatory_input_baseline"] = arr
            continue

        # 2) high-frequency common input band per pool (input index = 1)
        if name.startswith("common_input_high_freq_middle_of_range_pool"):
            pool = 0 if name.endswith("pool0") else 1
            hf[pool]["middle_of_range"] = float(val)
            continue
        if name.startswith("common_input_high_freq_half_width_range_pool"):
            pool = 0 if name.endswith("pool0") else 1
            hf[pool]["half_width"] = float(val)
            continue

        # 3) common-input std (second input, per pool)
        if name in ("common_input_std_pool0", "common_input_std_pool1"):
            pool = 0 if name.endswith("pool0") else 1
            cis[pool, 1] = float(val)
            cis_touched = True
            continue

        # 4) disynaptic connectivity entries
        if name == "disynaptic_inhib_self_connectivity_pool0":
            disyn[0, 0] = float(val); disyn_touched = True; continue
        if name == "disynaptic_inhib_self_connectivity_pool1":
            disyn[1, 1] = float(val); disyn_touched = True; continue
        if name == "disynaptic_inhib_connectivity_pool0_to_pool1":
            disyn[0, 1] = float(val); disyn_touched = True; continue
        if name == "disynaptic_inhib_connectivity_pool1_to_pool0":
            disyn[1, 0] = float(val); disyn_touched = True; continue

        # 5) plain assignment if it’s a real dataclass field (e.g., between_pool_excitatory_input_correlation)
        if name in init_field_names:
            init_kwargs[name] = val
            continue

        # otherwise ignore silently (keeps it robust to stray keys)
        # raise KeyError(f"Unrecognized parameter '{name}'")

    # write frequency bands if both parts present
    for pool in (0, 1):
        if all_values_filled(hf[pool]):
            mid  = hf[pool]["middle_of_range"]
            half = hf[pool]["half_width"]
            # input index 1
            freq_range[pool, 1, 0] = mid - half
            freq_range[pool, 1, 1] = mid + half
    init_kwargs["frequency_range_of_common_input"] = freq_range

    if cis_touched:
        init_kwargs["common_input_std"] = cis
    if disyn_touched:
        init_kwargs["disynpatic_inhib_connections_desired_MN_MN"] = disyn

    # output folder
    init_kwargs["output_folder_name"] = batch_name

    # final guard: keep only init=True fields
    init_kwargs = {k: v for k, v in init_kwargs.items() if k in init_field_names}

    return SimulationParameters(**init_kwargs)


In [None]:
# # # Utility/Helper functions to create the new folders for the posterior predictive checks simulations
import re
import numbers

def sanitize_str(s: str) -> str:
    # Replace any Windows‑illegal filename chars with '-'
    return re.sub(r'[<>:"/\\|?*\s]+', '', s).strip('-.')

def tuple_to_folder_name(tup):
    parts = []
    for x in tup:
        if isinstance(x, numbers.Number):
            # if it’s an integer-valued float, drop the .0
            if float(x).is_integer():
                parts.append(str(int(x)))
            else:
                parts.append(str(x))
        else:
            parts.append(sanitize_str(str(x)))
    # join with underscores
    return "_".join(parts)


In [None]:
# Helper used in the cell below if sanity_check_plots == True and muscle_pairs_posterior_predictive_checks == True
def data_limits(X, pad_frac: float = 0.05):
    """
    X: (N, D) torch.Tensor or np.ndarray
    returns: list of [low, high] for each dim, with padding and degenerate-interval guard
    """
    X = X.detach().cpu().numpy() if hasattr(X, "detach") else np.asarray(X)
    mins = np.nanmin(X, axis=0)
    maxs = np.nanmax(X, axis=0)

    lims = []
    for lo, hi in zip(mins, maxs):
        if not np.isfinite(lo) or not np.isfinite(hi):
            lo, hi = 0.0, 1.0
        if hi == lo:  # degenerate -> widen a hair
            width = 1.0 if lo == 0 else abs(lo) * 1e-3
            lo -= width; hi += width
        pad = pad_frac * (hi - lo)
        lims.append([lo - pad, hi + pad])
    return lims

In [None]:
dict_of_posterior_predictive_sims_param = {}
dict_of_posterior_predictive_sims_params_with_highest_likelihood = {}
sanity_check_plots = True  # True
folder_path_highest_likelihood = f"{path_to_save_posterior_predictive_check_sims}\\_highest_likelihood_sims"
os.makedirs(folder_path_highest_likelihood, exist_ok=True)

for current_condition, posterior_samples_current in posterior_samples_to_draw_posterior_predictive_checks_sim_param_from['posterior_samples'].items():
    current_path = f"{path_to_save_posterior_predictive_check_sims}\\{tuple_to_folder_name(current_condition)}"
    os.makedirs(current_path, exist_ok=True)
    current_path_highest_likelihood = f"{folder_path_highest_likelihood}\\{tuple_to_folder_name(current_condition)}"
    os.makedirs(current_path_highest_likelihood, exist_ok=True)

    print(f"Loading posterior predictive checks parameters (samples on the inferred posterior) for simulation {current_condition}")

    # Keep the original tensor for indexing / .item()
    posterior_samples_tensor = posterior_samples_current

    # Safe NumPy version for sbi.pairplot / data_limits (avoids .numpy())
    posterior_samples_np = to_np(posterior_samples_tensor)   # <--- key line

    N = posterior_samples_tensor.shape[0]
    g = torch.Generator().manual_seed(42)  # reproducible; remove for random each run

    # choose indices (without replacement if possible; with replacement otherwise)
    if n_sims_per_condition <= N:
        choice_idx = torch.randperm(N, generator=g)[:n_sims_per_condition]
    else:
        choice_idx = torch.randint(N, (n_sims_per_condition,), generator=g)

    print(
        f"    Total posterior samples: {posterior_samples_tensor.shape} ([samples, n_parameters]).\n"
        f"    {n_sims_per_condition} randomly selected samples will be simulated.\n "
    )

    if sanity_check_plots:
        muscle_pair_i = current_condition[0]

        if not muscle_pairs_posterior_predictive_checks:
            fig, axes = pairplot(
                posterior_samples_np,  # <--- use NumPy
                limits=[[low_original[i].item(), high_original[i].item()] for i in range(n_dim)],
                diag_kwargs={"mpl_kwargs": {"color": muscle_colors_dict[muscle_pair_i], "linewidth": 2}},
                upper_kwargs={"mpl_kwargs": {"cmap": muscle_colormaps_dict[muscle_pair_i]}},
                labels=variable_parameters,
                figsize=(3 * posterior_samples_np.shape[1], 3 * posterior_samples_np.shape[1]),
            )
        else:
            # Also compute limits from NumPy, not torch
            limits = data_limits(posterior_samples_np, pad_frac=0.05)
            fig, axes = pairplot(
                posterior_samples_np,  # <--- use NumPy
                limits=limits,
                diag_kwargs={"mpl_kwargs": {"color": muscle_colors_dict[muscle_pair_i], "linewidth": 2}},
                upper_kwargs={"mpl_kwargs": {"cmap": muscle_colormaps_dict[muscle_pair_i]}},
                labels=variable_parameters,
                figsize=(3 * posterior_samples_np.shape[1], 3 * posterior_samples_np.shape[1]),
            )

        plt.suptitle(f"{current_condition}\nDraws from posterior for posterior predictive checks simulations")

    # Initialize dict entry
    dict_of_posterior_predictive_sims_param[current_condition] = []

    # Best-parameter combination (highest likelihood)
    free_params_with_highest_likelihood_for_current_condition = {}
    for param_idx, param_val in enumerate(
        posterior_samples_to_draw_posterior_predictive_checks_sim_param_from['best_samples'][current_condition]
    ):
        free_params_with_highest_likelihood_for_current_condition[variable_parameters[param_idx]] = param_val.item()

    if not muscle_pairs_posterior_predictive_checks:  # single pair case
        dict_of_posterior_predictive_sims_params_with_highest_likelihood[current_condition] = (
            build_posterior_predictive_check_simulation_parameters(
                fixed_parameters,
                free_params_with_highest_likelihood_for_current_condition,
                current_path_highest_likelihood,
            )
        )
    else:
        dict_of_posterior_predictive_sims_params_with_highest_likelihood[current_condition] = (
            build_posterior_predictive_check_simulation_parameters_two_pools(
                fixed_parameters,
                free_params_with_highest_likelihood_for_current_condition,
                current_path_highest_likelihood,
            )
        )

    dict_of_posterior_predictive_sims_params_with_highest_likelihood[current_condition].output_plots = True

    # Draw individual posterior samples for simulations
    for posterior_sample_draw_idx in choice_idx.tolist():
        draw_i = posterior_samples_tensor[posterior_sample_draw_idx]  # still a torch row

        free_parameters = {}
        for param_i_idx, param_i_val in enumerate(draw_i):
            param_i_val = param_i_val.item()  # scalar float
            free_parameters[variable_parameters[param_i_idx]] = param_i_val

            if sanity_check_plots:
                # diagonal histograms: vertical line for each draw
                axes[param_i_idx, param_i_idx].axvline(
                    param_i_val, linestyle="-", linewidth=2, color="k", alpha=0.1, zorder=-1
                )

                # specifically highlight the "best" draw once
                if posterior_sample_draw_idx == 0:
                    param_i_highest_logp_val = (
                        posterior_samples_to_draw_posterior_predictive_checks_sim_param_from['best_samples']
                        [current_condition][param_i_idx]
                    )
                    axes[param_i_idx, param_i_idx].axvline(
                        param_i_highest_logp_val, linestyle="-", linewidth=1.5, color="r", alpha=1, zorder=10
                    )

                # off-diagonal scatter
                for param_j_idx, param_j_val in enumerate(draw_i):
                    if param_j_idx <= param_i_idx:
                        continue
                    param_j_val = param_j_val.item()
                    axes[param_i_idx, param_j_idx].scatter(
                        x=param_j_val,
                        y=param_i_val,
                        s=30,
                        color="#FFFFFF",
                        edgecolor="k",
                        alpha=0.7,
                    )
                    if posterior_sample_draw_idx == 0:
                        param_j_highest_logp_val = (
                            posterior_samples_to_draw_posterior_predictive_checks_sim_param_from['best_samples']
                            [current_condition][param_j_idx]
                        )
                        param_i_highest_logp_val = (
                            posterior_samples_to_draw_posterior_predictive_checks_sim_param_from['best_samples']
                            [current_condition][param_i_idx]
                        )
                        axes[param_i_idx, param_j_idx].scatter(
                            x=param_j_highest_logp_val,
                            y=param_i_highest_logp_val,
                            s=100,
                            marker="X",
                            facecolor="#FFFFFF",
                            edgecolor="r",
                            linewidth=1,
                            alpha=1,
                            zorder=10,
                        )

        # Fill dict entry iteratively
        if not muscle_pairs_posterior_predictive_checks:  # single pair case
            dict_of_posterior_predictive_sims_param[current_condition].append(
                build_posterior_predictive_check_simulation_parameters(
                    fixed_parameters, free_parameters, current_path
                )
            )
        else:
            dict_of_posterior_predictive_sims_param[current_condition].append(
                build_posterior_predictive_check_simulation_parameters_two_pools(
                    fixed_parameters, free_parameters, current_path
                )
            )

        # Set plotting to true
        dict_of_posterior_predictive_sims_param[current_condition][-1].output_plots = True

    if sanity_check_plots:
        plt.tight_layout()
        plt.savefig(
            f"{path_to_save_posterior_predictive_check_sims}\\{tuple_to_folder_name(current_condition)}_pairplot_of_posterior_samples_simulated.png",
            dpi=300,
        )
        plt.show()

flattened_list_of_posterior_predictive_sims_param = [
    item for sublist in dict_of_posterior_predictive_sims_param.values() for item in sublist
]
flattened_list_of_highest_likelihood_predictive_sims_param = list(
    dict_of_posterior_predictive_sims_params_with_highest_likelihood.values()
)


In [None]:
#### PARALLELIZATION

# # # Function to make sure to terminate any Python process that runs in the background (this can happen when the kernel crashes during the parallelized computations)
def kill_other_python_processes():
    me = os.getpid()
    user = getpass.getuser()
    for proc in psutil.process_iter(['pid', 'name', 'username']):
        try:
            # only consider Python executables run by this user
            if proc.info['username'] != user:
                continue
            name = proc.info['name'].lower()
            # match python, pythonw, python3, etc
            if name.startswith('python'):
                pid = proc.info['pid']
                if pid != me:
                    proc.kill()   # or proc.terminate()
                    print(f"{name} process terminated")
        except (psutil.NoSuchProcess, psutil.AccessDenied):
            pass
if __name__ == '__main__':
    kill_other_python_processes()
    # now safe to start joblib Parallel(...)

# --- Helper to wrap a single simulation and then reset Brian2 ---
def _run_and_reset(params):
    # params must be a SimulationParameters instance
    out = run_simulation(params)
    # after each run, reset Brian2’s magic network so the next worker starts fresh
    device.reinit()     # clears all Brian2 objects
    device.activate()   # re–activate the default runtime device
    return out

# # # PARALLEL SIMULATIONS
def parallel_simulate(params_prior_list, n_jobs=8):
    """
    params_prior_list : list of SimulationParameters
    n_jobs      :       number of parallel workers
    """
    # Note: `prefer="processes"` is the default for `n_jobs>1`
    sim_outputs = Parallel(n_jobs=n_jobs)(
        delayed(_run_and_reset)(p) for p in params_prior_list)
    return sim_outputs

_tail_thread = None
_tail_stop = threading.Event()

def start_tail(logfile="simulations_progress_log.log", poll_interval=1.0):
    """
    Spawn a thread that prints only the log‐lines whose timestamp
    is ≥ the moment you called start_tail(), and strips off everything
    before the log‐level (INFO:, WARNING:, ERROR:, etc.).
    """
    global _tail_thread, _tail_stop

    # make sure the file exists (touch it)
    open(logfile, "a").close()

    # remember "now" and clear any previous stop flag
    start_dt = datetime.now()
    _tail_stop.clear()

    def _tail_loop():
        level_re = re.compile(r'\b(?:DEBUG|INFO|WARNING|ERROR|CRITICAL)\b:\s*')
        with open(logfile, "r", encoding="utf-8") as f:
            # seek to end: we only want new lines
            f.seek(0, 2)
            while not _tail_stop.is_set():
                line = f.readline()
                if not line:
                    time.sleep(poll_interval)
                    continue

                # try to parse timestamp at the very start
                try:
                    ts_str = " ".join(line.split(" ")[:2]).rstrip(",")
                    ts = datetime.strptime(ts_str, "%Y-%m-%d %H:%M:%S,%f")
                except Exception:
                    ts = start_dt  # force print for non‐timestamped lines

                if ts >= start_dt:
                    # strip off everything before the level marker
                    m = level_re.search(line)
                    if m:
                        print(line[m.start():], end="")
                    else:
                        print(line, end="")

    # fire up the thread (only one at a time)
    if _tail_thread is None or not _tail_thread.is_alive():
        _tail_thread = threading.Thread(target=_tail_loop, daemon=True)
        _tail_thread.start()
    else:
        print("Tail already running; call `end_tail()` first if you want to restart.")

def stop_tail():
    """Stop the background tail thread."""
    _tail_stop.set()
    if _tail_thread:
        _tail_thread.join()

In [None]:
# Run parallel simulations - N iterations for each posterior predictive check sample drawn
print(f"Simulating {len(flattened_list_of_posterior_predictive_sims_param)} posterior predictive checks - simulations with parameters drawn from inferred posteriors\n")
stop_tail() # just in case one is already running
start_tail()
simulation_output_files = parallel_simulate(flattened_list_of_posterior_predictive_sims_param, n_jobs=sim_parallel_cpus)
time.sleep(1)  # give it a moment to print the last lines
stop_tail()

In [None]:
# Run parallel simulations - Best posterior estimate of sim parameters (highest likelihood) for each condition
print(f"Simulating {len(flattened_list_of_highest_likelihood_predictive_sims_param)} sims - one for each condition. Sims are the parameter combinations with the highest likelihood from the posterior\n")
stop_tail() # just in case one is already running
start_tail()
simulation_output_files = parallel_simulate(flattened_list_of_highest_likelihood_predictive_sims_param, n_jobs=sim_parallel_cpus)
time.sleep(1)  # give it a moment to print the last lines
stop_tail()

################################################################################################################################################

From there, the analyzes of the posterior-generated simulations can be performed directly with the ipynb notebook "analyze_batch.ipynb"