# Two-wave RPS Algorithm

In [1]:
import numpy as np
from allocation import compute_initial_boundary_probs, compute_wave_boundary_probs, allocate_wave, assign_treatments
from data_gen import get_beta_underlying_causal, generate_outcomes
from rashomon.hasse import enumerate_policies

from construct_RPS import construct_RPS

### 1. First-wave allocation

In [2]:
# get lattice
M = 4
R = [2,3,3,4]
H = 20
n1 = 500

R_vec = np.full(M, R) if np.isscalar(R) else np.array(R) # allow for heterogeneity in levels
policies = enumerate_policies(M, R)
policies = np.array(enumerate_policies(M, R))
K = len(policies[0])

boundary_probs = compute_initial_boundary_probs(policies, R, H)
n1_alloc = allocate_wave(boundary_probs, n1)

beta = get_beta_underlying_causal(policies, M, R, kind="gauss")

D1 = assign_treatments(n1_alloc)

sigma_noise = 5
outcome_seed = 53

y1 = generate_outcomes(D=D1, beta=beta, sigma_noise=sigma_noise, random_seed=outcome_seed)

### 2. Construct first-wave RPS

In [3]:
lambda_r = 0.01
eps = 0.05
verbose = True

In [4]:
R_set, R_profiles, theta_global, policies, profiles, profile_to_policies, profile_to_indices, nonempty_profile_ids = construct_RPS(policies, M, R, D1, y1, H, eps, lambda_r, verbose=False)

print(f"RPS has: {len(R_set)} feasible partitions over {len(R_profiles)} observed profiles.")

RPS has: 57 feasible partitions over 14 observed profiles.


### 3. Second-wave allocation and simulation

In [5]:
n2 = 500

boundary_probs_2 = compute_wave_boundary_probs(R_set, R_profiles, policies, profiles, profile_to_policies, profile_to_indices, nonempty_profile_ids)

alloc2 = allocate_wave(boundary_probs_2, n2)
D2 = assign_treatments(alloc2)
y2 = generate_outcomes(D=D2, beta=beta, sigma_noise=sigma_noise, random_seed=55)
D = np.vstack([D1, D2])
y = np.concatenate([y1, y2])

In [6]:
R_set, R_profiles, theta_global, policies, profiles, profile_to_policies, profile_to_indices, nonempty_profile_ids = construct_RPS(policies, M, R, D, y, H, eps, lambda_r, verbose=False)
# TODO Does theta recalculate each time?

print(f"RPS has: {len(R_set)} feasible partitions over {len(R_profiles)} observed profiles.")

RPS has: 247 feasible partitions over 14 observed profiles.
