In [241]:
import torch
torch.set_default_tensor_type(torch.FloatTensor) 
import copy
import os 
import copy
import sys
import os

notebook_dir = os.getcwd()
parent_dir = os.path.dirname(notebook_dir)
# Add the parent directory to the Python path
sys.path.append(parent_dir)

from rct_data_generator import *
from outcome_models import *
from plotting_functions import *
from mcmc_bayes_update import *
from eig_comp_utils import *
from research_exp_utils import *
# from rct_data_generator import generate_host_and_mirror


## Set-up

In [242]:
n_host_sample = 200 
sigma_error = 0.01
d = 5
A = torch.randn((d,d))
A = 1/(torch.det(A)) * A

T_allocation_host = torch.randn(d)
T_allocation_host = 100/torch.norm(T_allocation_host)*T_allocation_host

mu_nc = torch.randn(d)
mu_nc = 1/torch.norm(mu_nc)*mu_nc

mu_c = torch.randn(d)
mu_c = 1/torch.norm(mu_c)*mu_c

mu = torch.concat([mu_nc,mu_c])

In [243]:

X_host_no_T = (torch.randn((n_host_sample,d)) @ A ) 
T_host = torch.bernoulli(torch.sigmoid(X_host_no_T@ T_allocation_host))
X_host_times_T = (T_host.unsqueeze(dim=0).T * X_host_no_T)
X_host = torch.concat([X_host_no_T,X_host_times_T],dim=1)

Y_host = X_host @ mu
Y_host = (1/Y_host.std()) * (Y_host-Y_host.mean()) + sigma_error * torch.randn_like(Y_host)


In [244]:
sigma_error = 1
prior_mean = torch.zeros(2 * d)
beta_0, sigma_0_sq,inv_cov_0 = prior_mean, sigma_error,torch.eye(2*d)
prior_hyperparameters = {'beta_0': beta_0, 'sigma_0_sq': sigma_0_sq,"inv_cov_0":inv_cov_0}
bayesian_regression = BayesianLinearRegression(prior_hyperparameters)
bayesian_regression.set_causal_index(10)

In [245]:
bayesian_regression.fit(X_host,Y_host)

{'posterior_mean': tensor([ 0.0555,  0.0224,  0.4691,  0.5081, -0.2084,  0.6515, -0.0024, -0.2707,
          0.1645, -0.4364]),
 'posterior_cov_matrix': tensor([[ 1.5545e-02,  3.5238e-02, -7.4757e-05, -6.5244e-03, -3.9536e-03,
          -1.3799e-02, -2.2439e-02,  1.4084e-03,  5.5372e-03,  6.6257e-03],
         [ 3.5238e-02,  2.5921e-01,  2.6774e-02, -2.1178e-02,  5.7271e-02,
          -2.3771e-02, -1.6101e-01, -1.5678e-02,  1.4489e-02, -3.3829e-02],
         [-7.4753e-05,  2.6774e-02,  1.0645e-02,  4.4833e-03,  5.8146e-03,
           1.1573e-03, -1.6633e-02, -9.3536e-03, -5.0107e-03, -3.3667e-03],
         [-6.5244e-03, -2.1178e-02,  4.4833e-03,  1.2536e-02, -6.2431e-03,
           5.4575e-03,  1.2895e-02, -5.2503e-03, -1.1768e-02,  4.2567e-03],
         [-3.9536e-03,  5.7271e-02,  5.8146e-03, -6.2432e-03,  4.1220e-02,
           6.3215e-03, -3.3757e-02, -3.1206e-03,  4.6583e-03, -3.4808e-02],
         [-1.3799e-02, -2.3771e-02,  1.1573e-03,  5.4575e-03,  6.3215e-03,
           2.7753e

In [246]:
# bayesian_regression.closed_form_obs_EIG(X_host),bayesian_regression.closed_form_causal_EIG(X_host)

In [247]:
# bayesian_regression.samples_obs_EIG(X_host,100,10),bayesian_regression.samples_causal_EIG(X_host,100,10)

In [248]:
T_zero = torch.zeros_like(T_host)
T_one = 1 - T_zero

X_host_times_T_zero = (T_zero.unsqueeze(dim=0).T * X_host_no_T)
X_host_times_T_one = (T_one.unsqueeze(dim=0).T * X_host_no_T)

X_host_T_zero = torch.concat([X_host_no_T,X_host_times_T_zero],dim=1)
X_host_T_one = torch.concat([X_host_no_T,X_host_times_T_one],dim=1)

Y_host = X_host @ mu
Y_host = (1/Y_host.norm()) * Y_host + sigma_error * torch.randn_like(Y_host)

In [249]:
# bayesian_regression.closed_form_obs_EIG(X_host_T_zero),bayesian_regression.closed_form_causal_EIG(X_host_T_zero)

In [250]:
# bayesian_regression.samples_obs_EIG(X_host_T_zero,100,500),bayesian_regression.samples_causal_EIG(X_host_T_zero,100,500)

In [251]:
# bayesian_regression.closed_form_obs_EIG(X_host_T_one),bayesian_regression.closed_form_causal_EIG(X_host_T_one)

In [252]:
# bayesian_regression.samples_obs_EIG(X_host_T_one,100,150),bayesian_regression.samples_causal_EIG(X_host_T_one,100,150)

## Experiment

In [253]:
n_host_sample = 200 
sigma_error = 0
d = 10 
n_repeats = 40


## BART

In [254]:
from outcome_models import BayesianCausalForest
from tqdm import tqdm

In [255]:
X_host_np = np.array(X_host,dtype=np.float32)
T_host_np = np.array(T_host,dtype=np.int32)
Y_host_np = np.array(Y_host,dtype=np.float32)

In [None]:
prior_hyperparameters = {'sigma_0_sq':1, 'p_categorical_pr':0, 'p_categorical_trt':0 }
bcf = BayesianCausalForest(prior_hyperparameters,predictive_model_parameters={"num_trees_pr":200,"num_trees_trt":100},conditional_model_param={"num_trees_pr":200})
bcf.store_train_data(X=X_host_np,T=T_host_np,Y=Y_host_np)

n_samples_inner_expectation_obs=100
n_samples_outer_expectation_obs=50
n_samples_inner_expectation_caus=400
n_samples_outer_expectation_caus=50


sampling_parameters = {'n_samples_inner_expectation_obs':n_samples_inner_expectation_obs, 'n_samples_outer_expectation_obs':n_samples_outer_expectation_obs, \
                       'n_samples_inner_expectation_caus':n_samples_inner_expectation_caus, 'n_samples_outer_expectation_caus':n_samples_outer_expectation_caus}

bcf.joint_EIG_calc(X_host_np,T_host_np,sampling_parameters=sampling_parameters
                   )

In [None]:
bcf.store_train_data(X=X_host_np,T=T_host_np,Y=Y_host_np)

In [None]:
n_samples_inner_expectation=10
n_samples_outer_expectation=10
X = X_host_np
T = T_host_np

n_samples = n_samples_outer_expectation*(n_samples_inner_expectation+1)


In [None]:
predicitions = bcf.posterior_sample_predictions(X=X, T=T,  n_samples=n_samples)

In [None]:
from eig_comp_utils import predictions_in_EIG_obs_form,compute_EIG_obs_from_samples

In [None]:
# pred_in_form = predictions_in_EIG_obs_form(predicitions,n_outer_expectation=n_samples_outer_expectation,m_inner_expectation=n_samples_inner_expectation)

In [None]:
# compute_EIG_obs_from_samples(pred_in_form,1)

In [None]:
# bcf.samples_obs_EIG(X_host_np,T_host_np,n_samples_inner_expectation=50,n_samples_outer_expectation=50)

In [None]:
bcf.joint_EIG_calc(X_host_np[:50],T_host_np[:50],
                   n_samples_inner_expectation_obs=100,n_samples_outer_expectation_obs=50,
                   n_samples_inner_expectation_caus=100,n_samples_outer_expectation_caus=50)

TypeError: BayesianCausalForest.joint_EIG_calc() got an unexpected keyword argument 'n_samples_inner_expectation_obs'

In [None]:
bcf.samples_obs_EIG(X_host_np,T_host_np,n_samples_inner_expectation=50,n_samples_outer_expectation=50)

KeyboardInterrupt: 

In [None]:
predicitions,tau = bcf.posterior_sample_predictions(X=X, T=T,  n_samples=50,return_tau=True)

In [None]:
((Y_host_np - predicitions.mean(axis=0))**2).mean()

0.9937388924972309

In [None]:
Y = Y_host_np

In [None]:
causal_sample = []
print("Getting conditional samples")
for i in tqdm(range(len(tau))):
    Y_resid = Y - tau[i]
    conditional_predictions = bcf.posterior_conditional_predictions(X=X,T=T,Y_residuals=Y_resid,n_samples=200)
    causal_sample.append((predicitions[i],conditional_predictions))

Getting conditional samples


 36%|███▌      | 18/50 [00:34<01:01,  1.92s/it]


KeyboardInterrupt: 

In [None]:
causal_sample = []
print("Getting conditional samples")
for i in tqdm(range(len(tau))):
    Y_resid = Y - tau[i]
    conditional_predictions = bcf.posterior_conditional_predictions(X=X,T=T,Y_residuals=Y_resid,n_samples=200)
    causal_sample.append((predicitions[i],conditional_predictions))

In [None]:
from eig_comp_utils import compute_EIG_causal_from_samples

In [None]:
compute_EIG_causal_from_samples(pred_in_form,causal_sample,1)

2.013498975721575

In [None]:
preds,tau = bcf.posterior_conditional_predictions(X=X,T=T,Y_residuals=Y_resid,n_samples=200,return_tau=True)

  self.params["tau_trt"] = 0.1 * np.var(y) / self.params["num_trees_trt"]


In [None]:
preds.shape

(100, 200)

In [None]:
import xbcausalforest as xbcf

In [None]:
NUM_TREES_PR = 200 
NUM_TREES_TRT = 80

In [None]:
cf = XBCF(
    #model="Normal",
    parallel=True, 
    num_sweeps=200, 
    burnin=100,
    max_depth=250,
    num_trees_pr=NUM_TREES_PR,
    num_trees_trt=NUM_TREES_TRT,
    num_cutpoints=100,
    Nmin=1,
    #mtry_pr=X1.shape[1], # default 0 seems to be 'all'
    #mtry_trt=X.shape[1], 
    tau_pr = 0.6 * np.var(Y_host_np)/NUM_TREES_PR, #0.6 * np.var(y) / /NUM_TREES_PR,
    tau_trt = 0.1 * np.var(Y_host_np)/NUM_TREES_TRT, #0.1 * np.var(y) / /NUM_TREES_TRT,
    alpha_pr= 0.95, # shrinkage (splitting probability)
    beta_pr= 2, # shrinkage (tree depth)
    alpha_trt= 0.95, # shrinkage for treatment part
    beta_trt= 2,
    p_categorical_pr = 0,
    p_categorical_trt = 0,
    # standardize y and unstandardize for prediction
         )

In [None]:
cf.fit(
    x_t=X_host_np,
    x= X_host_np,
    y=Y_host_np,
    z=T_host_np
)

XBCF(num_sweeps = 200, burnin = 100, max_depth = 250, Nmin = 1, num_cutpoints = 100, no_split_penality = 4.605170185988092, mtry_pr = 10, mtry_trt = 10, p_categorical_pr = 0, p_categorical_trt = 0, num_trees_pr = 200, alpha_pr = 0.95, beta_pr = 2.0, tau_pr = 0.0029744889736175533, kap_pr = 16.0, s_pr = 4.0, pr_scale = False, num_trees_trt = 80, alpha_trt = 0.95, beta_trt = 2.0, tau_trt = 0.0012393704056739808, kap_trt = 16.0, s_trt = 4.0, trt_scale = False, verbose = False, parallel = True, set_random_seed = False, random_seed = 0, sample_weights_flag = True, a_scaling = True, b_scaling = True)

In [None]:
predicitions = cf.predict(X_host_np,X_host_np,return_mean=False,return_muhat=True)

In [None]:
b = cf.b
b_adj = (b/( np.expand_dims(b[:,1]-b[:,0],axis= 1)))

tau_adj = predicitions[0]* (b_adj.T[T_host_np])
preds = (tau_adj + predicitions[1])

In [None]:
Y_adj = (np.expand_dims(Y_host_np,axis=1) - tau_adj).T[0]

In [None]:
cf = XBCF(
    #model="Normal",
    parallel=True, 
    num_sweeps=200, 
    burnin=100,
    max_depth=250,
    num_trees_pr=NUM_TREES_PR,
    num_trees_trt=0,
    num_cutpoints=100,
    Nmin=1,
    #mtry_pr=X1.shape[1], # default 0 seems to be 'all'
    #mtry_trt=X.shape[1], 
    tau_pr = 0.6 * np.var(Y_host_np)/NUM_TREES_PR, #0.6 * np.var(y) / /NUM_TREES_PR,
    tau_trt = 0.1 * np.var(Y_host_np)/NUM_TREES_TRT, #0.1 * np.var(y) / /NUM_TREES_TRT,
    alpha_pr= 0.95, # shrinkage (splitting probability)
    beta_pr= 2, # shrinkage (tree depth)
    alpha_trt= 0.95, # shrinkage for treatment part
    beta_trt= 2,
    p_categorical_pr = 0,
    p_categorical_trt = 0,
    # standardize y and unstandardize for prediction
         )

cf.fit(
    x_t = np.zeros_like(X_host_np),
    x = X_host_np,
    y = Y_adj,
    z = T_host_np
)

  self.params["tau_trt"] = 0.1 * np.var(y) / self.params["num_trees_trt"]


XBCF(num_sweeps = 200, burnin = 100, max_depth = 250, Nmin = 1, num_cutpoints = 100, no_split_penality = 4.605170185988092, mtry_pr = 10, mtry_trt = 10, p_categorical_pr = 0, p_categorical_trt = 0, num_trees_pr = 200, alpha_pr = 0.95, beta_pr = 2.0, tau_pr = 0.0027849969323278634, kap_pr = 16.0, s_pr = 4.0, pr_scale = False, num_trees_trt = 0, alpha_trt = 0.95, beta_trt = 2.0, tau_trt = inf, kap_trt = 16.0, s_trt = 4.0, trt_scale = False, verbose = False, parallel = True, set_random_seed = False, random_seed = 0, sample_weights_flag = True, a_scaling = True, b_scaling = True)

In [None]:
predicitions = cf.predict(np.zeros_like(X_host_np),X_host_np,return_mean=False,return_muhat=True)

In [None]:
predicitions

(array([[ 0.,  0.,  0., ...,  0., -0.,  0.],
        [ 0.,  0.,  0., ...,  0., -0.,  0.],
        [ 0.,  0.,  0., ...,  0., -0.,  0.],
        ...,
        [ 0.,  0.,  0., ...,  0., -0.,  0.],
        [ 0.,  0.,  0., ...,  0., -0.,  0.],
        [ 0.,  0.,  0., ...,  0., -0.,  0.]]),
 array([[ 0.062448  ,  0.36889832, -0.02182688, ...,  0.40745114,
          0.08932317, -0.01795297],
        [ 0.54190005,  0.55311663,  0.31037042, ...,  0.55837775,
          0.84239695,  0.46283443],
        [-1.37596441, -1.23473215, -1.13664206, ..., -1.50755527,
         -1.16834067, -1.10944947],
        ...,
        [-0.91072695, -0.66860846, -1.09486624, ..., -0.77695684,
         -1.15556815, -1.04539682],
        [-0.10020598, -0.42414434, -0.16395756, ..., -0.48657292,
         -0.1416557 ,  0.10174316],
        [ 0.14673005,  0.2849502 ,  0.0815876 , ...,  0.22618989,
         -0.09158059,  0.06720227]]))

In [None]:
b = cf.b
b_adj = (b/( np.expand_dims(b[:,1]-b[:,0],axis= 1)))

tau_adj = predicitions[0]* (b_adj.T[T_host_np])
preds = (tau_adj + predicitions[1])

In [None]:
(Y_adj**2).mean()

0.9287824086259364

In [None]:
((Y_adj-preds.mean(axis=1))**2).mean()

0.015277002089668952

In [None]:
tau_adj

array([[ 0.,  0., -0., ...,  0., -0.,  0.],
       [-0., -0., -0., ..., -0.,  0.,  0.],
       [ 0.,  0., -0., ...,  0., -0.,  0.],
       ...,
       [ 0.,  0., -0., ...,  0., -0.,  0.],
       [-0., -0., -0., ..., -0.,  0.,  0.],
       [-0., -0., -0., ..., -0.,  0.,  0.]])

In [None]:
predicitions

(array([[ 0.,  0.,  0., ...,  0., -0.,  0.],
        [ 0.,  0.,  0., ...,  0., -0.,  0.],
        [ 0.,  0.,  0., ...,  0., -0.,  0.],
        ...,
        [ 0.,  0.,  0., ...,  0., -0.,  0.],
        [ 0.,  0.,  0., ...,  0., -0.,  0.],
        [ 0.,  0.,  0., ...,  0., -0.,  0.]]),
 array([[ 0.062448  ,  0.36889832, -0.02182688, ...,  0.40745114,
          0.08932317, -0.01795297],
        [ 0.54190005,  0.55311663,  0.31037042, ...,  0.55837775,
          0.84239695,  0.46283443],
        [-1.37596441, -1.23473215, -1.13664206, ..., -1.50755527,
         -1.16834067, -1.10944947],
        ...,
        [-0.91072695, -0.66860846, -1.09486624, ..., -0.77695684,
         -1.15556815, -1.04539682],
        [-0.10020598, -0.42414434, -0.16395756, ..., -0.48657292,
         -0.1416557 ,  0.10174316],
        [ 0.14673005,  0.2849502 ,  0.0815876 , ...,  0.22618989,
         -0.09158059,  0.06720227]]))

In [None]:
((preds.mean(axis=1)-Y_host_np)**2).mean()

0.024142782849967778

In [None]:
preds

array([[ 0.062448  ,  0.36889832, -0.02182688, ...,  0.40745114,
         0.08932317, -0.01795297],
       [ 0.54190005,  0.55311663,  0.31037042, ...,  0.55837775,
         0.84239695,  0.46283443],
       [-1.37596441, -1.23473215, -1.13664206, ..., -1.50755527,
        -1.16834067, -1.10944947],
       ...,
       [-0.91072695, -0.66860846, -1.09486624, ..., -0.77695684,
        -1.15556815, -1.04539682],
       [-0.10020598, -0.42414434, -0.16395756, ..., -0.48657292,
        -0.1416557 ,  0.10174316],
       [ 0.14673005,  0.2849502 ,  0.0815876 , ...,  0.22618989,
        -0.09158059,  0.06720227]])

In [None]:
Y_host_np

array([-1.0027681e-02,  3.3784306e-01, -1.2541382e+00,  3.1256959e-01,
       -4.7415957e-01,  4.5077965e-01,  2.4700131e+00, -6.7570603e-01,
       -1.4678315e+00, -5.0529778e-01, -2.2587560e-01, -5.4501891e-02,
       -1.3739359e+00, -1.4328986e-01, -3.6615300e-01, -2.5677934e-02,
        9.6452802e-01,  5.3417987e-01,  1.0382600e+00,  1.7982579e-03,
        1.0459331e+00, -3.3536810e-01,  1.2266842e+00, -1.4169722e+00,
        5.5360150e-02, -1.7924825e+00, -6.7738706e-01, -1.5795844e+00,
       -7.6324731e-02,  7.2360235e-01, -2.9838771e-01,  1.4375595e+00,
       -1.3860397e-01, -7.7572024e-01,  2.6350304e-01,  4.5717143e-02,
       -2.2918837e+00, -4.2746985e-01, -1.2554047e+00, -3.5815999e-01,
        1.5554879e+00,  1.1510255e+00, -1.6451495e+00, -9.9534285e-01,
       -1.4149719e-01,  2.4940895e-01, -2.8302634e-01, -2.0933665e-01,
        5.3453165e-01, -2.0229363e+00,  2.2036110e-01,  3.7880665e-01,
       -5.5487955e-01,  3.8057946e-02, -5.8195662e-01, -1.9363624e-01,
      

In [None]:
zero_one = b_adj
zero_one[:,0] = 0
zero_one[:,1] = 1

In [None]:
zero_one

array([[0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.

In [None]:
(preds * (zero_one.T[T_host_np]))

array([[ 0.062448  ,  0.36889832, -0.02182688, ...,  0.40745114,
         0.08932317, -0.01795297],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [-1.37596441, -1.23473215, -1.13664206, ..., -1.50755527,
        -1.16834067, -1.10944947],
       ...,
       [-0.91072695, -0.66860846, -1.09486624, ..., -0.77695684,
        -1.15556815, -1.04539682],
       [-0.        , -0.        , -0.        , ..., -0.        ,
        -0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
        -0.        ,  0.        ]])

In [None]:
T_host_np

array([1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1,
       1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1,
       1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1,
       0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1,
       0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0], dtype=int32)

In [None]:
b_adj.T[T_host_np].shape

(100, 200)

In [None]:
1/(b[:,1]-b[:,0])

array([   1.        ,    1.79263863,   16.85242175,    1.27232554,
          1.95618683,   -0.81307613,   -2.90397183,   -0.88551136,
          0.81156604,    2.34651251,    0.85757365,   -0.45146257,
          1.57298871,   -1.04484782,   -0.84375621,    0.70354786,
          6.5773465 ,   78.16367847,    1.09559407,    1.04183923,
          0.96606927,    2.46862989,   -6.17447733,    0.45069441,
         -8.30438462,    3.30564216,   -1.75249179,    1.36468276,
          1.63963622,   -3.60870995,   -4.35719005,   -0.89769291,
         55.39759803,   -1.01059817,   -1.72822839,   -1.38167704,
          0.65313602,   -0.96055083,   -0.50261003,   -0.87914368,
         -4.37855778,    1.12313459,   -9.39800577,   -0.74272738,
          0.73814445,   -0.81535814,    3.11883491,   -1.31026382,
         13.39771173,   -4.6314851 ,   -0.93088136,    3.0983337 ,
         -6.36670132,   -2.74815534,    0.65435172,   -8.08609789,
         -1.30175958,    1.6412005 ,    0.48467231,   -1.53750