# Two-wave RPS Algorithm

In [1]:
import numpy as np
from rashomon.hasse import enumerate_policies
from allocation import compute_boundary_probs, allocate_wave, assign_treatments
from data_gen import get_beta_underlying_causal, generate_outcomes

## 1. First-wave allocation

In [19]:
# get lattice
M = 4
R = [2,3,3,4]

R_vec = np.full(M, R) if np.isscalar(R) else np.array(R) # allow for heterogeneity in levels
assert R_vec.shape == (M,)
policies = enumerate_policies(M, R)

K = len(policies)
print(f"Found K = {K} policies (each policy is an {M}-tuple).")
H = 5  # sparsity parameter used inside compute_boundary_probs TODO choice
n1 = 500  # total first‐wave sample size

Found K = 72 policies (each policy is an 4-tuple).


**Compute first‐wave allocation**: We need R_i for each feature (here R_i = R for i=0,…,M-1), then we call `compute_boundary_probs` -> `allocate_first_wave`. We get `n1_alloc`: an array of length K summing to n1.

In [20]:
boundary_probs = compute_boundary_probs(policies, R, H)
n1_alloc = allocate_wave(boundary_probs, n1)
print(f"First‐wave allocation sums to {int(n1_alloc.sum())} (should be {n1}).")

First‐wave allocation sums to 500 (should be 500).


## 2. Simulating first-wave outcomes

We generate a np.array `beta` of true effects for each node. We pass our lattice `policies`, `M` and `R`, and then specify a `kind` of underlying causal model.

There are a range of options, all of which are continuous and non-trivial: they exhibit locally correlated effects and avoid brittle cancellations in effects. The options range from simple (polynomial, gaussian, basic interaction) to complex (radial basis function, mimic of a simple neural-net-like function)

In [21]:
beta = get_beta_underlying_causal(policies, M, R, kind="gauss")

In [22]:
# Not in use: different distribution for each true pool from a random 'true' partition sigma_true. Not used in this simulation due to our specifications on the underlying causal model (e.g. continuous, locally correlated effects, etc). Also needs changes on how it constructs a true partition.

# partition_seed = 123
# sigma_true, pi_pools_true, pi_policies_true = generate_true_partition(policies, R,random_seed=partition_seed)
# beta = get_beta_piecewise(policies, sigma_true, pi_pools_true, pi_policies_true, 0.5, 1, 10)

**Get outcomes**: we now track the first-wave assignment and generate the outcomes with additional noise

In [23]:
# now build first-wave assignment vector D
policies = np.array(enumerate_policies(M, R))  # (K, M)
D = assign_treatments(n1_alloc)  # (N1, M)
print("D shape:", D.shape)

D shape: (500,)


In [24]:
# generate outcomes y1
sigma_noise = 5
outcome_seed = 53
y = generate_outcomes(D=D, beta=beta, sigma_noise=sigma_noise, random_seed=outcome_seed)
print("Overall mean outcome:", np.mean(y))
print("Overall std outcome:", np.std(y))

Overall mean outcome: -0.07506026338360178
Overall std outcome: 5.3521106652966255


## 3. RPS for profiles with data

We now search for the optimal theta as given by a normalized loss and chosen epsilon. Need to already specify H and the regularization parameter.

In [25]:
lambda_r = 0.01
eps = 0.05 # chosen tolerance

In [26]:
import numpy as np
from rashomon.hasse import enumerate_policies, enumerate_profiles, policy_to_profile
from rashomon.aggregate import (
    RAggregate_profile,
    subset_data,
    find_profile_lower_bound,
)
from rashomon import loss

### RAggregate_observed walk through

In [27]:
verbose = True

In [28]:
N = len(D)
M = len(policies[0])

# Build profiles and maps between policies
profiles, profile_map = enumerate_profiles(M)
profile_to_policies = {k: [] for k in range(len(profiles))}
profile_to_indices = {k: [] for k in range(len(profiles))}
for i, pol in enumerate(policies):
    pid = profile_map[policy_to_profile(pol)]
    profile_to_policies[pid].append(pol)
    profile_to_indices[pid].append(i)

In [29]:
# Get just the profiles and profile_ids with data and track losses
valid_pids = []
lower_bounds = []
for profile_id, profile in enumerate(profiles):
    Dk, yk = subset_data(D, y, profile_to_indices[profile_id]) # using rashomon.aggregate, get correct subset of data
    if Dk is None:
        continue
    mask = np.array(profile, dtype=bool)
    # corresponding policies for this profile id
    reduced_policies = [tuple(np.array(p)[mask]) for p in profile_to_policies[profile_id]]

    # get losses and track lower bounds
    pm = loss.compute_policy_means(Dk, yk, len(reduced_policies))
    profile_lb = find_profile_lower_bound(Dk, yk, pm)
    lower_bounds.append(profile_lb / N)
    valid_pids.append(profile_id)

lower_bounds = np.array(lower_bounds)
best_loss = lower_bounds.min()
print(f"best_loss = {best_loss:.5f}")
lower_bounds = np.array(lower_bounds)
total_lb = lower_bounds.sum()

best_loss = 0.25737


In [30]:
# calculate rashomon threshold
theta_global = total_lb * (1 + eps) # get loss threshold in absolute reference to sum of lower bounds
if verbose:
    print(f"theta_global = {theta_global:.5f} from sum of lower bounds {total_lb:.5f}")

theta_global = 26.17569 from sum of lower bounds 24.92923


## 4. Construct the full RPS

In [31]:
R_profiles = []
nonempty_profile_ids = []
for i, profile_id in enumerate(valid_pids):
    profile_mask = np.array(profiles[profile_id], dtype=bool)
    M_k = profile_mask.sum()

    # Compute reduced policies using only active features for this profile
    reduced_policies = [tuple(np.array(p)[profile_mask]) for p in profile_to_policies[profile_id]]

    # Compute number of levels for each local (profile) feature
    R_k = np.array([len(set([p[feat] for p in reduced_policies])) for feat in range(M_k)])

    # Remap each feature in reduced_policies to contiguous 0-based values
    reduced_policies_arr = np.array(reduced_policies)
    for j in range(reduced_policies_arr.shape[1]):
        _, reduced_policies_arr[:, j] = np.unique(reduced_policies_arr[:, j], return_inverse=True)
    reduced_policies = [tuple(row) for row in reduced_policies_arr]
    R_k = np.array([len(np.unique(reduced_policies_arr[:, j])) for j in range(M_k)])

    # Value-mapping-based subsetting and remapping
    # This gives Dk (local indices) and yk (outcomes)
    policy_indices_this_profile = profile_to_indices[profile_id]
    mask = np.isin(D, policy_indices_this_profile)
    Dk = D[mask]
    yk = y[mask]

    # Now remap Dk from global policy indices to local indices in reduced_policies
    Dk = np.asarray(Dk).reshape(-1)
    policy_map = {idx: j for j, idx in enumerate(policy_indices_this_profile)}
    assert all(ix in policy_map for ix in Dk), f"Found Dk values not in mapping for profile {profile_id}"
    Dk_local = np.vectorize(policy_map.get)(Dk)      # map to local indices, shape (n,)
    assert yk.shape[0] == Dk_local.shape[0], "Dk_local and yk must have the same length"


    # Need to have Dk as a 1D array for the loss functions
    # Compute policy means with local indices
    pm = loss.compute_policy_means(Dk_local, yk, len(reduced_policies))
    assert pm.shape[0] == len(reduced_policies), "policy_means length mismatch"

    # get profile threshold
    theta_k = max(0.0, theta_global - (total_lb - lower_bounds[i]))

    # Need to reshape np array because the RAggregate_profile expects shape (n,1)
    Dk_local = Dk_local.reshape(-1, 1)
    yk = yk.reshape(-1, 1)
    # get rashomon set for each profile
    rashomon_profile = RAggregate_profile(
        M=M_k,
        R=R_k,
        H=H,
        D=Dk_local,  # Already local indices
        y=yk,
        theta=theta_k,
        profile=tuple(profiles[profile_id]),
        reg=lambda_r,
        policies=reduced_policies,
        policy_means=pm,
        normalize=N
    )

    # calculate losses for non-empty profiles and add to list of profiles
    Dk = np.asarray(Dk).reshape(-1) # loss functions again want a 1d array for D, but keep y 2d
    if len(rashomon_profile) > 0:
        rashomon_profile.calculate_loss(Dk_local, yk, reduced_policies, pm, lambda_r, normalize=N)
        R_profiles.append(rashomon_profile)
        nonempty_profile_ids.append(profile_id)
    if verbose:
        print(f"Profile {profile_id}: M_k={M_k}, #policies={len(reduced_policies)}, theta_k={theta_k:.5f}, RPS size={len(rashomon_profile)}")

Profile 1: M_k=1, #policies=3, theta_k=2.63517, RPS size=2
Profile 2: M_k=1, #policies=2, theta_k=1.69617, RPS size=1
Profile 3: M_k=2, #policies=6, theta_k=2.62537, RPS size=2
Profile 4: M_k=1, #policies=2, theta_k=1.50383, RPS size=1
Profile 5: M_k=2, #policies=6, theta_k=3.00661, RPS size=2
Profile 6: M_k=2, #policies=4, theta_k=2.72556, RPS size=1
Profile 7: M_k=3, #policies=12, theta_k=5.52833, RPS size=2
Profile 9: M_k=2, #policies=3, theta_k=2.12428, RPS size=2
Profile 10: M_k=2, #policies=2, theta_k=1.65157, RPS size=1
Profile 11: M_k=3, #policies=6, theta_k=3.34851, RPS size=2
Profile 12: M_k=2, #policies=2, theta_k=1.51779, RPS size=1
Profile 13: M_k=3, #policies=6, theta_k=3.59946, RPS size=2
Profile 14: M_k=3, #policies=4, theta_k=3.03184, RPS size=1
Profile 15: M_k=4, #policies=12, theta_k=7.38522, RPS size=2


In [32]:
# if none of the profiles have a rashomon set, observed RPS is empty
if len(R_profiles) == 0:
    if verbose:
        print("No profiles have feasible Rashomon sets; global RPS is empty.")

In [33]:
excluded_profiles = [profile_id for profile_id in valid_pids if profile_id not in nonempty_profile_ids]
if verbose:
    if len(excluded_profiles) > 0:
        print(f"Skipped profile number due to empty Rashomon set: {excluded_profiles}")
    else:
        print("All profiles with data have non-empty Rashomon sets.")

All profiles with data have non-empty Rashomon sets.


In [34]:
for idx, rp in enumerate(R_profiles):
    losses = np.array(rp.loss)
    print(f"Profile {nonempty_profile_ids[idx]}: min loss = {losses.min():.4f}, max loss = {losses.max():.4f}")

Profile 1: min loss = 1.4170, max loss = 1.4187
Profile 2: min loss = 0.4697, max loss = 0.4697
Profile 3: min loss = 1.4389, max loss = 1.4742
Profile 4: min loss = 0.2774, max loss = 0.2774
Profile 5: min loss = 1.8201, max loss = 2.0287
Profile 6: min loss = 1.5191, max loss = 1.5191
Profile 7: min loss = 4.4019, max loss = 5.2520
Profile 9: min loss = 0.8927, max loss = 0.9078
Profile 10: min loss = 0.4251, max loss = 0.4251
Profile 11: min loss = 2.1309, max loss = 2.1621
Profile 12: min loss = 0.2913, max loss = 0.2913
Profile 13: min loss = 2.4130, max loss = 3.0275
Profile 14: min loss = 1.8254, max loss = 1.8254
Profile 15: min loss = 6.2588, max loss = 6.4168


In [35]:
# Assemble observed RPS via find_feasible_combinations from RAggregate
from rashomon.aggregate import find_feasible_combinations
R_set = find_feasible_combinations(R_profiles, theta_global, H)
if verbose:
    print(f"RPS has: {len(R_set)} feasible partitions over {len(R_profiles)} observed profiles.")

RPS has: 0 feasible partitions over 14 observed profiles.


We also demonstrate a main wrapper call, as we would use from the original Rashomon module, but we note that the function operates unexpectedly because we don't have data for a number of the profiles. (We end up with an empty Rashomon Partition Set).

In [36]:
# Old algorithm doens't work. Requires data on everything, and feasible_combinations and loss functions throw an error when dealing with a smaller subsection.
# R_set_empty, R_profiles_empty = RAggregate(
#     M=M,
#     R=R_vec,
#     H=H,
#     D=D1,
#     y=y1,
#     theta=theta_global,
#     reg=lambda_r,
#     verbose=True,
# )

## 5. Second-wave allocation

In [None]:
# See data_gen and allocation files again