# Archive: Multi-wave RPS Algorithm (2+ waves)

In [4]:
import numpy as np
from allocation import compute_initial_boundary_probs, compute_wave_boundary_probs, allocate_wave, assign_treatments
from data_gen import get_beta_underlying_causal, generate_outcomes
from rashomon.hasse import enumerate_policies
from construct_RPS import construct_RPS

### 1. First-wave allocation

In [5]:
# get lattice
M = 4
R = [2,3,3,4]
H = 20
n1 = 500

R_vec = np.full(M, R) if np.isscalar(R) else np.array(R) # allow for heterogeneity in levels
policies = enumerate_policies(M, R)
policies = np.array(enumerate_policies(M, R))
K = len(policies[0])

boundary_probs = compute_initial_boundary_probs(policies, R, H)
n1_alloc = allocate_wave(boundary_probs, n1)

beta = get_beta_underlying_causal(policies, M, R, kind="gauss")

D = assign_treatments(n1_alloc)

sigma_noise = 5
outcome_seed = 53

y = generate_outcomes(D=D, beta=beta, sigma_noise=sigma_noise, random_seed=outcome_seed)

### 2. Construct first-wave RPS

In [6]:
lambda_r = 0.01
eps = 0.05
verbose = True

In [7]:
R_set, R_profiles, theta_global, policies, profiles, profile_to_policies, profile_to_indices, nonempty_profile_ids = construct_RPS(policies, M, R, D, y, H, eps, lambda_r, verbose=False)

print(f"RPS has: {len(R_set)} feasible partitions over {len(R_profiles)} observed profiles.")

RPS has: 57 feasible partitions over 14 observed profiles.


### 3. Next-wave allocation and simulation

In [5]:
# n2 = 500
#
# boundary_probs_2 = compute_wave_boundary_probs(R_set, R_profiles, policies, profiles, profile_to_policies, profile_to_indices, nonempty_profile_ids)
#
# alloc2 = allocate_wave(boundary_probs_2, n2)
# D2 = assign_treatments(alloc2)
# y2 = generate_outcomes(D=D2, beta=beta, sigma_noise=sigma_noise, random_seed=55)
# D = np.vstack([D1, D2])
# y = np.concatenate([y1, y2])

In [8]:
num_waves = 2
n_per_wave = [500, 500] # includes first wave

for wave in range(2, num_waves+1):
    print(f"--- Wave {wave}: allocation, simulation, RPS update ---")

    # get boundary probabilities
    boundary_probs = compute_wave_boundary_probs(
        R_set, R_profiles, policies, profiles, profile_to_policies, profile_to_indices, nonempty_profile_ids
    )

    # allocate number of observations to each policy
    alloc = allocate_wave(boundary_probs, n_per_wave[wave-1])

    # give each observation its policy number
    D_wave = assign_treatments(alloc)

    # simulate outcomes
    y_wave = generate_outcomes(D=D_wave, beta=beta, sigma_noise=sigma_noise, random_seed=1 + wave)


    # Ensure everything is 1D first
    D = D.reshape(-1)
    D_wave = D_wave.reshape(-1)
    y = y.reshape(-1)
    y_wave = y_wave.reshape(-1)

    # Concatenate
    D = np.concatenate([D, D_wave])
    y = np.concatenate([y, y_wave])

    # Recompute RPS after new data
    R_set, R_profiles, theta_global, policies, profiles, profile_to_policies, profile_to_indices, nonempty_profile_ids = construct_RPS(
        policies, M, R, D, y, H, eps, lambda_r, verbose=False
    )
    # TODO check the recalculation of theta each time

    print(f"After wave {wave}: RPS has: {len(R_set)} feasible partitions over {len(R_profiles)} observed profiles.")

--- Wave 2: allocation, simulation, RPS update ---
After wave 2: RPS has: 191 feasible partitions over 14 observed profiles.


### 4. Get final RPS predictions


In [7]:
from rashomon.loss import compute_policy_means, compute_pool_means
from rashomon.metrics import make_predictions
from rashomon.extract_pools import extract_pools

#### 4.1. Get MAP partition from RPS

In [8]:
# find the MAP partition from RPS
all_losses = []
for r_set in R_set:
    loss_r = 0.0
    for k, idx in enumerate(r_set):
        if R_profiles[k] is not None:
            loss_r += R_profiles[k].loss[idx]
    all_losses.append(loss_r)
all_losses = np.array(all_losses)
MAP_idx = np.argmin(all_losses)
MAP_r_set = R_set[MAP_idx]
print(f"MAP partition index: {MAP_idx}, loss: {all_losses[MAP_idx]:.4f}")

MAP partition index: 94, loss: 26.0773


#### Policy means, pool means, and predictions

We now get the policy means, pool means, and predictions. We again have to remap the local/global indices as done in our RPS construction (or the original rashomon tutorial).

In [9]:
K = len(policies)
pi_policies_MAP = np.full(K, -1, dtype=int)  # initialize all as unassigned

for k, rp in enumerate(R_profiles):
    profile_id = nonempty_profile_ids[k]
    profile_mask = np.array(profiles[profile_id], dtype=bool)

    # Remap local policies as done in RPS construction
    local_policies = [tuple(np.array(p)[profile_mask]) for p in profile_to_policies[profile_id]]
    if len(local_policies) == 0:
        continue  # skip empty
    arr = np.array(local_policies)
    for j in range(arr.shape[1]):
        _, arr[:, j] = np.unique(arr[:, j], return_inverse=True)
    local_policies_remap = [tuple(row) for row in arr]

    sigma = rp.sigma[MAP_r_set[k]]
    _, pi_policies_local = extract_pools(local_policies_remap, sigma)

    # Map back to global indices
    for local_idx, global_idx in enumerate(profile_to_indices[profile_id]):
        pi_policies_MAP[global_idx] = pi_policies_local[local_idx]

# Now get the set of pools from these assignments
pool_ids = np.unique(pi_policies_MAP[pi_policies_MAP != -1])
pi_pools_MAP = {pool_id: np.where(pi_policies_MAP == pool_id)[0].tolist() for pool_id in pool_ids}

n_pools = len(pi_pools_MAP)
print(f"Final pools in MAP partition: {n_pools}")

# Compute policy means, pool means, predictions
y = y.reshape(-1, 1)
D = D.reshape(-1, 1)
policy_means = compute_policy_means(D, y, len(policies))
pool_means = compute_pool_means(policy_means, pi_pools_MAP)
y_pred = make_predictions(D, pi_policies_MAP, pool_means)

Final pools in MAP partition: 12
