In [3]:
import numpy as np
import matplotlib.pyplot as plt

from collections import deque

from rashomon import loss
from rashomon import counter
from rashomon.aggregate import RAggregate
from rashomon.sets import RashomonSet, RashomonProblemCache, RashomonSubproblemCache
from rashomon.tva import enumerate_policies
from rashomon.extract_pools import extract_pools


%load_ext autoreload
%autoreload 2
# %matplotlib inline

### Function to pool based on $\Sigma$ matrix

In [187]:
sigma = np.array([[1, 1, 0],
                  [0, 1, 1]], dtype='float64')

M, n = sigma.shape
R = n + 2

num_policies = (R-1)**M
policies = enumerate_policies(M, R)
pi_pools, pi_policies = extract_pools(policies, sigma)

for pool_id, pool in pi_pools.items():
    print(pool_id, ":", pool)

0 : [0, 4, 8]
1 : [1, 2, 3, 5, 6, 7, 9, 10, 11]
2 : [12]
3 : [13, 14, 15]


### Generate data

In [188]:
np.random.seed(3)

num_pools = len(pi_pools)
mu = np.random.uniform(0, 4, size=num_pools)
var = [1] * num_pools

n_per_pol = 10

num_data = num_policies * n_per_pol
X = np.ndarray(shape=(num_data, M))
D = np.ndarray(shape=(num_data, 1), dtype='int_')
y = np.ndarray(shape=(num_data, 1))

for idx, policy in enumerate(policies):
    pool_i = pi_policies[idx]
    mu_i = mu[pool_i]
    var_i = var[pool_i]
    y_i = np.random.normal(mu_i, var_i, size=(n_per_pol, 1))

    start_idx = idx * n_per_pol
    end_idx = (idx + 1) * n_per_pol

    X[start_idx:end_idx, ] = policy
    D[start_idx:end_idx, ] = idx
    y[start_idx:end_idx, ] = y_i
    

In [189]:
# This function needs to called only once
policy_means = loss.compute_policy_means(D, y, num_policies)

# This function needs to be called every time the pools change
mu_pools = loss.compute_pool_means(policy_means, pi_pools)

# This function needs to be called every time the pools change
Q = loss.compute_Q(D, y, sigma, policies, policy_means, 1)

print(Q)

4.895456436046057


In [190]:
i = 0
j = 0

B = loss.compute_B(D, y, sigma, i, j, policies, policy_means, 1)
print(B)

4.859163572719872


### RAggregate

In [191]:
P_set = RAggregate(2, 5, 4, D, y, 5, reg=1)
print(P_set.size)
P_set.seen(sigma)

20


True

### Varying R

In [192]:
# Idea: np.inf for arbitrary dosage?
sigma = np.array([[1, 0, 1],
                  [0, 1, np.inf]], dtype='float64')
# sigma = np.array([[0, 0, 0],
#                   [0, 0, np.inf]], dtype='float64')

M, n = sigma.shape
R = np.array([5, 4])

num_policies = np.prod(R-1)
policies = enumerate_policies(M, R)
pi_pools, pi_policies = extract_pools(policies, sigma)

for pool_id, pool in pi_pools.items():
    print(pool_id, ":", [policies[i] for i in pool])

0 : [(1, 1), (2, 1)]
1 : [(1, 2), (1, 3), (2, 2), (2, 3)]
2 : [(3, 1), (4, 1)]
3 : [(3, 2), (3, 3), (4, 2), (4, 3)]


In [193]:
counter.num_pools(sigma)

4.0

In [194]:
counter.num_admissible_poolings(4, M, R)

6

In [195]:
np.random.seed(3)

num_pools = len(pi_pools)
mu = np.random.uniform(0, 4, size=num_pools)
var = [1] * num_pools

n_per_pol = 10

num_data = num_policies * n_per_pol
X = np.ndarray(shape=(num_data, M))
D = np.ndarray(shape=(num_data, 1), dtype='int_')
y = np.ndarray(shape=(num_data, 1))

for idx, policy in enumerate(policies):
    pool_i = pi_policies[idx]
    mu_i = mu[pool_i]
    var_i = var[pool_i]
    y_i = np.random.normal(mu_i, var_i, size=(n_per_pol, 1))

    start_idx = idx * n_per_pol
    end_idx = (idx + 1) * n_per_pol

    X[start_idx:end_idx, ] = policy
    D[start_idx:end_idx, ] = idx
    y[start_idx:end_idx, ] = y_i

In [196]:
policy_means = loss.compute_policy_means(D, y, num_policies)
policy_means

array([[22.3306393 , 10.        ],
       [26.09289987, 10.        ],
       [26.9056259 , 10.        ],
       [14.77744099, 10.        ],
       [27.87404021, 10.        ],
       [29.12185777, 10.        ],
       [14.06728243, 10.        ],
       [21.12531734, 10.        ],
       [15.25551251, 10.        ],
       [12.69909817, 10.        ],
       [21.11531527, 10.        ],
       [19.32885624, 10.        ]])

In [197]:
Q = loss.compute_Q(D, y, sigma, policies, policy_means, reg=1)
Q

4.9889861406550375

In [198]:
i = 1
j = 1
B = loss.compute_B(D, y, sigma, i, j, policies, policy_means, reg=1)
print(B)

6.975889428946971


In [199]:
P_set = RAggregate(2, R, 4, D, y, 1.4, reg=0.1)
print(P_set.size)
P_set.seen(sigma)

5


True