# SIM 3.0: Adaptive RPS algorithm, two-wave allocation


In [None]:
import numpy as np
import matplotlib.pyplot as plt

import visualizations
from rashomon.hasse import policy_to_profile, enumerate_policies, enumerate_profiles
from rashomon.aggregate import RAggregate
from datagen import (
    phi_basic, phi_linear_interact, phi_grouped_smooth, phi_grouped_coarse,
    phi_grouped_smooth2, phi_peak, generate_data_from_assignments
)
from allocation import (
    compute_policy_variances,
    allocate_wave1,
    allocate_wave2,
    allocate_wave2_pools,
    create_assignments_from_alloc
)
from rashomon.extract_pools import extract_pools, aggregate_pools
from rashomon.loss import compute_pool_means, compute_policy_means
from rashomon.metrics import make_predictions


In [None]:
# === Config ===
M = 3
R = np.array([5, 5, 3])
lambda_reg = 0.01                  # regularization parameter
epsilon = 0.05                      # tolerance off MAP (used by visualizations if needed)
H = 100

allocation_rule_wave1 = "minimax"
allocation_rule_wave2 = "minimax"
within_pool_rule = "minimax"

max_alloc = 300
feasible_waves = 4                  # two-wave demo
sig = 0.2

verbose = True
top_k = 10
n_preview = 30
num_workers = 2

# Sweep parameter for RPS size tuning (coarse-to-fine)
theta_init = 0.1
theta_init_step = 0.05
min_rset_size = 100


In [None]:
# === Enumerate policies and profiles ===
all_policies = enumerate_policies(M, R)
num_policies = len(all_policies)
profiles, profile_map = enumerate_profiles(M)
num_profiles = len(profiles)

# profile index mappings
policies_profiles = {}
policies_ids_profiles = {}
policies_profiles_masked = {}  # masked policies hold the active features only
for k, profile in enumerate(profiles):
    ids = [i for i, p in enumerate(all_policies) if policy_to_profile(p) == profile]
    policies_ids_profiles[k] = ids
    policies_profiles[k] = [all_policies[i] for i in ids]

    profile_mask = [bool(v) for v in profile]  # t/f map of which features are active
    masked_policies = [tuple([pol[i] for i in range(M) if profile_mask[i]]) for pol in policies_profiles[k]]
    policies_profiles_masked[k] = masked_policies


In [None]:
# === Wave budgets ===
# Wave 1 must be >= #policies (enforce one-per-policy coverage via allocate_wave1 floors)
n = [num_policies]
remaining_alloc = max_alloc - num_policies
if remaining_alloc < 1:
    raise Exception(f"Need at least {num_policies} observations for one-per-policy coverage in Wave 1.")
if H < num_profiles:
    raise Exception(f"Need H ≥ #profiles ({num_profiles}) for initial per-profile pooling space.")
adaptive_waves = feasible_waves - 1

# Allocate remaining equally across later waves
if adaptive_waves > 0:
    base = remaining_alloc // adaptive_waves
    remainder = remaining_alloc % adaptive_waves
    n.extend([base + 1 if i < remainder else base for i in range(adaptive_waves)])
if verbose:
    print(f"Per-wave allocation: {n}, total={sum(n)}")


In [None]:
# === Oracle outcomes (ground truth generator) ===
oracle_outcomes = np.array([phi_grouped_coarse(p, R) for p in all_policies])

oracle_rank_to_policy = np.argsort(-oracle_outcomes)           # index 0 gives best policy index
oracle_policy_to_rank = np.empty_like(oracle_rank_to_policy)
oracle_policy_to_rank[oracle_rank_to_policy] = np.arange(len(oracle_outcomes))

top_k_indices = oracle_rank_to_policy[:top_k]
top_k_policies = [all_policies[i] for i in top_k_indices]
top_k_values = oracle_outcomes[top_k_indices]

if verbose:
    print("Top-k best policies and their profiles:")
    for rank, idx in enumerate(top_k_indices, 1):
        policy = all_policies[idx]
        profile = policy_to_profile(policy)
        print(f"Rank {rank}: Policy idx {idx}, Policy {[int(i) for i in policy]}, Profile {profile}")


In [None]:
# === Containers ===
metrics_per_wave = []
D = np.empty((0, 1), dtype=int)
y = np.empty((0, 1), dtype=float)
pools_for_next_wave = None   # will be filled after Wave-1 RPS


In [None]:
# === Wave loop ===
for wave_number in range(1, feasible_waves + 1):
    np.random.seed(wave_number)

    # ---------------------------
    # Decide allocation D_wave
    # ---------------------------
    if wave_number == 1:
        # Policy-level design with coverage+variance-inflation logic
        policy_variances, policy_counts = compute_policy_variances(D, y, num_policies)   # empty -> all counts 0
        sigmas = np.sqrt(policy_variances)

        n_wave = n[0]
        alloc = allocate_wave1(
            allocation_rule_wave1,     # "neyman_policy" | "minimax_policy" | "best_arm"
            sigmas=sigmas,
            counts=policy_counts,
            N=n_wave
        )
        D_wave = create_assignments_from_alloc(alloc).reshape(-1, 1)
    else:
        # Pool→Policy two-stage design using pools learned from previous wave (MAP)
        if pools_for_next_wave is None:
            raise RuntimeError("Wave-2 requires pools from Wave-1. Ensure Wave-1 RPS completed.")
        policy_variances, policy_counts = compute_policy_variances(D, y, num_policies)
        policy_sigmas = np.sqrt(policy_variances)

        n_wave = n[wave_number - 1]
        alloc_policy = allocate_wave2_pools(
            rule=allocation_rule_wave2,          # "neyman_pool" | "minimax_pool" | "best_pool"
            pool_to_policies=pools_for_next_wave,
            policy_sigmas=policy_sigmas,
            N=n_wave,
            policy_counts=policy_counts,
            pool_weights=None,                   # equal weights in each pool
            pool_gaps=None,                      # supply if using 'best_pool'
            within_rule=within_pool_rule         # "neyman" / "minimax" / "uniform"
        )
        D_wave = create_assignments_from_alloc(alloc_policy).reshape(-1, 1)

    # ---------------------------
    # Generate outcomes this wave
    # ---------------------------
    X_wave, y_wave = generate_data_from_assignments(D_wave, all_policies, oracle_outcomes, sig=sig)

    # Accumulate full data so far
    D = np.vstack([D, D_wave])
    y = np.vstack([y, y_wave])

    # ---------------------------
    # Policy stats (means/vars)
    # ---------------------------
    policy_stats = compute_policy_means(D, y, num_policies)
    policy_variances, policy_counts = compute_policy_variances(D, y, num_policies)
    policy_sigmas = np.sqrt(policy_variances)
    # ---------------------------
    # Enumerate Rashomon set for this wave (coarse-to-fine theta)
    # ---------------------------
    if wave_number == 1:
        theta = theta_init
        theta_step = theta_init_step

    max_steps = 50
    num_sweeps = 3
    for sweep in range(num_sweeps):
        steps = 0
        while steps < max_steps:
            if verbose:
                print(f"Trying theta: {theta:.4f}")
            R_set, R_profiles = RAggregate(
                M, R, H, D, y, theta,
                reg=lambda_reg, num_workers=num_workers, verbose=False
            )
            if verbose:
                print(f"Theta: {theta:.4f} -- RPS size: {len(R_set)}")

            # coarse-to-fine stepping around smallest theta with non-empty set
            if len(R_set) > 0 and sweep == 0:
                theta -= theta_step
                theta_step = theta_step / (M*2)
                break
            if min_rset_size <= len(R_set) and sweep < num_sweeps - 1:
                theta -= theta_step
                theta_step = theta_step / (M*2)
                break
            if len(R_set) >= min_rset_size and sweep == num_sweeps - 1:
                theta_step = theta_step * ((M*2)**(num_sweeps-1))
                break

            theta += theta_step
            steps += 1

    if len(R_set) == 0:
        print("Warning: No feasible Rashomon set found within range.")
    elif verbose:
        print(f"End theta: {theta:.4f}, RPS size: {len(R_set)}")
        print(f"Wave {wave_number} Rashomon set: {len(R_set)} feasible global partitions (combinations of per-profile poolings).")

    # ---------------------------
    # Posterior + metrics & cache pools for next wave
    # ---------------------------
    num_partitions = len(R_set)

    regrets = []
    best_pred_indices_all = []
    policy_indices_all = []
    policy_means_all = []

    sorted_idx_all = []
    sorted_means_all = []

    partition_losses = np.zeros(num_partitions)

    posterior_mse = []
    posterior_best_mse = []
    posterior_iou = []

    pi_policies_r_list = []
    pi_pools_r_list = []       # NEW: capture pools for MAP
    pool_means_r_list = []

    # track pools for each profile k, for each rashomon set r
    for r, partition_r in enumerate(R_set):
        pi_policies_profiles_r = {}

        for k, profile in enumerate(profiles):
            sigma_k = R_profiles[k].sigma[partition_r[k]]
            if sigma_k is None:
                # entire profile is a single pool
                n_policies_profile = len(policies_profiles_masked[k])
                pi_policies_r_k = {i: 0 for i in range(n_policies_profile)}
            else:
                _, pi_policies_r_k = extract_pools(policies_profiles_masked[k], sigma_k)

            pi_policies_profiles_r[k] = pi_policies_r_k

        # aggregate into global partition structures
        pi_pools_r, pi_policies_r = aggregate_pools(pi_policies_profiles_r, policies_ids_profiles)
        pool_means_r = compute_pool_means(policy_stats, pi_pools_r)

        pi_pools_r_list.append(pi_pools_r)         # NEW
        pi_policies_r_list.append(pi_policies_r)
        pool_means_r_list.append(pool_means_r)

        # Partition loss
        partition_losses[r] = sum(R_profiles[k].loss[partition_r[k]] for k in range(len(partition_r)))

        # Predictions by policy
        policy_indices = np.array(list(pi_policies_r.keys()))
        policy_means = np.array([pool_means_r[pi_policies_r[idx]] for idx in policy_indices])
        order = np.argsort(-policy_means)
        sorted_idx = policy_indices[order]
        sorted_means = policy_means[order]

        # Store results
        policy_indices_all.append(policy_indices)
        policy_means_all.append(policy_means)
        sorted_idx_all.append(sorted_idx)
        sorted_means_all.append(sorted_means)

        best_pred = sorted_idx[0]
        regret = float(oracle_outcomes[oracle_rank_to_policy[0]] - oracle_outcomes[best_pred])
        regrets.append(regret)
        best_pred_indices_all.append(best_pred)

        # Posterior-weighted metrics
        y_r_est = make_predictions(D, pi_policies_r, pool_means_r)

        mse = np.mean((y_r_est - y) ** 2)  # mse on outcomes
        best_mse = (oracle_outcomes[oracle_rank_to_policy[0]] - oracle_outcomes[best_pred]) ** 2
        iou = len(set(sorted_idx[:top_k]) & set(top_k_indices)) / len(set(sorted_idx[:top_k]) | set(top_k_indices))

        posterior_mse.append(mse)
        posterior_best_mse.append(best_mse)
        posterior_iou.append(iou)

    # Posterior weights and expected metrics
    posterior_weights = np.exp(-partition_losses)
    posterior_weights /= posterior_weights.sum() if posterior_weights.sum() > 0 else 1.0
    map_idx = int(np.argmin(partition_losses))
    map_loss = float(partition_losses[map_idx])

    policy_indices = policy_indices_all[map_idx]
    policy_means = policy_means_all[map_idx]
    order = np.argsort(-policy_means)

    pi_policies_r = pi_policies_r_list[map_idx]
    pi_pools_r = pi_pools_r_list[map_idx]
    pool_means_r = pool_means_r_list[map_idx]

    # pi_pools_r is a list-like structure where each entry is an array of global policy indices in that pool
    pools_for_next_wave = {int(g): list(v) for g, v in pi_pools_r.items()}

    sorted_idx = sorted_idx_all[map_idx]
    sorted_means = sorted_means_all[map_idx]
    oracle_values = oracle_outcomes[sorted_idx]
    oracle_ranks = oracle_policy_to_rank[sorted_idx]
    is_topk = [i in top_k_indices for i in sorted_idx]

    expected_mse = float(np.dot(posterior_weights, posterior_mse))
    expected_best_mse = float(np.dot(posterior_weights, posterior_best_mse))
    expected_iou = float(np.dot(posterior_weights, posterior_iou))
    expected_regret = float(np.dot(posterior_weights, regrets))

    metrics_per_wave.append({
        "wave": wave_number,
        "theta": float(theta),
        "rps_size": int(len(R_set)),
        "expected_regret": expected_regret,
        "expected_mse": expected_mse,
        "expected_best_mse": expected_best_mse,
        "expected_iou": expected_iou,
        "map_loss": map_loss
    })

    if verbose:
        # MAP Summary and regret plots --
        df_map = visualizations.plot_map_true_vs_predicted_bar_topk(
            sorted_idx=sorted_idx,
            sorted_means=sorted_means,
            oracle_beta=oracle_outcomes,
            oracle_ranks=oracle_policy_to_rank,
            top_k_indices=top_k_indices,
            N=n_preview
        )
        visualizations.plot_map_regret_bar(sorted_idx, oracle_outcomes, oracle_rank_to_policy[0], N=n_preview)
        visualizations.plot_map_regression(df_map)
        visualizations.plot_oracle_ordered_bar(df_map, top_k_indices, oracle_outcomes, all_policies, N=n_preview)

        df_map_obs = visualizations.plot_map_true_vs_predicted_bar_observed(
            sorted_idx, sorted_means, oracle_outcomes, oracle_ranks,
            np.where(policy_stats[:, 1] > 0)[0],   # fixed keyword
            N=num_policies
        )
        visualizations.plot_oracle_ordered_bar(df_map_obs, top_k_indices, oracle_outcomes, all_policies, N=num_policies)
        visualizations.plot_map_regression_observed(df_map_obs)

        profile_losses = [rp.loss for rp in R_profiles]
        visualizations.plot_minimax_risk_matrix(profile_losses, map_idx=map_idx)
        print(f"Number of possible partitions per profile: {[len(p) for p in profile_losses]}")

        # (A) Pool sizes
        visualizations.plot_pool_size_bar(pools_for_next_wave)

        # (B) Pool SE bars
        visualizations.plot_pool_se_bar(pools_for_next_wave, policy_sigmas, policy_counts)

        # (C) Pool mean vs SE proxy (use MAP pool means you already computed)
        visualizations.plot_pool_mean_vs_se(pool_means_r, pools_for_next_wave, policy_sigmas, policy_counts)