# ALGORITHM 1: TWO-WAVE

In [1]:
import numpy as np
from rashomon.hasse import policy_to_profile, enumerate_policies, enumerate_profiles
from datagen import phi, generate_data_from_assignments
from boundary import (
    compute_boundary_probabilities, get_allocations,
    create_assignments_from_alloc, get_policy_neighbors, compute_global_boundary_matrix
)
from helpers_rps import (
    subset_wave_data_by_profile, compute_profile_policy_outcomes,
    build_global_wave_data, get_observed_profiles
)
from enumerate_rps import construct_RPS_adaptive
from evaluation import get_partition_losses

## Step 1: Enumerate policies and profiles and build index mappings

In [14]:
M = 3 # number of features
R = np.array([4, 3, 3]) # levels per feature
lambda_reg = 0.1 # reg parameter
epsilon = 0.05 # tolerance off MAP
n1 = 100 # units for first wave
n2 = 100  # units for second wave
theta_start = 0.7
verbose = False

# enumerate all policies and profiles
all_policies = enumerate_policies(M, R)
num_policies = len(all_policies)
profiles, profile_map = enumerate_profiles(M)
num_profiles = len(profiles)

# profile index mappings
policies_profiles = {}
policies_ids_profiles = {}
for k, profile in enumerate(profiles):
    policy_indices = [i for i, p in enumerate(all_policies) if policy_to_profile(p) == profile]
    policies_ids_profiles[k] = policy_indices
    policies_profiles[k] = [all_policies[i] for i in policy_indices]

# max pool size checks for sparsity
max_pool_size = max(len(policies) for policies in policies_profiles.values())
largest_profile = {k: len(policies) for k, policies in policies_profiles.items()}
if verbose:
    print(f"Profile-wise max pools: {largest_profile}")
    print(f"Max possible pool size for a profile: {max_pool_size}")

H = max_pool_size # can set H to be max possible pool size to be conservative

# masks per profile for correct splitting procedure
policies_profiles_masked = {}
for k, profile in enumerate(profiles):
    profile_mask = [bool(v) for v in profile] # t/f map of which features are active
    masked_policies = [tuple([pol[i] for i in range(M) if profile_mask[i]]) for pol in policies_profiles[k]] # list of policies but they now are only the active features
    policies_profiles_masked[k] = masked_policies # holds, at a given profile index, the masked policies for that profile

## Step 2: Get true outcomes and top arms

In [15]:
# vector of true policy outcomes
oracle_outcomes = np.array([phi(p) for p in all_policies])

# oracle rank mapped to policy index (so index 1 gives element best policy index)
oracle_rank_to_policy = np.argsort(-oracle_outcomes)

# policy index mapped to oracle rank (so index 1 gives oracle rank for policy 1)
oracle_policy_to_rank = np.empty_like(oracle_rank_to_policy)
oracle_policy_to_rank[oracle_rank_to_policy] = np.arange(1, len(oracle_outcomes)+1)

# Top-k indices, policies, and values
top_k = 10
top_k_indices = oracle_rank_to_policy[:top_k]
top_k_policies = [all_policies[i] for i in top_k_indices]
top_k_values = oracle_outcomes[top_k_indices]

# Overview of the rank, index, and profiles of the top policies
if verbose:
    print("Top-k best policies and their profiles:")
    for rank, idx in enumerate(top_k_indices, 1):
        policy = all_policies[idx]
        profile = policy_to_profile(policy)
        print(f"Rank {rank}: Policy idx {idx}, Policy {[int(i) for i in policy]}, Profile {profile}")

## Step 3: Get boundary probabilities and generate from assignments

In [16]:
# compute theoretical boundary probabilities using equation
boundary_probs = compute_boundary_probabilities(all_policies, R, H)

# integer allocation across all policies, sums to n1
alloc1 = get_allocations(boundary_probs, n1)
if verbose: print(f"Total allocated: {alloc1.sum()} (should be {n1})")

# generate assignments for wave 1
D1 = create_assignments_from_alloc(alloc1)  # shape (n1, 1)
X1, y1 = generate_data_from_assignments(D1, all_policies, oracle_outcomes, sig=1.0)

if verbose:
    print(f"Wave 1 assignments (policy indices): {D1[:10].flatten()}")
    print(f"Total n_1: {len(D1)} (should match allocation sum: {alloc1.sum()})")

In [17]:
# get the profiles that are observed and their corresponding observed policies
observed_policies_per_profile, observed_profiles = get_observed_profiles(D1, all_policies)

# get max pool size with only observed policies
max_observed_pool_size = max(len(policies) for policies in observed_policies_per_profile.values())

# output information about coverage of RPS
if verbose:
    print(f"Number of observed profiles out of total: {len(observed_profiles)} out of {num_profiles}")
    print(f"Number of observed policies out of total: {np.sum([len(policies) for policies in observed_policies_per_profile])} out of {num_policies}")
    print(f"The maximum possible pool size using just observed policies is now {max_observed_pool_size}.")

    print("\nAre all top-k best profiles observed?")
    for rank, idx in enumerate(top_k_indices, 1):
        policy = all_policies[idx]
        prof = policy_to_profile(policy)
        print(f"Best Policy {idx}: Profile {prof}, Observed? {prof in observed_profiles}")

## Step 4: Enumerate RPS with first-wave data

In [18]:
D1_profiles, y1_profiles, global_to_local1 = subset_wave_data_by_profile(D1, y1, policies_ids_profiles)

profile_policy_outcomes1 = compute_profile_policy_outcomes(D1_profiles, y1_profiles, policies_profiles)

# get max observed outcomes within each profile
if verbose:
    for k in profile_policy_outcomes1:
        pm = profile_policy_outcomes1[k]
        means = pm[:,0] / np.maximum(pm[:,1], 1)
        print(f"Profile {k}: Max observed mean = {means.max():.3f}")

D1_full, y1_full = build_global_wave_data(D1_profiles, y1_profiles, policies_ids_profiles)

In [19]:
theta_global = 5

R_set, R_profiles, theta_final, found_best, theta_trace, rps_size_trace = construct_RPS_adaptive(
    M, R, H, D1_full, y1_full, top_k, policies_profiles_masked, policies_ids_profiles,
    profiles, all_policies, top_k_indices, theta_global, reg=lambda_reg, adaptive=False, verbose=verbose, recovery_type="arm"
)
if verbose:
    print(f"First-wave Rashomon set: {len(R_set)} feasible global partitions (combinations of per-profile poolings).")
    for k, rprof in enumerate(R_profiles):
        print(f"Profile {k}: {len(rprof)} poolings in RPS (if observed)")

In [20]:
# After enumerating RPS (wave 1)
partition_losses, posterior_weights = get_partition_losses(R_set, R_profiles)

# Identify MAP partition (lowest-loss)
map_idx = np.argmin(partition_losses)
map_loss = partition_losses[map_idx]

if verbose:
    print(f"MAP partition loss: {map_loss:.6f}")
    print(f"Theta used for enumeration: {theta_global:.6f}")

## Step 5: Calculate new boundary probabilities and generate next wave

In [21]:
# look at neighbors and calculate probability on a boundary for each policy
neighbors = get_policy_neighbors(all_policies)

# TODO switch around so also deals with being neighbors with someone outside of your profile
boundary_matrix_1 = compute_global_boundary_matrix(
    R_set, R_profiles, neighbors, profiles, policies_profiles_masked, policies_ids_profiles, all_policies
)

# get T/F mask of the matrix of counts of number of boundaries
binary_boundary_matrix_1 = (boundary_matrix_1 > 0).astype(float)
partition_losses, posterior_weights_1 = get_partition_losses(R_set, R_profiles)
posterior_boundary_probs_1 = np.average(binary_boundary_matrix_1, axis=0, weights=posterior_weights_1)

# to avoid small numerical errors, round to 8 decimal
posterior_boundary_probs_1 = np.round(posterior_boundary_probs_1, decimals=8)

In [22]:
if verbose:
    assert posterior_boundary_probs_1.shape == (len(all_policies),), "Posterior vector shape mismatch"
    # 2. Check range and sum
    print("Posterior boundary min/max:", posterior_boundary_probs_1.min(), posterior_boundary_probs_1.max())
    # 3. Check for degenerate values
    assert np.all((posterior_boundary_probs_1 >= 0) & (posterior_boundary_probs_1 <= 1)), "Probabilities out of bounds"

In [23]:
# get allocations and create assignments, then generate data, for the next wave
alloc2 = get_allocations(posterior_boundary_probs_1, n2)
D2 = create_assignments_from_alloc(alloc2)
X2, y2 = generate_data_from_assignments(D2, all_policies, oracle_outcomes, sig=1.0)

if verbose:
    print(f"Second-wave assignments (policy indices): {D2[:10].flatten()}")
    print(f"Total n_2: {len(D2)} (should match allocation sum: {alloc2.sum()})")

## Step 6: Construct updated RPS

In [24]:
D_total = np.vstack([D1, D2])
y_total = np.vstack([y1, y2])

# Subset all observed data by profile and remap to global policy indices
D_total_profiles, y_total_profiles, global_to_local_total = subset_wave_data_by_profile(
    D_total, y_total, policies_ids_profiles
)
profile_policy_outcomes_total = compute_profile_policy_outcomes(D_total_profiles, y_total_profiles, policies_profiles)
# Map all profile-local indices back to global for RPS construction
D_total_full, y_total_full = build_global_wave_data(D_total_profiles, y_total_profiles, policies_ids_profiles)

In [25]:
# TODO depending on RPS change, may need to redefine theta!
R_set_2, R_profiles_2, theta_final_2, found_best_2, theta_trace_2, rps_size_trace_2 = construct_RPS_adaptive(
    M, R, H, D_total_full, y_total_full, top_k, policies_profiles_masked, policies_ids_profiles,
    profiles, all_policies, top_k_indices, theta_global, reg=lambda_reg, adaptive=False, verbose=verbose, recovery_type="arm"
)
if verbose:
    print(f"Second-wave Rashomon set: {len(R_set_2)} feasible global partitions (with all data).")
    for k, rprof in enumerate(R_profiles_2):
        print(f"Profile {k}: {len(rprof)} poolings in RPS (if observed)")

# After enumerating RPS (wave 1)
partition_losses2, posterior_weights = get_partition_losses(R_set_2, R_profiles_2)

# Identify MAP partition (lowest-loss)
map_idx = np.argmin(partition_losses2)
map_loss = partition_losses2[map_idx]

if verbose:
    print(f"MAP partition loss: {map_loss:.6f}")
    print(f"Theta used for enumeration: {theta_final_2:.6f}")