# Section 5: Performance Analysis - arXiv:2507.16022

This notebook reproduces the performance analysis from Section 5 of **"Sampler-free gravitational wave inference using matrix multiplication"**.

The analysis covers:
1. **Integration Performance**: Speed and efficiency of the dot-PE method vs traditional samplers
2. **Parameter Estimation Performance**: Accuracy and convergence behavior across different bank densities and parameter ranges

This notebook processes results from:
- Reference runs (dense banks) - establishing ground truth
- Convergence runs (regular banks) - testing parameter variations


In [None]:
# Setup and imports
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import warnings
from tqdm import tqdm
import pickle
import time

# Scientific computing
import scipy.stats
import scipy.ndimage
from scipy.interpolate import interp1d

# Import packages
from cogwheel import data, utils, gw_plotting, gw_utils
from dot_pe import inference

# Suppress warnings
warnings.filterwarnings("ignore")

# Plotting style
plt.style.use("seaborn-v0_8-whitegrid")
sns.set_palette("husl")


In [None]:
bank_names = [f"bank_mchirp_{x}" for x in [3, 5, 10, 20, 50, 100]]
dense_bank_names = [f"bank_mchirp_dense_{x}" for x in [3, 5, 10, 20, 50, 100]]
banks_homedir = Path(
    "/home/projects/barakz/Collaboration-gw/mushkin//magic_integral/banks/"
)

bank_folders = [banks_homedir / bank_name for bank_name in bank_names]
dense_bank_folders = [
    banks_homedir / bank_name for bank_name in dense_bank_names
]

bank_configs = [
    utils.read_json(bank_folder / "bank_config.json")
    for bank_folder in bank_folders
]

mass_ranges = {
    bank_name: (bank_config["min_mchirp"], bank_config["max_mchirp"])
    for bank_name, bank_config in zip(bank_names, bank_configs)
}

bank_labels = {
    bank_name: r"$\mathcal{M} \in ("
    + f"{mass_ranges[bank_name][0]:.3g},{mass_ranges[bank_name][1]:.3g}"
    + r"){\rm M}_{\odot}$"
    for bank_name in bank_names
}
markers = ["o", "s", "D", "^", "v", "P"]  # Define 6 different markers
bank_markers = {
    bank_name: marker for bank_name, marker in zip(bank_names, markers)
}
color_map = plt.colormaps["tab10"]
bank_colors = {bank: color_map(i) for i, bank in enumerate(bank_names)}


# Define functions


In [None]:
def load_pickle(path):
    try:
        if Path(path).exists():
            with open(path, "rb") as fp:
                return pickle.load(fp)
        else:
            raise FileNotFoundError(f"File not found: {path}")
    except (pickle.UnpicklingError, FileNotFoundError, IOError) as e:
        print(f"Error loading pickle file: {e}")
        return None


def get_joint_df_of_injections():
    dfs = []
    for bank_name in bank_names:
        df = pd.read_feather(f"../{bank_name}/injections.feather")
        df["eventname"] = [f"injection_{bank_name}_{i:04}" for i in range(len(df))]
        df["bank_name"] = bank_name
        dfs.append(df)
    return pd.concat(dfs).reset_index(inplace=False, drop=True)


def update_on_eventname(df1, df2):
    common = list(set(df1["eventname"]) & set(df2["eventname"]))
    df1 = df1[df1["eventname"].isin(common)].set_index("eventname")
    df2 = df2[df2["eventname"].isin(common)].set_index("eventname")

    updated = df1.combine_first(df2)
    updated = updated.loc[common]  # ensure output is limited to common eventnames
    updated.reset_index(inplace=True)
    return updated


def fill_df_with_params(df, inplace=True):
    if not inplace:
        df = df.copy()
    df["mchirp"] = gw_utils.m1m2_to_mchirp(df["m1"], df["m2"])
    df["chieff"] = gw_utils.chieff(df["m1"], df["m2"], df["s1z"], df["s2z"])
    df["eta"] = gw_utils.q_to_eta(df["m2"] / df["m1"])
    df["lnq"] = np.log(df["m2"] / df["m1"])
    if not inplace:
        return df


def sort_path_list(l):
    return sorted(l, key=lambda x: int(x.stem.split("_")[-1]))


def load_condition(rundir):
    """Return True/False if to load/not load results from a rundir"""
    rundir = Path(rundir)
    if not (rundir / "run_kwargs.json").exists():
        return False
    if not (rundir / "summary_results.json").exists():
        return False
    if not (rundir / "samples.feather").exists():
        return False
    # summary_results =  utils.read_json(rundir / "summary_results.json")
    return True


def get_mchirp_quantiles(
    rundir: Path, quantiles: list[float] = [0.05, 0.95], n_min: int = 20
) -> np.ndarray:
    rundir = Path(rundir)
    default_output = np.array([np.nan, np.nan], dtype=float)
    samples_path = rundir / "samples.feather"
    if not samples_path.exists():
        return default_output
    samples = pd.read_feather(samples_path)
    if len(samples) < n_min:
        return default_output
    qunatiles = samples["mchirp"].quantile(quantiles).to_numpy(dtype=float)
    if np.diff(quantiles) == 0:
        return default_output
    return qunatiles


def load_results(bank_names, inj_dfs=None):
    rows = []
    homedir = Path(".")  # Current directory contains the results
    for b, bank_name in enumerate(bank_names):
        if inj_dfs is None:
            inj_df = pd.read_feather(homedir / bank_name / "injections.feather")
        else:
            inj_df = inj_dfs[b]
        inj_df["eventname"] = [
            f"injection_{bank_name.replace('_dense_', '_')}_{i:04}"
            for i in range(len(inj_df))
        ]
        inj_df["bank_name"] = bank_name

        event_dirs = sort_path_list((homedir / f"{bank_name}_events").glob("injection_*"))
        for event_dir in tqdm(event_dirs, desc=bank_name):
            eventname = event_dir.name
            pickle_path = event_dir / "fisher_analysis" / "results.pkl"
            if (pickle_path).exists():
                fisher_results = load_pickle(pickle_path)
                fisher_results = {
                    k: fisher_results[k]
                    for k in [
                        "sigma_Mc_from_fisher",
                        "FIM_inversion_error",
                    ]
                }
            else:
                fisher_results = {}

            rundirs = sort_path_list(event_dir.glob("run_*"))
            for rundir in rundirs:
                if load_condition(rundir):
                    summary_results = utils.read_json(rundir / "summary_results.json")

                    run_kwargs = utils.read_json(rundir / "run_kwargs.json")
                    run_kwargs.pop("seed", None)  # Remove seed if present

                    par_dic = (
                        inj_df.loc[inj_df.eventname == eventname].iloc[0].to_dict()
                    )
                    row = (
                        run_kwargs
                        | summary_results
                        | par_dic
                        | dict(
                            rundir=str(rundir),
                            eventname=event_dir.name,
                            dense_bank_name=bank_name,
                        )
                        | fisher_results
                    )
                    rows.append(row)

    df = pd.DataFrame(rows)
    return df


def plot_2d_hist(
    ax, dfs, params, colors=None, smooth=1.0, labels=None, weights_col=None
):
    """
    Create a 2D histogram plot similar to MultiCornerPlot but for a single panel.

    Parameters

    ----------
    ax : matplotlib.axes.Axes
        The axis to plot on
    dfs : list of pandas.DataFrame
        List of dataframes containing the data for each overlay
    params : tuple of str
        (x_param, y_param) names of columns to plot
    colors : list of str, optional
        Colors for each dataset. If None, will use default color cycle
    smooth : float, default=1.0
        Smoothing factor for the gaussian filter
    labels : list of str, optional
        Labels for each dataset for the legend
    weights_col : str, optional
        Name of the column containing weights. If None, no weights are used.
    """

    if colors is None:
        # Generate colors exactly as in the original implementation
        tab20_colors = plt.cm.tab20.colors
        tab20b_colors = plt.cm.tab20b.colors
        all_colors = (
            tab20_colors[::2]
            + tab20_colors[1::2]
            + tab20b_colors[::2]
            + tab20b_colors[1::2]
        )
        colors = [all_colors[i % len(all_colors)] for i in range(len(dfs))]

    # Contour levels for 90% and 50% containment (in decreasing order as in original)
    contour_fractions = [0.9, 0.5]

    for i, df in enumerate(dfs):
        xpar, ypar = params

        # Get weights if specified
        weights = df.get(weights_col) if weights_col else None

        # Calculate number of bins using Rice rule
        if weights is None:
            n_effective = len(df)
        else:
            # Calculate effective sample size for weighted data
            n_effective = np.square(np.sum(weights)) / np.sum(np.square(weights))
        n_bins = int(np.ceil(2 * np.cbrt(n_effective)))
        valid_mask = df[[xpar, ypar]].notna().all(axis=1)
        df = df[valid_mask]
        if len(df) == 0:
            continue
        # Calculate 2D histogram
        hist2d, xedges, yedges = np.histogram2d(
            df[xpar], df[ypar], bins=n_bins, weights=weights
        )

        # Apply gaussian smoothing
        hist2d = scipy.ndimage.gaussian_filter(hist2d, smooth)

        # Get midpoints for interpolation
        x_mid = (xedges[1:] + xedges[:-1]) / 2
        y_mid = (yedges[1:] + yedges[:-1]) / 2

        # Interpolate to get PDF at edges
        pdf = scipy.interpolate.RectBivariateSpline(x_mid, y_mid, hist2d)(
            xedges, yedges
        ).T

        # Calculate levels for contours
        sorted_pdf = [0.0] + sorted(pdf.ravel())
        cdf = np.cumsum(sorted_pdf)
        cdf /= cdf[-1]
        ccdf = 1 - cdf
        levels = np.interp(contour_fractions, ccdf[::-1], sorted_pdf[::-1])
        print(f"  PDF max: {pdf.max()}, levels: {levels}")

        # Plot contours
        extent = (
            df[xpar].min(),
            df[xpar].max(),
            df[ypar].min(),
            df[ypar].max(),
        )

        # Draw contours
        contour = ax.contour(
            pdf,
            extent=extent,
            levels=levels,
            colors=[colors[i]],
            linestyles=["-"],
        )

        # Fill contours with transparency
        alphas = 1 - np.array(
            contour_fractions
        )  # Transparency increases with confidence
        next_levels = [*levels[1:], np.inf]
        for level_edges, alpha in zip(zip(levels, next_levels), alphas):
            ax.contourf(
                pdf,
                extent=extent,
                levels=level_edges,
                colors=[colors[i]],
                alpha=alpha,
            )

    if labels is not None:
        ax.legend(labels, frameon=False, loc="upper right")

    # Set axis labels
    ax.set_xlabel(params[0])
    ax.set_ylabel(params[1])

    # Add ticks on all sides and rotate them
    ax.tick_params(which="both", direction="in", right=True, top=True, rotation=45)

    return ax


def get_rms(x):
    return np.sqrt(np.mean(x**2))


def get_abs(x):
    return np.abs(x).item()


def return_self(x):
    return x

def process_subplot(
    ax,
    df,
    x_col,
    y_col,
    bank_names,
    labels,
    markers,
    func=lambda x: x,
    min_counts_for_display=10,
    xscale="log",
    yscale="log",
):
    for b, bank_name in enumerate(bank_names):
        color = bank_colors[bank_name]
        all_lines = []
        filtered = df[df["eventname"].str.contains(f"{bank_name}_", na=False)]

        for eventname in filtered["eventname"].unique():
            cut = filtered[filtered["eventname"] == eventname]

            if cut.empty:
                continue

            # Apply func row-wise to y_col
            x_vals = cut[x_col].to_numpy()
            y_vals = cut[y_col].apply(func).to_numpy()

            if len(x_vals) == 0 or len(y_vals) == 0:
                continue

            all_lines.append((x_vals, y_vals))

        if not all_lines:
            continue

        x_common = np.geomspace(
            min(x.min() for x, _ in all_lines),
            max(x.max() for x, _ in all_lines),
            10,
        )
        y_interp = []
        for x, y in all_lines:
            f = interp1d(x, y, bounds_error=False, fill_value=np.nan)
            y_interp.append(f(x_common))
        y_interp = np.array(y_interp)

        y_median = np.nanmedian(y_interp, axis=0)
        y_q25 = np.nanpercentile(y_interp, 25, axis=0)
        y_q75 = np.nanpercentile(y_interp, 75, axis=0)

        valid_counts = np.sum(~np.isnan(y_interp), axis=0)
        if min_counts_for_display:
            cond = valid_counts >= min_counts_for_display
            y_median[~cond] = np.nan
            y_q25[~cond] = np.nan
            y_q75[~cond] = np.nan

        ax.plot(
            x_common,
            y_median,
            color=color,
            lw=2,
            label=labels[bank_name],
            marker=markers[b],
        )
        ax.fill_between(x_common, y_q25, y_q75, color=color, alpha=0.1)

    ax.set_xscale(xscale)
    ax.set_yscale(yscale)
    ax.grid()


def get_crossing_points_from_ax(ax, y_threshold):
    results = []

    for coll in ax.collections:
        if not hasattr(coll, "get_paths"):
            continue

        paths = coll.get_paths()
        if not paths:
            continue

        verts = paths[0].vertices
        n = len(verts) // 2

        x = verts[:n, 0]
        y_lower = verts[:n, 1]
        y_upper = verts[n:, 1][
            ::-1
        ]  # reversed: upper bound goes right-to-left

        # Find first x where upper bound drops below threshold
        below = y_upper < y_threshold
        if np.any(below):
            idx = np.argmax(below)
        else:
            idx = len(x) - 1

        results.append(x[idx])

    return results


## Create key dataframes

In [None]:
injections = [
    pd.read_feather(f"{bank_name}_events/injections.feather")
    for bank_name in bank_names
]


In [None]:
dense_bank_results = load_results(dense_bank_names, injections)
dense_bank_results[["mchirp_005", "mchirp_095"]] = pd.DataFrame(
    dense_bank_results["rundir"].apply(get_mchirp_quantiles).tolist(),
    columns=["mchirp_005", "mchirp_095"],
)
dense_bank_results.to_feather("dense_bank_results.feather")

In [None]:
# print the number of events successfuly done in each bank
for bank_name in dense_bank_results.bank_name.unique():
    cond = dense_bank_results.bank_name == bank_name
    print(bank_name, len(dense_bank_results.loc[cond].eventname.unique()))

In [None]:
## reference dataframe
# take > 10 mintues to run, prefer to load it
_t = time.time()
bank_results = load_results(bank_names, injections)

_t = time.time() - _t
print(f"passed {_t / 60:.3g} minutes")

bank_results.to_feather("standard_bank_results.feather")

In [None]:
## load key dataframes from files

injections = [
    pd.read_feather(f"{bank_name}_events/injections.feather")
    for bank_name in bank_names
]

bank_results = pd.read_feather("standard_bank_results.feather")
dense_bank_results = pd.read_feather("dense_bank_results.feather")

bank_results.drop(["n_draws", "inds_path", "delete_blocks"], axis=1, inplace=True, errors='ignore')

dense_bank_results.drop(["n_draws", "inds_path"], axis=1, inplace=True, errors='ignore')


## Data Loading and Processing

This section loads results from both reference runs (dense banks) and convergence runs (regular banks), then processes them for performance analysis.


In [None]:
# Create the merged dataframe with reference results (ground truth)
# This combines convergence runs with dense bank reference runs

# Setup for dropping duplicates - define which columns identify unique runs
kwargs_setup_dense = ["bank_name", "eventname", "n_int", "n_ext"]
kwargs_setup_regular = ["bank_name", "eventname", "n_int", "n_ext"] 

# Create unique reference results from dense banks (ground truth)
unique_reference_results = dense_bank_results.drop_duplicates(
    subset=kwargs_setup_dense, keep="first"
).copy()

unique_reference_results.rename(
    columns={
        "rundir": "rundir_dense_bank", 
        "ln_evidence": "ln_evidence_reference",
        "n_i_inds_used": "n_i_inds_used_dense_bank",
        "n_effective": "n_effective_dense_bank",
        "n_effective_i": "n_effective_i_dense_bank", 
        "n_effective_e": "n_effective_e_dense_bank",
    },
    inplace=True,
)

unique_reference_results = unique_reference_results[[
    "rundir_dense_bank",
    "eventname", 
    "bank_name",
    "ln_evidence_reference",
    "n_i_inds_used_dense_bank",
    "n_effective_dense_bank",
    "n_effective_i_dense_bank",
    "n_effective_e_dense_bank",
]].copy()  # drop other columns to avoid confusion

# Create unique regular bank results 
unique_bank_results = bank_results.drop_duplicates(
    subset=kwargs_setup_regular, keep="first"  
).copy()

# Merge convergence results with reference results
bank_results_with_reference = update_on_eventname(
    df1=unique_bank_results, df2=unique_reference_results
)

# Calculate derived quantities for analysis
bank_results_with_reference["n_effective_harmonic"] = 2 / (
    1 / bank_results_with_reference["n_effective_i"] 
    + 1 / bank_results_with_reference["n_effective_e"]
)

bank_results_with_reference["ln_evidence_error"] = (
    bank_results_with_reference["ln_evidence"] 
    - bank_results_with_reference["ln_evidence_reference"]
)

bank_results_with_reference["abs_ln_evidence_error"] = np.abs(
    bank_results_with_reference["ln_evidence_error"]
)

bank_results_with_reference["norm_error"] = (
    np.exp(bank_results_with_reference["ln_evidence_error"]) - 1
)
bank_results_with_reference["norm_abs_error"] = np.abs(
    bank_results_with_reference["norm_error"]
)

print(f"Created bank_results_with_reference with {len(bank_results_with_reference)} rows")
print("Available columns:", list(bank_results_with_reference.columns))


## Figure 2: Convergence Analysis

Now we recreate Figure 2 from the paper, showing convergence behavior of ln Z accuracy as a function of intrinsic samples (N_int), preselected intrinsic samples (N_int'), and extrinsic samples (N_ext).


In [None]:
fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(12, 4), sharey=True)
axs = axs.flatten()
min_counts_for_display = 100
label_fontsize = 14
legend_fontsize = 10
legend_title_fontsize = 11

for ax in axs:
    ax.axhline(
        1,
        ls="--",
        c="k",
        label=r"$\ln\mathcal{Z}=\ln\hat\mathcal{Z}\pm1$",
    )

common_y_col = "abs_ln_evidence_error"
common_y_label = (
    r"$\left\vert \ln\mathcal{Z} - \ln\hat{\mathcal{Z}} \right\vert$"
)

# Panel 1: N_int vs ln Z accuracy (N_ext fixed to 1024)
ax = axs[0]
process_subplot(
    ax,
    bank_results_with_reference.loc[
        bank_results_with_reference["n_ext"] == 1024
    ],
    x_col="n_int",
    y_col=common_y_col,
    bank_names=bank_names,
    labels=bank_labels,
    markers=markers,
    func=return_self,
    min_counts_for_display=min_counts_for_display,
)
ax.set_xlabel(r"$N_{\rm int.}$", fontsize=label_fontsize)
ax.set_ylabel(common_y_label, fontsize=label_fontsize)

# Panel 2: N_int' (preselected) vs ln Z accuracy (N_ext fixed to 1024)
ax = axs[1]
process_subplot(
    ax,
    bank_results_with_reference.loc[
        bank_results_with_reference["n_ext"] == 1024
    ],
    x_col="n_i_inds_used",
    y_col=common_y_col,
    bank_names=bank_names,
    labels=bank_labels,
    markers=markers,
    func=return_self,
    min_counts_for_display=min_counts_for_display,
)
ax.set_xlabel(r"$N_{\rm int.}'$", fontsize=label_fontsize)

# Panel 3: N_ext vs ln Z accuracy (N_int fixed to 2^16)
ax = axs[2]
process_subplot(
    ax,
    bank_results_with_reference.loc[
        (bank_results_with_reference["n_int"] == 2**16)
        * (bank_results_with_reference["n_ext"] <= 1024)
    ],
    x_col="n_ext",
    y_col=common_y_col,
    bank_names=bank_names,
    labels=bank_labels,
    markers=markers,
    func=return_self,
    min_counts_for_display=min_counts_for_display,
)
ax.set_xlabel(r"$N_{\rm ext.}$", fontsize=label_fontsize)
leg = ax.legend(fontsize=legend_fontsize, bbox_to_anchor=(1.05, 0.975))

fig.tight_layout()
fig.savefig("convergence_n_int_n_ext.pdf", format="pdf")

print("Figure 2: Convergence analysis complete")


## Figure 4: Effective Sample Size Analysis

Now we recreate Figure 4 from the paper, showing the relationship between effective sample sizes (N_eff_int, N_eff_ext, and their harmonic mean) and inference accuracy.


In [None]:
# Figure 4: Three-panel n_effective plot - effective samples vs ln Z accuracy 
fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(12, 4), sharey=True)

min_n_eff_i = 32
min_n_eff_e = 32
min_counts_for_display = 100
label_fontsize = 14
legend_fontsize = 10
legend_title_fontsize = 11
y_label = r"$\left \vert \mathcal{Z}/\hat\mathcal{Z}-1\right\vert$"

# Panel 1: n_effective_int vs accuracy (n_ext >= min threshold)
ax = axs[0]
cond = bank_results_with_reference["n_ext"] >= min_n_eff_e
common_y_col = "norm_abs_error"
process_subplot(
    ax,
    bank_results_with_reference.loc[cond],
    x_col="n_effective_i",
    y_col=common_y_col,
    bank_names=bank_names,
    labels=bank_labels,
    markers=markers,
    func=get_abs,
    min_counts_for_display=min_counts_for_display,
    xscale="log",
    yscale="log",
)
ax.set_xlabel(r"$N_{\rm eff. int.}$", fontsize=label_fontsize)
ax.set_ylabel(y_label, fontsize=label_fontsize)

# Panel 2: n_effective_ext vs accuracy (n_int >= min threshold)
ax = axs[1]
process_subplot(
    ax,
    bank_results_with_reference.loc[
        bank_results_with_reference["n_effective_i"] >= min_n_eff_i
    ],
    x_col="n_effective_e",
    y_col=common_y_col,
    bank_names=bank_names,
    labels=bank_labels,
    markers=markers,
    func=get_abs,
    min_counts_for_display=min_counts_for_display,
)
ax.set_xlabel(r"$N_{\rm eff. ext.}$", fontsize=label_fontsize)

# Panel 3: harmonic mean of n_effective vs accuracy
ax = axs[2]
process_subplot(
    ax,
    bank_results_with_reference,
    x_col="n_effective_harmonic",
    y_col=common_y_col,
    bank_names=bank_names,
    labels=bank_labels,
    markers=markers,
    func=get_abs,
    min_counts_for_display=min_counts_for_display,
)
ax.set_xlabel(
    r"Harmonic Mean $(N_{\rm eff. int.},N_{\rm eff. ext.})$",
    fontsize=label_fontsize,
)

for ax in axs:
    ax.set_yscale("log")

# Add legend to rightmost panel
ax = axs[2]
leg = ax.legend(fontsize=legend_fontsize, bbox_to_anchor=(1.05, 0.975))
frame = leg.get_frame()
bbox_dict = {
    "boxstyle": frame.get_boxstyle().__class__.__name__.lower(),
    "facecolor": frame.get_facecolor(),
    "edgecolor": frame.get_edgecolor(),
    "linewidth": frame.get_linewidth(),
}

# Add condition text to first two panels
ax = axs[0]
ax.text(
    0.60,
    0.95,
    r"$N_{\rm eff. ext.}\geq " + str(min_n_eff_e) + r"$",
    transform=ax.transAxes,
    fontsize=legend_title_fontsize,
    verticalalignment="top",
    bbox=bbox_dict,
)

ax = axs[1]
ax.text(
    0.60,
    0.95,
    r"$N_{\rm eff. int.}\geq" + str(min_n_eff_i) + r"$",
    transform=ax.transAxes,
    fontsize=legend_title_fontsize,
    verticalalignment="top",
    bbox=bbox_dict,
)

fig.tight_layout()
fig.savefig("n_effective_delta_ln_Z.pdf", format="pdf")

print("Figure 4: N_effective analysis complete")


## Figure 5: Selection and Effective Sample Size Analysis

This section processes dense bank results to analyze the relationship between N'_int (after incoherent likelihood selection), effective intrinsic samples size N_eff_int, and M confidence-interval prior integral.


In [None]:
# Helper functions for processing dense bank results

def make_event_dataframe(event_dir: Path):
    rows = []
    event_dir = Path(event_dir)
    rundirs = sorted(
        event_dir.glob("run_*"), key=lambda x: int(x.name.split("_")[-1])
    )
    for rundir in rundirs:
        samples_path = rundir / "samples.feather"
        run_kwargs_path = rundir / "run_kwargs.json"
        summary_results_path = rundir / "summary_results.json"
        if (
            samples_path.exists()
            and run_kwargs_path.exists()
            and summary_results_path.exists()
        ):
            n_samples = len(pd.read_feather(samples_path))
            summary_results = utils.read_json(summary_results_path)
            row = dict(
                n_effective=summary_results["n_effective"],
                n_effective_i=summary_results["n_effective_i"],
                n_int_prime=summary_results["n_i_inds_used"],
                n_int=utils.read_json(run_kwargs_path)["n_int"],
                n_samples=n_samples,
                rundir=str(rundir),
                event_dir=str(event_dir),
                ln_evidence=summary_results["ln_evidence"],
            )
        else:
            row = {}
        rows.append(row)

    df = pd.DataFrame(rows)
    return df


def load_bank_results_dense(bank_folder: Path) -> pd.DataFrame:
    failed_events = []
    event_dirs = sorted(
        Path(bank_folder).glob("injection_bank_mchirp*"),
        key=lambda x: int(x.name.split("_")[-1]),
    )
    rows = []
    for event_dir in tqdm(event_dirs, total=len(event_dirs)):
        df = make_event_dataframe(event_dir)
        if len(df) == 0 or df.size == 0:
            failed_events.append(event_dir)
            continue
        df = df.sort_values(
            by=["n_samples", "n_effective"], ascending=False, inplace=False
        ).reset_index(inplace=False, drop=True)
        row = {"eventname": event_dir.name} | df.iloc[0].to_dict()
        rows.append(row)
    return pd.DataFrame(rows), failed_events


In [None]:
# Load dense bank results for Figure 5 analysis
# Note: This may take several minutes to process all dense bank results

# First try to load from saved files
dense_banks_dfs = []
try:
    for dense_bank_name in dense_bank_names:
        df = pd.read_feather(
            f"{dense_bank_name}_results_for_ess_n_int_prime_relation.feather"
        )
        dense_banks_dfs.append(df)
    print("Loaded dense bank results from saved files")
except FileNotFoundError:
    print("Processing dense bank results from scratch...")
    # Process dense bank results from scratch
    _t = time.time()
    
    failed_events = []
    for dense_bank_name in dense_bank_names:
        print(f"Processing {dense_bank_name}...")
        df, _failed_events = load_bank_results_dense(f"../{dense_bank_name}")
        failed_events.append(_failed_events)
        df[["mchirp_005", "mchirp_095"]] = pd.DataFrame(
            df["rundir"].apply(get_mchirp_quantiles).tolist(),
            columns=["mchirp_005", "mchirp_095"],
        )
        
        # Get bank configuration for mass range
        bank_config = utils.read_json(
            banks_homedir / dense_bank_name / "bank_config.json"
        )
        min_mchirp = bank_config.get("min_mchirp")
        max_mchirp = bank_config.get("max_mchirp")
        df["integral"] = np.log(df["mchirp_095"] / df["mchirp_005"]) / np.log(
            max_mchirp / min_mchirp
        )
        df["log10_n_int_prime"] = df["n_int_prime"].apply(np.log10)
        df["log10_n_effective_i"] = df["n_effective_i"].apply(np.log10)
        df["log10_integral"] = df["integral"].apply(np.log10)
        dense_banks_dfs.append(df)
    
    # Save results for future use
    for df, dense_bank_name in zip(dense_banks_dfs, dense_bank_names):
        df.to_feather(
            f"{dense_bank_name}_results_for_ess_n_int_prime_relation.feather"
        )
    
    _t = time.time() - _t
    print(f"Processing completed in {_t:.1f} seconds")


In [None]:
# Add derived quantities needed for Figure 5
for df in dense_banks_dfs:
    df["log10_n_int_prime_fraction"] = np.log10(
        df["n_int_prime"] / df["n_int"]
    )
    df["log10_efficiency"] = np.log10(df["n_effective_i"] / df["n_int"])

print("Dense bank processing complete")
print("Sample counts per bank:")
for dense_bank_name, df in zip(dense_bank_names, dense_banks_dfs):
    print(f"{dense_bank_name}: {len(df.eventname.unique())} events")


In [None]:
# Figure 5: Corner plot showing the relationship between N'_int, N_eff_int, and M CI integral
# Create labels with event counts
labels = [
    label + f" ({len(df.eventname.unique())}/1024)"
    for label, df in zip(bank_labels.values(), dense_banks_dfs)
]

# Define parameters for the corner plot
params = [
    "log10_integral",
    "log10_efficiency", 
    "log10_n_int_prime_fraction",
]

# Create the MultiCornerPlot
mcp = gw_plotting.MultiCornerPlot(
    dense_banks_dfs,
    params=params,
    labels=labels,
    smooth=1,
)

# Define latex labels for the parameters
latex_label = {
    "log10_n_int_prime_fraction": r"$\log_{10}\left[ {N_{\rm int.}'} / {N_{\rm int.}} \right]$",
    "log10_n_effective_i": r"$\log_{10}$ Intrinsic ESS",
    "log10_integral": r"$\log_{10} \left[\int_{\mathcal{M}_{5\%}}^{\mathcal{M}_{95\%}} \pi'(\mathcal{M}) {\rm d}\mathcal{M} \right]$",
    "log10_efficiency": r"$\log_{10}\left[{N_{\rm eff. int.}}/{N_{\rm int.}}\right]$",
}

for cp in mcp.corner_plots:
    cp.latex_labels |= latex_label

mcp.plot()

# Add 1:1 lines to specific panels
fig = plt.gcf()
# Add diagonal lines to panels [3, 6, 7] (based on the original code)
for i in [3, 6, 7]:
    ax = fig.axes[i]
    x = ax.get_xlim()
    ax.plot(x, x, ls="-", c="k")

fig.tight_layout()
fig.savefig("selection_and_ESS.pdf", format="pdf", dpi=300)

print("Figure 5: Selection and ESS analysis complete")


## Figure 6: PP-Plot Analysis

Now we recreate Figure 6 from the paper, showing probability-probability plots that validate the statistical properties of our posterior estimates across different mass ranges.


In [None]:
# PP-plot helper functions and parameters
PARAMS_FOR_PP_PLOT = [
    "mchirp",
    "q", 
    "chieff",
    "cums1r_s1z",
    "cums2r_s2z",
    "ra",
    "dec",
    "d_luminosity",
]

def get_cumsr_sz(sx, sy, sz):
    s_perp_sqrt = sx**2 + sy**2
    return s_perp_sqrt / (1 - sz**2)

def get_cums1r_s1z_cums2r_s2z(pars):
    cums1r_s1z = get_cumsr_sz(pars["s1x_n"], pars["s1y_n"], pars["s1z"])
    cums2r_s2z = get_cumsr_sz(pars["s2x_n"], pars["s2y_n"], pars["s2z"])
    return cums1r_s1z, cums2r_s2z

def get_credible_intervals(hh_min, hh_max, rundirs, params=None, min_required_samples=30):
    """
    Compute credible interval at which the injected value is recovered
    for multiple parameters and injections.
    """
    if params is None:
        params = PARAMS_FOR_PP_PLOT
    
    credible_intervals = {par: [] for par in params}
    
    for rundir in tqdm(rundirs):
        rundir = Path(rundir)
        samples_path = rundir / "samples.feather"
        injection_parameters_path = rundir / "injection_parameters.json"
        
        if not (samples_path.exists() and injection_parameters_path.exists()):
            for par in params:
                credible_intervals[par].append(np.nan)
            continue
            
        samples = pd.read_feather(samples_path)
        
        # Filter samples by SNR range
        if "hh" in samples.columns:
            valid_samples = samples[(samples["hh"] >= hh_min) & (samples["hh"] <= hh_max)]
        else:
            valid_samples = samples
            
        if len(valid_samples) < min_required_samples:
            print(f"Skipping {rundir} with only {len(valid_samples)} samples with ⟨h∣h⟩ in ({hh_min},{hh_max}).")
            for par in params:
                credible_intervals[par].append(np.nan)
            continue
        
        injection_pars = utils.read_json(injection_parameters_path)
        
        # Add derived parameters
        if "cums1r_s1z" in params or "cums2r_s2z" in params:
            if all(k in injection_pars for k in ["s1x_n", "s1y_n", "s1z", "s2x_n", "s2y_n", "s2z"]):
                cums1r_s1z, cums2r_s2z = get_cums1r_s1z_cums2r_s2z(injection_pars)
                injection_pars["cums1r_s1z"] = cums1r_s1z
                injection_pars["cums2r_s2z"] = cums2r_s2z
                
                cums1r_s1z_samples, cums2r_s2z_samples = get_cums1r_s1z_cums2r_s2z(valid_samples)
                valid_samples["cums1r_s1z"] = cums1r_s1z_samples
                valid_samples["cums2r_s2z"] = cums2r_s2z_samples
        
        for par in params:
            if par in injection_pars and par in valid_samples.columns:
                injected_value = injection_pars[par]
                parameter_samples = valid_samples[par].values
                
                # Calculate credible interval
                credible_interval = np.mean(parameter_samples <= injected_value)
                credible_intervals[par].append(credible_interval)
            else:
                credible_intervals[par].append(np.nan)
    
    return pd.DataFrame(credible_intervals)

def pp_plot(credible_intervals, params=None, ax=None, show_xy_labels=True, show_title=True, show_legend=True):
    """
    Make a probability-probability plot.
    """
    if ax is None:
        _, ax = plt.subplots()

    if params is None:
        params = list(credible_intervals)

    clean_credible_intervals = credible_intervals.dropna()
    for par in params:
        sorted_credible_intervals = np.sort(clean_credible_intervals[par])
        ax.plot(
            sorted_credible_intervals,
            np.linspace(0, 1, len(clean_credible_intervals)),
            label=gw_plotting.CornerPlot.DEFAULT_LATEX_LABELS.get(par, par),
            lw=1,
        )
    ax.plot((0, 1), (0, 1), "k:")  # Diagonal line

    if show_title:
        ax.set_title(
            f"$N = {len(clean_credible_intervals)} / {len(credible_intervals)}$",
            fontsize="medium",
        )

    ax.tick_params(axis="x", direction="in", top=True)
    ax.tick_params(axis="y", direction="in", right=True)
    ax.grid(linestyle=":")
    ax.set_aspect("equal")
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    if show_legend:
        ax.legend(
            fontsize=10,
            frameon=True,
            framealpha=0.5,
            labelspacing=0.25,
            loc="upper left",
            edgecolor="none",
            borderpad=0.3,
        )

    if show_xy_labels:
        ax.set_xlabel("Credible interval")
        ax.set_ylabel("Fraction of injections in credible interval")

    return ax


In [None]:
# Generate PP-plots using dense bank results
# Note: This uses the same dense bank data as Figure 5

print("Computing credible intervals for PP-plot analysis...")

# Parameters for credible interval calculation  
hh_min = 70
hh_max = 200
min_required_samples = 20

# Compute credible intervals for each dense bank
ci_per_bank = {}
for i, dense_bank_name in enumerate(dense_bank_names):
    print(f"Processing {dense_bank_name}...")
    
    # Get rundirs from the dense bank dataframe we already have
    df = dense_banks_dfs[i]
    rundirs = [Path(rundir) for rundir in df["rundir"].values.tolist()]
    
    # Compute credible intervals
    ci_per_bank[dense_bank_name] = get_credible_intervals(
        hh_min,
        hh_max, 
        rundirs,
        params=PARAMS_FOR_PP_PLOT,
        min_required_samples=min_required_samples,
    )

print("Credible interval computation complete")
for k, v in ci_per_bank.items():
    print(f"{k}: {len(v)} total, {len(v.dropna())} valid")

# Get bank configurations for subplot titles
config_paths = [
    banks_homedir / dense_bank_name / "bank_config.json"
    for dense_bank_name in dense_bank_names
]

bank_configs_dense = {
    dense_bank_name: utils.read_json(config_path)
    for dense_bank_name, config_path in zip(dense_bank_names, config_paths)
}


In [None]:
# Create Figure 6: PP-plots in 2x3 grid
fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True, sharey=True)
axes = axes.flatten()

for ax, (dense_bank_name, credible_intervals) in zip(axes, ci_per_bank.items()):
    # Show legend only on first subplot
    show_legend = dense_bank_name == dense_bank_names[0]
    
    # Create PP-plot
    pp_plot(
        credible_intervals,
        ax=ax,
        show_legend=show_legend,
        show_xy_labels=True,
    )
    
    # Add subplot title with mass range
    mmin = bank_configs_dense[dense_bank_name]["min_mchirp"]
    mmax = bank_configs_dense[dense_bank_name]["max_mchirp"]
    
    title = (
        r"$\mathcal{M}^{\rm det}\in("
        + f"{mmin:.3g},{mmax:.3g}"
        + r"){\rm M}_{\odot}$, "
    )
    
    # Extract sample count from current title and add to new title
    current_title = ax.get_title()
    sample_count = current_title.split("/")[0] if "/" in current_title else "N"
    ax.set_title(title + sample_count + r"/1024$", fontsize=14)

    # Adjust legend font size for first subplot
    if show_legend:
        for text in ax.get_legend().get_texts():
            text.set_fontsize(12)

# Set common axis labels
for ax in axes:
    ax.set_xlabel("")
    ax.set_ylabel("")

# Set row labels
axes[0].set_ylabel("Fraction of injections in credible interval", fontsize=14)
axes[3].set_ylabel("Fraction of injections in credible interval", fontsize=14)

# Set column labels
axes[1].set_xlabel("Credible interval", fontsize=14)
axes[4].set_xlabel("Credible interval", fontsize=14)

# Adjust font sizes
for ax in axes:
    ax.set_title(ax.get_title(), fontsize=14)
    for label in ax.get_xticklabels():
        label.set_fontsize(12)
    for label in ax.get_yticklabels():
        label.set_fontsize(12)

fig.tight_layout()
fig.savefig("pp_plots.pdf", bbox_inches="tight", format="pdf")

print("Figure 6: PP-plot analysis complete")
