In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

%matplotlib inline

In [3]:
import pickle
import econml

from rashomon import tva
from rashomon import loss
from rashomon import counter
from rashomon import metrics
from rashomon import extract_pools
from rashomon.aggregate import (RAggregate_profile, RAggregate,
    find_profile_lower_bound, find_feasible_combinations, remove_unused_poolings, subset_data)
from rashomon.sets import RashomonSet, RashomonProblemCache, RashomonSubproblemCache

from econml.grf import CausalForest
from sklearn.metrics import mean_squared_error
from sklearn import linear_model

from copy import deepcopy

%load_ext autoreload
%autoreload 2

In [1]:
def generate_data(mu, var, n_per_pol, all_policies, pi_policies, M):
    num_data = num_policies * n_per_pol
    X = np.zeros(shape=(num_data, M))
    D = np.zeros(shape=(num_data, 1), dtype='int_')
    y = np.zeros(shape=(num_data, 1))  + np.inf
    mu_true = np.zeros(shape=(num_data, 1))

    idx_ctr = 0
    for k, profile in enumerate(profiles):
        policies_k = policies_profiles[k]

        for idx, policy in enumerate(policies_k):
            policy_idx = [i for i, x in enumerate(all_policies) if x == policy]
            
            if pi_policies[k] is None and np.isnan(mu[k]):
                continue
                
            pool_id = pi_policies[k][idx]
            mu_i = mu[k][pool_id]
            var_i = var[k][pool_id]
            y_i = np.random.normal(mu_i, var_i, size=(n_per_pol, 1))

            start_idx = idx_ctr * n_per_pol
            end_idx = (idx_ctr + 1) * n_per_pol

            X[start_idx:end_idx, ] = policy
            D[start_idx:end_idx, ] = policy_idx[0]
            y[start_idx:end_idx, ] = y_i
            mu_true[start_idx:end_idx, ] = mu_i

            idx_ctr += 1

    absent_idx = np.where(np.isinf(y))[0]
    X = np.delete(X, absent_idx, 0)
    y = np.delete(y, absent_idx, 0)
    D = np.delete(D, absent_idx, 0)
    mu_true = np.delete(mu_true, absent_idx, 0)

    return X, D, y, mu_true

# Setup

## Parameters

In [4]:
# M = 2
# R = np.array([2, 4])


# # (0, 0)
# sigma_0 = None
# mu_0 = np.array([0])
# var_0 = np.array([1])

# # (0, 1)
# sigma_1 = np.array([[1, 1]])
# mu_1 = np.array([0])
# var_1 = np.array([1])


# # (1, 0)
# sigma_2 = None
# mu_2 = np.array([0])
# var_2 = np.array([1])


# # (1, 1)
# sigma_3 = np.array([[1, 1],
#                     [0, 1]])
# mu_3 = np.array([6, 4,])
# var_3 = np.array([1, 1])


# sigma = [sigma_0, sigma_1, sigma_2, sigma_3]
# mu = [mu_0, mu_1, mu_2, mu_3]
# var = [var_0, var_1, var_2, var_3]

In [5]:
# M = 3
# R = np.array([4, 4, 4])

# # Fix the partitions
# # Profile 0: (0, 0)
# sigma_0 = None
# mu_0 = np.array([0])
# var_0 = np.array([1])

# # Profile 1: (0, 0, 1)
# sigma_1 = np.array([[1, 1]])
# mu_1 = np.array([1])
# var_1 = np.array([1])

# # Profile 2: (0, 1, 0)
# sigma_2 = np.array([[1, 0]])
# mu_2 = np.array([0.5, 3.8])
# var_2 = np.array([1, 1])

# # Profile 3: (0, 1, 1)
# sigma_3 = np.array([[1, 1],
#                     [1, 1]])
# mu_3 = np.array([1.5])
# var_3 = np.array([1])

# # Profile 4: (1, 0, 0)
# sigma_4 = np.array([[1, 1]])
# mu_4 = np.array([1])
# var_4 = np.array([1])

# # Profile 5: (1, 0, 1)
# sigma_5 = np.array([[0, 1],
#                     [1, 1]])
# mu_5 = np.array([2, 3])
# var_5 = np.array([1, 1])

# # Profile 6: (1, 1, 0)
# sigma_6 = np.array([[1, 1],
#                     [1, 1]])
# mu_6 = np.array([2.5])
# var_6 = np.array([1])

# # Profile 1: (1, 1, 1)
# sigma_7 = np.array([[1, 1],
#                     [1, 1],
#                     [1, 0]])
# mu_7 = np.array([3, 4])
# var_7 = np.array([1, 1])

# sigma = [sigma_0, sigma_1, sigma_2, sigma_3, sigma_4, sigma_5, sigma_6, sigma_7]
# mu = [mu_0, mu_1, mu_2, mu_3, mu_4, mu_5, mu_6, mu_7]
# var = [var_0, var_1, var_2, var_3, var_4, var_5, var_6, var_7]


In [4]:
# M = 4
# R = np.array([2, 5, 5, 5])


# # (0, 1, 1, 1)
# sigma_0 = np.array([[1, 1, 1],
#                     [1, 1, 1],
#                     [1, 1, 1]
#                    ])
# mu_0 = np.array([0])
# var_0 = np.array([0])

# # (1, 1, 1)
# sigma_1 = np.array([[np.inf, np.inf, np.inf],
#                     [1, 1, 1],
#                     [1, 0, 1],
#                     [1, 0, 0],
#                    ])
# mu_1 = np.array([2, 4, 7, 5, 6, 3])
# var_1 = np.array([1, 1, 1, 1, 1, 1]) * 0.5


# interested_profiles = [(0, 1, 1, 1), (1, 1, 1, 1)]

# sigma_tmp = [sigma_0, sigma_1]
# mu_tmp = [mu_0, mu_1]
# var_tmp = [var_0, var_1]

M = 4
R = np.array([2, 4, 5, 5])


# (0, 1, 1, 1)
sigma_0 = np.array([[1, 1, 1],
                    [1, 1, 1],
                    [1, 1, 1]
                   ])
mu_0 = np.array([0])
var_0 = np.array([0])

# (1, 1, 1)
sigma_1 = np.array([[np.inf, np.inf, np.inf],
                    [0, 0, np.inf],
                    [1, 0, 1],
                    [1, 1, 0],
                   ])
mu_1 = np.array([2, 4, 2, 0, 3, 5, 7, 1, 1, -1, -1, -2])
var_1 = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) * 0.5


interested_profiles = [(0, 1, 1, 1), (1, 1, 1, 1)]

sigma_tmp = [sigma_0, sigma_1]
mu_tmp = [mu_0, mu_1]
var_tmp = [var_0, var_1]

\begin{align*}
    \text{Control} & \\
    (0,\ 1:3,\  1:4,\  1:4) &= 0 \\
    \text{Treatment} & \\
    (1,\  1,\  1:2,\  1:3) &= 2 \\
    (1,\  1,\  1:2,\  4) &= 4 \\
    (1,\  1,\  3:4,\  1:3) &= 2 \\
    (1,\  1,\  3:4,\  4) &= 0 \\
    (1,\  2,\  1:2,\  1:3) &= 3 \\
    (1,\  2,\  1:2,\  4) &= 5 \\
    (1,\  2,\  3:4,\  1:3) &= 7 \\
    (1,\  2,\  3:4,\  4) &= 1 \\
    (1,\  3,\  1:2,\  1:3) &= 1 \\
    (1,\  3,\  1:2,\  4) &= -1 \\
    (1,\  3,\  3:4,\  1:3) &= -1 \\
    (1,\  3,\  3:4,\  4) &= -2
\end{align*}

In [5]:
num_profiles = 2**M
profiles, profile_map = tva.enumerate_profiles(M)
all_policies = tva.enumerate_policies(M, R)
num_policies = len(all_policies)

interested_profile_idx = []
sigma = []
mu = []
var = []
for k, profile in enumerate(profiles):
    sigma_k = None
    mu_k = np.nan
    var_k =  np.nan
    for i, p in enumerate(interested_profiles):
        if p == profile:
            sigma_k = sigma_tmp[i]
            mu_k = mu_tmp[i]
            var_k = var_tmp[i]
            break
    sigma.append(sigma_k)
    mu.append(mu_k)
    var.append(var_k)

In [6]:
# Identify the pools
policies_profiles = {}
policies_profiles_masked = {}
policies_ids_profiles = {}
pi_policies = {}
pi_pools = {}
for k, profile in enumerate(profiles):

    policies_temp = [(i, x) for i, x in enumerate(all_policies) if tva.policy_to_profile(x) == profile]
    unzipped_temp = list(zip(*policies_temp))
    policies_ids_k = list(unzipped_temp[0])
    policies_k = list(unzipped_temp[1])
    policies_profiles[k] = deepcopy(policies_k)
    policies_ids_profiles[k] = policies_ids_k

    profile_mask = list(map(bool, profile))

    # Mask the empty arms
    for idx, pol in enumerate(policies_k):
        policies_k[idx] = tuple([pol[i] for i in range(M) if profile_mask[i]])
    policies_profiles_masked[k] = policies_k


    # profile_idx = None
    # for idx, p in enumerate(interested_profiles):
    #     if p == profile:
    #         profile_idx = idx
    # if profile_idx is None:
    #     pi_policies[k] = None
    #     pi_pools[k] = None
    #     continue
    if sigma[k] is None:
        pi_policies[k] = None
        pi_pools[k] = None
        continue

    if np.sum(profile) > 0:
        pi_pools_k, pi_policies_k = extract_pools.extract_pools(policies_k, sigma[k])
        if len(pi_pools_k.keys()) != mu[k].shape[0]:
            print(pi_pools_k)
            print(f"Profile {k}. Expected {len(pi_pools_k.keys())} pools. Received {mu[k].shape[0]} means.")
        pi_policies[k] = pi_policies_k
        # pi_pools_k has indicies that match with policies_profiles[k]
        # Need to map those indices back to all_policies
        pi_pools[k] = {}
        for x, y in pi_pools_k.items():
            y_full = [policies_profiles[k][i] for i in y]
            y_agg = [all_policies.index(i) for i in y_full]
            pi_pools[k][x] = y_agg
    else:
        pi_policies[k] = {0: 0}
        pi_pools[k] = {0: [0]}

best_per_profile = [np.max(mu_k) for mu_k in mu]
true_best_profile = np.nanargmax(best_per_profile)
true_best_profile_idx = int(true_best_profile)
true_best_effect = np.max(mu[true_best_profile])
true_best = pi_pools[true_best_profile][np.argmax(mu[true_best_profile])]
min_dosage_best_policy = metrics.find_min_dosage(true_best, all_policies)

# The transformation matrix for Lasso
G = tva.alpha_matrix(all_policies)

In [9]:
for pi, l in pi_pools[15].items():
    print(pi)
    print([all_policies[x] for x in l])

0
[(1, 1, 1, 1), (1, 1, 1, 2), (1, 1, 1, 3), (1, 1, 2, 1), (1, 1, 2, 2), (1, 1, 2, 3)]
1
[(1, 1, 1, 4), (1, 1, 2, 4)]
2
[(1, 1, 3, 1), (1, 1, 3, 2), (1, 1, 3, 3), (1, 1, 4, 1), (1, 1, 4, 2), (1, 1, 4, 3)]
3
[(1, 1, 3, 4), (1, 1, 4, 4)]
4
[(1, 2, 1, 1), (1, 2, 1, 2), (1, 2, 1, 3), (1, 2, 2, 1), (1, 2, 2, 2), (1, 2, 2, 3)]
5
[(1, 2, 1, 4), (1, 2, 2, 4)]
6
[(1, 2, 3, 1), (1, 2, 3, 2), (1, 2, 3, 3), (1, 2, 4, 1), (1, 2, 4, 2), (1, 2, 4, 3)]
7
[(1, 2, 3, 4), (1, 2, 4, 4)]
8
[(1, 3, 1, 1), (1, 3, 1, 2), (1, 3, 1, 3), (1, 3, 2, 1), (1, 3, 2, 2), (1, 3, 2, 3)]
9
[(1, 3, 1, 4), (1, 3, 2, 4)]
10
[(1, 3, 3, 1), (1, 3, 3, 2), (1, 3, 3, 3), (1, 3, 4, 1), (1, 3, 4, 2), (1, 3, 4, 3)]
11
[(1, 3, 3, 4), (1, 3, 4, 4)]


## Generate data

In [10]:
np.random.seed(3)

n_per_pol = 10

# Generate data
X, D, y, mu_true = generate_data(mu, var, n_per_pol, all_policies, pi_policies, M)
policy_means = loss.compute_policy_means(D, y, num_policies)
# The dummy matrix for Lasso
D_matrix = tva.get_dummy_matrix(D, G, num_policies)

trt_idx = 0
feature_idx = list(np.arange(0, trt_idx)) + list(np.arange(trt_idx+1, M))

T = np.zeros(shape=y.shape)
T[X[:, trt_idx] > 0] = 1

y_0d = y.reshape((-1,))
X_cf = X[:, feature_idx]

X_trt_subset = X[X[:, trt_idx] > 0, :]
X_trt_subset = X_trt_subset[:, feature_idx]
y_trt_subset = y[X[:, trt_idx] > 0]


D_trt_subset = D[X[:, trt_idx] > 0]
D_matrix_trt_subset = D_matrix[X[:, trt_idx] > 0, :]

# Estimation

## Causal Forests

In [8]:
# https://econml.azurewebsites.net/_autosummary/econml.grf.CausalForest.html?highlight=causalforest#econml.grf.CausalForest
est = CausalForest(criterion="het", n_estimators=100,
                   min_samples_leaf=1,
                   # max_depth=None,
                   min_samples_split=2,
                   random_state=3,
                  )

# est.fit(X_cf, X[:, trt_idx], y_0d)
est.fit(X_cf, T, y_0d)
# est.fit(X, 1 + np.zeros(T.shape), y_0d)

In [9]:
treatment_effects = est.predict(X_trt_subset)
# treatment_effects = est.predict(X)

# # Confidence intervals via Bootstrap-of-Little-Bags for forests
# lb, ub = est.effect_interval(X, alpha=0.05)

mse_cf = mean_squared_error(y_trt_subset, treatment_effects)
print(mse_cf)

0.2312209137557294


In [12]:
cf_res = metrics.compute_all_metrics(
                    y_trt_subset, treatment_effects,
    D_trt_subset, true_best, all_policies, profile_map, min_dosage_best_policy, true_best_effect)

cf_res

{'sqrd_err': 0.2312209137557294,
 'iou': 0.14285714285714285,
 'best_prof': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
 'min_dos_inc': False,
 'best_pol_diff': -0.1461451815562338}

## Rashomon Sets

In [22]:
H = np.inf


theta = 2
reg = 1e-1
R_set, rashomon_profiles = RAggregate(M, R, H, D, y, theta, reg,
                                     verbose=True,
                                     )
# theta = 3
# reg = 1e-1
# R_set, rashomon_profiles = RAggregate(M, R, H, D_trt_subset, y_trt_subset, theta, reg,
#                                      verbose=True,
#                                      )

print(len(R_set))

Skipping profile (0, 0, 0, 0)
Skipping profile (0, 0, 0, 1)
Skipping profile (0, 0, 1, 0)
Skipping profile (0, 0, 1, 1)
Skipping profile (0, 1, 0, 0)
Skipping profile (0, 1, 0, 1)
Skipping profile (0, 1, 1, 0)
(0, 1, 1, 1) 1.8850601130757645
220
Skipping profile (1, 0, 0, 0)
Skipping profile (1, 0, 0, 1)
Skipping profile (1, 0, 1, 0)
Skipping profile (1, 0, 1, 1)
Skipping profile (1, 1, 0, 0)
Skipping profile (1, 1, 0, 1)
Skipping profile (1, 1, 1, 0)
(1, 1, 1, 1) 2.0
22
Finding feasible combinations
Min = 1.4216438521207115, Max = 3.7961839981481615
209


### Find poolings across Treatment and Control

In [291]:
def ids_to_policies_profiles(pi_pools_0, policies_ids_profiles, all_policies, profile_id):
    pi_pools = {}
    policies_list = policies_ids_profiles[profile_id]
    for pi_id, pol_idx_list in pi_pools_0.items():
        pol_univ_list = [policies_list[idx] for idx in pol_idx_list]
        pi_pools[pi_id] = pol_univ_list
    return pi_pools


def remove_arm(policies_list, idx_to_skip):
    policies_list = [tuple(pol[:idx_to_skip] + pol[(idx_to_skip+1):]) for pol in policies_list]
    return policies_list


def get_intersection_matrix(ctl_pools, trt_pools, all_policies, trt_arm_idx):
    n_ctl_pools = len(ctl_pools.keys())
    n_trt_pools = len(trt_pools.keys())
    
    sigma_int = np.zeros(shape=(n_ctl_pools, n_trt_pools)) + np.inf
    for ci, pi_ci in ctl_pools.items():
        pi_ci = [all_policies[i] for i in pi_ci]
        pi_ci = remove_arm(pi_ci, trt_arm_idx)
        pi_ci = set(pi_ci)
        for ti, pi_ti in trt_pools.items():
            pi_ti = [all_policies[i] for i in pi_ti]
            pi_ti = remove_arm(pi_ti, trt_arm_idx)
            pi_ti = set(pi_ti)
            if len(pi_ci.intersection(pi_ti)) > 0:
                sigma_int[ci, ti] = 0
    
    return sigma_int

def get_trt_ctl_pooled_partition(trt_pools, ctl_pools, sigma_int):
    pools_tmp = {
        "trt": trt_pools.copy(),
        "ctl": ctl_pools.copy(),
        "mix": {}
    }
    
    mixed_indices = np.where(sigma_int == 1)
    
    for ctl_i, trt_i in zip(mixed_indices[0], mixed_indices[1]):
        mix_trt_i_pols = pools_tmp["trt"].pop(trt_i)
        mix_ctl_i_pols = pools_tmp["ctl"].pop(ctl_i)
        mix_i_pols = list(set(mix_trt_i_pols + mix_ctl_i_pols))
        mix_id = len(pools_tmp["mix"])
        pools_tmp["mix"][mix_id] = mix_i_pols
    
    sigma_pools = {}
    sigma_policies = {}
    pool_counter = 0
    for _, dict_i in pools_tmp.items():
        for _, policies_ij in dict_i.items():
            sigma_pools[pool_counter] = policies_ij
            for p in policies_ij:
                sigma_policies[p] = pool_counter
            pool_counter += 1

    return sigma_pools, sigma_policies

def compute_het_Q(D_tc, y_tc, sigma_int, trt_pools, ctl_pools, policy_means, reg=1, normalize=0):
    """
    Compute the loss after pooling across treatment and control as per sigma_int
    D_tc indices need not be re-indexed. The indicies should match policy_means
    policy_means is for the entire dataset
    """
    
    sigma_pools, sigma_policies = get_trt_ctl_pooled_partition(trt_pools, ctl_pools, sigma_int)
    mu_pools = loss.compute_pool_means(policy_means, sigma_pools)
    D_tc_pool = [sigma_policies[pol_id] for pol_id in D_tc[:, 0]]
    mu_D = mu_pools[D_tc_pool]
    mse = mean_squared_error(y_tc[:, 0], mu_D)
    
    if normalize > 0:
        mse = mse * D.shape[0] / normalize

    h = mu_pools.shape[0]
    Q = mse + reg * h

    return Q

def het_partition_solver(D_tc, y_tc, policy_means, trt_pools, ctl_pools, col, sigma_int, theta, P_qe, reg):
    ncols = sigma_int.shape[1]
    if col == ncols:
        return P_qe

    Q_0 = compute_het_Q(D_tc, y_tc, sigma_int, trt_pools, ctl_pools, policy_means, reg)
    if Q_0 <= theta:
        P_qe.insert(sigma_int.copy())
        P_qe.Q = np.append(P_qe.Q, Q_0)
        P_qe = het_partition_solver(
            D_tc, y_tc, policy_means, trt_pools, ctl_pools, col+1, sigma_int, theta, P_qe, reg)
    
    zero_loc = np.where(sigma_int[:, col] == 0)[0]
    nz = len(zero_loc)
    for i in range(nz):
        row = zero_loc[i]
        
        sigma_tmp = sigma_int.copy()
        sigma_tmp[row, :] = np.inf
        sigma_tmp[:, col] = np.inf
        sigma_tmp[row, col] = 1
        
        Q_i = compute_het_Q(D_tc, y_tc, sigma_tmp, trt_pools, ctl_pools, policy_means, reg=1e-1)
        if Q_i <= theta:
            P_qe.insert(sigma_tmp)
            P_qe.Q = np.append(P_qe.Q, Q_i)
            P_qe = het_partition_solver(
                D_tc, y_tc, policy_means, trt_pools, ctl_pools, col+1, sigma_int, theta, P_qe, reg)

    return P_qe

In [283]:
policy_means = loss.compute_policy_means(D, y, num_policies)

In [295]:
trt_idx = 0

ctl_profile_idx = 7
trt_profile_idx = 15

ctl_profile = (0, 1, 1, 1)
trt_profile = (1, 1, 1, 1)

# Subset data for interested profiles
trt_policies_ids = policies_ids_profiles[trt_profile_idx]
ctl_policies_ids = policies_ids_profiles[ctl_profile_idx]
tc_policies_ids = trt_policies_ids + ctl_policies_ids

mask = np.isin(D, tc_policies_ids)
D_tc = D[mask].reshape((-1,1))
y_tc = y[mask].reshape((-1,1))

policy_means_tc = policy_means[tc_policies_ids, :]

# Look at a pair of treatment, control partitions

R_set_idx = 15
R_set_i = R_set[R_set_idx]

# Get treatment and control partitions
sigma_trt_R_set_idx = R_set_i[trt_profile_idx]
sigma_trt_i = rashomon_profiles[trt_profile_idx].sigma[sigma_trt_R_set_idx]
sigma_ctl_R_set_idx = R_set_i[ctl_profile_idx]
sigma_ctl_i = rashomon_profiles[ctl_profile_idx].sigma[sigma_ctl_R_set_idx]


trt_pools_0, _ = extract_pools.extract_pools(policies_profiles_masked[trt_profile_idx], sigma_trt_i)
ctl_pools_0, _ = extract_pools.extract_pools(policies_profiles_masked[ctl_profile_idx], sigma_ctl_i)

trt_pools = ids_to_policies_profiles(trt_pools_0, policies_ids_profiles, all_policies, trt_profile_idx)
ctl_pools = ids_to_policies_profiles(ctl_pools_0, policies_ids_profiles, all_policies, ctl_profile_idx)

sigma_int = get_intersection_matrix(ctl_pools, trt_pools, all_policies, trt_idx)

P_qe = RashomonSet(sigma_int.shape)
P_qe = het_partition_solver(
    D_tc, y_tc, policy_means, trt_pools, ctl_pools, col, sigma_int, theta, P_qe, reg)

print(P_qe.size)

4


### Analyze results from Rashomon sets

In [216]:

sim_i = 0
current_results = []
best_loss = np.inf

no_data_profiles = 14

for idx, r_set in enumerate(R_set):
    # print(idx)

    # MSE
    pi_policies_profiles_r = {}
    for k, profile in enumerate(profiles):
        if rashomon_profiles[k].sigma[0] is None:
            pi_policies_profiles_r[k] = { idx: 0 for idx in range(len(policies_profiles_masked[k])) }
            continue
        _, pi_policies_r_k = extract_pools.extract_pools(
            policies_profiles_masked[k],
            rashomon_profiles[k].sigma[r_set[k]]
        )
        pi_policies_profiles_r[k] = pi_policies_r_k

    pi_pools_r, pi_policies_r = extract_pools.aggregate_pools(
        pi_policies_profiles_r, policies_ids_profiles)
    pool_means_r = loss.compute_pool_means(policy_means, pi_pools_r)
    y_r_est = metrics.make_predictions(D, pi_policies_r, pool_means_r)

    r_set_results = metrics.compute_all_metrics(
        y, y_r_est, D, true_best, all_policies, profile_map,
        min_dosage_best_policy, true_best_effect)
    sqrd_err = r_set_results["sqrd_err"]
    iou_r = r_set_results["iou"]
    best_profile_indicator = r_set_results["best_prof"]
    min_dosage_present = r_set_results["min_dos_inc"]
    best_pol_diff = r_set_results["best_pol_diff"]
    this_loss = sqrd_err + reg * len(pi_pools_r)

    this_list = [
        n_per_pol, sim_i, len(pi_pools_r)-14, sqrd_err, iou_r, min_dosage_present, best_pol_diff
        ]
    this_list += best_profile_indicator
    current_results.append(this_list)

    if this_loss < best_loss:
        best_loss = this_loss

    if best_profile_indicator[true_best_profile_idx] == 1:
        found_best_profile = True
        # print("Found", this_loss)


In [217]:
profiles_str = [str(prof) for prof in profiles]
rashomon_cols = ["n_per_pol", "sim_num", "num_pools", "MSE", "IOU", "min_dosage", "best_pol_diff"]
rashomon_cols += profiles_str


rashomon_df = pd.DataFrame(current_results, columns=rashomon_cols)

rashomon_df.head()

Unnamed: 0,n_per_pol,sim_num,num_pools,MSE,IOU,min_dosage,best_pol_diff,"(0, 0, 0, 0)","(0, 0, 0, 1)","(0, 0, 1, 0)",...,"(0, 1, 1, 0)","(0, 1, 1, 1)","(1, 0, 0, 0)","(1, 0, 0, 1)","(1, 0, 1, 0)","(1, 0, 1, 1)","(1, 1, 0, 0)","(1, 1, 0, 1)","(1, 1, 1, 0)","(1, 1, 1, 1)"
0,10,0,13,0.121644,1.0,True,0.03473,0,0,0,...,0,0,0,0,0,0,0,0,0,1
1,10,0,7,0.932975,0.75,True,1.525075,0,0,0,...,0,0,0,0,0,0,0,0,0,1
2,10,0,4,1.238453,0.375,True,2.504818,0,0,0,...,0,0,0,0,0,0,0,0,0,1
3,10,0,7,1.077576,0.5,True,2.042294,0,0,0,...,0,0,0,0,0,0,0,0,0,1
4,10,0,9,0.916117,0.5,True,2.463649,0,0,0,...,0,0,0,0,0,0,0,0,0,1


In [218]:
print(np.mean(np.sort(rashomon_df["MSE"])[:10]))
print(np.mean(np.abs(rashomon_df["best_pol_diff"][np.argsort(rashomon_df["MSE"])[:10]])))

0.12164385212071123
0.03472959659401731


In [219]:
np.mean(rashomon_df["IOU"][np.argsort(rashomon_df["MSE"])[:10]])

1.0

In [221]:
np.sort(rashomon_df["MSE"])
# np.array(rashomon_df["num_pools"])[33]
# np.sort(rashomon_df["MSE"] + reg * rashomon_df["num_pools"])
# np.sort(rashomon_df["num_pools"])

array([0.12164385, 0.12164385, 0.12164385, 0.12164385, 0.12164385,
       0.12164385, 0.12164385, 0.12164385, 0.12164385, 0.12164385,
       0.12164385, 0.12164385, 0.12164385, 0.12164385, 0.12164385,
       0.12164385, 0.12164385, 0.12164385, 0.12164385, 0.12164385,
       0.12164385, 0.12164385, 0.12164385, 0.12164385, 0.12164385,
       0.12164385, 0.12164385, 0.12164385, 0.12164385, 0.12164385,
       0.12164385, 0.12164385, 0.12164385, 0.12164385, 0.12164385,
       0.12164385, 0.12164385, 0.12164385, 0.12164385, 0.12164385,
       0.12164385, 0.12164385, 0.12164385, 0.12164385, 0.12164385,
       0.12164385, 0.12164385, 0.12164385, 0.12164385, 0.12164385,
       0.12164385, 0.12164385, 0.12164385, 0.12164385, 0.12164385,
       0.12164385, 0.12164385, 0.12164385, 0.12164385, 0.12164385,
       0.12164385, 0.12164385, 0.12164385, 0.12164385, 0.12164385,
       0.12164385, 0.12164385, 0.12164385, 0.12164385, 0.12164385,
       0.12164385, 0.12164385, 0.12164385, 0.12164385, 0.12164

## Lasso

In [19]:
lasso_reg = 1e-3


lasso = linear_model.Lasso(lasso_reg, fit_intercept=False)
lasso.fit(D_matrix_trt_subset, y_trt_subset)
# lasso.fit(X_trt_subset, y_trt_subset)
alpha_est = lasso.coef_
y_tva = lasso.predict(D_matrix_trt_subset)
# y_tva = lasso.predict(X_trt_subset)

tva_results = metrics.compute_all_metrics(
                    y_trt_subset, y_tva,
    D_trt_subset, true_best, all_policies, profile_map, min_dosage_best_policy, true_best_effect)

In [20]:
tva_results

{'sqrd_err': 0.23522551170666117,
 'iou': 0.16666666666666666,
 'best_prof': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
 'min_dos_inc': False,
 'best_pol_diff': -0.03413597934468626}

In [None]:
alpha_est

In [None]:
D_mat_unique = np.unique(D_matrix_trt_subset, axis=0)

y_est = np.matmul(D_mat_unique, alpha_est)

In [None]:
y_est

In [None]:
np.unique(D_trt_subset)