In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

%matplotlib inline

In [2]:
import pickle
import econml

from rashomon import tva
from rashomon import loss
from rashomon import counter
from rashomon import metrics
from rashomon import extract_pools
from rashomon.aggregate import (RAggregate_profile, RAggregate,
    find_profile_lower_bound, find_feasible_combinations, remove_unused_poolings, subset_data)
from rashomon.sets import RashomonSet, RashomonProblemCache, RashomonSubproblemCache
from rashomon.aggregate import find_te_het_partitions

from econml.grf import CausalForest
from sklearn.metrics import mean_squared_error, confusion_matrix
from sklearn import linear_model

from copy import deepcopy

%load_ext autoreload
%autoreload 2

In [1]:
!pip install econml

Collecting econml
  Downloading econml-0.15.0.tar.gz (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting sparse (from econml)
  Obtaining dependency information for sparse from https://files.pythonhosted.org/packages/07/a3/22e031f6833d84edd54b0809087d910907358bddc1c92e56b7b2db30f5ed/sparse-0.15.1-py2.py3-none-any.whl.metadata
  Downloading sparse-0.15.1-py2.py3-none-any.whl.metadata (4.5 kB)
Collecting shap<0.44.0,>=0.38.1 (from econml)
  Obtaining dependency information for shap<0.44.0,>=0.38.1 from https://files.pythonhosted.org/packages/fb/99/2364cc073662517335383f68a10549c6b75486b99f0d671179e4dd8252d6/shap-0.43.0-cp311-cp311-macosx_11_0_arm64.whl.metadata
  Downloading shap-0.43.0-cp311-cp311-macosx_11_0_arm64.

In [3]:
def generate_data(mu, var, n_per_pol, all_policies, pi_policies, M):
    num_data = num_policies * n_per_pol
    X = np.zeros(shape=(num_data, M))
    D = np.zeros(shape=(num_data, 1), dtype='int_')
    y = np.zeros(shape=(num_data, 1))  + np.inf
    mu_true = np.zeros(shape=(num_data, 1))

    idx_ctr = 0
    for k, profile in enumerate(profiles):
        policies_k = policies_profiles[k]

        for idx, policy in enumerate(policies_k):
            policy_idx = [i for i, x in enumerate(all_policies) if x == policy]
            
            if pi_policies[k] is None and np.isnan(mu[k]):
                continue
                
            pool_id = pi_policies[k][idx]
            mu_i = mu[k][pool_id]
            var_i = var[k][pool_id]
            y_i = np.random.normal(mu_i, var_i, size=(n_per_pol, 1))

            start_idx = idx_ctr * n_per_pol
            end_idx = (idx_ctr + 1) * n_per_pol

            X[start_idx:end_idx, ] = policy
            D[start_idx:end_idx, ] = policy_idx[0]
            y[start_idx:end_idx, ] = y_i
            mu_true[start_idx:end_idx, ] = mu_i

            idx_ctr += 1

    absent_idx = np.where(np.isinf(y))[0]
    X = np.delete(X, absent_idx, 0)
    y = np.delete(y, absent_idx, 0)
    D = np.delete(D, absent_idx, 0)
    mu_true = np.delete(mu_true, absent_idx, 0)

    return X, D, y, mu_true

# Setup

## Parameters

In [4]:
# M = 4
# R = np.array([2, 5, 5, 5])


# # (0, 1, 1, 1)
# sigma_0 = np.array([[1, 1, 1],
#                     [1, 1, 1],
#                     [1, 1, 1]
#                    ])
# mu_0 = np.array([0])
# var_0 = np.array([0])

# # (1, 1, 1)
# sigma_1 = np.array([[np.inf, np.inf, np.inf],
#                     [1, 1, 1],
#                     [1, 0, 1],
#                     [1, 0, 0],
#                    ])
# mu_1 = np.array([2, 4, 7, 5, 6, 3])
# var_1 = np.array([1, 1, 1, 1, 1, 1]) * 0.5


# interested_profiles = [(0, 1, 1, 1), (1, 1, 1, 1)]

# sigma_tmp = [sigma_0, sigma_1]
# mu_tmp = [mu_0, mu_1]
# var_tmp = [var_0, var_1]

M = 3
R = np.array([2, 4, 5, 5])


# (0, 1, 1, 1)
sigma_0 = np.array([[1, 1, 1],
                    [1, 1, 1],
                    [1, 1, 1]
                   ])
mu_0 = np.array([0])
var_0 = np.array([1])

# (1, 1, 1)
sigma_1 = np.array([[np.inf, np.inf, np.inf],
                    [0, 0, np.inf],
                    [1, 0, 1],
                    [1, 1, 0],
                   ])
mu_1 = np.array([2, 4, 2, 0, 3, 5, 7, 1, 1, -1, -1, -2])
var_1 = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


interested_profiles = [(0, 1, 1, 1), (1, 1, 1, 1)]

sigma_tmp = [sigma_0, sigma_1]
mu_tmp = [mu_0, mu_1]
var_tmp = [var_0, var_1]

\begin{align*}
    \text{Control} & \\
    (0,\ 1:3,\  1:4,\  1:4) &= 0 \\
    \text{Treatment} & \\
    (1,\  1,\  1:2,\  1:3) &= 2 \\
    (1,\  1,\  1:2,\  4) &= 4 \\
    (1,\  1,\  3:4,\  1:3) &= 2 \\
    (1,\  1,\  3:4,\  4) &= 0 \\
    (1,\  2,\  1:2,\  1:3) &= 3 \\
    (1,\  2,\  1:2,\  4) &= 5 \\
    (1,\  2,\  3:4,\  1:3) &= 7 \\
    (1,\  2,\  3:4,\  4) &= 1 \\
    (1,\  3,\  1:2,\  1:3) &= 1 \\
    (1,\  3,\  1:2,\  4) &= -1 \\
    (1,\  3,\  3:4,\  1:3) &= -1 \\
    (1,\  3,\  3:4,\  4) &= -2
\end{align*}

In [5]:
num_profiles = 2**M
profiles, profile_map = tva.enumerate_profiles(M)
all_policies = tva.enumerate_policies(M, R)
num_policies = len(all_policies)

interested_profile_idx = []
sigma = []
mu = []
var = []
for k, profile in enumerate(profiles):
    sigma_k = None
    mu_k = np.nan
    var_k =  np.nan
    for i, p in enumerate(interested_profiles):
        if p == profile:
            sigma_k = sigma_tmp[i]
            mu_k = mu_tmp[i]
            var_k = var_tmp[i]
            break
    sigma.append(sigma_k)
    mu.append(mu_k)
    var.append(var_k)

In [6]:
# Identify the pools
policies_profiles = {}
policies_profiles_masked = {}
policies_ids_profiles = {}
pi_policies = {}
pi_pools = {}
for k, profile in enumerate(profiles):

    policies_temp = [(i, x) for i, x in enumerate(all_policies) if tva.policy_to_profile(x) == profile]
    unzipped_temp = list(zip(*policies_temp))
    policies_ids_k = list(unzipped_temp[0])
    policies_k = list(unzipped_temp[1])
    policies_profiles[k] = deepcopy(policies_k)
    policies_ids_profiles[k] = policies_ids_k

    profile_mask = list(map(bool, profile))

    # Mask the empty arms
    for idx, pol in enumerate(policies_k):
        policies_k[idx] = tuple([pol[i] for i in range(M) if profile_mask[i]])
    policies_profiles_masked[k] = policies_k


    # profile_idx = None
    # for idx, p in enumerate(interested_profiles):
    #     if p == profile:
    #         profile_idx = idx
    # if profile_idx is None:
    #     pi_policies[k] = None
    #     pi_pools[k] = None
    #     continue
    if sigma[k] is None:
        pi_policies[k] = None
        pi_pools[k] = None
        continue

    if np.sum(profile) > 0:
        pi_pools_k, pi_policies_k = extract_pools.extract_pools(policies_k, sigma[k])
        if len(pi_pools_k.keys()) != mu[k].shape[0]:
            print(pi_pools_k)
            print(f"Profile {k}. Expected {len(pi_pools_k.keys())} pools. Received {mu[k].shape[0]} means.")
        pi_policies[k] = pi_policies_k
        # pi_pools_k has indicies that match with policies_profiles[k]
        # Need to map those indices back to all_policies
        pi_pools[k] = {}
        for x, y in pi_pools_k.items():
            y_full = [policies_profiles[k][i] for i in y]
            y_agg = [all_policies.index(i) for i in y_full]
            pi_pools[k][x] = y_agg
    else:
        pi_policies[k] = {0: 0}
        pi_pools[k] = {0: [0]}

best_per_profile = [np.max(mu_k) for mu_k in mu]
true_best_profile = np.nanargmax(best_per_profile)
true_best_profile_idx = int(true_best_profile)
true_best_effect = np.max(mu[true_best_profile])
true_best = pi_pools[true_best_profile][np.argmax(mu[true_best_profile])]
min_dosage_best_policy = metrics.find_min_dosage(true_best, all_policies)

# The transformation matrix for Lasso
G = tva.alpha_matrix(all_policies)

In [7]:
trt_arm_idx = 0

ctl_profile_idx = 7
trt_profile_idx = 15

ctl_profile = (0, 1, 1, 1)
trt_profile = (1, 1, 1, 1)

# Subset data for interested profiles
trt_policies_ids = policies_ids_profiles[trt_profile_idx]
ctl_policies_ids = policies_ids_profiles[ctl_profile_idx]
tc_policies_ids = trt_policies_ids + ctl_policies_ids

trt_policies = policies_profiles_masked[trt_profile_idx]
ctl_policies = policies_profiles_masked[ctl_profile_idx]

trt_pools, trt_pools_policies = extract_pools.extract_pools(trt_policies, sigma[trt_profile_idx])
ctl_pools, ctl_pools_policies = extract_pools.extract_pools(ctl_policies, sigma[ctl_profile_idx])

D_trt = np.array(list(trt_pools_policies.keys()))
D_ctl = np.array(list(ctl_pools_policies.keys()))

D_trt_pooled = [trt_pools_policies[pol_id] for pol_id in D_trt]
D_ctl_pooled = [ctl_pools_policies[pol_id] for pol_id in D_ctl]
y_trt = mu[trt_profile_idx][D_trt_pooled]
y_ctl = mu[ctl_profile_idx][D_ctl_pooled]

X_trt = np.array(policies_profiles[trt_profile_idx])[:, 1:]

te_true = y_trt - y_ctl
max_te = np.max(te_true)
max_te_policies_p = D_trt[np.where(te_true == max_te)]
max_te_policies = [policies_ids_profiles[trt_profile_idx][x] for x in max_te_policies_p]
min_dosage_best_te = metrics.find_min_dosage(max_te_policies, all_policies)

In [8]:
# for pi, l in pi_pools[15].items():
#     print(pi)
#     print([all_policies[x] for x in l])

## Generate data

In [9]:
np.random.seed(3)

n_per_pol = 10

# Generate data
X, D, y, mu_true = generate_data(mu, var, n_per_pol, all_policies, pi_policies, M)
policy_means = loss.compute_policy_means(D, y, num_policies)
# The dummy matrix for Lasso
D_matrix = tva.get_dummy_matrix(D, G, num_policies)

trt_arm_idx = 0
feature_idx = list(np.arange(0, trt_arm_idx)) + list(np.arange(trt_arm_idx+1, M))

T = np.zeros(shape=y.shape)
T[X[:, trt_arm_idx] > 0] = 1

y_0d = y.reshape((-1,))
X_cf = X[:, feature_idx]

X_trt_subset = X[X[:, trt_arm_idx] > 0, :]
X_trt_subset = X_trt_subset[:, feature_idx]
y_trt_subset = y[X[:, trt_arm_idx] > 0]


D_trt_subset = D[X[:, trt_arm_idx] > 0]
D_matrix_trt_subset = D_matrix[X[:, trt_arm_idx] > 0, :]
D_matrix_ctl_subset = D_matrix[X[:, trt_arm_idx] == 0, :]

D_trt_univ = np.array([policies_ids_profiles[trt_profile_idx][x] for x in D_trt]).reshape((-1, 1))
D_ctl_univ = np.array([policies_ids_profiles[ctl_profile_idx][x] for x in D_ctl]).reshape((-1, 1))
D_matrix_trt = tva.get_dummy_matrix(D_trt_univ, G, num_policies)
D_matrix_ctl = tva.get_dummy_matrix(D_ctl_univ, G, num_policies)


policy_means = loss.compute_policy_means(D, y, num_policies)

mask = np.isin(D, tc_policies_ids)
D_tc = D[mask].reshape((-1,1))
y_tc = y[mask].reshape((-1,1))

In [10]:


# policy_means_tc = policy_means[tc_policies_ids, :]

# Estimation

## Causal Forests

In [36]:
# https://econml.azurewebsites.net/_autosummary/econml.grf.CausalForest.html?highlight=causalforest#econml.grf.CausalForest
est = CausalForest(criterion="het", n_estimators=100,
                   min_samples_leaf=10,
                   # max_depth=None,
                   min_samples_split=10,
                   random_state=3,
                  )

# est.fit(X_cf, X[:, trt_idx], y_0d)
est.fit(X_cf, T, y_0d)
# est.fit(X, 1 + np.zeros(T.shape), y_0d)

In [37]:
te_cf = est.predict(X_trt)

cf_res = metrics.compute_te_het_metrics(
                            te_true, te_cf,
                            max_te, max_te_policies,
                            D_trt, policies_ids_profiles[trt_profile_idx]
                        )

cf_res

{'mse_te': 1.484211437485257,
 'max_te_est': 6.530341442900793,
 'max_te_err': 0.4696585570992067,
 'iou': 0.14285714285714285,
 'conf_matrix': array([[0.8, 0. , 0.2],
        [0. , 0. , 1. ],
        [0. , 0. , 1. ]])}

## Rashomon Sets

In [18]:
H = np.inf


theta = 1.15
reg = 1e-2
R_set, rashomon_profiles = RAggregate(M, R, H, D, y, theta, reg,
                                     verbose=True,
                                     )
# theta = 3
# reg = 1e-1
# R_set, rashomon_profiles = RAggregate(M, R, H, D_trt_subset, y_trt_subset, theta, reg,
#                                      verbose=True,
#                                      )

print(len(R_set))

Skipping profile (0, 0, 0, 0)
Skipping profile (0, 0, 0, 1)
Skipping profile (0, 0, 1, 0)
Skipping profile (0, 0, 1, 1)
Skipping profile (0, 1, 0, 0)
Skipping profile (0, 1, 0, 1)
Skipping profile (0, 1, 1, 0)
(0, 1, 1, 1) 0.6902404523030581
220
Skipping profile (1, 0, 0, 0)
Skipping profile (1, 0, 0, 1)
Skipping profile (1, 0, 1, 0)
Skipping profile (1, 0, 1, 1)
Skipping profile (1, 1, 0, 0)
Skipping profile (1, 1, 0, 1)
Skipping profile (1, 1, 1, 0)
(1, 1, 1, 1) 0.6848560440763305
5
Finding feasible combinations
Min = 1.13590072054853, Max = 1.3516820539585388
10


### Find poolings across Treatment and Control

In [28]:
conf_matrices = []
results_list = []
sim_i = 0

for idx, r_set in enumerate(R_set):
    conf_matrix_list_idx = []

    sigma_trt_R_set_idx = r_set[trt_profile_idx]
    sigma_trt_i = rashomon_profiles[trt_profile_idx].sigma[sigma_trt_R_set_idx]
    sigma_ctl_R_set_idx = r_set[ctl_profile_idx]
    sigma_ctl_i = rashomon_profiles[ctl_profile_idx].sigma[sigma_ctl_R_set_idx]

    trt_pools_0, _ = extract_pools.extract_pools(trt_policies, sigma_trt_i)
    ctl_pools_0, _ = extract_pools.extract_pools(ctl_policies, sigma_ctl_i) # This gives dictionary of pools --> this is the partition

    # Right now the treatment and control profiles are not pooled together; need to check if it is possible to pool together treatment and control

    for (ti, ci) in zip(trt_policies, ctl_policies):
        ti = tuple(ti[:trt_arm_idx] + ti[(trt_arm_idx+1):])
        if ti != ci:
            raise RuntimeError("Treatment and control pairs do not match!")

    trt_pools = tva.profile_ids_to_univ_ids(trt_pools_0, trt_policies_ids)
    ctl_pools = tva.profile_ids_to_univ_ids(ctl_pools_0, ctl_policies_ids)

    P_qe_idx = te_partitions[idx]

    for te_pool_id, sigma_int in enumerate(P_qe_idx.sigma):
        sigma_pools, sigma_policies = extract_pools.get_trt_ctl_pooled_partition(
            trt_pools, ctl_pools, sigma_int
        )
        mu_pools = loss.compute_pool_means(policy_means, sigma_pools)
        D_tc_pool = [sigma_policies[pol_id] for pol_id in D_tc[:, 0]]
        mu_D = mu_pools[D_tc_pool]

        # Find TE
        D_trt_pooled_i = [sigma_policies[pol_id] for pol_id in trt_policies_ids]
        D_ctl_pooled_i = [sigma_policies[pol_id] for pol_id in ctl_policies_ids]
        y_trt_i = mu_pools[D_trt_pooled_i]
        y_ctl_i = mu_pools[D_ctl_pooled_i]

        te_i = y_trt_i - y_ctl_i

        metrics_results_i = metrics.compute_te_het_metrics(
            te_true, te_i,
            max_te, max_te_policies,
            D_trt, policies_ids_profiles[trt_profile_idx]
        )
        mse_te_i = metrics_results_i["mse_te"]
        max_te_err_i = metrics_results_i["max_te_err"]
        iou_i = metrics_results_i["iou"]
        conf_mat_i = metrics_results_i["conf_matrix"]

        # Compute overall MSE
        mse_i = mean_squared_error(y_tc[:, 0], mu_D)

        # Count number of pools
        num_pools_i = len(sigma_pools.keys())

        results_i = [
            n_per_pol, sim_i, idx, te_pool_id,
            mse_te_i, max_te_err_i, iou_i,
            mse_i, num_pools_i
        ]
        results_list.append(results_i)

        conf_matrix_list_idx.append(conf_mat_i)
        
    conf_matrices.append(conf_matrix_list_idx)

rashomon_cols = [
            "n_per_pol", "sim_num", "idx", "te_idx",
            "MSE_TE", "max_te_diff", "IOU", "MSE", "num_pools"
        ]
rashomon_df = pd.DataFrame(results_list, columns=rashomon_cols)

In [39]:
np.mean(rashomon_df["MSE_TE"])

2.202740279341172

In [21]:
num_models = 0
for P_i in te_partitions:
    num_models += P_i.size
print(num_models)

10


## Lasso

In [35]:
lasso_reg = 1e-1


lasso = linear_model.Lasso(lasso_reg, fit_intercept=False)
lasso.fit(D_matrix, y)
# lasso.fit(X_trt_subset, y_trt_subset)
alpha_est = lasso.coef_


# y_tva = lasso.predict(D_matrix_trt_subset)
# # y_tva = lasso.predict(X_trt_subset)

# tva_results = metrics.compute_all_metrics(
#                     y_trt_subset, y_tva,
#     D_trt_subset, true_best, all_policies, profile_map, min_dosage_best_policy, true_best_effect)

In [36]:
y_trt_lasso = lasso.predict(D_matrix_trt)
y_ctl_lasso = lasso.predict(D_matrix_ctl)

te_lasso = y_trt_lasso - y_ctl_lasso

tva_results = metrics.compute_te_het_metrics(
    te_true, te_lasso,
    max_te, max_te_policies,
    D_trt, policies_ids_profiles[trt_profile_idx]
)

tva_results

{'mse_te': 2.65225707739319,
 'max_te_est': 3.347887373599511,
 'max_te_err': 3.652112626400489,
 'iou': 0.42857142857142855,
 'conf_matrix': array([[ 2,  0,  8],
        [ 0,  0,  2],
        [ 0,  0, 36]])}

In [20]:
tva_results

{'sqrd_err': 0.23522551170666117,
 'iou': 0.16666666666666666,
 'best_prof': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
 'min_dos_inc': False,
 'best_pol_diff': -0.03413597934468626}

In [None]:
alpha_est

In [None]:
D_mat_unique = np.unique(D_matrix_trt_subset, axis=0)

y_est = np.matmul(D_mat_unique, alpha_est)

In [None]:
y_est

In [None]:
np.unique(D_trt_subset)