In [None]:
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cmasher as cmr
import os
import re
from matplotlib.colors import to_rgb
from scipy.stats import gaussian_kde
from scipy.ndimage import zoom
import matplotlib.cm as cm
from matplotlib.colors import Normalize
import itertools

In [None]:
all_files = {
    "single_muscles": {
        "path": "C:\\Users\\franc\\Documents\\GitHub\\Mapping_Recurrent_Inhibition\\Simulation_based_inference\\saved_posterior_density_estimators\\single_muscles_density_estimator",
        "posterior_by_subject": None,
        "posterior_subjects_aggregated": None,
    },
    "muscle_pairs": {
        "path": "C:\\Users\\franc\\Documents\\GitHub\\Mapping_Recurrent_Inhibition\\Simulation_based_inference\\saved_posterior_density_estimators\\paired_muscles_density_estimator",
        "posterior_by_subject": None,
        "posterior_subjects_aggregated": None,
    }
}
save_path = f"Posterior inference output figures"
if not os.path.exists(save_path):
    os.mkdir(save_path)

direction = "inhibited"  # or "inhibiting"
posterior_by_subject_filename = "posterior_samples_each_subject_df.csv"
posterior_subjects_aggregated_filename = "posterior_samples_subjects_grouped_df.csv"

dims = [
    "disynpatic_inhib_connections_desired_MN_MN",
    "common_input_std",
    "excitatory_input_baseline",
    "between_pool_excitatory_input_correlation"
]

# Aliases from your two sources → "canonical" dim names
dim_aliases = {
    "disynpatic_inhib_connections_desired_MN_MN": ["disynpatic_inhib_connections_desired_MN_MN_other_pool"],
    "common_input_std":  ["common_input_std"],
    "excitatory_input_baseline": ["excitatory_input_baseline"],
    "between_pool_excitatory_input_correlation": ["between_pool_excitatory_input_correlation"],
}

dims_to_display = [
    "disynpatic_inhib_connections_desired_MN_MN", 
    "common_input_std",
]

# Limits for canonical dims (based on your priors) => used for 0-1 normalization relative to priors
lims_by_dim = {
    # single-muscle
    "disynpatic_inhib_connections_desired_MN_MN": (0, 3),
    "common_input_std": (0, 7e3),
    "excitatory_input_baseline": (20e3, 70e3),
    # pair (alias; will map back to canonical if needed)
    "disynpatic_inhib_connections_desired_MN_MN_other_pool": (0, 3),
    "between_pool_excitatory_input_correlation": (0, 1),
}

####  DISPLAY SETTINGS  ####

colors_dict = {
    "VL<->VL": "#D62728",
    "VL<->VM": "#FF5101", # "#FF9201",
    # "VM<->VL": "#FFC400", # "#FF9201",
    "VM<->VM": "#FFC400",
    "TA<->TA": "#00C71B",
    "FDI<->FDI": "#14BFA8",
    "GM<->GM": "#2489DC",
    "GM<->SOL": "#5243FF", # "#7D74EC",
    # "SOL<->GM": "#BB86ED", # "#7D74EC",
    "SOL<->SOL": "#BB86ED",

    "VL": "#D62728",
    "VM": "#FFC400",
    "TA": "#00C71B",
    "FDI": "#14BFA8",
    "GM": "#2489DC",
    "SOL": "#BB86ED",

    "GM<-SOL": "#5243FF", # "#7D74EC",
    "GM inhibited by SOL": "#5243FF", # "#7D74EC",
    # "SOL<-GM": "#5243FF", # "#7D74EC",
    # "SOL inhibited by GM": "#5243FF", # "#7D74EC",
    # "GM->SOL": "#5243FF", # "#7D74EC",
    # "GM inhibiting SOL": "#5243FF", # "#7D74EC",
    # "SOL->GM": "#5243FF", # "#7D74EC",
    # "SOL inhibiting GM": "#5243FF", # "#7D74EC",

    "VL<-VM": "#FF5101", # "#FF9201",
    "VL inhibited by VM": "#FF5101", # "#FF9201",
    # "VM<-VL": "#FF5101", # "#FF9201",
    # "VM inhibited by VL": "#FF5101", # "#FF9201",
    # "VL->VM": "#FF5101", # "#FF9201",
    # "VL inhibiting VM": "#FF5101", # "#FF9201",
    # "VM->VL": "#FF5101", # "#FF9201",
    # "VM inhibiting VL": "#FF5101", # "#FF9201",
}

order_display = { # Only used for p delta heatmap plots
    "VL": "VL",
    "VL<->VL": "VL",
    "VL<-VL": "VL",
    "VL->VL": "VL",

    "VM<-VL": "VM inhibited by VL",
    "VL->VM": "VL inhibiting VM",

    "VM->VL": "VM inhibiting VL",
    "VL<-VM": "VL inhibited by VM",

    "VM": "VM",
    "VM<->VM": "VM",
    "VM<-VM": "VM",
    "VM->VM": "VM",

    "TA": "TA",
    "TA<->TA": "TA",
    "TA<-TA": "TA",
    "TA->TA": "TA",

    "FDI": "FDI",
    "FDI<->FDI": "FDI",
    "FDI<-FDI": "FDI",
    "FDI->FDI": "FDI",

    "GM": "GM",
    "GM<->GM": "GM",
    "GM<-GM": "GM",
    "GM->GM": "GM",

    "SOL<-GM": "SOL inhibited by GM",
    "GM->SOL": "GM inhibiting SOL",

    "GM<-SOL": "GM inhibited by SOL",
    "SOL->GM": "SOL inhibiting GM",

    "SOL": "SOL",
    "SOL<->SOL": "SOL",
    "SOL<-SOL": "SOL",
    "SOL->SOL": "SOL",
}


In [None]:
# -----------------------------------------
# Helpers
# -----------------------------------------
def _derive_muscle_label(muscle_pair: str, direction: str) -> str:
    """
    from 'A<->B' and a direction ('inhibited' or 'inhibiting'),
    return:
      - 'A' if A == B
      - 'A<-B' if direction == 'inhibited' and A != B
      - 'A->B' if direction == 'inhibiting' and A != B
    """
    if not isinstance(muscle_pair, str) or "<->" not in muscle_pair:
        return str(muscle_pair)

    left, right = muscle_pair.split("<->", 1)
    left, right = left.strip(), right.strip()

    if left == right:
        return left
    if direction == "inhibited":
        return f"{left}<-{right}"
    elif direction == "inhibiting":
        return f"{left}->{right}"
    else:
        # fallback: keep original
        return f"{left}<->{right}"

def _read_and_tag_csv(path, fname, source_key, level, direction):
    f = os.path.join(path, fname)
    if not os.path.exists(f):
        print(f"Warning: {f!r} not found, skipping")
        return None

    df = pd.read_csv(f)

    # Tag origin
    df["source_key"] = source_key                 # "single_muscles" or "muscle_pairs"
    df["level"] = level                           # "by_subject" or "subjects_aggregated"
    df["direction"] = direction                   # global direction used to build 'muscle' label

    # Identity columns
    if "subject" not in df.columns:
        df["subject"] = "ALL"

    if "intensity" in df.columns:
        df["intensity"] = pd.to_numeric(df["intensity"], errors="coerce")

    # Must have muscle_pair (your files do); if not, create a placeholder
    if "muscle_pair" not in df.columns:
        df["muscle_pair"] = np.nan

    # 'muscle' label derived from 'muscle_pair' + direction
    df["muscle"] = df["muscle_pair"].astype(str).apply(lambda s: _derive_muscle_label(s, direction))

    # Scope: single if A==B, else pair
    def _scope_from_pair(p):
        if not isinstance(p, str) or "<->" not in p:
            return "unknown"
        a, b = [t.strip() for t in p.split("<->", 1)]
        return "single" if a == b else "pair"

    df["scope"] = df["muscle_pair"].astype(str).apply(_scope_from_pair)
    return df

def normalize_columns_inplace(df, lims_by_dim, suffix=None, clip=True):
    """
    Normalize chosen columns to [0,1] using lims_by_dim={col: (low, high)}.
    If suffix is provided (e.g., '_unit'), writes to new columns; otherwise overwrites in place.
    """
    for col, (lo, hi) in lims_by_dim.items():
        if col in df.columns:
            width = (hi - lo) if (hi is not None and lo is not None) else None
            if width is None or width == 0:
                continue
            vals = (df[col].astype(float) - lo) / width
            if clip:
                vals = vals.clip(0.0, 1.0)
            if suffix:
                df[col + suffix] = vals
            else:
                df[col] = vals
    return df

# -----------------------------------------
# Build a single unified DF
# -----------------------------------------
frames_all = []

for key_i, val_i in all_files.items():
    path = val_i.get("path")
    if not path:
        continue

    # by-subject
    df_bs = _read_and_tag_csv(path, posterior_by_subject_filename, source_key=key_i,
                              level="by_subject", direction=direction)
    if df_bs is not None:
        frames_all.append(df_bs)

    # aggregated across subjects
    df_ag = _read_and_tag_csv(path, posterior_subjects_aggregated_filename, source_key=key_i,
                              level="subjects_aggregated", direction=direction)
    if df_ag is not None:
        frames_all.append(df_ag)

df_all_posteriors = pd.concat(frames_all, ignore_index=True, sort=False) if frames_all else pd.DataFrame()

# Normalize desired columns (overwrite, or write to new columns with suffix)
# Overwrite in place:
normalize_columns_inplace(df_all_posteriors, lims_by_dim, suffix=None, clip=True)
# If you prefer to keep originals: use suffix="_unit" instead of None.
# normalize_columns_inplace(df_all_posteriors, lims_by_dim, suffix="_unit", clip=True)

# Reorder a few meta columns first
meta_first = ["source_key", "level", "scope", "direction", "subject", "muscle", "muscle_pair", "intensity"]
cols = [c for c in meta_first if c in df_all_posteriors.columns] + \
       [c for c in df_all_posteriors.columns if c not in meta_first]
df_all_posteriors = df_all_posteriors[cols]

print(df_all_posteriors.shape, "rows x cols")
print(df_all_posteriors[["source_key","level","scope"]].value_counts())

In [None]:
df_all_posteriors # check result of previous cell

In [None]:
df_pooled_subjects = df_all_posteriors[df_all_posteriors['level']=="subjects_aggregated"]
df_each_subject = df_all_posteriors[df_all_posteriors['level']!="subjects_aggregated"]

# Pairplots (2D density plot + marginal 1D densities)

In [None]:
def multi_muscle_pairplot_scatter_kde(
    df: pd.DataFrame,
    *,
    # Column harmonization:
    dim_aliases: dict[str, list[str]],
    dims_to_display: list[str],          # EXACTLY [X_dim, Y_dim] → X is bottom-right, Y is top-left
    # Coloring & grouping:
    hue_by: str,                          # e.g., "muscle"
    muscle_colors: dict[str, str],        # map hue value -> hex
    stratify_by: list[str] | None = None, # e.g., ["intensity"]
    # Prior bounds for normalization:
    lims_by_dim: dict[str, tuple] | None = None,
    normalize_to_unit: bool = True,
    assume_pre_normalized: bool = False,
    # Style knobs:
    show_scatter: bool = True,            # turn off points to show density-only with fills
    point_size: float = 26,
    scatter_alpha: float = 0.30,
    kde1d_lw: float = 2.2,
    kde2d_levels: tuple = (0.3, 0.6, 0.9),  # mass levels
    kde2d_lw: float = 1.6,
    kde2d_alpha: float = 1.0,
    kde2d_grid: int = 220,
    darker_mix: float = 0.5,              # 0..1: 0 = original color, 1 = black (for contour lines)
    # NEW: smoothing controls
    kde1d_bw: float | str | None = "scott",  # "scott", "silverman", float multiplier, or callable
    kde2d_bw: float | str | None = "scott",  # idem, for the joint 2-D KDE
    # Density-only fills (when show_scatter=False):
    kde2d_fill_alpha: float = 0.10,       # per-pass alpha for recursive fill
    # Axes control (optional):
    xlim: tuple | None = None,
    ylim: tuple | None = None,
    # IO:
    save_dir: str = ".",
    filename_suffix: str = "posterior_pairs",
    dpi: int = 300,
    show_legend: bool = True,
):
    """
    2×2 process-style pairplot:
      TL: Y-marginal KDE per muscle
      TR: joint scatter per muscle + per-muscle 2D KDE contours (darker color)
          (if show_scatter=False: density-only with recursive fills of mass levels)
      BL: legend (marker + line for each muscle)
      BR: X-marginal KDE per muscle

    dims_to_display must be [X_dim, Y_dim].

    Smoothing:
      - kde1d_bw: passed to scipy.stats.gaussian_kde(..., bw_method=...) for 1-D marginals
      - kde2d_bw: same for 2-D joint KDE
    """
    import os, re
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from matplotlib.colors import to_rgb
    from matplotlib.gridspec import GridSpec
    from matplotlib.lines import Line2D
    from scipy.stats import gaussian_kde

    os.makedirs(save_dir, exist_ok=True)

    # ---------- helpers ----------
    def _build_alias_maps(dim_aliases):
        alias_to_canon = {}
        canon_to_searchlist = {}
        for canon, aliases in dim_aliases.items():
            search = [canon] + list(aliases)
            canon_to_searchlist[canon] = search
            alias_to_canon[canon] = canon
            for a in aliases:
                alias_to_canon[a] = canon
        return alias_to_canon, canon_to_searchlist

    def _alias_columns_to_canonical(df, dim_aliases):
        df = df.copy()
        searchlists = {canon: [canon] + list(aliases)
                       for canon, aliases in dim_aliases.items()}
        for canon, search in searchlists.items():
            present = [col for col in search if col in df.columns]
            if not present:
                df[canon] = np.nan
            else:
                df[canon] = df[present].bfill(axis=1).iloc[:, 0]
        return df

    def _robust_limits(arr, pad=0.05):
        arr = np.asarray(arr, float)
        arr = arr[np.isfinite(arr)]
        if arr.size == 0:
            return (0.0, 1.0)
        lo = float(np.nanpercentile(arr, 1))
        hi = float(np.nanpercentile(arr, 99))
        if not np.isfinite(lo) or not np.isfinite(hi):
            lo, hi = np.nanmin(arr), np.nanmax(arr)
        if lo == hi:
            lo, hi = lo - 0.5, hi + 0.5
        span = hi - lo
        return (lo - pad * span, hi + pad * span)

    def _safe_kde_1d(x, grid, bw_method=None):
        x = np.asarray(x, float)
        x = x[np.isfinite(x)]
        if x.size < 3 or np.nanstd(x) == 0:
            return None
        try:
            kde = gaussian_kde(x, bw_method=bw_method)
            return kde(grid)
        except Exception:
            return None

    def _darker(hexcol: str, mix: float = 0.5):
        r, g, b = to_rgb(hexcol)
        m = np.clip(float(mix), 0.0, 1.0)
        return (r*(1-m), g*(1-m), b*(1-m))

    def _kde2d_norm(X, Y, xlim, ylim, grid_n=220, bw_method=None):
        X = np.asarray(X, float); Y = np.asarray(Y, float)
        mask = np.isfinite(X) & np.isfinite(Y)
        X, Y = X[mask], Y[mask]
        if X.size < 3 or np.nanstd(X) == 0 or np.nanstd(Y) == 0:
            return None, None, None
        gx = np.linspace(xlim[0], xlim[1], int(grid_n))
        gy = np.linspace(ylim[0], ylim[1], int(grid_n))
        XX, YY = np.meshgrid(gx, gy)
        try:
            kde = gaussian_kde(np.vstack([X, Y]), bw_method=bw_method)
            dens = kde(np.vstack([XX.ravel(), YY.ravel()])).reshape(XX.shape)
        except Exception:
            return None, None, None
        dmin, dmax = float(np.nanmin(dens)), float(np.nanmax(dens))
        if not np.isfinite(dmin) or not np.isfinite(dmax) or dmax <= dmin:
            return None, None, None
        dens_norm = (dens - dmin) / (dmax - dmin)
        return XX, YY, dens_norm

    def _pretty_group_title(keys, vals):
        if not keys: return "All"
        if isinstance(vals, tuple): return ", ".join(f"{k}={v!s}" for k, v in zip(keys, vals))
        return f"{keys[0]}={vals!s}"

    def _fname_id(keys, vals):
        if not keys: return "all"
        if not isinstance(vals, tuple): vals = (vals,)
        raw = "_".join(f"{k}-{v}" for k, v in zip(keys, vals))
        return re.sub(r"[^\w\-\.]", "_", raw)

    # ---------- dims & aliasing ----------
    if len(dims_to_display) != 2:
        raise ValueError("dims_to_display must be exactly [X_dim, Y_dim].")
    X_req, Y_req = dims_to_display  # explicit: first is X (BR), second is Y (TL)

    lims_by_dim = lims_by_dim or {}
    alias_to_canon, _ = _build_alias_maps(dim_aliases)

    bounds_by_canon = {}
    for k, v in lims_by_dim.items():
        canon_k = alias_to_canon.get(k, k)
        bounds_by_canon[canon_k] = tuple(v)

    df_c = _alias_columns_to_canonical(df, dim_aliases)

    xdim_canon = alias_to_canon.get(X_req, X_req)
    ydim_canon = alias_to_canon.get(Y_req, Y_req)
    if xdim_canon not in df_c.columns or ydim_canon not in df_c.columns:
        raise KeyError(f"Requested dims not found after aliasing: {xdim_canon}, {ydim_canon}")

    # Build plot columns (normalized or raw)
    def _prep_plot_col(d):
        col = df_c[d].to_numpy(dtype=float)
        if normalize_to_unit and (d in bounds_by_canon):
            if assume_pre_normalized:
                vals = col; lims = (0.0, 1.0)
            else:
                lo, hi = bounds_by_canon[d]
                span = (hi - lo) if (hi > lo) else 1.0
                vals = (col - lo) / span
                lims = (0.0, 1.0)
            return vals, lims
        else:
            vals = col
            lims = _robust_limits(vals)
            return vals, lims

    xvals_all, xlims_def = _prep_plot_col(xdim_canon)
    yvals_all, ylims_def = _prep_plot_col(ydim_canon)

    xlim_use = xlims_def if xlim is None else xlim
    ylim_use = ylims_def if ylim is None else ylim

    # Grouping
    if stratify_by:
        grouped = df_c.assign(__x=xvals_all, __y=yvals_all).groupby(stratify_by, dropna=False)
    else:
        grouped = [((), df_c.assign(__x=xvals_all, __y=yvals_all))]

    out_paths = []

    # ---------- plotting per group ----------
    for group_key, sub in grouped:
        fig = plt.figure(figsize=(8.4, 8.4), constrained_layout=False)
        gs  = GridSpec(2, 2, width_ratios=[1, 1], height_ratios=[1, 1],
                       wspace=0.07, hspace=0.07, figure=fig)
        ax_tl = fig.add_subplot(gs[0, 0])  # Y marginal (KDE)
        ax_tr = fig.add_subplot(gs[0, 1])  # joint
        ax_bl = fig.add_subplot(gs[1, 0])  # legend
        ax_br = fig.add_subplot(gs[1, 1])  # X marginal (KDE)

        hue_vals = list(sub[hue_by].dropna().unique())
        hue_vals = [hv for hv in hue_vals if hv in muscle_colors]

        gx = np.linspace(xlim_use[0], xlim_use[1], 300)
        gy = np.linspace(ylim_use[0], ylim_use[1], 300)

        # ---- TL: Y marginal KDEs ----
        for hv in hue_vals:
            col = muscle_colors[hv]
            y = sub.loc[sub[hue_by]==hv, "__y"].to_numpy(dtype=float)
            dens = _safe_kde_1d(y, gy, bw_method=kde1d_bw)
            if dens is not None:
                ax_tl.plot(gy, dens, color=col, lw=kde1d_lw, label=str(hv))
        ax_tl.set_xlim(ylim_use[0], ylim_use[1])
        ax_tl.set_xlabel(ydim_canon); ax_tl.set_ylabel("density")
        try: ax_tl.set_box_aspect(1)
        except: pass

        # ---- BR: X marginal KDEs ----
        for hv in hue_vals:
            col = muscle_colors[hv]
            x = sub.loc[sub[hue_by]==hv, "__x"].to_numpy(dtype=float)
            dens = _safe_kde_1d(x, gx, bw_method=kde1d_bw)
            if dens is not None:
                ax_br.plot(gx, dens, color=col, lw=kde1d_lw, label=str(hv))
        ax_br.set_xlim(xlim_use[0], xlim_use[1])
        ax_br.set_xlabel(xdim_canon); ax_br.set_ylabel("density")
        try: ax_br.set_box_aspect(1)
        except: pass

        # ---- TR: joint (scatter + 2D KDE) ----
        levels_sorted_inc = tuple(sorted(kde2d_levels))
        levels_sorted_dec = tuple(sorted(kde2d_levels, reverse=True))

        for hv in hue_vals:
            col = muscle_colors[hv]
            x = sub.loc[sub[hue_by]==hv, "__x"].to_numpy(dtype=float)
            y = sub.loc[sub[hue_by]==hv, "__y"].to_numpy(dtype=float)
            m = np.isfinite(x) & np.isfinite(y)
            if show_scatter and np.any(m):
                ax_tr.scatter(x[m], y[m], s=point_size, c=col, edgecolor='none', alpha=scatter_alpha, label=str(hv))
            XX, YY, densN = _kde2d_norm(x[m], y[m], xlim_use, ylim_use, grid_n=kde2d_grid, bw_method=kde2d_bw) if np.any(m) else (None, None, None)
            if densN is not None:
                dark = _darker(col, darker_mix)
                if not show_scatter:
                    for i, lev in enumerate(levels_sorted_dec):
                        for _ in range(i+1):
                            ax_tr.contourf(XX, YY, densN, levels=[lev, 1.01],
                                           colors=[col], alpha=kde2d_fill_alpha, antialiased=True)
                ax_tr.contour(XX, YY, densN, levels=levels_sorted_inc, colors=[dark],
                              linewidths=kde2d_lw, alpha=kde2d_alpha)

        ax_tr.set_xlim(xlim_use[0], xlim_use[1]); ax_tr.set_ylim(ylim_use[0], ylim_use[1])
        ax_tr.set_xlabel(xdim_canon); ax_tr.set_ylabel(ydim_canon)
        try: ax_tr.set_box_aspect(1)
        except: pass

        # ---- Legend ----
        ax_bl.axis("off")
        if show_legend and hue_vals:
            proxies, labels = [], []
            for hv in hue_vals:
                col = muscle_colors[hv]
                if show_scatter:
                    h = Line2D([0], [0], color=col, lw=2.0,
                               marker='o', markersize=6, markerfacecolor=col, markeredgecolor='none')
                else:
                    h = Line2D([0], [0], color=col, lw=2.0)
                proxies.append(h); labels.append(str(hv))
            if proxies:
                ax_bl.legend(proxies, labels, frameon=False, loc="center", ncol=1)

        title = _pretty_group_title(stratify_by or [], group_key)
        fig.suptitle(title, y=0.98, fontsize=12)

        fname_id = _fname_id(stratify_by or [], group_key)
        out_svg = os.path.join(save_dir, f"{filename_suffix}_{fname_id}.svg")
        fig.savefig(out_svg, dpi=dpi)
        plt.show()
        out_paths.append(out_svg)

    return out_paths


### Subjects pooled (what is reported in the paper)

In [None]:
out_paths = multi_muscle_pairplot_scatter_kde(
    df=df_pooled_subjects,
    dim_aliases=dim_aliases,
    dims_to_display=["common_input_std", "disynpatic_inhib_connections_desired_MN_MN"],  # X, Y
    hue_by="muscle",
    muscle_colors=colors_dict,
    stratify_by=["intensity"],
    lims_by_dim=lims_by_dim,
    # normalize_to_unit=True,
    # assume_pre_normalized=True,
    normalize_to_unit=False,
    assume_pre_normalized=False,

    # visuals
    show_scatter=False,             # set False to see density-only with recursive fills
    point_size=20,
    scatter_alpha=0.01,
    kde1d_lw=2,
    kde2d_levels=(0.3, 0.6, 0.9),
    kde2d_lw=1.4,
    kde2d_alpha=1.0,
    kde2d_grid=220,
    darker_mix=0,
    kde2d_fill_alpha=0.1,         # used only when show_scatter=False

    # KDE smoothing ; larger values = smoother
    # kde1d_bw="scott", # "scott" or "silverman"
    # kde2d_bw="scott", # "scott" or "silverman"
    kde1d_bw=0.2,          # manually specified bandwidth multiplier
    kde2d_bw=0.2,          # manually specified bandwidth multiplier

    save_dir=save_path,
    filename_suffix="posterior_pairs_onlykde",
    dpi=300,
    show_legend=True,
)


### Per-subject pairplots

In [None]:
save_path_per_subject = f"Posterior inference output figures\\Individual participants pairplots"
if not os.path.exists(save_path_per_subject):
    os.mkdir(save_path_per_subject)

In [None]:
for subject_i in df_each_subject['subject'].unique():
    print(f"Processing {subject_i}...")
    df_temp = df_each_subject[df_each_subject['subject']==subject_i]
    out_paths = multi_muscle_pairplot_scatter_kde(
        df=df_temp,
        dim_aliases=dim_aliases,
        dims_to_display=["common_input_std", "disynpatic_inhib_connections_desired_MN_MN"],  # X, Y
        hue_by="muscle",
        muscle_colors=colors_dict,
        stratify_by=["intensity"],
        lims_by_dim=lims_by_dim,
        # normalize_to_unit=True,
        # assume_pre_normalized=True,
        normalize_to_unit=False,
        assume_pre_normalized=False,

        # visuals
        show_scatter=False,             # set False to see density-only with recursive fills
        point_size=20,
        scatter_alpha=0.01,
        kde1d_lw=2,
        kde2d_levels=(0.3, 0.6, 0.9),
        kde2d_lw=1.4,
        kde2d_alpha=1.0,
        kde2d_grid=220,
        darker_mix=0,
        kde2d_fill_alpha=0.1,         # used only when show_scatter=False

        # KDE smoothing ; larger values = smoother
        # kde1d_bw="scott", # "scott" or "silverman"
        # kde2d_bw="scott", # "scott" or "silverman"
        kde1d_bw=0.2,          # manually specified bandwidth multiplier
        kde2d_bw=0.2,          # manually specified bandwidth multiplier

        save_dir=save_path_per_subject,
        filename_suffix=f"{subject_i}_posterior",
        dpi=300,
        show_legend=True,
)


# Pairwise between-muscles & between-intensity comparisons of posteriors

In [None]:
import numpy as np
import pandas as pd
import itertools
from typing import Optional

# --------- Column coalescing ----------
def coalesce_param_series(
    df: pd.DataFrame,
    param: str,
    dim_aliases: dict[str, list[str]] | None = None,
) -> pd.Series:
    """
    Return a single Series with the parameter values, coalescing the canonical
    column `param` and any alias columns listed in `dim_aliases[param]`.
    If `param` is itself an alias in someone else's list, we also include those canonicals.
    """
    candidates = [param]
    if dim_aliases and param in dim_aliases:
        candidates += dim_aliases[param]
    if dim_aliases:
        for canon, alist in dim_aliases.items():
            if param in alist and canon not in candidates:
                candidates.append(canon)
    candidates = [c for c in candidates if c in df.columns]
    if not candidates:
        return pd.Series(index=df.index, dtype=float)
    out = df[candidates[0]].copy()
    for c in candidates[1:]:
        out = out.where(~out.isna(), df[c])
    return out


# --------- pδ at one intensity ----------
def compare_groups_pdelta_at_intensity(
    df: pd.DataFrame,
    intensity: float,
    param: str,
    group_col: str = "muscle_pair",
    min_samples: int = 1,
    dim_aliases: dict[str, list[str]] | None = None,
    drop_empty_groups: bool = False,
    by_cols: list[str] | None = None,
) -> pd.DataFrame:
    """
    For a given intensity, compute pΔ and 95% ETI for every UNORDERED pair of groups.
    If `by_cols` is provided, comparisons are done within each stratum.
    """
    df_i = df[df["intensity"] == intensity].copy()
    df_i["_param_eff"] = coalesce_param_series(df_i, param, dim_aliases)
    df_i = df_i[~df_i["_param_eff"].isna()]

    strata = df_i.groupby(by_cols, dropna=False) if by_cols else [((), df_i)]
    out_rows = []
    for skey, sub in strata:
        if drop_empty_groups:
            counts = sub.groupby(group_col)["_param_eff"].size()
            groups = [g for g, c in counts.items() if c >= min_samples]
        else:
            groups = sub[group_col].dropna().unique().tolist()

        for A, B in itertools.combinations(sorted(groups), 2):
            sampA = sub.loc[sub[group_col] == A, "_param_eff"].to_numpy()
            sampB = sub.loc[sub[group_col] == B, "_param_eff"].to_numpy()
            if len(sampA) < min_samples or len(sampB) < min_samples:
                continue
            n = min(len(sampA), len(sampB))
            delta = sampA[:n] - sampB[:n]
            p = float(np.mean(delta > 0))
            lo, hi = np.percentile(delta, [2.5, 97.5])
            row = {
                "intensity": intensity,
                "group_A": A, "group_B": B,
                "p_delta_pos": p,
                "ci_lower": lo, "ci_upper": hi,
                "nA": int(len(sampA)), "nB": int(len(sampB)),
            }
            if by_cols:
                if not isinstance(skey, tuple): skey = (skey,)
                row.update({k: v for k, v in zip(by_cols, skey)})
            out_rows.append(row)
    return pd.DataFrame(out_rows)


# --------- Δd at one intensity ----------
def _normalize_series_for_param(
    s: pd.Series,
    param: str,
    *,
    lims_by_dim: dict[str, tuple] | None = None,
    normalize_to_unit: bool = False,
    assume_pre_normalized: bool = False,
) -> pd.Series:
    if not normalize_to_unit:
        return s.astype(float)
    if assume_pre_normalized:
        return s.astype(float)
    if lims_by_dim is None or param not in lims_by_dim:
        raise KeyError(f"normalize_to_unit=True but no bounds for '{param}' in lims_by_dim.")
    lo, hi = lims_by_dim[param]
    span = (hi - lo) if (hi > lo) else 1.0
    return (s.astype(float) - lo) / span

def compare_groups_deltad_at_intensity(
    df: pd.DataFrame,
    intensity: float,
    param: str,
    group_col: str = "muscle_pair",
    dim_aliases: dict[str, list[str]] | None = None,
    min_samples: int = 1,
    drop_empty_groups: bool = False,
    by_cols: list[str] | None = None,
    *,

    # normalization controls (same semantics as your pairplot code)
    lims_by_dim: dict[str, tuple] | None = None,
    normalize_to_unit: bool = False,
    assume_pre_normalized: bool = False,

    # Δd definition
    delta_d_method: str = "median_diff",   # "median_diff" or "hl"
    hl_max_pairs: int = 200_000,
    hl_rng_seed: int = 0,
) -> pd.DataFrame:
    """
    For a given intensity, compute Δd for every UNORDERED pair of groups.

    delta_d_method:
      - "median_diff": Δd(A,B) = median(A) − median(B)
      - "hl":          Δd(A,B) = Hodges–Lehmann shift = median_{i,j}(a_i − b_j)

    Optionally normalize values to 0–1 using lims_by_dim (prior bounds).
    """
    df_i = df[df["intensity"] == intensity].copy()
    df_i["_param_eff"] = coalesce_param_series(df_i, param, dim_aliases)
    df_i = df_i[~df_i["_param_eff"].isna()]
    if df_i.empty:
        return pd.DataFrame([])

    # Normalized working column for Δd
    df_i["__val"] = _normalize_series_for_param(
        df_i["_param_eff"], param,
        lims_by_dim=lims_by_dim,
        normalize_to_unit=normalize_to_unit,
        assume_pre_normalized=assume_pre_normalized,
    )

    strata = df_i.groupby(by_cols, dropna=False) if by_cols else [((), df_i)]
    out_rows = []

    for skey, sub in strata:
        # eligible groups
        if drop_empty_groups:
            counts = sub.groupby(group_col)["__val"].size()
            groups = [g for g, c in counts.items() if c >= min_samples]
        else:
            groups = sub[group_col].dropna().unique().tolist()

        # per-group arrays (finite) and medians
        arrays = {
            g: sub.loc[sub[group_col] == g, "__val"].to_numpy(float)
            for g in groups
        }
        arrays = {g: v[np.isfinite(v)] for g, v in arrays.items()}
        med = {g: (np.nanmedian(v) if v.size else np.nan) for g, v in arrays.items()}
        n   = {g: int(arrays[g].size) for g in arrays.keys()}

        groups = sorted([g for g in groups if n[g] >= min_samples])

        for A, B in itertools.combinations(groups, 2):
            if n[A] < min_samples or n[B] < min_samples:
                continue

            if delta_d_method == "hl":
                val = hodges_lehmann(arrays[A], arrays[B],
                                     max_pairs=hl_max_pairs, rng_seed=hl_rng_seed)
            else:  # "median_diff"
                val = float(med[A] - med[B])

            row = {
                "intensity": intensity,
                "group_A": A, "group_B": B,
                "delta_d": val,
                "median_A": float(med[A]) if np.isfinite(med[A]) else np.nan,
                "median_B": float(med[B]) if np.isfinite(med[B]) else np.nan,
                "nA": n[A], "nB": n[B],
            }
            if by_cols:
                if not isinstance(skey, tuple): skey = (skey,)
                row.update({k: v for k, v in zip(by_cols, skey)})
            out_rows.append(row)

    return pd.DataFrame(out_rows)


# --------- intensity strip: pδ high vs low ----------
def compare_intensities_for_group(
    df: pd.DataFrame,
    group_value: str,
    param: str,
    group_col: str = "muscle_pair",
    low_intensity: float | None = None,
    high_intensity: float | None = None,
    min_samples: int = 1,
    dim_aliases: dict[str, list[str]] | None = None,
    skip_if_missing: bool = False,
    by_filters: dict[str, object] | None = None,
) -> pd.DataFrame:
    """
    P( high > low ), as you already had.
    """
    sub = df[df[group_col] == group_value].copy()
    if by_filters:
        for k, v in by_filters.items():
            sub = sub[sub[k] == v]
    sub["_param_eff"] = coalesce_param_series(sub, param, dim_aliases)
    sub = sub[~sub["_param_eff"].isna()]

    ints = sorted(sub["intensity"].unique())
    if (low_intensity is None) or (high_intensity is None):
        if len(ints) < 2:
            if skip_if_missing:
                return pd.DataFrame([])
            return pd.DataFrame([{
                group_col: group_value, "intensity_low": np.nan, "intensity_high": np.nan,
                "p_delta_pos": np.nan, "ci_lower": np.nan, "ci_upper": np.nan,
                "n_low": 0, "n_high": 0, **(by_filters or {})
            }])
        low_intensity, high_intensity = min(ints), max(ints)

    samp_low  = sub.loc[sub["intensity"] == low_intensity,  "_param_eff"].to_numpy()
    samp_high = sub.loc[sub["intensity"] == high_intensity, "_param_eff"].to_numpy()
    if len(samp_low) < min_samples or len(samp_high) < min_samples:
        if skip_if_missing:
            return pd.DataFrame([])
        return pd.DataFrame([{
            group_col: group_value,
            "intensity_low": low_intensity, "intensity_high": high_intensity,
            "p_delta_pos": np.nan, "ci_lower": np.nan, "ci_upper": np.nan,
            "n_low": int(len(samp_low)), "n_high": int(len(samp_high)),
            **(by_filters or {})
        }])

    n = min(len(samp_low), len(samp_high))
    delta = samp_high[:n] - samp_low[:n]
    p = float(np.mean(delta > 0))
    lo, hi = np.percentile(delta, [2.5, 97.5])
    row = {
        group_col: group_value,
        "intensity_low": low_intensity,
        "intensity_high": high_intensity,
        "p_delta_pos": p, "ci_lower": lo, "ci_upper": hi,
        "n_low": int(len(samp_low)), "n_high": int(len(samp_high)),
    }
    if by_filters:
        row.update(by_filters)
    return pd.DataFrame([row])

# --------- Hodges–Lehmann estimator (median of pairwise differences) ----------
def hodges_lehmann(x, y, max_pairs=200_000, rng_seed=0):
    """
    HL(x,y) = median_{i,j} (x_i - y_j).
    Exact if len(x)*len(y) <= max_pairs; otherwise uses a random subset of pairs.
    """
    x = np.asarray(x, float); x = x[np.isfinite(x)]
    y = np.asarray(y, float); y = y[np.isfinite(y)]
    if x.size == 0 or y.size == 0:
        return np.nan

    total_pairs = x.size * y.size
    if total_pairs <= max_pairs:
        diffs = (x[:, None] - y[None, :]).ravel()
        return float(np.median(diffs))

    # Subsample ~max_pairs pairs via independent index sampling
    rng = np.random.default_rng(rng_seed)
    k = int(np.sqrt(max_pairs))
    ix = rng.integers(0, x.size, size=k)
    iy = rng.integers(0, y.size, size=k)
    diffs = x[ix][:, None] - y[iy][None, :]
    return float(np.median(diffs))


# --------- Intensity strip: Δd (high vs low) with selectable definition ----------
def compare_intensities_deltad_for_group(
    df: pd.DataFrame,
    group_value: str,
    param: str,
    group_col: str = "muscle_pair",
    low_intensity: float | None = None,
    high_intensity: float | None = None,
    min_samples: int = 1,
    dim_aliases: dict[str, list[str]] | None = None,
    skip_if_missing: bool = False,
    by_filters: dict[str, object] | None = None,
    *,
    # normalization controls (same semantics as your pairplot code)
    lims_by_dim: dict[str, tuple] | None = None,
    normalize_to_unit: bool = False,
    assume_pre_normalized: bool = False,
    # Δd definition
    delta_d_method: str = "median_diff",   # "median_diff" or "hl"
    hl_max_pairs: int = 200_000,
    hl_rng_seed: int = 0,
) -> pd.DataFrame:
    """
    Compute Δd between the high and low intensity distributions for a single group.

    delta_d_method:
      - "median_diff": Δd = median(high) − median(low)
      - "hl":          Δd = Hodges–Lehmann shift = median_{i,j}(high_i − low_j)

    Normalization:
      If normalize_to_unit=True and lims_by_dim contains bounds for `param`,
      values are mapped to [0,1]. If assume_pre_normalized=True, values are
      used as-is but axes/bounds are conceptually [0,1].
    """
    sub = df[df[group_col] == group_value].copy()
    if by_filters:
        for k, v in by_filters.items():
            sub = sub[sub[k] == v]

    # Coalesce param column(s) per your helper
    sub["_param_eff"] = coalesce_param_series(sub, param, dim_aliases)
    sub = sub[~sub["_param_eff"].isna()]

    # Pick intensities if not given
    ints = sorted(sub["intensity"].unique())
    if (low_intensity is None) or (high_intensity is None):
        if len(ints) < 2:
            if skip_if_missing:
                return pd.DataFrame([])
            return pd.DataFrame([{
                group_col: group_value, "intensity_low": np.nan, "intensity_high": np.nan,
                "delta_d": np.nan, "n_low": 0, "n_high": 0, **(by_filters or {})
            }])
        low_intensity, high_intensity = min(ints), max(ints)

    # Normalize working copy (expects your _normalize_series_for_param helper to exist)
    sub["__val"] = _normalize_series_for_param(
        sub["_param_eff"], param,
        lims_by_dim=lims_by_dim,
        normalize_to_unit=normalize_to_unit,
        assume_pre_normalized=assume_pre_normalized,
    )

    # Extract finite arrays for each intensity
    v_low  = sub.loc[sub["intensity"] == low_intensity,  "__val"].to_numpy(dtype=float)
    v_high = sub.loc[sub["intensity"] == high_intensity, "__val"].to_numpy(dtype=float)
    v_low  = v_low[np.isfinite(v_low)]
    v_high = v_high[np.isfinite(v_high)]

    n_low  = int(v_low.size)
    n_high = int(v_high.size)
    if n_low < min_samples or n_high < min_samples:
        if skip_if_missing:
            return pd.DataFrame([])
        return pd.DataFrame([{
            group_col: group_value,
            "intensity_low": low_intensity, "intensity_high": high_intensity,
            "delta_d": np.nan, "n_low": n_low, "n_high": n_high,
            **(by_filters or {})
        }])

    # Compute Δd with the selected method (sign = high − low)
    if delta_d_method == "hl":
        dval = hodges_lehmann(v_high, v_low, max_pairs=hl_max_pairs, rng_seed=hl_rng_seed)
    else:  # "median_diff"
        dval = float(np.nanmedian(v_high) - np.nanmedian(v_low))

    row = {
        group_col: group_value,
        "intensity_low": low_intensity, "intensity_high": high_intensity,
        "delta_d": dval, "n_low": n_low, "n_high": n_high,
    }
    if by_filters:
        row.update(by_filters)
    return pd.DataFrame([row])


In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, TwoSlopeNorm
from matplotlib.cm import ScalarMappable
import numpy as np
import pandas as pd

def plot_pairwise_heatmaps_single_metric(
    *,
    df_low: pd.DataFrame,         # pairwise comparisons at low intensity (e.g., 10%)
    df_high: pd.DataFrame,        # pairwise comparisons at high intensity (e.g., 40%)
    df_intensity: pd.DataFrame,   # per-group comparisons high vs low (strip)
    metric_col: str,              # 'p_delta_pos' or 'delta_d'
    metric_label: str,            # e.g., "pδ = P(row > col)" or "Δd = median(row) − median(col)"
    pair_col_A: str = "group_A",
    pair_col_B: str = "group_B",
    # groups / labels / colors
    group_values: list[str] | None = None,
    display_map: dict[str, str] | None = None,
    muscle_colors: dict[str, str] | None = None,
    group_col_label: str = "group",
    mp_col_in_strip: str | None = None,     # column in df_intensity that holds the group id
    intensity_strip_label: str = "High vs Low",
    # symmetry fill for the opposite triangle (optional)
    symmetric_fill: bool = True,            # fill M[B,A] too
    symmetric_rule: str | None = None,      # None|'one_minus'|'negate'|'copy'  (auto if None)
    # color scaling
    cmap: str = "RdBu_r",
    vmin: float | None = None,
    vmax: float | None = None,
    vcenter: float | None = None,           # set 0 for Δd; None for pδ; or 0.5 for pδ if you want diverging
    # layout
    title_low: str = "Low intensity",
    title_high: str = "High intensity",
    figsize=(24, 10),
    savepath: str = ".",
    savename_suffix: str = "",
    title_suffix: str = "",
    dpi: int = 300,
    annotate: bool = True,
    annotate_fontsize: int = 11,
    annotate_fontweight: str = "regular",
):
    """
    Draws ONE metric (pδ or Δd) for (low, high) + an intensity strip using the same metric.
    Robust to having (B,A) present but not (A,B): mirrors via symmetric_rule to avoid NaNs.
    """
    # --- derive group list from inputs if not given
    def _nonan_unique(x):
        return pd.Series(x).dropna().unique().tolist()

    pair_groups = set(_nonan_unique(df_low.get(pair_col_A, [])))  \
                | set(_nonan_unique(df_low.get(pair_col_B, [])))  \
                | set(_nonan_unique(df_high.get(pair_col_A, []))) \
                | set(_nonan_unique(df_high.get(pair_col_B, [])))

    # strip group id column
    if mp_col_in_strip is None:
        candidates = [c for c in df_intensity.columns
                      if c not in {"intensity_low","intensity_high",
                                   "p_delta_pos","delta_d","ci_lower","ci_upper",
                                   "n_low","n_high"}]
        mp_col_in_strip = candidates[0] if candidates else "group"
    strip_groups = set(_nonan_unique(df_intensity.get(mp_col_in_strip, [])))

    all_groups_present = sorted(pair_groups | strip_groups)
    if group_values is None:
        mps = all_groups_present
    else:
        mps = [g for g in group_values if g in all_groups_present]

    disp = {g: (display_map[g] if display_map and g in display_map else g) for g in mps}

    # --- symmetric rule auto default
    if symmetric_rule is None:
        symmetric_rule = 'one_minus' if metric_col == "p_delta_pos" else 'negate'

    # --- build square matrix; accept rows in either direction and mirror if needed
    def make_matrix_bi(df_pairwise: pd.DataFrame) -> pd.DataFrame:
        M = pd.DataFrame(np.nan, index=mps, columns=mps)
        if df_pairwise is None or df_pairwise.empty:
            # diagonal default
            np.fill_diagonal(M.values, 0.5 if metric_col == "p_delta_pos" else 0.0)
            return M

        # Insert all provided (A,B)
        for _, row in df_pairwise.iterrows():
            A = row.get(pair_col_A, None)
            B = row.get(pair_col_B, None)
            val = row.get(metric_col, np.nan)
            if (A in M.index) and (B in M.columns) and pd.notna(val):
                M.at[A, B] = float(val)

        # Now mirror where only the reverse exists:
        n = len(M.index)
        for i, A in enumerate(mps):
            for j, B in enumerate(mps):
                if i == j:
                    continue
                vAB = M.at[A, B]
                vBA = M.at[B, A]
                # if only reverse is present, fill forward using the rule's inverse
                if pd.isna(vAB) and pd.notna(vBA):
                    if symmetric_rule == 'one_minus':
                        M.at[A, B] = 1.0 - float(vBA)
                    elif symmetric_rule == 'negate':
                        M.at[A, B] = -float(vBA)
                    else:  # 'copy'
                        M.at[A, B] = float(vBA)

        # Finally, if asked, fill the other triangle from what we now have
        if symmetric_fill:
            for i, A in enumerate(mps):
                for j, B in enumerate(mps):
                    if i >= j:
                        continue
                    vAB = M.at[A, B]
                    vBA = M.at[B, A]
                    if pd.notna(vAB):
                        if symmetric_rule == 'one_minus':
                            M.at[B, A] = 1.0 - float(vAB)
                        elif symmetric_rule == 'negate':
                            M.at[B, A] = -float(vAB)
                        else:
                            M.at[B, A] = float(vAB)
                    elif pd.notna(vBA):
                        if symmetric_rule == 'one_minus':
                            M.at[A, B] = 1.0 - float(vBA)
                        elif symmetric_rule == 'negate':
                            M.at[A, B] = -float(vBA)
                        else:
                            M.at[A, B] = float(vBA)

        # diagonal defaults
        np.fill_diagonal(M.values, 0.5 if metric_col == "p_delta_pos" else 0.0)
        return M

    Mlow  = make_matrix_bi(df_low)
    Mhigh = make_matrix_bi(df_high)

    # --- color scaling (include strip values if using diverging)
    cmap_obj = plt.get_cmap(cmap).copy()
    cmap_obj.set_bad("lightgrey")

    if vcenter is None:
        # pδ-like in [0,1]
        if vmin is None: vmin = 0.0
        if vmax is None: vmax = 1.0
        norm = Normalize(vmin, vmax)
    else:
        # Δd-like → diverging; auto symmetric if vmin/vmax not given
        if vmin is None or vmax is None:
            vals = []
            for M in (Mlow.values, Mhigh.values):
                arr = np.asarray(M, float)
                arr = arr[np.isfinite(arr)]
                if arr.size:
                    vals.append(arr)
            # include strip too (same metric)
            strip_vals = df_intensity.set_index(mp_col_in_strip)[metric_col].reindex(mps).to_numpy(float)
            strip_vals = strip_vals[np.isfinite(strip_vals)]
            if strip_vals.size:
                vals.append(strip_vals)
            if vals:
                allv = np.concatenate(vals)
                bound = float(np.nanmax(np.abs(allv))) if allv.size else 1.0
            else:
                bound = 1.0
            vmin, vmax = -bound, +bound
        norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)

    # --- figure
    fig, axes = plt.subplots(1, 3, figsize=figsize)

    def draw_panel(ax, M: pd.DataFrame, title_text: str):
        im = ax.imshow(M.values, cmap=cmap_obj, norm=norm, origin="lower")
        ax.invert_xaxis()
        ax.set_xticks(np.arange(len(mps)))
        ax.set_xticklabels([disp[g] for g in mps], rotation=90, weight="bold")
        ax.set_yticks(np.arange(len(mps)))
        ax.set_yticklabels([disp[g] for g in mps], weight="bold")
        ax.set_xlabel(f"... compared to this {group_col_label}?")
        ax.set_ylabel(f"This {group_col_label} ...")
        ax.set_title(title_text)
        if muscle_colors:
            for lbl in ax.get_xticklabels():
                lbl.set_color(muscle_colors.get(lbl.get_text(), "grey"))
            for lbl in ax.get_yticklabels():
                lbl.set_color(muscle_colors.get(lbl.get_text(), "grey"))
        if annotate:
            for i in range(len(mps)):
                for j in range(len(mps)):
                    v = M.iat[i, j]
                    if pd.isna(v): 
                        continue
                    # if metric_col == "p_delta_pos":
                    #     # stars near extremes
                    #     star = "★" if (v > 0.95 or v < 0.05) else ""
                    #     txt = f"{v:.2f}{star}"
                    #     color = "white" if abs(v - 0.5) > 0.25 else "black"
                    # else:
                    txt = f"{v:.2f}"
                    vmax_abs = max(abs(getattr(norm, "vmin", -1.0)), abs(getattr(norm, "vmax", 1.0)))
                    # color = "white" if abs(v) > 0.8 * vmax_abs else "black"
                    if metric_col == "p_delta_pos":
                        color = "white" if abs(v - 0.5) > 0.25 else "black"
                    else: 
                        color = "black"
                    ax.text(j, i, txt, ha="center", va="center",
                    fontsize=annotate_fontsize, fontweight=annotate_fontweight, color=color)
        return im

    im_low  = draw_panel(axes[0], Mlow,  title_low)
    im_high = draw_panel(axes[1], Mhigh, title_high)

    # --- strip panel — same metric (and annotated)
    ax = axes[2]
    strip_series = df_intensity.set_index(mp_col_in_strip)[metric_col].reindex(mps)
    col = strip_series.values.reshape(-1, 1)
    im_strip = ax.imshow(col, cmap=cmap_obj, norm=norm, origin="lower")
    ax.set_xticks([0]); ax.set_xticklabels([intensity_strip_label], weight="bold")
    ax.set_yticks(np.arange(len(mps))); ax.set_yticklabels([disp[g] for g in mps], weight="bold")
    ax.set_title(f"{intensity_strip_label} — {metric_label}")
    if muscle_colors:
        for lbl in ax.get_yticklabels():
            lbl.set_color(muscle_colors.get(lbl.get_text(), "grey"))
    if annotate:
        for i, v in enumerate(strip_series):
            if pd.isna(v): 
                continue
            # if metric_col == "p_delta_pos":
            #     star = "★" if (v > 0.95 or v < 0.05) else ""
            #     txt = f"{v:.2f}{star}"
            #     color = "white" if abs(v - 0.5) > 0.25 else "black"
            # else:
            txt = f"{v:.2f}"
            vmax_abs = max(abs(getattr(norm, "vmin", -1.0)), abs(getattr(norm, "vmax", 1.0)))
            if metric_col == "p_delta_pos":
                color = "white" if abs(v - 0.5) > 0.25 else "black"
            else: 
                color = "black"
            ax.text(0, i, txt, ha="center", va="center",
                    fontsize=annotate_fontsize, fontweight=annotate_fontweight, color=color)

    # --- colorbar
    cax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    fig.colorbar(ScalarMappable(norm=norm, cmap=cmap_obj), cax=cax, label=metric_label)

    plt.tight_layout(rect=[0, 0, 0.9, 1])
    if title_suffix:
        plt.suptitle(title_suffix, fontsize=15, y=1.02)

    if savepath and savepath != ".":
        fn = f"pairwise_heatmaps_{metric_col}_{title_suffix.replace(' ','_')}_{savename_suffix}"
        fig.savefig(f"{savepath}/{fn}.png", bbox_inches="tight", dpi=dpi)
        fig.savefig(f"{savepath}/{fn}.svg", bbox_inches="tight", dpi=dpi)
    plt.show()


In [None]:
# Manual diverging colormap helper function
import numpy as np
from matplotlib.colors import LinearSegmentedColormap, to_rgb
def make_diverging_cmap_from_lists(
    color_min_list,
    color_max_list,
    color_mid="#FFFFFF",
    *,
    name="custom_div_from_lists",
    bad_color="lightgrey",
):
    """
    Create a smooth diverging colormap that interpolates linearly:
        color_min_list[0] -> ... -> color_min_list[-1] -> color_mid (at 0.5)
        -> color_max_list[0] -> ... -> color_max_list[-1]

    The midpoint color is *exactly* at 0.5 even if left/right lists have different lengths.

    Args
    ----
    color_min_list : list[str]     # e.g. ["blue", "cyan"]
    color_max_list : list[str]     # e.g. ["orange", "red"]
    color_mid      : str           # e.g. "white"
    name           : str           # colormap name (for debugging/registration)
    bad_color      : str           # color for NaNs

    Returns
    -------
    matplotlib.colors.LinearSegmentedColormap
    """
    # sanitize inputs
    color_min_list = list(color_min_list or [])
    color_max_list = list(color_max_list or [])

    if not color_min_list and not color_max_list:
        raise ValueError("Provide at least one color on either side.")

    stops = []
    cols  = []

    # Left side: spread evenly in [0, 0.5), excluding 0.5
    nL = len(color_min_list)
    if nL > 0:
        left_pos = np.linspace(0.0, 0.5, nL + 1, endpoint=True)[:-1]  # nL stops, last < 0.5
        stops.extend(left_pos.tolist())
        cols.extend(color_min_list)

    # Midpoint exactly at 0.5
    stops.append(0.5)
    cols.append(color_mid)

    # Right side: spread evenly in (0.5, 1], excluding 0.5
    nR = len(color_max_list)
    if nR > 0:
        right_pos = np.linspace(0.5, 1.0, nR + 1, endpoint=True)[1:]  # nR stops, first > 0.5
        stops.extend(right_pos.tolist())
        cols.extend(color_max_list)

    # Build colormap
    cmap = LinearSegmentedColormap.from_list(name, list(zip(stops, cols)))
    cmap.set_bad(bad_color)
    return cmap



In [None]:
# --- Inputs
param_to_use = "disynpatic_inhib_connections_desired_MN_MN" # "common_input_std"
# param_to_use = "common_input_std"
group_col    = "muscle"
int_low, int_high = 10.0, 40.0

cmap_pdelta = make_diverging_cmap_from_lists(
        color_min_list=["#3C5FD4", "#65A9F7"],
        color_max_list=["#FF9B8E", "#DF463B"],
        color_mid="#FFFFFF")
    # cmr.prinsenvlag_r
    # "RdBu_r"  # pδ in [0,1]
cmap_deltad = make_diverging_cmap_from_lists(
    color_min_list=["#AF46C4", "#D87EDB"],
    color_max_list=["#FFD856", "#FFAE00"],
    color_mid="#FFFFFF") # "PuOr_r" # Δd in [-max,+max]
    ### ### Choose delta method ### ### 
    # hodges_lehmann HL(x,y) = median_{i,j} (x_i - y_j) => More robust
    # delta_d_method="hl", hl_max_pairs=200_000, hl_rng_seed=0
    ### or use "median_diff" for median(high)-median(low)
delta_d_method = "median_diff"  # "median_diff" or "hl"
annotate_fontsize = 18
annotate_fontweight = "regular"

df_use = df_all_posteriors[df_all_posteriors["level"] == "subjects_aggregated"].copy()

# Order & labels
all_groups = sorted(df_use[group_col].unique())
group_vals = [g for g in order_display.keys() if g in all_groups]
group_vals = [g for g in group_vals if g in colors_dict]
# reverse order (y-axis top-to-bottom and x axis left-to-right)
group_vals = list(reversed(group_vals))
disp_map   = {g: order_display[g] for g in group_vals}


# -------------------------------
# A) pδ metric
# -------------------------------
df_pdelta_low  = compare_groups_pdelta_at_intensity(
    df=df_use, intensity=int_low, param=param_to_use,
    group_col=group_col, dim_aliases=dim_aliases,
    min_samples=1, drop_empty_groups=True
)
df_pdelta_high = compare_groups_pdelta_at_intensity(
    df=df_use, intensity=int_high, param=param_to_use,
    group_col=group_col, dim_aliases=dim_aliases,
    min_samples=1, drop_empty_groups=True
)
df_pdelta_strip = pd.concat([
    compare_intensities_for_group(
        df_use, g, param=param_to_use, group_col=group_col,
        low_intensity=int_low, high_intensity=int_high,
        dim_aliases=dim_aliases, skip_if_missing=True
    ) for g in group_vals
], ignore_index=True).rename(columns={group_col: "group"})

plot_pairwise_heatmaps_single_metric(
    df_low=df_pdelta_low,
    df_high=df_pdelta_high,
    df_intensity=df_pdelta_strip,
    metric_col="p_delta_pos",
    metric_label="pδ = P(row > col)",
    pair_col_A="group_A", pair_col_B="group_B",
    group_values=group_vals,
    display_map=disp_map,
    muscle_colors=colors_dict,
    group_col_label=("pair" if group_col=="muscle_pair" else "direction"),
    mp_col_in_strip="group",
    symmetric_fill=True,            # also fills upper triangle with 1-p
    symmetric_rule="one_minus",
    cmap=cmap_pdelta,
    vmin=0.0, vmax=1.0, vcenter=None,
    title_low="10%", title_high="40%",
    savepath=save_path,
    title_suffix=f"{param_to_use} — pdelta ({group_col})",
    annotate_fontsize=annotate_fontsize,
    annotate_fontweight=annotate_fontweight
)

if delta_d_method == "median_diff":
    metric_label_deltad = "Δd = median(row) − median(col)"
else:
    metric_label_deltad = "Δd = Hodges–Lehmann (median_{i,j} (x_i - y_j))"
# -------------------------------
# B) Δd metric (raw or normalized)
# -------------------------------
# choose normalization mode:
USE_NORMALIZED = True  # set False for raw units
df_deltad_low  = compare_groups_deltad_at_intensity(
    df=df_use, intensity=int_low, param=param_to_use,
    group_col=group_col, dim_aliases=dim_aliases,
    min_samples=1, drop_empty_groups=True,
    lims_by_dim=lims_by_dim,                   # only needed if normalize=True
    normalize_to_unit=USE_NORMALIZED,
    assume_pre_normalized=True,               # True if your DF already in [0,1]
    ### ### Choose delta method ### ### 
    # hodges_lehmann HL(x,y) = median_{i,j} (x_i - y_j) => More robust
    # delta_d_method="hl", hl_max_pairs=200_000, hl_rng_seed=0
    ### or use "median_diff" for median(high)-median(low)
    delta_d_method=delta_d_method
)
df_deltad_high = compare_groups_deltad_at_intensity(
    df=df_use, intensity=int_high, param=param_to_use,
    group_col=group_col, dim_aliases=dim_aliases,
    min_samples=1, drop_empty_groups=True,
    lims_by_dim=lims_by_dim,
    normalize_to_unit=USE_NORMALIZED,
    assume_pre_normalized=True,
    ### ### Choose delta method ### ### 
    # hodges_lehmann HL(x,y) = median_{i,j} (x_i - y_j) => More robust
    # delta_d_method="hl", hl_max_pairs=200_000, hl_rng_seed=0
    ### or use "median_diff" for median(high)-median(low)
    delta_d_method=delta_d_method
)
df_deltad_strip = pd.concat([
    compare_intensities_deltad_for_group(
        df_use, g, param=param_to_use, group_col=group_col,
        low_intensity=int_low, high_intensity=int_high,
        dim_aliases=dim_aliases, skip_if_missing=True,
        lims_by_dim=lims_by_dim,
        normalize_to_unit=USE_NORMALIZED,
        assume_pre_normalized=True,
    ) for g in group_vals
], ignore_index=True).rename(columns={group_col: "group"})

plot_pairwise_heatmaps_single_metric(
    df_low=df_deltad_low,
    df_high=df_deltad_high,
    df_intensity=df_deltad_strip,
    metric_col=f"delta_d",
    metric_label=metric_label_deltad,
    pair_col_A="group_A", pair_col_B="group_B",
    group_values=group_vals,
    display_map=disp_map,
    muscle_colors=colors_dict,
    group_col_label=("pair" if group_col=="muscle_pair" else "direction"),
    mp_col_in_strip="group",
    symmetric_fill=True,            # fills upper triangle with −Δd
    symmetric_rule="negate",
    cmap=cmap_deltad,                    # pick any diverging colormap
    # vcenter=0.0, vmin=None, vmax=None,  # auto symmetric ±max|Δd|
    vcenter=0.0, vmin=-0.5, vmax=0.5,  # auto symmetric ±max|Δd|
    title_low="10%", title_high="40%",
    savepath=save_path,
    title_suffix=f"{param_to_use} — deltad ({group_col}) - {delta_d_method}",
    annotate_fontsize=annotate_fontsize,
    annotate_fontweight=annotate_fontweight
)
