In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

%matplotlib inline

In [322]:
import pickle
import econml

from rashomon import tva
from rashomon import loss
from rashomon import counter
from rashomon import metrics
from rashomon import extract_pools
from rashomon.aggregate import (RAggregate_profile, RAggregate,
    find_profile_lower_bound, find_feasible_combinations, remove_unused_poolings, subset_data)
from rashomon.sets import RashomonSet, RashomonProblemCache, RashomonSubproblemCache

from econml.grf import CausalForest
from sklearn.metrics import mean_squared_error

from copy import deepcopy

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [308]:
def generate_data(mu, var, n_per_pol, all_policies, pi_policies, M):
    num_data = num_policies * n_per_pol
    X = np.zeros(shape=(num_data, M))
    D = np.zeros(shape=(num_data, 1), dtype='int_')
    y = np.zeros(shape=(num_data, 1))  + np.inf
    mu_true = np.zeros(shape=(num_data, 1))

    idx_ctr = 0
    for k, profile in enumerate(profiles):
        policies_k = policies_profiles[k]

        for idx, policy in enumerate(policies_k):
            policy_idx = [i for i, x in enumerate(all_policies) if x == policy]
            
            if pi_policies[k] is None and np.isnan(mu[k]):
                continue
                
            pool_id = pi_policies[k][idx]
            mu_i = mu[k][pool_id]
            var_i = var[k][pool_id]
            y_i = np.random.normal(mu_i, var_i, size=(n_per_pol, 1))

            start_idx = idx_ctr * n_per_pol
            end_idx = (idx_ctr + 1) * n_per_pol

            X[start_idx:end_idx, ] = policy
            D[start_idx:end_idx, ] = policy_idx[0]
            y[start_idx:end_idx, ] = y_i
            mu_true[start_idx:end_idx, ] = mu_i

            idx_ctr += 1

    absent_idx = np.where(np.isinf(y))[0]
    X = np.delete(X, absent_idx, 0)
    y = np.delete(y, absent_idx, 0)
    D = np.delete(D, absent_idx, 0)
    mu_true = np.delete(mu_true, absent_idx, 0)

    return X, D, y, mu_true

# Setup

## Parameters

In [None]:
# M = 2
# R = np.array([2, 4])


# # (0, 0)
# sigma_0 = None
# mu_0 = np.array([0])
# var_0 = np.array([1])

# # (0, 1)
# sigma_1 = np.array([[1, 1]])
# mu_1 = np.array([0])
# var_1 = np.array([1])


# # (1, 0)
# sigma_2 = None
# mu_2 = np.array([0])
# var_2 = np.array([1])


# # (1, 1)
# sigma_3 = np.array([[1, 1],
#                     [0, 1]])
# mu_3 = np.array([6, 4,])
# var_3 = np.array([1, 1])


# sigma = [sigma_0, sigma_1, sigma_2, sigma_3]
# mu = [mu_0, mu_1, mu_2, mu_3]
# var = [var_0, var_1, var_2, var_3]

In [None]:
# M = 3
# R = np.array([4, 4, 4])

# # Fix the partitions
# # Profile 0: (0, 0)
# sigma_0 = None
# mu_0 = np.array([0])
# var_0 = np.array([1])

# # Profile 1: (0, 0, 1)
# sigma_1 = np.array([[1, 1]])
# mu_1 = np.array([1])
# var_1 = np.array([1])

# # Profile 2: (0, 1, 0)
# sigma_2 = np.array([[1, 0]])
# mu_2 = np.array([0.5, 3.8])
# var_2 = np.array([1, 1])

# # Profile 3: (0, 1, 1)
# sigma_3 = np.array([[1, 1],
#                     [1, 1]])
# mu_3 = np.array([1.5])
# var_3 = np.array([1])

# # Profile 4: (1, 0, 0)
# sigma_4 = np.array([[1, 1]])
# mu_4 = np.array([1])
# var_4 = np.array([1])

# # Profile 5: (1, 0, 1)
# sigma_5 = np.array([[0, 1],
#                     [1, 1]])
# mu_5 = np.array([2, 3])
# var_5 = np.array([1, 1])

# # Profile 6: (1, 1, 0)
# sigma_6 = np.array([[1, 1],
#                     [1, 1]])
# mu_6 = np.array([2.5])
# var_6 = np.array([1])

# # Profile 1: (1, 1, 1)
# sigma_7 = np.array([[1, 1],
#                     [1, 1],
#                     [1, 0]])
# mu_7 = np.array([3, 4])
# var_7 = np.array([1, 1])

# sigma = [sigma_0, sigma_1, sigma_2, sigma_3, sigma_4, sigma_5, sigma_6, sigma_7]
# mu = [mu_0, mu_1, mu_2, mu_3, mu_4, mu_5, mu_6, mu_7]
# var = [var_0, var_1, var_2, var_3, var_4, var_5, var_6, var_7]


In [480]:
# M = 4
# R = np.array([2, 5, 5, 5])


# # (0, 1, 1, 1)
# sigma_0 = np.array([[1, 1, 1],
#                     [1, 1, 1],
#                     [1, 1, 1]
#                    ])
# mu_0 = np.array([0])
# var_0 = np.array([0])

# # (1, 1, 1)
# sigma_1 = np.array([[np.inf, np.inf, np.inf],
#                     [1, 1, 1],
#                     [1, 0, 1],
#                     [1, 0, 0],
#                    ])
# mu_1 = np.array([2, 4, 7, 5, 6, 3])
# var_1 = np.array([1, 1, 1, 1, 1, 1]) * 0.5


# interested_profiles = [(0, 1, 1, 1), (1, 1, 1, 1)]

# sigma_tmp = [sigma_0, sigma_1]
# mu_tmp = [mu_0, mu_1]
# var_tmp = [var_0, var_1]

M = 3
R = np.array([2, 5, 5])


# (0, 1, 1, 1)
sigma_0 = np.array([[1, 1, 1],
                    [1, 1, 1]
                   ])
mu_0 = np.array([0])
var_0 = np.array([0])

# (1, 1, 1)
sigma_1 = np.array([[np.inf, np.inf, np.inf],
                    [1, 0, 1],
                    [1, 0, 0],
                   ])
mu_1 = np.array([2, 4, 7, 5, 6, 3])
var_1 = np.array([1, 1, 1, 1, 1, 1]) * 0.5


interested_profiles = [(0, 1, 1), (1, 1, 1)]

sigma_tmp = [sigma_0, sigma_1]
mu_tmp = [mu_0, mu_1]
var_tmp = [var_0, var_1]

In [302]:
num_profiles = 2**M
profiles, profile_map = tva.enumerate_profiles(M)
all_policies = tva.enumerate_policies(M, R)
num_policies = len(all_policies)

interested_profile_idx = []
sigma = []
mu = []
var = []
for k, profile in enumerate(profiles):
    sigma_k = None
    mu_k = np.nan
    var_k =  np.nan
    for i, p in enumerate(interested_profiles):
        if p == profile:
            sigma_k = sigma_tmp[i]
            mu_k = mu_tmp[i]
            var_k = var_tmp[i]
            break
    sigma.append(sigma_k)
    mu.append(mu_k)
    var.append(var_k)

In [303]:
# Identify the pools
policies_profiles = {}
policies_profiles_masked = {}
policies_ids_profiles = {}
pi_policies = {}
pi_pools = {}
for k, profile in enumerate(profiles):

    policies_temp = [(i, x) for i, x in enumerate(all_policies) if tva.policy_to_profile(x) == profile]
    unzipped_temp = list(zip(*policies_temp))
    policies_ids_k = list(unzipped_temp[0])
    policies_k = list(unzipped_temp[1])
    policies_profiles[k] = deepcopy(policies_k)
    policies_ids_profiles[k] = policies_ids_k

    profile_mask = list(map(bool, profile))

    # Mask the empty arms
    for idx, pol in enumerate(policies_k):
        policies_k[idx] = tuple([pol[i] for i in range(M) if profile_mask[i]])
    policies_profiles_masked[k] = policies_k


    # profile_idx = None
    # for idx, p in enumerate(interested_profiles):
    #     if p == profile:
    #         profile_idx = idx
    # if profile_idx is None:
    #     pi_policies[k] = None
    #     pi_pools[k] = None
    #     continue
    if sigma[k] is None:
        pi_policies[k] = None
        pi_pools[k] = None
        continue

    if np.sum(profile) > 0:
        pi_pools_k, pi_policies_k = extract_pools.extract_pools(policies_k, sigma[k])
        if len(pi_pools_k.keys()) != mu[k].shape[0]:
            print(pi_pools_k)
            print(f"Profile {k}. Expected {len(pi_pools_k.keys())} pools. Received {mu[k].shape[0]} means.")
        pi_policies[k] = pi_policies_k
        # pi_pools_k has indicies that match with policies_profiles[k]
        # Need to map those indices back to all_policies
        pi_pools[k] = {}
        for x, y in pi_pools_k.items():
            y_full = [policies_profiles[k][i] for i in y]
            y_agg = [all_policies.index(i) for i in y_full]
            pi_pools[k][x] = y_agg
    else:
        pi_policies[k] = {0: 0}
        pi_pools[k] = {0: [0]}

best_per_profile = [np.max(mu_k) for mu_k in mu]
true_best_profile = np.nanargmax(best_per_profile)
true_best_profile_idx = int(true_best_profile)
true_best_effect = np.max(mu[true_best_profile])
true_best = pi_pools[true_best_profile][np.argmax(mu[true_best_profile])]
min_dosage_best_policy = metrics.find_min_dosage(true_best, all_policies)

# The transformation matrix for Lasso
G = tva.alpha_matrix(all_policies)

In [304]:
# X_test = np.array(all_policies)
# # mu_true = 
# print(X_test)

In [305]:
for pi, l in pi_pools[7].items():
    print(pi)
    print([all_policies[x] for x in l])

0
[(1, 1, 1), (1, 1, 2), (1, 2, 1), (1, 2, 2)]
1
[(1, 1, 3), (1, 2, 3)]
2
[(1, 1, 4), (1, 2, 4)]
3
[(1, 3, 1), (1, 3, 2), (1, 4, 1), (1, 4, 2)]
4
[(1, 3, 3), (1, 4, 3)]
5
[(1, 3, 4), (1, 4, 4)]


## Generate data

In [455]:
np.random.seed(3)

n_per_pol = 50

# Generate data
X, D, y, mu_true = generate_data(mu, var, n_per_pol, all_policies, pi_policies, M)
policy_means = loss.compute_policy_means(D, y, num_policies)
# The dummy matrix for Lasso
D_matrix = tva.get_dummy_matrix(D, G, num_policies)

trt_idx = 0
feature_idx = list(np.arange(0, trt_idx)) + list(np.arange(trt_idx+1, M))

T = np.zeros(shape=y.shape)
T[X[:, trt_idx] > 0] = 1

y_0d = y.reshape((-1,))
X_cf = X[:, feature_idx]

X_trt_subset = X[X[:, trt_idx] > 0, :]
X_trt_subset = X_trt_subset[:, feature_idx]
y_trt_subset = y[X[:, trt_idx] > 0]

# Estimation

## Causal Forests

In [469]:
# https://econml.azurewebsites.net/_autosummary/econml.grf.CausalForest.html?highlight=causalforest#econml.grf.CausalForest
est = CausalForest(criterion="het", n_estimators=100,
                   min_samples_leaf=n_per_pol,
                   # max_depth=None,
                   min_samples_split=2,
                   random_state=3,
                  )

# est.fit(X_cf, X[:, trt_idx], y_0d)
est.fit(X_cf, T, y_0d)
# est.fit(X, 1 + np.zeros(T.shape), y_0d)

In [470]:
treatment_effects = est.predict(X_trt_subset)
# treatment_effects = est.predict(X)

# # Confidence intervals via Bootstrap-of-Little-Bags for forests
# lb, ub = est.effect_interval(X, alpha=0.05)

mse_cf = mean_squared_error(y_trt_subset, treatment_effects)
print(mse_cf)

1.3963168492544538


## Rashomon Sets

In [473]:
D

array([[ 6],
       [ 6],
       [ 6],
       ...,
       [49],
       [49],
       [49]])

In [476]:
H = np.inf
theta = 1.5
reg = 1e-1

D_trt_subset = D[X[:, trt_idx] > 0]

# R_set, rashomon_profiles = RAggregate(M, R, H, D, y, theta, reg,
#                                      verbose=True,
#                                      )
R_set, rashomon_profiles = RAggregate(M, R, H, D_trt_subset, y_trt_subset, theta, reg,
                                     verbose=True,
                                     )

print(len(R_set))

Skipping profile (0, 0, 0)
Skipping profile (0, 0, 1)
Skipping profile (0, 1, 0)
Skipping profile (0, 1, 1)
Skipping profile (1, 0, 0)
Skipping profile (1, 0, 1)
Skipping profile (1, 1, 0)
(1, 1, 1) 1.5
14
Finding feasible combinations
Min = 0.8586263597272497, Max = 1.498065906066686
14


In [477]:

sim_i = 0
current_results = []
best_loss = np.inf

for idx, r_set in enumerate(R_set):
    # print(idx)

    # MSE
    pi_policies_profiles_r = {}
    for k, profile in enumerate(profiles):
        if rashomon_profiles[k].sigma[0] is None:
            pi_policies_profiles_r[k] = { idx: 0 for idx in range(len(policies_profiles_masked[k])) }
            continue
        _, pi_policies_r_k = extract_pools.extract_pools(
            policies_profiles_masked[k],
            rashomon_profiles[k].sigma[r_set[k]]
        )
        pi_policies_profiles_r[k] = pi_policies_r_k

    pi_pools_r, pi_policies_r = extract_pools.aggregate_pools(
        pi_policies_profiles_r, policies_ids_profiles)
    pool_means_r = loss.compute_pool_means(policy_means, pi_pools_r)
    y_r_est = metrics.make_predictions(D, pi_policies_r, pool_means_r)

    r_set_results = metrics.compute_all_metrics(
        y, y_r_est, D, true_best, all_policies, profile_map,
        min_dosage_best_policy, true_best_effect)
    sqrd_err = r_set_results["sqrd_err"]
    iou_r = r_set_results["iou"]
    best_profile_indicator = r_set_results["best_prof"]
    min_dosage_present = r_set_results["min_dos_inc"]
    best_pol_diff = r_set_results["best_pol_diff"]
    this_loss = sqrd_err + reg * len(pi_pools_r)

    this_list = [
        n_per_pol, sim_i, len(pi_pools_r), sqrd_err, iou_r, min_dosage_present, best_pol_diff
        ]
    this_list += best_profile_indicator
    current_results.append(this_list)

    if this_loss < best_loss:
        best_loss = this_loss

    if best_profile_indicator[true_best_profile_idx] == 1:
        found_best_profile = True
        # print("Found", this_loss)


In [478]:
profiles_str = [str(prof) for prof in profiles]
rashomon_cols = ["n_per_pol", "sim_num", "num_pools", "MSE", "IOU", "min_dosage", "best_pol_diff"]
rashomon_cols += profiles_str


rashomon_df = pd.DataFrame(current_results, columns=rashomon_cols)

rashomon_df.head()

Unnamed: 0,n_per_pol,sim_num,num_pools,MSE,IOU,min_dosage,best_pol_diff,"(0, 0, 0)","(0, 0, 1)","(0, 1, 0)","(0, 1, 1)","(1, 0, 0)","(1, 0, 1)","(1, 1, 0)","(1, 1, 1)"
0,50,0,13,0.129313,1.0,True,0.020945,0,0,0,0,0,0,0,1
1,50,0,15,0.128954,1.0,True,0.020945,0,0,0,0,0,0,0,1
2,50,0,11,0.346718,1.0,True,0.020945,0,0,0,0,0,0,0,1
3,50,0,16,0.128829,0.5,True,-0.037661,0,0,0,0,0,0,0,1
4,50,0,16,0.128908,1.0,True,0.020945,0,0,0,0,0,0,0,1


In [479]:
np.mean(rashomon_df["MSE"])

0.22738463574810247

In [468]:
np.sort(rashomon_df["MSE"])

array([0.12882926, 0.12890828, 0.12895421, 0.12895421, 0.12895421,
       0.12895421, 0.12895421, 0.12895421, 0.12895421, 0.12931318,
       0.12931318, 0.12931318, 0.12931318, 0.12931318, 0.12931318,
       0.12931318, 0.12931318, 0.12931318, 0.12931318, 0.12931318,
       0.12931318, 0.12931318, 0.12931318, 0.12931318, 0.12931318,
       0.12931318, 0.12931318, 0.12931318, 0.12931318, 0.12931318,
       0.12931318, 0.12931318, 0.12931318, 0.29938626, 0.29938626,
       0.29938626, 0.29938626, 0.29938626, 0.29938626, 0.29938626,
       0.29938626, 0.29938626, 0.29938626, 0.29938626, 0.29938626,
       0.29938626, 0.34631901, 0.34631901, 0.34631901, 0.34631901,
       0.34631901, 0.34631901, 0.34631901, 0.34646159, 0.34646159,
       0.34646159, 0.34646159, 0.34646159, 0.34646159, 0.34646159,
       0.34671753, 0.34671753, 0.34671753, 0.34671753, 0.34671753,
       0.34671753, 0.34671753, 0.34671753, 0.34671753, 0.34671753,
       0.34671753, 0.34671753, 0.34671753, 0.34671753, 0.34671