# SIM 4.0: Adaptive RPS algorithm, comparison for traditional allocation strategies

In [None]:
# ----- imports & style
import numpy as np, pandas as pd, seaborn as sns, matplotlib.pyplot as plt
import visualizations

sns.set_theme(style="whitegrid")

# Rashomon core
from rashomon.hasse import enumerate_policies, enumerate_profiles, policy_to_profile
from allocation import compute_policy_variances
from rashomon.loss import compute_policy_means, compute_pool_means
# runner
from simulate_rps_partition import run_twowave_experiment, build_rps_and_map

In [None]:
# --- experiment setup ---
R = np.array([5,4,3,2])
M = 4
num_policies = len(enumerate_policies(M, R))
H = 20 # complexity parameter
H_true = 20 # true number of pools for data-generating partition
lambda_reg = 0.0000001

# for constructing RPS
theta_init = 0.01
theta_init_step = 0.005
min_rset_size = 100
num_workers = 4

# wave setup
allocation_rule_wave1 = "minimax"
wave2_algorithms = ["uniform","minimax", "thompson"]
# wave2_algorithms = ["uniform","neyman","minimax", "ucb", "thompson"]

within_pool_rule = "minimax"

max_alloc = 600 # must exceed number of policies
sig = 0.2
random_seed = 3

verbose = True

micro_batch_size = 60 # resolution of computing regret
rps_refresh_every = 4 # recompute RPS MAP every x micro-batches

In [None]:
# --- Truth (pool mean) controls ---
# mean_spec = ("normal", 0.0, 1.0) - eg Gaussian
# mean_spec = ("uniform", -0.6, 0.6) - eg uniform in [a, b]
mean_spec = ("normal", 0, 2)

# Can a unique best pool with gap delta over the runner-up
force_unique_best = False
best_gap = 0.4
best_pool_choice = "smallest" # or "largest" or an explicit pool id (int)
mean_order = "id_asc" # for linspace/custom sequence assignment

In [None]:
# run experiment
out = run_twowave_experiment(
    R=R, H=H, H_true=H_true,
    lambda_reg=lambda_reg,

    theta_init=theta_init, theta_init_step=theta_init_step, min_rset_size=min_rset_size,
    num_workers=num_workers,

    allocation_rule_wave1=allocation_rule_wave1,
    wave2_algorithms=wave2_algorithms,
    within_pool_rule=within_pool_rule,

    max_alloc=max_alloc, sig=sig, random_seed=random_seed,
    profile_separate_truth=True,

    micro_batch_size=micro_batch_size,
    rps_refresh_every=rps_refresh_every,

    verbose=verbose,

    mean_spec=mean_spec,
    force_unique_best=force_unique_best,
    best_gap=best_gap,
    best_pool_choice=best_pool_choice,
    mean_order=mean_order
)

# get metrics from experiment
metrics = pd.DataFrame(out["metrics"])
metrics.sort_values(["strategy","algorithm","t"], inplace=True)
metrics["cum_regret"] = metrics.groupby(["strategy","algorithm"])["regret"].cumsum()

# optional display quick table
if verbose:
    display(
        metrics.groupby(["strategy","algorithm"]).tail(1)[
            ["t","regret","cum_regret","mse_oracle_partition","mse_pool_est","iou_topk"]
        ]
    )

In [None]:
# Final RPS rebuild on full data after full experiment
theta_init = 0.35
theta_init_step = 0.01

def rebuild_rps_at_end(out, H, lambda_reg, theta_init, theta_init_step, min_rset_size, num_workers, verbose=True):
    D, y = out["D"], out["y"]
    all_policies = out["all_policies"]

    profiles, _ = enumerate_profiles(M)
    policies_ids_profiles = {}
    policies_profiles_masked = {}
    for k, profile in enumerate(profiles):
        ids = [i for i, p in enumerate(all_policies) if policy_to_profile(p) == profile]
        policies_ids_profiles[k] = ids
        profile_mask = [bool(v) for v in profile]
        masked_policies = [tuple([all_policies[i][j] for j in range(M) if profile_mask[j]]) for i in ids]
        policies_profiles_masked[k] = masked_policies

    rps_final = build_rps_and_map(
        M, R, H, D, y, lambda_reg,
        theta_init, theta_init_step, min_rset_size,
        profiles, policies_profiles_masked, policies_ids_profiles,
        num_workers=num_workers, verbose=verbose
    )
    return rps_final

rps_final = rebuild_rps_at_end(
    out=out, H=H,
    lambda_reg=lambda_reg,
    theta_init=theta_init, theta_init_step=theta_init_step, min_rset_size=min_rset_size,
    num_workers=num_workers, verbose=True
)

if verbose:
    print(f"[FINAL RPS] θ={rps_final['theta']:.5f}, |R_set|={len(rps_final['R_set'])}, |G_MAP|={len(rps_final['map_pools']) if rps_final['map_pools'] else 0}")

In [None]:
strategy_palette = {"policy_only": "#2b6cb0", "rps_assisted": "#b22222"}  # blue vs red
algo_order = ["uniform","neyman","minimax","ucb","thompson"]

def finalize_metrics_df(metrics_df, sig, group_cols=("strategy","algorithm")):
    m = metrics_df.copy()
    if "t" in m:
        m = m.sort_values(list(group_cols)+["t"])
    if "regret" in m and "cum_regret" not in m:
        m["cum_regret"] = m.groupby(list(group_cols))["regret"].cumsum()
    if "noise_var" not in m:
        m["noise_var"] = float(sig**2)
    return m

metrics = finalize_metrics_df(metrics, sig, group_cols=("strategy","algorithm"))

In [None]:
# Add absolute versions and a Wave-1 vertical marker (t = #Wave-1 obs)
metrics["regret_abs"] = metrics["regret"].abs()
metrics["regret_true_abs"] = metrics["regret_true"].abs()
metrics["cum_regret_abs"] = metrics.groupby(["strategy","algorithm"])["regret_abs"].cumsum()

# If you have wave-1 count available, use it; else infer as #policies (one per policy)
t_wave1 = out.get("t_wave1", len(out["all_policies"]))

In [None]:
# Tracks: a compact table at final t per (strategy, algorithm). Feel free to customize columns.
end_tab = (
    metrics.sort_values(["strategy","algorithm","t"])
           .groupby(["strategy","algorithm"])
           .tail(1)[["t","regret","regret_true","cum_regret",
                     "mse_oracle_partition","mse_pool_est","iou_topk",
                     "map_num_pools","rps_size"]]
)
display(end_tab)

In [None]:
# Tracks: absolute instantaneous and cumulative regret, with Wave-1 break shown.
fig, ax = plt.subplots(1,2, figsize=(13,4), sharex=True)

# Instantaneous |regret|
sns.lineplot(
    data=metrics, x="t", y="regret_abs",
    hue="strategy", style="algorithm",
    hue_order=["policy_only","rps_assisted"], style_order=algo_order,
    palette=strategy_palette, dashes=True, ax=ax[0]
)
ax[0].axvline(t_wave1, ls="--", color="gray", alpha=0.6)
ax[0].set_title("Instantaneous |regret| (true − estimated best)"); ax[0].set_xlabel("t"); ax[0].set_ylabel("|regret|")

# Cumulative |regret|
sns.lineplot(
    data=metrics, x="t", y="cum_regret_abs",
    hue="strategy", style="algorithm",
    hue_order=["policy_only","rps_assisted"], style_order=algo_order,
    palette=strategy_palette, dashes=True, lw=1.6, ax=ax[1]
)
ax[1].axvline(t_wave1, ls="--", color="gray", alpha=0.6)
ax[1].set_title("Cumulative |regret|"); ax[1].set_xlabel("t"); ax[1].set_ylabel("cum |regret|")

plt.tight_layout(); plt.show()

# (Optional) If you still want the signed versions too, run your original cells, or:
# sns.lineplot(..., y="regret") and sns.lineplot(..., y="cum_regret")

In [None]:
# Tracks: in-sample residual MSE under the true partition, vs. σ² and an OOS envelope ≈ σ² + |G*|σ²/t.
# Idea: ensure we’re approaching noise floor; RPS should typically drop faster.

G_true = len(out["true_pools"])
t_vals = np.sort(metrics["t"].unique())
mse_test_envelope = pd.DataFrame({
    "t": t_vals, "mse_test_envelope": sig**2 + (G_true * (sig**2)) / np.maximum(t_vals, 1)
})

fig, ax = plt.subplots(1,2, figsize=(13,4), sharex=True)

sns.lineplot(
    data=metrics, x="t", y="mse_oracle_partition",
    hue="strategy", style="algorithm",
    hue_order=["policy_only","rps_assisted"], style_order=algo_order,
    palette=strategy_palette, dashes=True, ax=ax[0]
)
ax[0].axhline(sig**2, ls="--", lw=1, color="gray")
ax[0].set_title("MSE* under true partition (in-sample)"); ax[0].set_xlabel("t"); ax[0].set_ylabel("MSE*")
ax[0].fill_between([t_vals.min(), t_vals.max()], sig**2*0.95, sig**2*1.05,
                   color="gray", alpha=0.08, label="σ² ±5%")

sns.lineplot(data=mse_test_envelope, x="t", y="mse_test_envelope",
             color="black", linestyle=":", ax=ax[1], label="≈ σ² + |G*|σ²/t")
sns.lineplot(
    data=metrics, x="t", y="mse_oracle_partition",
    hue="strategy", style="algorithm",
    hue_order=["policy_only","rps_assisted"], style_order=algo_order,
    palette=strategy_palette, dashes=True, ax=ax[1]
)
ax[1].axhline(sig**2, ls="--", lw=1, color="gray")
ax[1].set_title("MSE* vs simple OOS envelope"); ax[1].set_xlabel("t"); ax[1].set_ylabel("MSE*")

plt.tight_layout(); plt.show()

In [None]:
# Tracks: parameter estimation error over true pools, and set-overlap quality of top-k policies.
fig, ax = plt.subplots(1,2, figsize=(13,4), sharex=True)

sns.lineplot(
    data=metrics, x="t", y="mse_pool_est",
    hue="strategy", style="algorithm",
    hue_order=["policy_only","rps_assisted"], style_order=algo_order,
    palette=strategy_palette, dashes=True, ax=ax[0]
)
ax[0].set_title("Pool-mean parameter MSE (true pools)"); ax[0].set_xlabel("t"); ax[0].set_ylabel("param MSE")

sns.lineplot(
    data=metrics, x="t", y="iou_topk",
    hue="strategy", style="algorithm",
    hue_order=["policy_only","rps_assisted"], style_order=algo_order,
    palette=strategy_palette, dashes=True, ax=ax[1]
)
ax[1].set_title("Top-k IoU vs oracle"); ax[1].set_xlabel("t"); ax[1].set_ylabel("IoU")

plt.tight_layout(); plt.show()

In [None]:
# Tracks: for each algorithm, juxtapose policy-only vs RPS-assisted as two clean lines.
def facet_two_lines(metrics, y, title, ylabel):
    d = metrics.copy()
    g = sns.FacetGrid(d, col="algorithm", col_order=algo_order, height=3.1, aspect=1.2, sharex=True, sharey=True)

    def _draw(data, color, **kws):
        sns.lineplot(data=data[data["strategy"] == "policy_only"], x="t", y=y,
                     color="#2b6cb0", lw=1.6, label="policy_only")
        sns.lineplot(data=data[data["strategy"] == "rps_assisted"], x="t", y=y,
                     color="#b22222", lw=1.6, label="rps_assisted")

    g.map_dataframe(_draw)
    for ax, algo in zip(g.axes.flat, algo_order):
        ax.set_title(algo)
        ax.set_xlabel("t"); ax.set_ylabel(ylabel)
        if ax.get_legend():
            ax.get_legend().remove()
    # show a single legend
    handles, labels = g.axes.flat[0].get_legend_handles_labels() if g.axes.flat[0].get_legend() else ([],[])
    if handles:
        g.fig.legend(handles[:2], labels[:2], loc="upper center", ncol=2, frameon=False)
    g.fig.suptitle(title, y=1.02)
    plt.tight_layout(); plt.show()

facet_two_lines(metrics, y="regret", title="Instantaneous regret — policy vs RPS-assisted", ylabel="regret")
facet_two_lines(metrics, y="mse_oracle_partition", title="MSE* under true partition", ylabel="MSE*")
facet_two_lines(metrics, y="mse_pool_est", title="Pool-parameter MSE", ylabel="param MSE")
facet_two_lines(metrics, y="iou_topk", title="Top-k IoU", ylabel="IoU")

In [None]:
# Tracks: per-algorithm small multiples; two lines (policy vs RPS). Works for any y-column.
def facet_two_lines(metrics, y, title, ylabel, t_marker=None):
    d = metrics.copy()
    g = sns.FacetGrid(d, col="algorithm", col_order=algo_order, height=3.1, aspect=1.2, sharex=True, sharey=True)

    def _draw(data, color, **kws):
        sns.lineplot(data=data[data["strategy"] == "policy_only"], x="t", y=y,
                     color="#2b6cb0", lw=1.6, label="policy_only")
        sns.lineplot(data=data[data["strategy"] == "rps_assisted"], x="t", y=y,
                     color="#b22222", lw=1.6, label="rps_assisted")

    g.map_dataframe(_draw)
    for ax, algo in zip(g.axes.flat, algo_order):
        if t_marker is not None:
            ax.axvline(t_marker, ls="--", color="gray", alpha=0.5)
        ax.set_title(algo); ax.set_xlabel("t"); ax.set_ylabel(ylabel)
        if ax.get_legend(): ax.get_legend().remove()

    # single legend
    handles, labels = g.axes.flat[0].get_legend_handles_labels() if g.axes.flat[0].get_legend() else ([],[])
    if handles:
        g.fig.legend(handles[:2], labels[:2], loc="upper center", ncol=2, frameon=False)
    g.fig.suptitle(title, y=1.02); plt.tight_layout(); plt.show()

# Now run:
facet_two_lines(metrics, y="regret_abs", title="Instantaneous |regret| — policy vs RPS-assisted", ylabel="|regret|", t_marker=t_wave1)
facet_two_lines(metrics, y="mse_oracle_partition", title="MSE* under true partition", ylabel="MSE*", t_marker=t_wave1)
facet_two_lines(metrics, y="mse_pool_est", title="Pool-parameter MSE", ylabel="param MSE", t_marker=t_wave1)
facet_two_lines(metrics, y="iou_topk", title="Top-k IoU", ylabel="IoU", t_marker=t_wave1)

In [None]:
# Tracks: how structure evolves; |G_MAP| should approach truth; |R_set| stability.
fig, ax = plt.subplots(1,2, figsize=(13,4), sharex=True)

if "map_num_pools" in metrics:
    sns.lineplot(
        data=metrics, x="t", y="map_num_pools",
        hue="strategy", style="algorithm",
        hue_order=["policy_only","rps_assisted"], style_order=algo_order,
        palette=strategy_palette, dashes=True, ax=ax[0]
    )
    ax[0].axhline(len(out["true_pools"]), ls="--", lw=1, color="gray")
    ax[0].set_title("MAP number of pools over time"); ax[0].set_xlabel("t"); ax[0].set_ylabel("|G_MAP|")
else:
    ax[0].axis("off"); ax[0].set_title("No map_num_pools in metrics")

if "rps_size" in metrics:
    sns.lineplot(
        data=metrics, x="t", y="rps_size",
        hue="strategy", style="algorithm",
        hue_order=["policy_only","rps_assisted"], style_order=algo_order,
        palette=strategy_palette, dashes=True, ax=ax[1]
    )
    ax[1].set_title("|R_set| over time"); ax[1].set_xlabel("t"); ax[1].set_ylabel("|R_set|")
else:
    ax[1].axis("off"); ax[1].set_title("No rps_size in metrics")

plt.tight_layout(); plt.show()

In [None]:
# Tracks: where predictions land at the end; should tighten toward the y=x line for RPS-assisted.

def _pools_to_vector(pools, K):
    v = np.empty(K, dtype=int)
    for pid, members in pools.items():
        v[np.array(members, dtype=int)] = int(pid)
    return v

def _pred_vector(D, y, K, pools=None):
    stats = compute_policy_means(D, y, K)               # [:,0]=sum_y, [:,1]=count
    if pools and len(pools) > 0:
        pm = compute_pool_means(stats, pools)           # pool means by weighted sums
        return pm[_pools_to_vector(pools, K)]           # broadcast pool means to policies
    return stats[:,0] / np.maximum(stats[:,1], 1)       # per-policy empirical means

K = len(out["all_policies"]); mu_true = out["oracle_outcomes"]
rows = []
for algo in algo_order:
    # policy-only
    snap_pol = out["final_snapshots"][("policy_only", algo)]
    mu_hat_pol = _pred_vector(snap_pol["D"], snap_pol["y"], K, pools=None)
    rows.append(pd.DataFrame({
        "policy": np.arange(K), "algorithm": algo, "strategy": "policy_only",
        "mu_true": mu_true, "mu_hat": mu_hat_pol
    }))
    # rps-assisted (may be absent if no RPS)
    snap_rps = out["final_snapshots"].get(("rps_assisted", algo))
    if snap_rps is not None and snap_rps.get("pools") is not None:
        mu_hat_rps = _pred_vector(snap_rps["D"], snap_rps["y"], K, pools=snap_rps["pools"])
        rows.append(pd.DataFrame({
            "policy": np.arange(K), "algorithm": algo, "strategy": "rps_assisted",
            "mu_true": mu_true, "mu_hat": mu_hat_rps
        }))

policy_err = pd.concat(rows, ignore_index=True)
policy_err["abs_err"] = np.abs(policy_err["mu_hat"] - policy_err["mu_true"])

# Scatter
g = sns.FacetGrid(policy_err, col="strategy", hue="algorithm",
                  hue_order=algo_order, sharex=True, sharey=True, height=4, aspect=1.1, palette="tab10")
g.map_dataframe(sns.scatterplot, x="mu_true", y="mu_hat", alpha=0.7, edgecolor="black")
for ax in g.axes.flat:
    lo = min(policy_err["mu_true"].min(), policy_err["mu_hat"].min())
    hi = max(policy_err["mu_true"].max(), policy_err["mu_hat"].max())
    ax.plot([lo,hi],[lo,hi], ls="--", color="gray")
    ax.set_xlabel("true μ"); ax.set_ylabel("pred μ")
g.fig.suptitle("Per-policy predicted vs true (final)"); plt.tight_layout(); plt.show()

# Violin of |error|
plt.figure(figsize=(10,4))
sns.violinplot(
    data=policy_err, x="algorithm", y="abs_err",
    hue="strategy", hue_order=["policy_only","rps_assisted"],
    split=True, palette={"policy_only":"#2b6cb0","rps_assisted":"#b22222"},
    inner="quartile"
)
plt.title("|pred − true| per policy (final)"); plt.xlabel("algorithm"); plt.ylabel("|error|")
plt.tight_layout(); plt.show()

In [None]:
# Tracks: side-by-side final-time improvements (RPS − Policy) per algorithm (negative = RPS better).
finals = (metrics.sort_values(["strategy","algorithm","t"])
                .groupby(["strategy","algorithm"])
                .tail(1)[["strategy","algorithm","cum_regret_abs","mse_oracle_partition","mse_pool_est"]])

# Pivot to a stable two-column layout per metric
pv = finals.pivot(index="algorithm", columns="strategy")
# Ensure both strategies’ columns exist (fill missing with NaN)
need = [("cum_regret_abs","policy_only"), ("cum_regret_abs","rps_assisted"),
        ("mse_oracle_partition","policy_only"), ("mse_oracle_partition","rps_assisted"),
        ("mse_pool_est","policy_only"), ("mse_pool_est","rps_assisted")]
for tup in need:
    if tup not in pv.columns:
        pv[tup] = np.nan
pv = pv[need]  # order

# Flatten names to metric_strategy (e.g., cum_regret_abs_policy_only)
pv.columns = [f"{m}_{s}" for (m,s) in pv.columns]

pv["delta_cum_regret_abs"] = pv["cum_regret_abs_rps_assisted"] - pv["cum_regret_abs_policy_only"]
pv["delta_mse_star"]       = pv["mse_oracle_partition_rps_assisted"] - pv["mse_oracle_partition_policy_only"]
pv = pv.reindex(algo_order)

fig, ax = plt.subplots(1,2, figsize=(13,4))
sns.barplot(data=pv.reset_index(), x="algorithm", y="delta_cum_regret_abs", ax=ax[0], color="#888888")
ax[0].axhline(0, color="black", lw=1); ax[0].set_title("Δ cumulative |regret| (RPS − Policy)"); ax[0].set_ylabel("Δ cum |regret|")

sns.barplot(data=pv.reset_index(), x="algorithm", y="delta_mse_star", ax=ax[1], color="#888888")
ax[1].axhline(0, color="black", lw=1); ax[1].set_title("Δ MSE* (RPS − Policy)"); ax[1].set_ylabel("Δ MSE*")
plt.tight_layout(); plt.show()

In [None]:
# Tracks: rank alignment of final predicted policy ordering with oracle ordering, per (strategy, algorithm).
from scipy.stats import spearmanr
from rashomon.loss import compute_policy_means, compute_pool_means

def _pools_to_vec(pools, K):
    v = np.empty(K, dtype=int)
    for pid, mems in pools.items():
        v[np.asarray(mems, int)] = int(pid)
    return v

def _pred_vector(D, y, K, pools=None):
    stats = compute_policy_means(D, y, K)
    if pools and len(pools) > 0:
        pm = compute_pool_means(stats, pools)
        return pm[_pools_to_vec(pools, K)]
    return stats[:,0] / np.maximum(stats[:,1], 1)

def rank_corr(mu_true, mu_hat):
    r = spearmanr(mu_true, mu_hat, nan_policy="omit")
    return float(r.correlation)

rows = []
K = len(out["all_policies"])
mu_true = out["oracle_outcomes"]

for algo in algo_order:
    # policy-only
    snap_pol = out["final_snapshots"].get(("policy_only", algo))
    if snap_pol is not None:
        mu_hat_pol = _pred_vector(snap_pol["D"], snap_pol["y"], K, pools=None)
        rows.append({"strategy":"policy_only","algorithm":algo,"spearman_r":rank_corr(mu_true, mu_hat_pol)})
    # rps-assisted
    snap_rps = out["final_snapshots"].get(("rps_assisted", algo))
    if snap_rps is not None:
        mu_hat_rps = _pred_vector(snap_rps["D"], snap_rps["y"], K, pools=snap_rps.get("pools"))
        rows.append({"strategy":"rps_assisted","algorithm":algo,"spearman_r":rank_corr(mu_true, mu_hat_rps)})

spearman_table = pd.DataFrame(rows).sort_values(["algorithm","strategy"])
display(spearman_table)

In [None]:
# Tracks: where policies move between true pools and the final MAP (after rebuild).
# Requires rps_final from your rebuild block.
if 'rps_final' in globals() and rps_final.get("map_pools") is not None:
    def _pools_to_vec(pools, K):
        v = np.empty(K, dtype=int)
        for pid, members in pools.items():
            v[np.array(members, dtype=int)] = int(pid)
        return v
    K = len(out["all_policies"])
    v_true = _pools_to_vec(out["true_pools"], K)
    v_map  = _pools_to_vec(rps_final["map_pools"], K)
    df_cm = pd.crosstab(pd.Series(v_true, name="True"), pd.Series(v_map, name="MAP"))
    plt.figure(figsize=(7,5))
    sns.heatmap(df_cm, cmap="Blues", cbar=True, annot=False)
    plt.title("True vs FINAL MAP pools (policy counts)")
    plt.tight_layout(); plt.show()

In [None]:
# Tracks: how much information each MAP pool accumulated; useful to explain why RPS helps.

# Pick one algorithm’s RPS-assisted final snapshot (they’re similar structurally)
snap = next((out["final_snapshots"][k] for k in out["final_snapshots"]
             if k[0]=="rps_assisted" and out["final_snapshots"][k].get("pools") is not None), None)
if snap is not None:
    pools = snap["pools"]
    Dr, yr = snap["D"], snap["y"]
    K = len(out["all_policies"])
    stats = compute_policy_means(Dr, yr, K)
    vars_, counts_ = compute_policy_variances(Dr, yr, K)
    sigmas = np.sqrt(vars_)

    # pool sizes
    pool_sizes = pd.Series({int(pid): len(members) for pid, members in pools.items()}, name="pool_size")
    # pool counts & SE proxy
    from collections import defaultdict
    pool_counts = defaultdict(int)
    pool_se = defaultdict(float)
    for pid, members in pools.items():
        m = np.array(members, dtype=int)
        n_tot = int(stats[m,1].sum())
        # conservative SE proxy: average per-policy sigma / sqrt(total count)
        se_proxy = np.mean(sigmas[m]) / np.sqrt(max(n_tot,1))
        pool_counts[int(pid)] = n_tot
        pool_se[int(pid)] = se_proxy

    df_pool = pd.DataFrame({
        "pool_id": list(pool_sizes.index),
        "pool_size": pool_sizes.values,
        "pool_count": [pool_counts[pid] for pid in pool_sizes.index],
        "se_proxy": [pool_se[pid] for pid in pool_sizes.index]
    })

    fig, ax = plt.subplots(1,3, figsize=(15,4))
    sns.barplot(data=df_pool.sort_values("pool_size", ascending=False),
                x="pool_id", y="pool_size", ax=ax[0], color="#94a3b8")
    ax[0].set_title("Pool sizes (policies per pool)"); ax[0].set_xlabel("pool"); ax[0].set_ylabel("# policies")

    sns.barplot(data=df_pool.sort_values("pool_count", ascending=False),
                x="pool_id", y="pool_count", ax=ax[1], color="#60a5fa")
    ax[1].set_title("Total observations per pool"); ax[1].set_xlabel("pool"); ax[1].set_ylabel("# obs")

    sns.scatterplot(data=df_pool, x="pool_count", y="se_proxy", ax=ax[2])
    ax[2].set_title("SE proxy vs observations"); ax[2].set_xlabel("# obs in pool"); ax[2].set_ylabel("SE proxy")
    plt.tight_layout(); plt.show()