# Two-wave RPS Algorithm Walk-Through

In [31]:
import numpy as np
from rashomon.hasse import enumerate_policies
import allocation
from data_gen import generate_outcomes

## 1. First-wave allocation

In [2]:
M = 3
R = [4,3,3]
H = 12  # sparsity parameter
n1 = 100  # total first‐wave sample size

In [3]:
R_vec = np.full(M, R) if np.isscalar(R) else np.array(R) # allow for heterogeneity in levels
assert R_vec.shape == (M,)
policies = np.array(enumerate_policies(M, R))

In [4]:
K = len(policies)
print(f"Found K = {K} policies (each policy is an {M}-tuple).")

Found K = 36 policies (each policy is an 3-tuple).


### 1.1. Compute initial boundary probabilities

In [5]:
# MODULAR version: boundary_probs_1 = allocation.compute_initial_boundary_probs(policies, R, H)

K = len(policies)
if np.isscalar(R):
    R_arr = np.array([R] * M, dtype=int)
else:
    R_arr = np.array(R, dtype=int)
    if R_arr.size != M:
        raise ValueError(f"R must be an int or list/array of length M={M}, instead got size {R_arr.size}.")

boundary_probs_1 = np.zeros(K, dtype=float)
for idx in range(K):
    v = policies[idx]

    M = len(v)
    # from formula in paper
    term = 1.0
    for i in range(M):
        R_i = R_arr[i]
        ratio = 2 * min(int(v[i]), R_i - 1 - int(v[i])) / (R_i - 1)
        term *= (1 - ratio) ** (H - 1)
    boundary_probs_1[idx] = 1 - term

### 1.2 Allocate observations to policies

In [6]:
# MODULAR version: n1_alloc = allocation.allocate_wave(boundary_probs_1, n1)

total_prob = boundary_probs_1.sum()
if total_prob == 0:
    raise ValueError("Sum of boundary probabilities is zero. Check H or R.")
normalized = boundary_probs_1 / total_prob

# get direct allocation (could be fractional)
dir_alloc = normalized * n1

# floor to integers
n1_alloc = np.floor(dir_alloc).astype(int)
remainder = dir_alloc - n1_alloc
shortage = n1 - int(n1_alloc.sum())

# distribute the remaining samples to the largest remainders
if shortage > 0:
    idx_sorted = np.argsort(remainder)
    top_indices = idx_sorted[-shortage:]
    n1_alloc[top_indices] += 1

print(f"First‐wave allocation sums to {int(n1_alloc.sum())} (should be {n1}).")

First‐wave allocation sums to 100 (should be 100).


## 2. Simulating first-wave outcomes

### 2.1. Get underlying true outcomes for each policy

We generate a np.array `beta` of true effects for each node. We pass our lattice `policies`, `M` and `R`, and then specify a `kind` of underlying causal model.

There are a range of options, all of which are continuous and non-trivial: they exhibit locally correlated effects and avoid brittle cancellations in effects. The options range from simple (polynomial, gaussian, basic interaction) to complex (radial basis function, mimic of a simple neural-net-like function)

In [8]:
# beta = get_beta_underlying_causal(policies, M, R, kind="gauss")
def phi(policy):
    """Ground truth function (edit as needed for your scenario)."""
    x = np.array(policy)
    return 2 * x[0] + 0.5 * x[1]**2 + x[0]*x[2] + 0.2 * x[1] * x[2]
beta = np.array([phi(p) for p in policies])

We do not, as we did previously in the original Rashomon module, use a different distribution for each true pool from a random 'true' partition sigma_true. This is not used in this simulation due to our specifications on the underlying causal model (e.g. continuous, locally correlated effects, etc), but we nevertheless include first-pass code at generating a true partition and getting a piecewise beta if we did wish to include it.


In [9]:
# partition_seed = 123
# sigma_true, pi_pools_true, pi_policies_true = generate_true_partition(policies, R,random_seed=partition_seed)
# beta = get_beta_piecewise(policies, sigma_true, pi_pools_true, pi_policies_true, 0.5, 1, 10)

### 2.2. Get simulated outcomes

We now track the first-wave assignment and generate the outcomes with additional noise

In [10]:
# MODULAR version: D1 = allocation.assign_treatments(n1_alloc)

total = int(n1_alloc.sum())
D1 = np.zeros(total, dtype=int)
pos = 0
for idx, count in enumerate(n1_alloc):
    if count > 0:
        D1[pos : pos + count] = idx
        pos += count

print("D shape:", D1.shape)

D shape: (100,)


In [11]:
# generate outcomes y1
sigma_noise = 5
outcome_seed = 53
y1 = generate_outcomes(D=D1, beta=beta, sigma_noise=sigma_noise, random_seed=outcome_seed)
print("Overall mean outcome:", np.mean(y1))
print("Overall std outcome:", np.std(y1))

Overall mean outcome: 5.5550266752544255
Overall std outcome: 5.898318566561542


## 3. Construct the RPS

In [12]:
import numpy as np
from rashomon.hasse import enumerate_profiles, policy_to_profile
from rashomon.aggregate import (
    RAggregate_profile,
    subset_data,
    find_profile_lower_bound,
)
from rashomon import loss

In [13]:
lambda_r = 0.01
eps = 0.05 # chosen tolerance
verbose = True

### 3.1. Mapping profiles

We build a map between profiles and policies to observe which profiles contain data and which are as yet unobserved.

In [14]:
N = len(D1)
M = len(policies[0])

# Build profiles and maps between policies
profiles, profile_map = enumerate_profiles(M)
profile_to_policies = {k: [] for k in range(len(profiles))}
profile_to_indices = {k: [] for k in range(len(profiles))}
for i, pol in enumerate(policies):
    pid = profile_map[policy_to_profile(pol)]
    profile_to_policies[pid].append(pol)
    profile_to_indices[pid].append(i)

### 3.2. Find losses and get threshold theta

We then find the lower bound losses for each profile. We subset the relevant data for the profile, mask and get the corresponding policies, and then compute the policy means to find the profile loss lower bound. We find the best loss overall and calculate the total sum of lower bounds.

In [15]:
# Get just the profiles and profile_ids with data and track losses
valid_pids = []
lower_bounds = []
for profile_id, profile in enumerate(profiles):
    Dk, yk = subset_data(D1, y1, profile_to_indices[profile_id]) # using rashomon.aggregate, get correct subset of data
    if Dk is None:
        continue
    mask = np.array(profile, dtype=bool)
    # corresponding policies for this profile id
    reduced_policies = [tuple(np.array(p)[mask]) for p in profile_to_policies[profile_id]]

    # get losses and track lower bounds
    pm = loss.compute_policy_means(Dk, yk, len(reduced_policies))
    profile_lb = find_profile_lower_bound(Dk, yk, pm)
    lower_bounds.append(profile_lb / N)
    valid_pids.append(profile_id)SD

lower_bounds = np.array(lower_bounds)
best_loss = lower_bounds.min()
print(f"best_loss = {best_loss:.5f}")
lower_bounds = np.array(lower_bounds)
total_lb = lower_bounds.sum()

best_loss = 0.68255


In [16]:
# calculate rashomon threshold
theta_global = total_lb * (1 + eps) # get loss threshold in absolute reference to sum of lower bounds
if verbose:
    print(f"theta_global = {theta_global:.5f} from sum of lower bounds {total_lb:.5f}")

theta_global = 20.12003 from sum of lower bounds 19.16194


### 3.3. Construct the RPS for each profile

In [17]:
R_profiles = []
nonempty_profile_ids = []
for i, profile_id in enumerate(valid_pids):
    profile_mask = np.array(profiles[profile_id], dtype=bool)
    M_k = profile_mask.sum()

    # Compute reduced policies using only active features for this profile
    reduced_policies = [tuple(np.array(p)[profile_mask]) for p in profile_to_policies[profile_id]]

    # Compute number of levels for each local (profile) feature
    R_k = np.array([len(set([p[feat] for p in reduced_policies])) for feat in range(M_k)])

    # Remap each feature in reduced_policies to contiguous 0-based values
    reduced_policies_arr = np.array(reduced_policies)
    for j in range(reduced_policies_arr.shape[1]):
        _, reduced_policies_arr[:, j] = np.unique(reduced_policies_arr[:, j], return_inverse=True)
    reduced_policies = [tuple(row) for row in reduced_policies_arr]
    R_k = np.array([len(np.unique(reduced_policies_arr[:, j])) for j in range(M_k)])

    # Value-mapping-based subsetting and remapping
    # This gives Dk (local indices) and yk (outcomes)
    policy_indices_this_profile = profile_to_indices[profile_id]
    mask = np.isin(D1, policy_indices_this_profile)
    Dk = D1[mask]
    yk = y1[mask]

    # Now remap Dk from global policy indices to local indices in reduced_policies
    Dk = np.asarray(Dk).reshape(-1)
    policy_map = {idx: j for j, idx in enumerate(policy_indices_this_profile)}
    assert all(ix in policy_map for ix in Dk), f"Found Dk values not in mapping for profile {profile_id}"
    Dk_local = np.vectorize(policy_map.get)(Dk)      # map to local indices, shape (n,)
    assert yk.shape[0] == Dk_local.shape[0], "Dk_local and yk must have the same length"

    # Need to have Dk as a 1D array for the loss functions
    # Compute policy means with local indices
    pm = loss.compute_policy_means(Dk_local, yk, len(reduced_policies))
    assert pm.shape[0] == len(reduced_policies), "policy_means length mismatch"

    # get profile threshold
    theta_k = max(0.0, theta_global - (total_lb - lower_bounds[i]))

    # Need to reshape np array because the RAggregate_profile expects shape (n,1)
    Dk_local = Dk_local.reshape(-1, 1)
    yk = yk.reshape(-1, 1)
    # get rashomon set for each profile
    rashomon_profile = RAggregate_profile(
        M=M_k,
        R=R_k,
        H=H,
        D=Dk_local,  # Already local indices
        y=yk,
        theta=theta_k,
        profile=tuple(profiles[profile_id]),
        reg=lambda_r,
        policies=reduced_policies,
        policy_means=pm,
        normalize=N
    )

    # calculate losses for non-empty profiles and add to list of profiles
    Dk = np.asarray(Dk).reshape(-1) # loss functions again want a 1d array for D, but keep y 2d
    if len(rashomon_profile) > 0:
        rashomon_profile.calculate_loss(Dk_local, yk, reduced_policies, pm, lambda_r, normalize=N)
        R_profiles.append(rashomon_profile)
        nonempty_profile_ids.append(profile_id)
    if verbose:
        print(f"Profile {profile_id}: M_k={M_k}, #policies={len(reduced_policies)}, theta_k={theta_k:.5f}, RPS size={len(rashomon_profile)}")

Profile 1: M_k=1, #policies=2, theta_k=2.34215, RPS size=1
Profile 2: M_k=1, #policies=2, theta_k=2.50604, RPS size=1
Profile 3: M_k=2, #policies=4, theta_k=5.30727, RPS size=1
Profile 4: M_k=1, #policies=3, theta_k=1.64065, RPS size=1
Profile 5: M_k=2, #policies=6, theta_k=2.83749, RPS size=2
Profile 6: M_k=2, #policies=6, theta_k=3.59373, RPS size=2
Profile 7: M_k=3, #policies=12, theta_k=7.64129, RPS size=2


In [18]:
# if none of the profiles have a rashomon set, observed RPS is empty
if len(R_profiles) == 0:
    if verbose:
        print("No profiles have feasible Rashomon sets; global RPS is empty.")

In [19]:
excluded_profiles = [profile_id for profile_id in valid_pids if profile_id not in nonempty_profile_ids]
if verbose:
    if len(excluded_profiles) > 0:
        print(f"Skipped profile number due to empty Rashomon set: {excluded_profiles}")
    else:
        print("All profiles with data have non-empty Rashomon sets.")

All profiles with data have non-empty Rashomon sets.


In [20]:
# Visual check on losses
for idx, rp in enumerate(R_profiles):
    losses = np.array(rp.loss)
    print(f"Profile {nonempty_profile_ids[idx]}: min loss = {losses.min():.4f}, max loss = {losses.max():.4f}")

Profile 1: min loss = 1.4041, max loss = 1.4041
Profile 2: min loss = 1.5679, max loss = 1.5679
Profile 3: min loss = 4.3892, max loss = 4.3892
Profile 4: min loss = 0.7126, max loss = 0.7126
Profile 5: min loss = 1.9394, max loss = 5.0033
Profile 6: min loss = 2.6956, max loss = 4.2876
Profile 7: min loss = 6.8032, max loss = 8.9012


### 3.4. Construct full RPS across all observed profiles

In [23]:
# Assemble observed RPS via find_feasible_combinations from RAggregate
from rashomon.aggregate import find_feasible_combinations
R_set = find_feasible_combinations(R_profiles, 30, H)
if verbose:
    print(f"RPS has: {len(R_set)} feasible partitions over {len(R_profiles)} observed profiles.")

RPS has: 8 feasible partitions over 7 observed profiles.


We note that directly constructing the RPS from the original Rashomon module does not work here. In particular, not all profiles necessarily have observations, and the original module return an empty RPS in that case. In particular, we need to correctly subset and mask the profiles, compute loss, and track indices, each of which have slightly different assumptions in the original module.

## 5. Second-wave allocation

### (5.0.) Quick validity checks from first-wave

In [24]:
print(f"Number of profiles: {len(profiles)}")
print(f"Number of R_profiles (number of profiles with data): {len(R_profiles)}")
print(f"Final RPS size (len(R_set)): {len(R_set)}")

assert len(D1) == len(y1), "Mismatch: D and y must have same length."
print("CHECKED: Assignment and outcome vectors are same length.")

K, M = np.array(policies).shape
assert K > 0 and M > 0, "Policies array shape invalid."
print(f"CHECKED: Lattice has {K} policies of {M} features each.")

assert len(R_profiles) > 0, "No nonempty profiles in R_profiles!"
assert len(R_set) > 0, "RPS is empty! No feasible partitions found."
print("CHECKED: All observed profiles are nonempty, and RPS is nonempty.")

for pid, indices in profile_to_indices.items():
    assert all(0 <= ix < K for ix in indices), f"Profile {pid} has out-of-bounds policy indices."
print("CHECKED: All profile_to_indices entries are valid global policy indices.")

for idx, rp in enumerate(R_profiles):
    assert hasattr(rp, 'sigma'), f"R_profile {idx} missing 'sigma'."
    assert hasattr(rp, 'loss'), f"R_profile {idx} missing 'loss'."
    assert len(rp.sigma) == len(rp.loss), f"R_profile {idx}: sigma/loss length mismatch."
    assert np.all(np.isfinite(rp.loss)), f"R_profile {idx}: loss contains NaN or inf."
print("CHECKED: All R_profiles have matching, finite loss and partition arrays.")

if hasattr(R_profiles[0], 'profile'):
    for rp in R_profiles:
        pid = rp.profile if hasattr(rp, 'profile') else None
        if pid is not None:
            assert pid in profile_to_indices, f"R_profile profile id {pid} not in profile_to_indices."
print("CHECKED: All R_profiles have valid profile IDs.")

#--- RPS partition indices in range ---
for partition in R_set:
    assert all(0 <= idx < len(R_profiles[k].sigma) for k, idx in enumerate(partition)), \
        f"Partition {partition} has out-of-range index."
print("CHECKED: All RPS partitions refer to valid indices in R_profiles.")

Number of profiles: 8
Number of R_profiles (number of profiles with data): 7
Final RPS size (len(R_set)): 8
CHECKED: Assignment and outcome vectors are same length.
CHECKED: Lattice has 36 policies of 3 features each.
CHECKED: All observed profiles are nonempty, and RPS is nonempty.
CHECKED: All profile_to_indices entries are valid global policy indices.
CHECKED: All R_profiles have matching, finite loss and partition arrays.
CHECKED: All R_profiles have valid profile IDs.
CHECKED: All RPS partitions refer to valid indices in R_profiles.


### 5.1. Next wave allocation weights

In [25]:
from rashomon.extract_pools import lattice_edges, extract_pools

In [26]:
# precompute all lattice edges
lattice_ed = lattice_edges(policies)  # policies holds the full enumerated lattice

K = len(policies)
boundary_counts = np.zeros(K, float)

# compute posterior weights for each RPS partition
Q = np.array([
    sum(R_profiles[k].loss[part[k]] for k in range(len(R_profiles)))
    for part in R_set
])
post_weights = np.exp(-Q - Q.min())
post_weights /= post_weights.sum()

In [27]:
for part, w_i in zip(R_set, post_weights):
    pi_policies_profiles = {}
    for k, rp in enumerate(R_profiles):
        profile_id = nonempty_profile_ids[k]
        profile_mask = np.array(profiles[profile_id], dtype=bool)

        # Mask and remap as in RPS construction
        local_policies = [tuple(np.array(p)[profile_mask]) for p in profile_to_policies[profile_id]]
        if len(local_policies) == 0:
            continue  # skip empty
        arr = np.array(local_policies)
        for j in range(arr.shape[1]):
            _, arr[:, j] = np.unique(arr[:, j], return_inverse=True)
        local_policies_remap = [tuple(row) for row in arr]

        # Use the same remapped local_policies as when RPS was constructed
        sigma = rp.sigma[part[k]]
        _, pi_policies_local = extract_pools(local_policies_remap, sigma)

        # Map from local index to global index
        for local_idx, global_idx in enumerate(profile_to_indices[profile_id]):
            pi_policies_profiles[global_idx] = pi_policies_local[local_idx]

    # Build full K-vector: -1 for not-in-any-profile
    pi_policies = np.full(K, -1, dtype=int)
    for global_idx, pool_id in pi_policies_profiles.items():
        pi_policies[global_idx] = pool_id

    # Only count boundaries between *observed* nodes
    for u, v in lattice_ed:
        if pi_policies[u] != -1 and pi_policies[v] != -1 and pi_policies[u] != pi_policies[v]:
            boundary_counts[u] += w_i
            boundary_counts[v] += w_i

if boundary_counts.sum() == 0:
    raise ValueError("No boundaries detected in second-wave allocation. Check RPS content.")

boundary_probs_2 = boundary_counts / boundary_counts.sum()
print(f"Second-wave boundary allocation probabilities sum to {boundary_probs_2.sum():.4f}")
print(f"Nonzero probabilities: {np.count_nonzero(boundary_probs_2)} of {K} policies.")

Second-wave boundary allocation probabilities sum to 1.0000
Nonzero probabilities: 35 of 36 policies.


### 5.2. Allocate and simulate second-wave outcomes

We now use the modular versions for concision. See first section for exact code.

In [28]:
n2 = 500

# allocate nodes to each policy and get D2
alloc2 = allocation.allocate_wave(boundary_probs_2, n2)
assert alloc2.sum() == n2

D2 = allocation.assign_treatments(alloc2)

We then simulate outcomes, again using our true outcomes.

In [29]:
# still the same beta for the underlying causal function
y2 = generate_outcomes(D=D2, beta=beta, sigma_noise=sigma_noise, random_seed=55)

### ... see two_wave_sim.ipynb for final construction of RPS, now using a modular approach