# Two-wave RPS Algorithm

In [4]:
import numpy as np
from rashomon.hasse import enumerate_policies
from rashomon.aggregate import RAggregate
from first_wave import compute_boundary_probs, allocate_wave, assign_first_wave_treatments
from data_gen import get_beta_underlying_causal, generate_outcomes

## 1. First-wave allocation

In [5]:
# get lattice
M = 4
R = 3

R_vec = np.full(M, R) if np.isscalar(R) else np.array(R) # allow for heterogeneity in levels
assert R_vec.shape == (M,)
policies = enumerate_policies(M, R)

K = len(policies)
print(f"Found K = {K} policies (each policy is an {M}-tuple).")
H = 5  # sparsity parameter used inside compute_boundary_probs TODO choice
n1 = 500  # total first‐wave sample size

Found K = 81 policies (each policy is an 4-tuple).


**Compute first‐wave allocation**: We need R_i for each feature (here R_i = R for i=0,…,M-1), then we call `compute_boundary_probs` -> `allocate_first_wave`. We get `n1_alloc`: an array of length K summing to n1.

In [6]:
boundary_probs = compute_boundary_probs(policies, R, H)
n1_alloc = allocate_wave(boundary_probs, n1)
print(f"First‐wave allocation sums to {int(n1_alloc.sum())} (should be {n1}).")

First‐wave allocation sums to 500 (should be 500).


## 2. Simulating first-wave outcomes

We generate a np.array `beta` of true effects for each node. We pass our lattice `policies`, `M` and `R`, and then specify a `kind` of underlying causal model.

There are a range of options, all of which are continuous and non-trivial: they exhibit locally correlated effects and avoid brittle cancellations in effects. The options range from simple (polynomial, gaussian, basic interaction) to complex (radial basis function, mimic of a simple neural-net-like function)

In [7]:
beta = get_beta_underlying_causal(policies, M, R, kind="gauss_sin")

In [8]:
# Not in use: different distribution for each true pool from a random 'true' partition sigma_true. Not used in this simulation due to our specifications on the underlying causal model (e.g. continuous, locally correlated effects, etc). Also needs changes on how it constructs a true partition.

# partition_seed = 123
# sigma_true, pi_pools_true, pi_policies_true = generate_true_partition(policies, R,random_seed=partition_seed)
# beta = get_beta_piecewise(policies, sigma_true, pi_pools_true, pi_policies_true, 0.5, 1, 10)

**Get outcomes**: we now track the first-wave assignment and generate the outcomes with additional noise

In [9]:
# now build first-wave assignment vector D
policies = np.array(enumerate_policies(M, R))  # (K, M)
D1 = assign_first_wave_treatments(n1_alloc)  # (N1, M)
print("D1 shape:", D1.shape)
N1 = D1.shape[0]
print("Length of D1:", N1)  # should equal sum n1_alloc == n1

D1 shape: (500,)
Length of D1: 500


In [10]:
# generate outcomes y1
sigma_noise = 5
outcome_seed = 53
y1 = generate_outcomes(D=D1, beta=beta, sigma_noise=sigma_noise, random_seed=outcome_seed)
print("Overall mean outcome:", np.mean(y1))
print("Overall std outcome:", np.std(y1))

Overall mean outcome: -0.06186879481060036
Overall std outcome: 5.345327463311591


## 3. RPS for profiles with data

We now search for the optimal theta as given by a normalized loss and chosen epsilon. Need to already specify H and the regularization parameter.

In [11]:
lambda_r = 0.3
eps = 0.05 # chosen tolerance

In [12]:
import numpy as np
from rashomon.hasse import enumerate_policies, enumerate_profiles, policy_to_profile
from rashomon.aggregate import (
    RAggregate_profile,
    subset_data,
    find_profile_lower_bound
)
from rashomon import loss

In [13]:
# Step 1: Enumerate profiles and map policies to each
profiles, profile_map = enumerate_profiles(M)
all_policies = enumerate_policies(M, R_vec)

profile_to_policies = {}
profile_to_indices = {}
for i, pol in enumerate(all_policies):
    pid = profile_map[policy_to_profile(pol)]
    profile_to_policies.setdefault(pid, []).append(pol)
    profile_to_indices.setdefault(pid, []).append(i)

We now filter for the profiles just with any data.

In [14]:
# Step 2: Filter profiles with data and compute normalized lower-bound losses
valid_pids = []
lb_k = []  # normalized lower-bound loss for each valid profile

for pid, profile in enumerate(profiles):
    Dk, yk = subset_data(D1, y1, profile_to_indices[pid])
    if Dk is None:
        continue
    mask = np.array(profile, dtype=bool)
    reduced_policies = [tuple(np.array(p)[mask]) for p in profile_to_policies[pid]]
    pm = loss.compute_policy_means(Dk, yk, len(reduced_policies))
    raw_lb = find_profile_lower_bound(Dk, yk, pm)
    lb_k.append(raw_lb / N1)
    valid_pids.append(pid)

lb_k = np.array(lb_k)                   # array of normalized lower bounds
best_loss = lb_k.min()                 # best profile loss
total_lb = lb_k.sum()
theta_global = total_lb * (1 + eps) # Theta is in reference to total loss here, not a relative value
print(f"best_loss = {best_loss:.5f}")
print(f"theta_global = {theta_global:.5f}")

best_loss = 0.09952
theta_global = 25.23455


We now construct the RPS for each profile with data from our first allocation.

In [15]:
R_profiles = []
loss_args = []

for i, pid in enumerate(valid_pids):
    profile_mask = np.array(profiles[pid], dtype=bool)
    M_k = profile_mask.sum()
    R_k = R_vec[profile_mask]

    Dk, yk = subset_data(D1, y1, profile_to_indices[pid])
    reduced_policies = [tuple(np.array(p)[profile_mask]) for p in profile_to_policies[pid]]
    pm = loss.compute_policy_means(Dk, yk, len(reduced_policies))

    theta_k = max(0.0, theta_global - (total_lb - lb_k[i]))

    print(f"Calling RAggregate_profile on profile {pid}, M_k={M_k}, len(policies)={len(reduced_policies)}, theta_k={theta_k:.5f}")
    print(f": lower_bound: {lb_k[i]:.5f}, theta_k: {theta_k:.5f}")

    rp = RAggregate_profile(
        M=M_k,
        R=R_k,
        H=H,
        D=Dk,
        y=yk,
        theta=theta_k,
        profile=tuple(profiles[pid]),
        reg=lambda_r,
        policies=reduced_policies,
        policy_means=pm,
        normalize=N1
    )

    print(f": RPS size for profile {pid}: {len(rp)}")
    if len(rp) > 0:
        R_profiles.append(rp)
        loss_args.append((Dk, yk, reduced_policies, pm))

Calling RAggregate_profile on profile 1, M_k=1, len(policies)=2, theta_k=1.85380
: lower_bound: 0.65216, theta_k: 1.85380
: RPS size for profile 1: 2
Calling RAggregate_profile on profile 2, M_k=1, len(policies)=2, theta_k=1.55939
: lower_bound: 0.35774, theta_k: 1.55939
: RPS size for profile 2: 2
Calling RAggregate_profile on profile 3, M_k=2, len(policies)=4, theta_k=2.28029
: lower_bound: 1.07864, theta_k: 2.28029
: RPS size for profile 3: 4
Calling RAggregate_profile on profile 4, M_k=1, len(policies)=2, theta_k=1.30117
: lower_bound: 0.09952, theta_k: 1.30117
: RPS size for profile 4: 2
Calling RAggregate_profile on profile 5, M_k=2, len(policies)=4, theta_k=2.28977
: lower_bound: 1.08812, theta_k: 2.28977
: RPS size for profile 5: 4
Calling RAggregate_profile on profile 6, M_k=2, len(policies)=4, theta_k=1.79353
: lower_bound: 0.59188, theta_k: 1.79353
: RPS size for profile 6: 4
Calling RAggregate_profile on profile 7, M_k=3, len(policies)=8, theta_k=3.94933
: lower_bound: 2.74

In [16]:
# Compute loss only for nonempty profile RPSs
for rp, (Dk, yk, policies_k, pm_k) in zip(R_profiles, loss_args):
    rp.calculate_loss(Dk, yk, policies_k, pm_k, lambda_r, normalize=N1)

## 4. Construct the full RPS

We demonstrate the creation of the full RPS from the profile-specific partitions.

In [17]:
# Step 1: Data checks
D = np.asarray(D1)
y = np.asarray(y1)
if D.ndim != 1:
    raise ValueError(f"D should be 1D (policy indices), got shape {D.shape}")
if y.ndim != 1:
    y = y.ravel()
N = len(D)
if len(y) != N:
    raise ValueError(f"y and D must have same length: got {len(y)} and {N}")


# Enumerate all policies and profiles
profiles, profile_map = enumerate_profiles(M)
all_policies = enumerate_policies(M, R_vec)
policy_to_pid = {tuple(pol): profile_map[policy_to_profile(pol)] for pol in all_policies}
policy_to_indices = {pid: [] for pid in range(len(profiles))}
for i, pol in enumerate(all_policies):
    pid = policy_to_pid[tuple(pol)]
    policy_to_indices[pid].append(i)


In [18]:
# Step 2: Identify active profiles and compute lower bounds
valid_pids = []
lb_k = []
for pid, profile in enumerate(profiles):
    indices = policy_to_indices[pid]
    idx_mask = np.isin(D, indices)
    if np.sum(idx_mask) > 0:
        Dk_policyidx = D[idx_mask]         # Global policy indices for this profile
        yk = y[idx_mask]
        profile_mask = np.array(profile, dtype=bool)
        # Map to reduced policies (profile-local tuples)
        policies_k = [tuple(np.array(all_policies[pol_idx])[profile_mask]) for pol_idx in Dk_policyidx]
        policies_k_unique = list(sorted(set(policies_k)))
        # Build tuple->local index mapping for this profile
        tuple_to_local_idx = {p: i for i, p in enumerate(policies_k_unique)}
        Dk_local = np.array([tuple_to_local_idx[p] for p in policies_k])
        pm = loss.compute_policy_means(Dk_local, yk, len(policies_k_unique))
        raw_lb = find_profile_lower_bound(Dk_local, yk.reshape(-1, 1), pm)
        lb_k.append(raw_lb / N)
        valid_pids.append(pid)

    else:
        lb_k.append(0.0)

lb_k_arr = np.array(lb_k)
total_lb = lb_k_arr.sum()
theta_global = total_lb * (1 + eps)
print(f"Observed profiles: {valid_pids}")
print(f"Threshold: {theta_global:.5f}")
print(f"Per-profile lower bounds: {lb_k_arr}")

Observed profiles: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
Threshold: 25.23455
Per-profile lower bounds: [0.         0.65215607 0.35774498 1.07864121 0.09952396 1.08812371
 0.5918832  2.74768663 0.2035705  0.89286711 0.989239   1.65771603
 1.06501584 1.77526929 2.68874073 8.14472582]


In [19]:
# Step 3: Build per-profile RashomonSets (always using local indices!)
R_profiles = []
for i, pid in enumerate(valid_pids):
    indices = policy_to_indices[pid]
    idx_mask = np.isin(D, indices)
    Dk_policyidx = D[idx_mask]
    yk = y[idx_mask]
    profile_mask = np.array(profiles[pid], dtype=bool)
    M_k = profile_mask.sum()
    R_k = R_vec[profile_mask]
    policies_k = [tuple(np.array(all_policies[pol_idx])[profile_mask]) for pol_idx in Dk_policyidx]
    policies_k_unique = list(sorted(set(policies_k)))
    tuple_to_local_idx = {p: j for j, p in enumerate(policies_k_unique)}
    Dk_local = np.array([tuple_to_local_idx[p] for p in policies_k])
    pm = loss.compute_policy_means(Dk_local, yk, len(policies_k_unique))
    theta_k = max(0.0, theta_global - (total_lb - lb_k[pid]))
    print(f"Profile {pid}: M_k={M_k}, #policies={len(policies_k_unique)}, theta_k={theta_k:.5f}")
    rp = RAggregate_profile(
        M=M_k,
        R=R_k,
        H=H,
        D=Dk_local.reshape(-1, 1),          # 1D array of local indices
        y=yk.reshape(-1, 1),                # 1D array of outcomes
        theta=theta_k,
        profile=tuple(profiles[pid]),
        reg=lambda_r,
        policies=policies_k_unique,
        policy_means=pm,
        normalize=N
    )
    print(f": RPS size: {len(rp)}")
    R_profiles.append(rp)

Profile 1: M_k=1, #policies=1, theta_k=1.85380
: RPS size: 2
Profile 2: M_k=1, #policies=1, theta_k=1.55939
: RPS size: 2
Profile 3: M_k=2, #policies=3, theta_k=2.28029
: RPS size: 4
Profile 4: M_k=1, #policies=1, theta_k=1.30117
: RPS size: 2
Profile 5: M_k=2, #policies=3, theta_k=2.28977
: RPS size: 4
Profile 6: M_k=2, #policies=3, theta_k=1.79353
: RPS size: 4
Profile 7: M_k=3, #policies=7, theta_k=3.94933
: RPS size: 0
Profile 8: M_k=1, #policies=1, theta_k=1.40522
: RPS size: 2
Profile 9: M_k=2, #policies=3, theta_k=2.09451
: RPS size: 4
Profile 10: M_k=2, #policies=3, theta_k=2.19088
: RPS size: 4
Profile 11: M_k=3, #policies=7, theta_k=2.85936
: RPS size: 0
Profile 12: M_k=2, #policies=3, theta_k=2.26666
: RPS size: 4
Profile 13: M_k=3, #policies=7, theta_k=2.97691
: RPS size: 0
Profile 14: M_k=3, #policies=7, theta_k=3.89039
: RPS size: 0
Profile 15: M_k=4, #policies=15, theta_k=9.34637
: RPS size: 0


In [20]:
from itertools import product
# Step 4: Cross-product of valid RashomonSets (partial Rashomon set)
nonempty_idx = [i for i, rp in enumerate(R_profiles) if len(rp) > 0]
nonempty_profiles = [R_profiles[i] for i in nonempty_idx]
R_set_partial = list(product(*[range(len(rp)) for rp in nonempty_profiles]))
print(f"Found {len(R_set_partial)} Rashomon sets across {len(nonempty_profiles)} nonempty profiles.")
print(f"Nonempty profile indices: {nonempty_idx}")
print(f"Total profiles (partitions): {len(R_profiles)}")
print(f"Profiles with nonempty RPS: {len(nonempty_profiles)}")


Found 65536 Rashomon sets across 10 nonempty profiles.
Nonempty profile indices: [0, 1, 2, 3, 4, 5, 7, 8, 9, 11]
Total profiles (partitions): 15
Profiles with nonempty RPS: 10


We can now set this up in a larger wrapper function too:

In [22]:
from observed_RPS import observed_rps
R_set_partial, R_profiles, nonempty_idx, profiles = observed_rps(
    M, R_vec, H, D1, y1, lambda_r, eps=0.05
)

Profile 1: M_k=1, #policies=1, theta_k=1.85380
Profile 2: M_k=1, #policies=1, theta_k=1.55939
Profile 3: M_k=2, #policies=3, theta_k=2.28029
Profile 4: M_k=1, #policies=1, theta_k=1.30117
Profile 5: M_k=2, #policies=3, theta_k=2.28977
Profile 6: M_k=2, #policies=3, theta_k=1.79353
Profile 7: M_k=3, #policies=7, theta_k=3.94933
Profile 8: M_k=1, #policies=1, theta_k=1.40522
Profile 9: M_k=2, #policies=3, theta_k=2.09451
Profile 10: M_k=2, #policies=3, theta_k=2.19088
Profile 11: M_k=3, #policies=7, theta_k=2.85936
Profile 12: M_k=2, #policies=3, theta_k=2.26666
Profile 13: M_k=3, #policies=7, theta_k=2.97691
Profile 14: M_k=3, #policies=7, theta_k=3.89039
Profile 15: M_k=4, #policies=15, theta_k=9.34637


We demonstrate a main wrapper call, as we would use from the original Rashomon module, but we note that the function operates unexpectedly because we don't have data for a number of the profiles. (We end up with an empty Rashomon Partition Set).

In [None]:
# Doesn't work!! Because only have data on a small subspace, and quits out when we can't make a pooling decision
# Call main RAggregate function
R_set, R_profiles = RAggregate(
    M=M,
    R=R_vec,
    H=H,
    D=D1,
    y=y1,
    theta=theta_global,
    reg=lambda_r,
    verbose=True,
)

## 5. Second-wave allocation