In [1]:
import numpy as np
import matplotlib.pyplot as plt

# !pip install scikit-learn
from sklearn import linear_model
from sklearn.metrics import mean_squared_error

from collections import deque

from rashomon import loss
from rashomon import counter
from rashomon.aggregate import RAggregate
from rashomon.sets import RashomonSet, RashomonProblemCache, RashomonSubproblemCache
from rashomon import tva
from rashomon.extract_pools import extract_pools


%load_ext autoreload
%autoreload 2
# %matplotlib inline

### Worst case scenario

In [184]:
sigma = np.array([[1, 1, 0],
                  [0, 1, 0]], dtype='float64')

M, n = sigma.shape
R = np.array([5, 5])

num_policies = np.prod(R-1)
policies = tva.enumerate_policies(M, R)
pi_pools, pi_policies = extract_pools(policies, sigma)

for pool_id, pool in pi_pools.items():
    print(pool_id, ":", [policies[i] for i in pool])
    # print(pool_id, ":", pool)

0 : [(1, 1), (2, 1), (3, 1)]
1 : [(1, 2), (1, 3), (2, 2), (2, 3), (3, 2), (3, 3)]
2 : [(1, 4), (2, 4), (3, 4)]
3 : [(4, 1)]
4 : [(4, 2), (4, 3)]
5 : [(4, 4)]


### Generate data

In [300]:
np.random.seed(3)

num_pools = len(pi_pools)
# mu = np.random.uniform(0, 4, size=num_pools)
# mu = np.array([3, 6, 1, 4])
# mu = np.array([0, 1, 2, 3, 5.5, 3.5])
mu = np.array([0, 1.5, 3, 3, 6, 4.5])
var = [1] * num_pools

n_per_pol = 100

num_data = num_policies * n_per_pol
X = np.ndarray(shape=(num_data, M))
D = np.ndarray(shape=(num_data, 1), dtype='int_')
y = np.ndarray(shape=(num_data, 1))

for idx, policy in enumerate(policies):
    pool_i = pi_policies[idx]
    mu_i = mu[pool_i]
    var_i = var[pool_i]
    y_i = np.random.normal(mu_i, var_i, size=(n_per_pol, 1))

    start_idx = idx * n_per_pol
    end_idx = (idx + 1) * n_per_pol

    X[start_idx:end_idx, ] = policy
    D[start_idx:end_idx, ] = idx
    y[start_idx:end_idx, ] = y_i
    

In [301]:
# This function needs to called only once
policy_means = loss.compute_policy_means(D, y, num_policies)

# This function needs to be called every time the pools change
mu_pools = loss.compute_pool_means(policy_means, pi_pools)

# This function needs to be called every time the pools change
Q = loss.compute_Q(D, y, sigma, policies, policy_means, 0.1)

print(Q)

1.6215777927122845


In [302]:
i = 0
j = 0

B = loss.compute_B(D, y, sigma, i, j, policies, policy_means, 0.1)
print(B)

1.6189993231597828


### RAggregate

In [292]:
def find_min_dosage(policy_ids, policies):
    best_dosage = np.inf
    best_policy = []
    for policy_id in policy_ids:
        dosage = np.sum(policies[policy_id])
        if dosage == best_dosage:
            best_policy.append(policy_id)
        if dosage < best_dosage:
            best_policy = [policy_id]
            best_dosage = dosage
    return best_policy

def check_membership(true_x, est_x):
    true_set = set(true_x)
    est_set = set(est_x)
    if len(true_set.intersection(est_x)) > 0:
        return True
    return False

def min_dosage_best_policy(true_best, est_best):
    true_best_set = set(true_best)
    est_best_set = set(est_best)
    return

def find_best_policy_diff(y_true, y_est):
    return np.max(y_true) - np.max(y_est)

In [303]:
P_set = RAggregate(M, R, 8, D, y, 2, reg=0.1)
print(P_set.size)
P_set.seen(sigma)

6


True

In [304]:
pol_means = loss.compute_policy_means(D, y, num_policies)
true_best = pi_pools[np.argmax(mu)]
min_dosage_best_policy = find_min_dosage(true_best, policies)

for s_i in P_set:
    # print(s_i)
    pi_pools_i, pi_policies_i = extract_pools(policies, s_i)
    pool_means_i = loss.compute_pool_means(pol_means, pi_pools_i)
    
    Q = loss.compute_Q(D, y, s_i, policies, pol_means, reg=0.1)
    y_pred = loss.make_predictions(D, pi_policies_i, pool_means_i)
    sqrd_err = mean_squared_error(y, y_pred)
    pol_max = loss.find_best_policies(D, y_pred)
    iou = loss.intersect_over_union(set(true_best), set(pol_max))

    # Min dosage membership
    min_dosage_present = check_membership(min_dosage_best_policy, pol_max)

    # Best policy difference
    best_pol_diff = np.max(mu) - np.max(pool_means_i)
    
    # print(np.max(y_pred), pool_means)
    print(f"Num pools: {len(pi_pools_i)}, Squared loss: {sqrd_err:0.5f}, Q: {Q:0.5f}")
    print(f"Best policy IOU: {iou:.3f}")
    print(f"Min dosage: {min_dosage_present}")
    print(f"Best policy error: {best_pol_diff}")
    print("---")

Num pools: 4, Squared loss: 1.33086, Q: 1.73086
Best policy IOU: 0.667
Min dosage: True
Best policy error: 0.48808069259229026
---
Num pools: 6, Squared loss: 1.33013, Q: 1.93013
Best policy IOU: 0.667
Min dosage: True
Best policy error: 0.48808069259229026
---
Num pools: 6, Squared loss: 1.33006, Q: 1.93006
Best policy IOU: 0.667
Min dosage: True
Best policy error: 0.48808069259229026
---
Num pools: 6, Squared loss: 1.24759, Q: 1.84759
Best policy IOU: 0.500
Min dosage: True
Best policy error: 0.029118080070194452
---
Num pools: 6, Squared loss: 1.02158, Q: 1.62158
Best policy IOU: 1.000
Min dosage: True
Best policy error: -0.009341137959020251
---
Num pools: 8, Squared loss: 1.02089, Q: 1.82089
Best policy IOU: 0.500
Min dosage: False
Best policy error: -0.04780035598823318
---


### LASSO - Beta -> alpha

In [305]:
G = tva.alpha_matrix(M, R, policies)
# print(G)
D_matrix = tva.get_dummy_matrix(D, G, num_policies)

In [327]:
reg_param = 1e-1
mod1 = linear_model.Lasso(reg_param, fit_intercept=False)
mod1.fit(D_matrix, y)
alpha_est = mod1.coef_
y_tva = mod1.predict(D_matrix)
sqrd_err = mean_squared_error(y_tva, y)
print(sqrd_err)
L1_tva = sqrd_err + reg_param * np.linalg.norm(alpha_est, ord=1)
Q_tva = sqrd_err + reg_param * np.linalg.norm(alpha_est, ord=0)
print(Q_tva, L1_tva)

1.3474809575770588
1.8474809575770588 1.8596037994018215


In [328]:
# print(pi_pools[1])
# np.unique(D[np.where(y_tva == np.max(y_tva)), ])
tva_best = loss.find_best_policies(D, y_tva)
iou_tva = loss.intersect_over_union(set(true_best), set(tva_best))
print(iou_tva)
min_dosage_present_tva = check_membership(min_dosage_best_policy, tva_best)
print(min_dosage_present_tva)
best_policy_error_tva = np.max(mu) - np.max(y_tva)
print(best_policy_error_tva)

0.0
False
0.8787715817523711


In [331]:
tva_best, true_best

(array([15]), [13, 14])

In [330]:
np.max(y_tva)

5.121228418247629

In [310]:
print(alpha_est)
np.matmul(G, alpha_est)

[ 0.          1.89264672  0.07407503  0.17692615  0.          0.
  0.          0.          0.          0.          0.          0.
  2.83731825  0.14026227  0.         -0.        ]


array([0.        , 1.89264672, 1.96672174, 2.1436479 , 0.        ,
       1.89264672, 1.96672174, 2.1436479 , 0.        , 1.89264672,
       1.96672174, 2.1436479 , 2.83731825, 4.87022724, 4.94430227,
       5.12122842])

In [164]:
G

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.],
       [1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0.,

In [288]:
alpha_des = np.copy(alpha_est)
alpha_des[:] = 0

alpha_des[1] = 1.5
alpha_des[3] = 1.5
alpha_des[12] = 3
alpha_des[13] = 1.5
alpha_des[15] = -3
# alpha_des[13] = 1.5
print(alpha_des)

np.matmul(G, alpha_des)

[ 0.   1.5  0.   1.5  0.   0.   0.   0.   0.   0.   0.   0.   3.   1.5
  0.  -3. ]


array([0. , 1.5, 1.5, 3. , 0. , 1.5, 1.5, 3. , 0. , 1.5, 1.5, 3. , 3. ,
       6. , 6. , 4.5])