In [1]:
import numpy as np
import matplotlib.pyplot as plt

from rashomon import loss
from rashomon import count_pools
from rashomon.cache import RashomonCache
from rashomon.tva import enumerate_policies
from rashomon.extract_pools import extract_pools

# %matplotlib inline

### Function to pool based on $\Sigma$ matrix

In [2]:
sigma = np.array([[1, 1, 0],
                  [0, 1, 1]], dtype='float64')

M, n = sigma.shape
R = n + 2

num_policies = (R-1)**M
policies = enumerate_policies(M, R)
pi_pools, pi_policies = extract_pools(policies, sigma)

for pool_id, pool in pi_pools.items():
    print(pool_id, ":", pool)

0 : [0, 4, 8]
1 : [1, 2, 3, 5, 6, 7, 9, 10, 11]
2 : [12]
3 : [13, 14, 15]


### Generate data

In [3]:
np.random.seed(3)

num_pools = len(pi_pools)
mu = np.random.uniform(0, 4, size=num_pools)
var = [1] * num_pools

n_per_pol = 10

num_data = num_policies * n_per_pol
X = np.ndarray(shape=(num_data, M))
D = np.ndarray(shape=(num_data, 1), dtype='int_')
y = np.ndarray(shape=(num_data, 1))

for idx, policy in enumerate(policies):
    pool_i = pi_policies[idx]
    mu_i = mu[pool_i]
    var_i = var[pool_i]
    y_i = np.random.normal(mu_i, var_i, size=(n_per_pol, 1))

    start_idx = idx * n_per_pol
    end_idx = (idx + 1) * n_per_pol

    X[start_idx:end_idx, ] = policy
    D[start_idx:end_idx, ] = idx
    y[start_idx:end_idx, ] = y_i
    

In [4]:
# This function needs to called only once
policy_means = loss.compute_policy_means(D, y, num_policies)

# This function needs to be called every time the pools change
mu_pools = loss.compute_pool_means(policy_means, pi_pools)

# This function needs to be called every time the pools change
Q = loss.compute_Q(D, y, mu_pools, pi_policies, 1)

print(Q)

4.895456436046057


In [5]:
i = 0
j = 0

B = loss.compute_B(D, y, sigma, i, j, policies, policy_means, 1)
print(B)

4.859163572719872


### Caching object

In [6]:
seen_sigma = RashomonCache()
seen_sigma.insert(sigma, i, j)

# Should be True
print(seen_sigma.seen(sigma, i, j))

# Should be False
sigma2 = np.copy(sigma)
sigma2[1, 1] = 1 - sigma2[1, 1]
print(seen_sigma.seen(sigma2, i, j))

# Should be True
sigma3 = np.copy(sigma)
sigma3[0, 2] = 1 - sigma3[0, 2]
print(seen_sigma.seen(sigma3, i, j))



True
False
True
