In [None]:
# Import general libraries
import numpy as np
import pandas as pd
import os
import glob
from pathlib import Path
import pickle
import json
import warnings
import h5py

from simulator import SimulationParameters, run_simulation
from analyzer import AnalyzesParams, analyze_data

In [None]:
parent_folder_to_load_from = 'C:\\Users\\franc\\Documents\\Mapping_Recurrent_Inhibition_minimal_datasets_to_run_scripts\\Example_small_simulation_batch'
within_between_or_all_muscle_pairs = "all" # "all" # "all" "within" or "between"
nb_of_MUs_for_COH = 10 # 5
coh_band = (15, 35) # in Hz
# "all" consider all muscle_pair types
# "between" consider only muscle_pair types where poolX<->poolY (or vice versa, basically when both pools are different). Drop muscle_pairs where poolX<->poolX
# "within" consider only muscle_pair types where poolX<->poolX (or vice versa, basically when both pools are the same). Drop muscle_pairs where poolX<->poolY (or vice-versa)
load_hdf5_sim_output_for_MNs_electrophysiological_properties = True # Need to be "True" for some analysis regarding MN properties (so they get added in the big csv), but takes some time to load


In [None]:
# ----------------- utilities -----------------
def deep_equal(a, b):
    """Recursively compare a and b, supporting np.ndarray, dict, list/tuple, and scalars."""
    if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
        return np.array_equal(a, b)
    if isinstance(a, dict) and isinstance(b, dict):
        if set(a.keys()) != set(b.keys()):
            return False
        return all(deep_equal(a[k], b[k]) for k in a)
    if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
        if len(a) != len(b):
            return False
        return all(deep_equal(x, y) for x, y in zip(a, b))
    return a == b

def _load_priors(path):
    with open(path, 'rb') as f:
        return pickle.load(f)

def _select_prior_from_files(files, folder_label):
    """
    Given a list of prior-pkl files in one folder, run consistency checks.
    Returns the chosen prior object (first file) and prints/warns as needed.
    """
    files = sorted(files)
    if len(files) == 0:
        return None

    if len(files) == 1:
        return _load_priors(files[0])

    warnings.warn(
        f"Multiple prior‐pkls found in {folder_label!r}; loading all of them and comparing.",
        UserWarning
    )
    priors_list = [_load_priors(p) for p in files]

    compare_keys = [
        'fixed_parameters',
        'free_parameter_bounds',
        'use_posterior_as_priors',
    ]
    extra_keys = [
        'override_free_parameter_bounds_from_posterior',
        'muscle_pair_posterior',
        'intensity_posterior',
        'path_for_posterior_samples',
    ]

    mismatch_found = False

    # compare the standard keys
    for key in compare_keys:
        vals = [d.get(key) for d in priors_list]
        if not all(deep_equal(v, vals[0]) for v in vals[1:]):
            mismatch_found = True
            warnings.warn(
                f"Mismatch in key {key!r} across prior‐pkls in {folder_label!r}:\n"
                + "\n".join(
                    f"  • {os.path.basename(files[i])}: {vals[i]!r}"
                    for i in range(len(vals))
                ),
                UserWarning
            )

    # if all use posterior-as-priors, compare extra keys too
    if all(d.get('use_posterior_as_priors') for d in priors_list):
        for key in extra_keys:
            vals = [d.get(key) for d in priors_list]
            if not all(deep_equal(v, vals[0]) for v in vals[1:]):
                mismatch_found = True
                warnings.warn(
                    f"Mismatch in key {key!r} across prior‐pkls in {folder_label!r}:\n"
                    + "\n".join(
                        f"  • {os.path.basename(files[i])}: {vals[i]!r}"
                        for i in range(len(vals))
                    ),
                    UserWarning
                )

    if not mismatch_found:
        print(f"✅ All prior‐pkls in {folder_label!r} match on compared keys; using the first one.")

    return priors_list[0]

# ----------------- main logic -----------------
priors = {}

# 1) Look directly in the parent folder
pattern_direct = os.path.join(parent_folder_to_load_from, "*prior*.pkl")
matching_files_direct = glob.glob(pattern_direct)

if matching_files_direct:
    # Same behavior as before: key is the parent folder path
    chosen = _select_prior_from_files(matching_files_direct, parent_folder_to_load_from)
    if chosen is not None:
        priors[f'{parent_folder_to_load_from}'] = chosen
else:
    # 2) No priors directly inside parent; scan immediate subfolders
    found_any = False
    for entry in os.scandir(parent_folder_to_load_from):
        if not entry.is_dir():
            continue
        subfolder = entry.path
        pattern_sub = os.path.join(subfolder, "*prior*.pkl")
        files = glob.glob(pattern_sub)
        if not files:
            continue

        found_any = True
        chosen = _select_prior_from_files(files, subfolder)
        if chosen is not None:
            # Key = subfolder name (as requested)
            priors[os.path.basename(subfolder)] = chosen

    if not found_any:
        raise FileNotFoundError(
            f"No .pkl file with 'prior' in its name found in {parent_folder_to_load_from!r} "
            f"or its immediate subfolders."
        )


In [None]:
def load_analyses(parent_folder, filename='analysis_output.pkl',
                  params_file='sim_parameters.json',
                  recursive=False,
                  load_hdf5_props=False,              
                  hdf5_filename='simulation_output.h5'
                  ):
    """
    Load all 'filename' pickles under parent_folder, and for each one also
    load the accompanying JSON params file into ['sim_parameters'].

    If load_hdf5_props=True and an HDF5 file (simulation_output.h5) is present in the
    same folder, read group 'motoneurons_properties' and attach it under
    data['motoneurons_properties'] as a dict {prop_name: 1D np.ndarray}.
    """
    parent = Path(parent_folder)
    analysis_output = {}

    if recursive:
        pkl_paths = list(parent.rglob(filename))
    else:
        pkl_paths = [
            subdir / filename
            for subdir in parent.iterdir()
            if subdir.is_dir() and (subdir / filename).exists()
        ]

    print("Loading analysis outputs...")
    for pkl_path in pkl_paths:
        key = pkl_path.parent.name

        # 1) load the pickle
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)

        # 2) load the JSON params
        params_path = pkl_path.parent / params_file
        if params_path.exists():
            with open(params_path, 'r', encoding='utf-8') as jf:
                data['sim_parameters'] = json.load(jf)
        else:
            data['sim_parameters'] = None
            print(f"Warning: '{params_file}' not found in {pkl_path.parent}")

        # 3) (optional) load HDF5 motoneuron properties
        if load_hdf5_props:
            h5_path = pkl_path.parent / hdf5_filename
            if h5_path.exists():
                try:
                    with h5py.File(h5_path, 'r') as hf:
                        grp = hf.get('motoneurons_properties')
                        if grp is not None:
                            props = {name: np.asarray(grp[name][()]).squeeze()
                                     for name in grp.keys()}
                            data['motoneurons_properties'] = props
                        else:
                            data['motoneurons_properties'] = None
                            print(f"Warning: 'motoneurons_properties' not found in {h5_path}")
                except Exception as e:
                    data['motoneurons_properties'] = None
                    print(f"Warning: failed reading {h5_path}: {e}")
            else:
                data['motoneurons_properties'] = None  # not present here

        analysis_output[key] = data

    mode = 'entire hierarchy' if recursive else 'direct subfolders'
    count = len(pkl_paths)
    print(f"Loaded {count} file{'s' if count!=1 else ''} by searching {mode} of '{parent_folder}'")

    return analysis_output


In [None]:
analysis_output = load_analyses(
    parent_folder_to_load_from,
    recursive=True,
    load_hdf5_props=load_hdf5_sim_output_for_MNs_electrophysiological_properties
)

In [None]:
import pandas as pd
import numpy as np

def build_dataframe(
    analysis_output,
    sim_params_keys=None,
    firing_rate_stats=None,
    connectivity_keys=None,
    graph_measures_keys=None,
    coherence_keys=None,
    include_cross_hist=False,
    # mean coherence options ---
    coherence_band: tuple[float, float] | None = None,   # e.g. (15, 35)
    coherence_group_N: int | None = None,                # e.g. 5 (nb of motor units)
    coherence_source: str = "Coherence_total",           # where to read from in analysis_output['Coherence']
):
    """
    Build a DataFrame of MN rows, optionally exploded by Cross_histograms,
    and include:
      - inhibited_by_ground_truth_receiving
      - inhibiting_ground_truth_delivering
      - asymmetry_ground_truth_diff
      - asymmetry_ground_truth_ratio
      - ISI_cov (new column, pulled from data['Firing_rates']['isi_cov_MN'])
    """
    sim_params_keys     = sim_params_keys     or []
    firing_rate_stats   = firing_rate_stats   or []
    connectivity_keys   = connectivity_keys   or []
    graph_measures_keys = graph_measures_keys or []
    coherence_keys      = coherence_keys      or []

    CROSS_DIRECTIONS = ['inhibited','inhibiting','combined']
    FWD_BWD_MEASURES = ['raw_area','corrected_area','z_score','p_val','null_mean','null_std']
    ASYM_MEASURES    = [
        'raw_area_asym_ratio','corrected_area_asym_ratio',
        'raw_area_asym_diff','corrected_area_asym_diff'
    ]
    OTHER_MEASURES   = [
        'sync_height','sync_time',
        'delay_forward_IPSP','delay_backward_IPSP',
        'r2_full','r2_base','n_spikes',
        'hist_plateau_duration','proportion_of_prob_within_plateau_duration',
    ]

    rows = []
    for sim_name, data in analysis_output.items():
        params      = data.get('sim_parameters', {})
        total_nb    = params.get('total_nb_motoneurons', 0)
        nb_per_pool = params.get('nb_motoneurons_per_pool')
        try:
            nb_per_pool = int(nb_per_pool)
        except Exception:
            nb_per_pool = None

        fr_dict    = data.get('Firing_rates', {}).get('MN', {})        # e.g. {'mean': { 'MN_0': …, … }, …}
        cov_mn_dict = data.get('Firing_rates', {}).get('isi_cov_MN', {})  # e.g. { 'MN_0': 0.31, 'MN_1': NaN, … }
        conn       = data.get('Ground_truth_RI_connectivity', {})
        graph      = data.get('Graph_theory_connectivity_measures', {})
        coh        = data.get('Coherence', {})
        cross      = data.get('Cross_histograms', {})

        per_rec = conn.get('per_pool_received', {})
        per_del = conn.get('per_pool_delivered', {})

        # --- Prepare mean coherence in a band for groups of N MUs ---
        # We’ll compute per (simulation, pool-pair) and cache results to avoid repeating work.
        coh_colname = None
        compute_band_mean = None  # function(pool_key: str) -> float

        if (coherence_band is not None) and (coherence_group_N is not None) and coh:
            fmin, fmax = coherence_band
            coh_block = coh.get(coherence_source, {})
            freqs = np.asarray(coh_block.get('frequencies', []), dtype=float)

            # Build frequency mask safely
            if freqs.size == 0:
                freq_mask = None
            else:
                freq_mask = (freqs >= fmin) & (freqs <= fmax)
                if not np.any(freq_mask):
                    freq_mask = None

            # Prepare output column name once
            coh_colname = f"coherence_{int(fmin)}-{int(fmax)}hz_{int(coherence_group_N)}MN"

            cache = {}
            def compute_band_mean(pool_key_with_dash: str) -> float:
                """
                pool_key_with_dash: e.g. 'pool_0-pool_1' or 'pool_0-pool_0'
                returns mean coherence over [fmin,fmax] for group size N, or np.nan
                """
                if freq_mask is None:
                    return np.nan
                if pool_key_with_dash in cache:
                    return cache[pool_key_with_dash]

                # Try as-is, then reversed order if missing (for between-pool pairs)
                spec = coh_block.get(pool_key_with_dash)
                if spec is None and '-' in pool_key_with_dash:
                    a, b = pool_key_with_dash.split('-', 1)
                    spec = coh_block.get(f"{b}-{a}")

                if spec is None:
                    cache[pool_key_with_dash] = np.nan
                    return np.nan

                arr_by_N = spec.get(int(coherence_group_N))
                if arr_by_N is None:
                    cache[pool_key_with_dash] = np.nan
                    return np.nan

                vals = np.asarray(arr_by_N, dtype=float)
                if vals.shape[0] != freqs.shape[0]:
                    # shape mismatch → safer to bail out
                    cache[pool_key_with_dash] = np.nan
                    return np.nan

                val = float(np.nanmean(vals[freq_mask]))
                cache[pool_key_with_dash] = val
                return val
        # --- end of coherence per sim and pool pair ---


        if include_cross_hist and cross:
            # “explode” each cross‐hist entry
            for pool_pair, neuron_map in cross.items():
                # pool_pair is like "pool_0<->pool_1", so pool_B is the second pool
                if '<->' not in pool_pair:
                    continue # case when cross-histograms are saved: there is another item in the dict that is not of the format 'pool_A<->pool_B'
                _, pool_B = pool_pair.split('<->')

                for neuron_idx, dir_dict in neuron_map.items():
                    # build base row for this MN in this pool_pair
                    recv = per_rec.get(pool_B, np.full(total_nb, np.nan))[neuron_idx] \
                           if pool_B in per_rec else np.nan
                    delv = per_del.get(pool_B, np.full(total_nb, np.nan))[neuron_idx] \
                           if pool_B in per_del else np.nan

                    # pull that MN’s ISI‐cov if it exists, otherwise NaN
                    # the keys in isi_cov_MN dict are strings like 'MN_0', 'MN_1', …
                    cov_key = f"MN_{neuron_idx}"
                    isi_cov_mn = cov_mn_dict.get(cov_key, np.nan)

                    base = {
                        'sim_name':                             sim_name,
                        'MN_index':                             int(neuron_idx),
                        'pool':                                 (neuron_idx // nb_per_pool) if nb_per_pool else np.nan,
                        'idx_within_pool':                      (neuron_idx % nb_per_pool)  if nb_per_pool else np.nan,
                        'inhibited_by_ground_truth_receiving':   recv,
                        'inhibiting_ground_truth_delivering':    delv,
                        'asymmetry_ground_truth_diff':           delv - recv,
                        'asymmetry_ground_truth_ratio':          (delv / recv) if (recv not in (0, np.nan)) else np.nan,
                        'ISI_cov':                               isi_cov_mn
                    }
                    # sim parameters
                    for key in sim_params_keys:
                        base[key] = params.get(key, np.nan)
                    # firing‐rate stats (mean, std, etc.) per MN_i
                    for stat in firing_rate_stats:
                        m = fr_dict.get(stat, {})
                        base[f"Firing_rates_{stat}"] = m.get(f"MN_{neuron_idx}", np.nan)
                    # connectivity (e.g. 'MN_delivered_total', 'MN_received_total', etc.)
                    for ck in connectivity_keys:
                        arr = conn.get(ck)
                        base[ck] = (arr[neuron_idx] if (isinstance(arr,(list,np.ndarray)) and neuron_idx < len(arr)) else np.nan)
                    # graph‐measures
                    for gm in graph_measures_keys:
                        arr = graph.get(gm)
                        base[f"Graph_theory_{gm}"] = (arr[neuron_idx] if (isinstance(arr,(list,np.ndarray)) and neuron_idx < len(arr)) else np.nan)
                    # coherence
                    for ckey in coherence_keys:
                        arr = coh.get(ckey)
                        base[f"Coherence_{ckey}"] = (arr[neuron_idx] if (isinstance(arr,(list,np.ndarray)) and neuron_idx < len(arr)) else np.nan)

                    # per-MN motoneuron properties from HDF5 (if present) ---
                    mn_props = data.get('motoneurons_properties') or {}
                    if isinstance(mn_props, dict) and len(mn_props):
                        for prop_name, arr in mn_props.items():
                            try:
                                base[f"MNprop_{prop_name}"] = float(arr[int(neuron_idx)])
                            except Exception:
                                base[f"MNprop_{prop_name}"] = np.nan

                    # now flatten each direction (inhibited/inhibiting/combined)
                    for direction in CROSS_DIRECTIONS:
                        row = base.copy()
                        row['pool_pair'] = pool_pair
                        row['direction'] = direction
                        dd = dir_dict.get(direction, {})

                        # forward/backward measures
                        for m in FWD_BWD_MEASURES:
                            row[f"{m}_fwd"] = dd.get('forward',{}).get(m, np.nan)
                            row[f"{m}_bwd"] = dd.get('backward',{}).get(m, np.nan)
                        # asymmetry
                        for m in ASYM_MEASURES:
                            row[m] = dd.get('asymmetry',{}).get(m, np.nan)
                        # other scalar measures
                        for m in OTHER_MEASURES:
                            row[m] = dd.get(m, np.nan)
                    
                        # --- Attach mean coherence for this pool_pair (same for all rows of the pair) ---
                        if coh_colname is not None and compute_band_mean is not None:
                            # The coherence keys use '-' (dash) not '<->'
                            pair_dash = pool_pair.replace('<->', '-')
                            row[coh_colname] = compute_band_mean(pair_dash)
                        # --- end mean coherence ---

                        rows.append(row)

        else:
            # No cross‐hist → one row per MN (everything else is NaN for cross‐hist columns)
            for i in range(total_nb):
                # pull that MN’s ISI‐cov if present
                cov_key = f"MN_{i}"
                isi_cov_mn = cov_mn_dict.get(cov_key, np.nan)

                row = {
                    'sim_name':                        sim_name,
                    'MN_index':                        i,
                    'pool':                            (i // nb_per_pool) if nb_per_pool else np.nan,
                    'idx_within_pool':                 (i % nb_per_pool) if nb_per_pool else np.nan,
                    'pool_pair':                       np.nan,
                    'direction':                       np.nan,
                    'inhibited_by_ground_truth_receiving':   np.nan,
                    'inhibiting_ground_truth_delivering':    np.nan,
                    'asymmetry_ground_truth_diff':           np.nan,
                    'asymmetry_ground_truth_ratio':          np.nan,
                    'ISI_cov':                           isi_cov_mn
                }
                # sim parameters
                for key in sim_params_keys:
                    row[key] = params.get(key, np.nan)
                # firing‐rate stats
                for stat in firing_rate_stats:
                    m = fr_dict.get(stat, {})
                    row[f"Firing_rates_{stat}"] = m.get(f"MN_{i}", np.nan)
                # connectivity, graph, coherence
                for ck in connectivity_keys:
                    arr = conn.get(ck)
                    row[ck] = (arr[i] if (isinstance(arr,(list,np.ndarray)) and i < len(arr)) else np.nan)
                for gm in graph_measures_keys:
                    arr = graph.get(gm)
                    row[f"Graph_theory_{gm}"] = (arr[i] if (isinstance(arr,(list,np.ndarray)) and i < len(arr)) else np.nan)
                for ckey in coherence_keys:
                    arr = coh.get(ckey)
                    row[f"Coherence_{ckey}"] = (arr[i] if (isinstance(arr,(list,np.ndarray)) and i < len(arr)) else np.nan)

                # set all forward/backward and asymmetry/other to NaN
                for m in FWD_BWD_MEASURES:
                    row[f"{m}_fwd"] = np.nan
                    row[f"{m}_bwd"] = np.nan
                for m in ASYM_MEASURES + OTHER_MEASURES:
                    row[m] = np.nan

                # --- attach mean coherence for the MN's own pool (within-pool) ---
                if coh_colname is not None and compute_band_mean is not None and nb_per_pool:
                    pool_idx = int(i // nb_per_pool)
                    pool_name = f"pool_{pool_idx}"
                    pair_dash = f"{pool_name}-{pool_name}"  # within-pool coherence
                    row[coh_colname] = compute_band_mean(pair_dash)
                # --- end mean coherence ---

                # per-MN motoneuron properties from HDF5 (if present) ---
                mn_props = data.get('motoneurons_properties') or {}
                if isinstance(mn_props, dict) and len(mn_props):
                    for prop_name, arr in mn_props.items():
                        try:
                            row[f"MNprop_{prop_name}"] = float(arr[int(i)])
                        except Exception:
                            row[f"MNprop_{prop_name}"] = np.nan

                rows.append(row)

    return pd.DataFrame(rows)


In [None]:
df = build_dataframe(
    analysis_output,
    sim_params_keys=['total_nb_motoneurons', 'nb_motoneurons_per_pool', 'nb_pools','mean_soma_diameter',
                     'excitatory_input_baseline','common_input_std','disynpatic_inhib_connections_desired_MN_MN','common_input_characteristics',
                     'independent_to_common_input_ratio', 'frequency_range_of_common_input',
                     'between_pool_excitatory_input_correlation',
                     'synaptic_IPSP_membrane_or_user_defined_time_constant',
                     'synaptic_IPSP_decay_time_constant',
                     'MN_RC_synpatic_delay'], # values in the parameter space
    firing_rate_stats=['mean','std','max','min'], # the ISI CoVs will also be loaded
    connectivity_keys=[
      'MN_delivered_total',
      'MN_received_total'
    ],
    graph_measures_keys=['density'],
    coherence_keys=None,
    include_cross_hist=True,
    # --- Set the values below to None if not putting coherence values in the dataframe ---
    coherence_band=coh_band,
    coherence_group_N=nb_of_MUs_for_COH,
    coherence_source="Coherence_total",
)

In [None]:
def reorganize_dataframe(df):
    """
    From a DataFrame with columns:
      ['direction',
       'raw_area_fwd','raw_area_bwd',
       'corrected_area_fwd','corrected_area_bwd',
       'raw_area_asym_diff','raw_area_asym_ratio',
       'corrected_area_asym_diff','corrected_area_asym_ratio', …]
    create more explicit mapping columns.

    New columns:
      - perspective
      - inhibited_by_estimation_raw
      - inhibiting_estimation_raw
      - inhibited_by_estimation_corrected
      - inhibiting_estimation_corrected
      - asymmetry_estimation_diff_raw
      - asymmetry_estimation_ratio_raw
      - asymmetry_estimation_diff_corrected
      - asymmetry_estimation_ratio_corrected
    """
    df = df.copy()
    
    # 1) perspective mapping
    df['perspective'] = df['direction'].map({
        'inhibited': 'other_MUs_as_ref',
        'inhibiting': 'MU_as_ref',
        'combined':   'combined'
    }).fillna(np.nan)
    
    # 2) raw‐estimate mapping
    is_inh = df['direction'] == 'inhibited'
    df['inhibited_by_estimation_raw'] = np.where(
        is_inh,
        df['raw_area_fwd'],
        df['raw_area_bwd']
    )
    df['inhibiting_estimation_raw'] = np.where(
        is_inh,
        df['raw_area_bwd'],
        df['raw_area_fwd']
    )
    
    # 3) corrected‐estimate mapping
    df['inhibited_by_estimation_corrected'] = np.where(
        is_inh,
        df['corrected_area_fwd'],
        df['corrected_area_bwd']
    )
    df['inhibiting_estimation_corrected'] = np.where(
        is_inh,
        df['corrected_area_bwd'],
        df['corrected_area_fwd']
    )

    # 4) asymmetry mapping
    # start by copying the existing asymmetry columns
    df['asymmetry_estimation_diff_raw']      = df['raw_area_asym_diff']
    df['asymmetry_estimation_ratio_raw']     = df['raw_area_asym_ratio']
    df['asymmetry_estimation_diff_corrected'] = df['corrected_area_asym_diff']
    df['asymmetry_estimation_ratio_corrected']= df['corrected_area_asym_ratio']
    
    # now flip/invert for the 'inhibited' direction
    mask = df['direction'] == 'inhibited'
    # diff: multiply by -1
    df.loc[mask, 'asymmetry_estimation_diff_raw']      *= -1
    df.loc[mask, 'asymmetry_estimation_diff_corrected'] *= -1
    # ratio: invert
    df.loc[mask, 'asymmetry_estimation_ratio_raw']      = 1 / df.loc[mask, 'asymmetry_estimation_ratio_raw']
    df.loc[mask, 'asymmetry_estimation_ratio_corrected'] = 1 / df.loc[mask, 'asymmetry_estimation_ratio_corrected']
    
    return df


In [None]:
df_new = reorganize_dataframe(df)

### Add a new perspective which duplicates the perspective ("MU_as_ref" or "other_MUs_as_ref") which has the most spikes ###

# 1) Max per (sim_name, MN_index, direction), ignoring 'combined'
group_max = (
    df_new[df_new.perspective != 'combined']
      .groupby(['sim_name','MN_index','direction'])['n_spikes']
      .max()
      .rename('max_noncombined_spikes')
)

# 2) Merge it back (now aligned on direction too)
df_new = df_new.merge(
    group_max,
    on=['sim_name','MN_index','direction'],
    how='left'
)

# 3) Flag the winner _per direction_
df_new['perspective_with_most_spikes'] = (
    (df_new.perspective != 'combined')
  & (df_new.n_spikes == df_new.max_noncombined_spikes)
)

# 4) Extract _both_ winners (one per direction), relabel & append them
df_most = df_new[df_new.perspective_with_most_spikes].copy()
df_most['perspective'] = 'most_spikes'

df_new = pd.concat([df_new, df_most], ignore_index=True)

# 5) Drop the helper columns
df_new = df_new.drop(columns=['max_noncombined_spikes','perspective_with_most_spikes'])

# 6) Set the IPSP time constant from membrane τ_m when requested
if load_hdf5_sim_output_for_MNs_electrophysiological_properties:
    needed = {
        'synaptic_IPSP_membrane_or_user_defined_time_constant',
        'MNprop_membrane_time_constant'
    }
    if needed.issubset(df_new.columns):
        # build a robust mask (handles stray spaces / case)
        mask = (
            df_new['synaptic_IPSP_membrane_or_user_defined_time_constant']
            .astype(str).str.strip().str.lower()
            .eq('membrane')
        )

        # assign only where mask is True; preserves existing values elsewhere
        df_new.loc[mask, 'synaptic_IPSP_decay_time_constant'] = \
            df_new.loc[mask, 'MNprop_membrane_time_constant']
    else:
        missing = needed - set(df_new.columns)
        print(f"[info] Skipping IPSP overwrite: missing columns {sorted(missing)}")

# Check what it looks like
df_new


In [None]:
# Filter according to the desired pool_pair types
def filter_by_pool_pair(df, mode="all"):
    if mode == "all":
        return df
    # split to left/right
    pools = df["pool_pair"].str.split("<->", expand=True)
    same = pools[0] == pools[1]
    if mode == "within":
        # only poolX<->poolX
        return df[same].copy()
    elif mode == "between":
        # only poolX<->poolY with X != Y
        return df[~same].copy()
    else:
        raise ValueError(f"Unknown mode {mode!r}, must be 'all'|'within'|'between'")

# then:
df_new = filter_by_pool_pair(df_new, within_between_or_all_muscle_pairs)


In [None]:
df_new.to_csv(f"{parent_folder_to_load_from}\\___general_analysis_of_simulations.csv")

In [None]:
# checking results
df_filtered = df_new.copy()
df_filtered = df_filtered[df_filtered['n_spikes']>10_000]
df_filtered = df_filtered[df_filtered['r2_base']>0.1]
df_filtered = df_filtered[df_filtered['r2_full']>0.75]
df_filtered