In [94]:
import numpy as np
import matplotlib.pyplot as plt

from rashomon.tva import enumerate_policies
from rashomon.extract_pools import extract_pools
from rashomon import count_pools# import num_admissible_poolings, num_pools

# %matplotlib inline

### Function to pool based on $\Sigma$ matrix

In [39]:
sigma = np.array([[1, 1, 0],
                  [0, 1, 1]], dtype='float64')

M, n = sigma.shape
R = n + 2

num_policies = (R-1)**M
policies = enumerate_policies(M, R)
pi_pools, pi_policies = extract_pools(policies, sigma)

for pool_id, pool in pi_pools.items():
    print(pool_id, ":", pool)

0 : [0, 4, 8]
1 : [1, 2, 3, 5, 6, 7, 9, 10, 11]
2 : [12]
3 : [13, 14, 15]


### Generate data

In [4]:
np.random.seed(3)

num_pools = len(pi_pools)
mu = np.random.uniform(0, 4, size=num_pools)
var = [1] * num_pools

n_per_pol = 10

num_data = num_policies * n_per_pol
X = np.ndarray(shape=(num_data, M))
D = np.ndarray(shape=(num_data, 1), dtype='int_')
y = np.ndarray(shape=(num_data, 1))

for idx, policy in enumerate(policies):
    pool_i = pi_policies[idx]
    mu_i = mu[pool_i]
    var_i = var[pool_i]
    y_i = np.random.normal(mu_i, var_i, size=(n_per_pol, 1))

    start_idx = idx * n_per_pol
    end_idx = (idx + 1) * n_per_pol

    X[start_idx:end_idx, ] = policy
    D[start_idx:end_idx, ] = idx
    y[start_idx:end_idx, ] = y_i
    

### Function to compute Q

In [5]:
def compute_policy_means(D, y, num_policies):
    """
    Returns: policy_means
    policy_means is a np.ndarray of size (num_policies,2)
    policy_means[i, 0] = sum of all y where D[i,0] = i
    policy_means[i, 1] = count of where D[i,0] = i
    """
    policy_means = np.ndarray(shape=(num_policies, 2))
    for policy_id in range(num_policies):
        idx = np.where(D == policy_id)
        policy_means[policy_id, 0] = np.sum(y[idx])
        policy_means[policy_id, 1] = len(idx[0])
    return policy_means

def compute_pool_means(policy_means, pi_pools):
    """
    Returns: mu_pools
    mu_pools is a np.ndarray of size (H,) where H is the number of pools
    mu_pools[i] = mean value in pool i
    """
    H = len(pi_pools.keys())
    mu_pools_temp = np.ndarray(shape=(H, 2))
    for pool_id, pool in pi_pools.items():
        policy_subset = policy_means[pool, :]
        mu_pools_temp[pool_id, :] = np.sum(policy_subset, axis=0)
    mu_pools = np.float64(mu_pools_temp[:, 0]) / mu_pools_temp[:, 1]
    return mu_pools


def compute_Q(D, y, mu_pools, pi_policies, reg=1):
    """
    Compute the loss Q
    """
    
    H = mu_pools.shape[0]

    D_pool = [pi_policies[pol_id] for pol_id in D[:,0]]
    mu_D = mu_pools[D_pool]
    sqrd_diff = (y[:, 0] - mu_D)**2
    
    Q = np.mean(sqrd_diff) + reg*H
    
    return Q

In [6]:
# This function needs to called only once
policy_means = compute_policy_means(D, y, num_policies)

# This function needs to be called every time the pools change
mu_pools = compute_pool_means(policy_means, pi_pools)

# This function needs to be called every time the pools change
Q = compute_Q(D, y, mu_pools, pi_policies, 1)

print(Q)

4.895456436046057


### Function to compute B

In [95]:
def partition_sigma(sigma, i, j):
    """
    Maximally split policies in arm i starting at dosage j
    All other existing splits are maintained
    """
    sigma_fix = np.copy(sigma)
    sigma_fix[i, j:] = 0
    return sigma_fix


def compute_B(D, y, sigma, i, j, policies, policy_means, reg=1):
    """
    The B function in Theorem 6.3 \ref{thm:rashomon-equivalent-bound}
    """

    # Split maximally in arm i from dosage j
    sigma_fix = partition_sigma(sigma, i, j)
    pi_fixed_pools, pi_fixed_policies = extract_pools(policies, sigma_fix)

    # Compute squared loss for this maximal split
    # This loss is B minus the regularization term
    mu_fixed_pools = compute_pool_means(policy_means, pi_fixed_pools)
    D_pool = [pi_fixed_policies[pol_id] for pol_id in D[:,0]]
    mu_D = mu_fixed_pools[D_pool]
    sqrd_diff = (y[:, 0] - mu_D)**2
    
    B = np.mean(sqrd_diff)

    # The least number of pools
    # The number of pools when the splittable policies are pooled maximally
    sigma_fix[i, (j+1):] = 1
    H = count_pools.num_pools(sigma_fix)

    B += reg*H
    
    return B

In [106]:
i = 0
j = 0

B = compute_B(D, y, sigma, i, j, policies, policy_means, 1)
print(B)

4.859163572719872
