In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

%matplotlib inline

In [7]:
import pickle
import econml

from rashomon import tva
from rashomon import loss
from rashomon import counter
from rashomon import metrics
from rashomon import extract_pools
from rashomon.aggregate import (RAggregate_profile, RAggregate,
    find_profile_lower_bound, find_feasible_combinations, remove_unused_poolings, subset_data)
from rashomon.sets import RashomonSet, RashomonProblemCache, RashomonSubproblemCache

from copy import deepcopy

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [270]:
def generate_data(mu, var, n_per_pol, all_policies, pi_policies, M):
    num_data = num_policies * n_per_pol
    X = np.zeros(shape=(num_data, M))
    D = np.zeros(shape=(num_data, 1), dtype='int_')
    y = np.zeros(shape=(num_data, 1))  + np.inf
    mu_true = np.zeros(shape=(num_data, 1))

    idx_ctr = 0
    for k, profile in enumerate(profiles):
        policies_k = policies_profiles[k]

        for idx, policy in enumerate(policies_k):
            policy_idx = [i for i, x in enumerate(all_policies) if x == policy]
            
            if pi_policies[k] is None and mu[k] is None:
                continue
                
            pool_id = pi_policies[k][idx]
            mu_i = mu[k][pool_id]
            var_i = var[k][pool_id]
            y_i = np.random.normal(mu_i, var_i, size=(n_per_pol, 1))

            start_idx = idx_ctr * n_per_pol
            end_idx = (idx_ctr + 1) * n_per_pol

            X[start_idx:end_idx, ] = policy
            D[start_idx:end_idx, ] = policy_idx[0]
            y[start_idx:end_idx, ] = y_i
            mu_true[start_idx:end_idx, ] = mu_i

            idx_ctr += 1

    absent_idx = np.where(np.isinf(y))[0]
    X = np.delete(X, absent_idx, 0)
    y = np.delete(y, absent_idx, 0)
    D = np.delete(D, absent_idx, 0)
    mu_true = np.delete(mu_true, absent_idx, 0)

    return X, D, y, mu_true

# Setup

## Parameters

In [None]:
# M = 2
# R = np.array([2, 4])


# # (0, 0)
# sigma_0 = None
# mu_0 = np.array([0])
# var_0 = np.array([1])

# # (0, 1)
# sigma_1 = np.array([[1, 1]])
# mu_1 = np.array([0])
# var_1 = np.array([1])


# # (1, 0)
# sigma_2 = None
# mu_2 = np.array([0])
# var_2 = np.array([1])


# # (1, 1)
# sigma_3 = np.array([[1, 1],
#                     [0, 1]])
# mu_3 = np.array([6, 4,])
# var_3 = np.array([1, 1])


# sigma = [sigma_0, sigma_1, sigma_2, sigma_3]
# mu = [mu_0, mu_1, mu_2, mu_3]
# var = [var_0, var_1, var_2, var_3]

In [None]:
# M = 3
# R = np.array([4, 4, 4])

# # Fix the partitions
# # Profile 0: (0, 0)
# sigma_0 = None
# mu_0 = np.array([0])
# var_0 = np.array([1])

# # Profile 1: (0, 0, 1)
# sigma_1 = np.array([[1, 1]])
# mu_1 = np.array([1])
# var_1 = np.array([1])

# # Profile 2: (0, 1, 0)
# sigma_2 = np.array([[1, 0]])
# mu_2 = np.array([0.5, 3.8])
# var_2 = np.array([1, 1])

# # Profile 3: (0, 1, 1)
# sigma_3 = np.array([[1, 1],
#                     [1, 1]])
# mu_3 = np.array([1.5])
# var_3 = np.array([1])

# # Profile 4: (1, 0, 0)
# sigma_4 = np.array([[1, 1]])
# mu_4 = np.array([1])
# var_4 = np.array([1])

# # Profile 5: (1, 0, 1)
# sigma_5 = np.array([[0, 1],
#                     [1, 1]])
# mu_5 = np.array([2, 3])
# var_5 = np.array([1, 1])

# # Profile 6: (1, 1, 0)
# sigma_6 = np.array([[1, 1],
#                     [1, 1]])
# mu_6 = np.array([2.5])
# var_6 = np.array([1])

# # Profile 1: (1, 1, 1)
# sigma_7 = np.array([[1, 1],
#                     [1, 1],
#                     [1, 0]])
# mu_7 = np.array([3, 4])
# var_7 = np.array([1, 1])

# sigma = [sigma_0, sigma_1, sigma_2, sigma_3, sigma_4, sigma_5, sigma_6, sigma_7]
# mu = [mu_0, mu_1, mu_2, mu_3, mu_4, mu_5, mu_6, mu_7]
# var = [var_0, var_1, var_2, var_3, var_4, var_5, var_6, var_7]


In [252]:
M = 3
R = np.array([2, 4, 4])


# (0, 1, 1)
sigma_0 = np.array([[1, 1],
                    [1, 1]
                   ])
mu_0 = np.array([0])
var_0 = np.array([0.1])

# (1, 1, 1)
sigma_1 = np.array([[np.inf, np.inf],
                    [1, 1],
                    [1, 0],
                   ])
mu_1 = np.array([2, 5])
var_1 = np.array([1, 1])


interested_profiles = [(0, 1, 1), (1, 1, 1)]

sigma_tmp = [sigma_0, sigma_1]
mu_tmp = [mu_0, mu_1]
var_tmp = [var_0, var_1]

In [278]:
num_profiles = 2**M
profiles, profile_map = tva.enumerate_profiles(M)
all_policies = tva.enumerate_policies(M, R)
num_policies = len(all_policies)

interested_profile_idx = []
sigma = []
mu = []
var = []
for k, profile in enumerate(profiles):
    sigma_k = None
    mu_k = np.nan
    var_k =  np.nan
    for i, p in enumerate(interested_profiles):
        if p == profile:
            sigma_k = sigma_tmp[i]
            mu_k = mu_tmp[i]
            var_k = var_tmp[i]
            break
    sigma.append(sigma_k)
    mu.append(mu_k)
    var.append(var_k)

In [281]:
# Identify the pools
policies_profiles = {}
policies_profiles_masked = {}
policies_ids_profiles = {}
pi_policies = {}
pi_pools = {}
for k, profile in enumerate(profiles):

    policies_temp = [(i, x) for i, x in enumerate(all_policies) if tva.policy_to_profile(x) == profile]
    unzipped_temp = list(zip(*policies_temp))
    policies_ids_k = list(unzipped_temp[0])
    policies_k = list(unzipped_temp[1])
    policies_profiles[k] = deepcopy(policies_k)
    policies_ids_profiles[k] = policies_ids_k

    profile_mask = list(map(bool, profile))

    # Mask the empty arms
    for idx, pol in enumerate(policies_k):
        policies_k[idx] = tuple([pol[i] for i in range(M) if profile_mask[i]])
    policies_profiles_masked[k] = policies_k


    # profile_idx = None
    # for idx, p in enumerate(interested_profiles):
    #     if p == profile:
    #         profile_idx = idx
    # if profile_idx is None:
    #     pi_policies[k] = None
    #     pi_pools[k] = None
    #     continue
    if sigma[k] is None:
        pi_policies[k] = None
        pi_pools[k] = None
        continue

    if np.sum(profile) > 0:
        pi_pools_k, pi_policies_k = extract_pools.extract_pools(policies_k, sigma[k])
        if len(pi_pools_k.keys()) != mu[k].shape[0]:
            print(pi_pools_k)
            print(f"Profile {k}. Expected {len(pi_pools_k.keys())} pools. Received {mu[k].shape[0]} means.")
        pi_policies[k] = pi_policies_k
        # pi_pools_k has indicies that match with policies_profiles[k]
        # Need to map those indices back to all_policies
        pi_pools[k] = {}
        for x, y in pi_pools_k.items():
            y_full = [policies_profiles[k][i] for i in y]
            y_agg = [all_policies.index(i) for i in y_full]
            pi_pools[k][x] = y_agg
    else:
        pi_policies[k] = {0: 0}
        pi_pools[k] = {0: [0]}

best_per_profile = [np.max(mu_k) for mu_k in mu]
true_best_profile = np.nanargmax(best_per_profile)
true_best_profile_idx = int(true_best_profile)
true_best_effect = np.max(mu[true_best_profile])
true_best = pi_pools[true_best_profile][np.argmax(mu[true_best_profile])]
min_dosage_best_policy = metrics.find_min_dosage(true_best, all_policies)

# The transformation matrix for Lasso
G = tva.alpha_matrix(all_policies)

In [255]:
# X_test = np.array(all_policies)
# # mu_true = 
# print(X_test)

In [256]:
for pi, l in pi_pools[7].items():
    print(pi)
    print([all_policies[x] for x in l])

0
[(1, 1, 1), (1, 1, 2), (1, 2, 1), (1, 2, 2), (1, 3, 1), (1, 3, 2)]
1
[(1, 1, 3), (1, 2, 3), (1, 3, 3)]


## Generate data

In [271]:
np.random.seed(3)

n_per_pol = 10

# Generate data
X, D, y, mu_true = generate_data(mu, var, n_per_pol, all_policies, pi_policies, M)
policy_means = loss.compute_policy_means(D, y, num_policies)
# The dummy matrix for Lasso
D_matrix = tva.get_dummy_matrix(D, G, num_policies)

trt_idx = 0
feature_idx = list(np.arange(0, trt_idx)) + list(np.arange(trt_idx+1, M))

T = np.zeros(shape=y.shape)
T[X[:, trt_idx] > 0] = 1

y_0d = y.reshape((-1,))
X_cf = X[:, feature_idx]

# Estimation

## Causal Forests

In [177]:
from econml.grf import CausalForest
from sklearn.linear_model import LassoCV
from sklearn.metrics import mean_squared_error

In [272]:
# https://econml.azurewebsites.net/_autosummary/econml.grf.CausalForest.html?highlight=causalforest#econml.grf.CausalForest
est = CausalForest(criterion="het", n_estimators=100,       
                      # min_samples_leaf=1,
                       random_state=3,
                     )

# est.fit(X_cf, X[:, trt_idx], y_0d)
est.fit(X_cf, T, y_0d)
# est.fit(X, 1 + np.zeros(T.shape), y_0d)

In [273]:
treatment_effects = est.predict(X_cf)
# treatment_effects = est.predict(X)

# # Confidence intervals via Bootstrap-of-Little-Bags for forests
# lb, ub = est.effect_interval(X, alpha=0.05)

In [274]:
mse_cf = mean_squared_error(y_0d, treatment_effects)
mse_cf_mu = mean_squared_error(mu_true, treatment_effects)
print(mse_cf)
print(mse_cf_mu)

5.848258168681055
5.451506548224624
