In [18]:
from pathlib import Path

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pymc_bart as pmb
import torch 
from sklearn.model_selection import train_test_split

In [19]:
coal = np.loadtxt(pm.get_data("coal.csv"))

In [20]:
# discretize data
years = int(coal.max() - coal.min())
bins = years // 4
hist, x_edges = np.histogram(coal, bins=bins)
# compute the location of the centers of the discretized data
x_centers = x_edges[:-1] + (x_edges[1] - x_edges[0]) / 2
# xdata needs to be 2D for BART
x_data = x_centers[:, None]
# express data as the rate number of disaster per year
y_data = hist

In [21]:
n_host_sample = 80 
sigma_error = 1 
d = 10 
A = torch.randn((d,d))
A = 1/(torch.det(A)) * A

T_allocation_host = torch.randn(d)
T_allocation_host = 100/torch.norm(T_allocation_host)*T_allocation_host

mu_nc = torch.randn(d)
mu_nc = 1/torch.norm(mu_nc)*mu_nc

mu_c = torch.randn(d)
mu_c = 1/torch.norm(mu_c)*mu_c

mu = torch.concat([mu_nc,mu_c])


X_host_no_T = (torch.randn((n_host_sample,d)) @ A ) 
T_host = torch.bernoulli(torch.sigmoid(X_host_no_T@ T_allocation_host))
X_host_times_T = (T_host.unsqueeze(dim=0).T * X_host_no_T)
X_host = torch.concat([X_host_no_T,X_host_times_T],dim=1)

Y_host = X_host @ mu
Y_host = (1/Y_host.norm()) * Y_host + sigma_error * torch.randn_like(Y_host)

## Causal Random Forests

In [22]:
from xbcausalforest import XBCF

In [23]:
n_host_sample = 80 
sigma_error = 1 
d = 10 
A = torch.randn((d,d))
A = 1/(torch.det(A)) * A

T_allocation_host = torch.randn(d)
T_allocation_host = 100/torch.norm(T_allocation_host)*T_allocation_host

mu_nc = torch.randn(d)
mu_nc = 1/torch.norm(mu_nc)*mu_nc

mu_c = torch.randn(d)
mu_c = 1/torch.norm(mu_c)*mu_c

mu = torch.concat([mu_nc,mu_c])


X_host_no_T = (torch.randn((n_host_sample,d)) @ A ) 
T_host = torch.bernoulli(torch.sigmoid(X_host_no_T@ T_allocation_host))
X_host_times_T = (T_host.unsqueeze(dim=0).T * X_host_no_T)
X_host = torch.concat([X_host_no_T,X_host_times_T],dim=1)

Y_host = X_host @ mu
Y_host = (1/Y_host.norm()) * Y_host + sigma_error * torch.randn_like(Y_host)

Y = np.array(Y_host,dtype=np.float32)
T = np.array(T_host,dtype=np.int32)
X = np.array(X_host_no_T,dtype=np.float32)

In [24]:
NUM_TREES_PR  = 200
NUM_TREES_TRT = 100

cf = XBCF(
    #model="Normal",
    parallel=True, 
    num_sweeps=50, 
    burnin=15,
    max_depth=250,
    num_trees_pr=NUM_TREES_PR,
    num_trees_trt=NUM_TREES_TRT,
    num_cutpoints=100,
    Nmin=1,
    #mtry_pr=X1.shape[1], # default 0 seems to be 'all'
    #mtry_trt=X.shape[1], 
    tau_pr = 0.6 * np.var(Y)/NUM_TREES_PR, #0.6 * np.var(y) / /NUM_TREES_PR,
    tau_trt = 0.1 * np.var(Y)/NUM_TREES_TRT, #0.1 * np.var(y) / /NUM_TREES_TRT,
    alpha_pr= 0.95, # shrinkage (splitting probability)
    beta_pr= 2, # shrinkage (tree depth)
    alpha_trt= 0.95, # shrinkage for treatment part
    beta_trt= 2,
    p_categorical_pr = 0,
    p_categorical_trt = 0,
    # standardize y and unstandardize for prediction
         )

In [37]:
cf = XBCF(
    # #model="Normal",
    # parallel=True, 
    num_sweeps=500, 
    burnin=15,
    max_depth=250,
    # num_trees_pr=NUM_TREES_PR,
    # num_trees_trt=NUM_TREES_TRT,
    # num_cutpoints=100,
    # Nmin=1,
    # #mtry_pr=X1.shape[1], # default 0 seems to be 'all'
    # #mtry_trt=X.shape[1], 
    # tau_pr = 0.6 * np.var(Y)/NUM_TREES_PR, #0.6 * np.var(y) / /NUM_TREES_PR,
    # tau_trt = 0.1 * np.var(Y)/NUM_TREES_TRT, #0.1 * np.var(y) / /NUM_TREES_TRT,
    # alpha_pr= 0.95, # shrinkage (splitting probability)
    # beta_pr= 2, # shrinkage (tree depth)
    # alpha_trt= 0.95, # shrinkage for treatment part
    # beta_trt= 2,
    p_categorical_pr = 0,
    p_categorical_trt = 0,
    # standardize_target=True, # standardize y and unstandardize for prediction
         )

In [38]:
cf.fit(
    x_t=np.zeros_like(X), # Covariates treatment effect
    x=X, # Covariates outcome (including propensity score)
    y=Y,  # Outcome
    z=T, # Treatment group
)

XBCF(num_sweeps = 500, burnin = 15, max_depth = 250, Nmin = 1, num_cutpoints = 100, no_split_penality = 4.605170185988092, mtry_pr = 10, mtry_trt = 10, p_categorical_pr = 0, p_categorical_trt = 0, num_trees_pr = 30, alpha_pr = 0.95, beta_pr = 1.25, tau_pr = 0.02693632364273071, kap_pr = 16.0, s_pr = 4.0, pr_scale = False, num_trees_trt = 10, alpha_trt = 0.25, beta_trt = 3.0, tau_trt = 0.013468161821365357, kap_trt = 16.0, s_trt = 4.0, trt_scale = False, verbose = False, parallel = False, set_random_seed = False, random_seed = 0, sample_weights_flag = True, a_scaling = True, b_scaling = True)

In [39]:
cf.predict(X,return_mean=False).shape

(80, 500)

In [40]:
def setattrs(_self, **kwargs):
    for k,v in kwargs.items():
        setattr(_self, k, v)

In [41]:
setattrs(cf, num_sweeps=50, burnin=15,)

As can be seen from below, normal predict returns causal predictions and has a predictive version

In [42]:
cf.sigma_draws

In [43]:
cf.predict(X,return_mean=False)

array([[-0.09439312, -0.36632751,  0.09668792, ...,  0.38641616,
        -0.11607723, -0.04188406],
       [-0.09439312, -0.47813647,  0.09668792, ...,  0.38641616,
        -0.11607723, -0.06905785],
       [-0.09439312, -0.47813647,  0.09668792, ...,  0.38641616,
        -0.11607723,  0.11227725],
       ...,
       [-0.09439312, -0.36632751,  0.09668792, ...,  0.38641616,
        -0.11607723, -0.22321917],
       [ 0.01981385, -0.47813647,  0.09668792, ...,  0.38641616,
        -0.11607723, -0.22321917],
       [ 0.01981385, -0.47813647,  0.09668792, ...,  0.38641616,
        -0.11607723, -0.22321917]])

In [52]:
from eig_comp_utils import predictions_in_EIG_obs_form,compute_EIG_obs_from_samples

ModuleNotFoundError: No module named 'eig_comp_utils'

In [47]:
def setattrs(self, **kwargs):
    for k,v in kwargs.items():
        setattr(self, k, v)

class BayesianCausalForest():
        
        def __init__(self, prior_hyperparameters):
            self.sigma_0_sq = prior_hyperparameters['sigma_0_sq']
            self.p_categorical_pr = prior_hyperparameters['p_categorical_pr']
            self.p_categorical_trt= prior_hyperparameters['p_categorical_trt']
            self.model = XBCF(p_categorical_pr = self.p_categorical_pr,p_categorical_trt = self.p_categorical_trt) 
        
        def set_model_atrs(self,**kwargs):
                 for k,v in kwargs.items():
                    setattr(self.model, k, v)
        
        def store_train_data(self,X,Y,T):
            self.X_train = X
            self.Y_train = Y
            self.T_train = T
        
        def posterior_sample_predictions(self, X, n_samples):

            """"Returns n samples from the posterior"""
            self.set_model_atrs(num_sweeps = n_samples)
            self.model.fit(
                    x_t=self.X_train, # Covariates treatment effect
                    x=self.X_train, # Covariates outcome (including propensity score)
                    y=self.Y_train,  # Outcome
                    z=self.T_train, # Treatment group
                    )
            return self.model.predict(X,return_mean=False)
        
        def samples_obs_EIG(self,X,n_samples_outer_expectation,n_samples_inner_expectation):
                n_samples = n_samples_outer_expectation*(n_samples_inner_expectation+1)
                predicitions = self.posterior_sample_predictions(X=X,   n_samples=n_samples  )
                predictions_in_form = predictions_in_EIG_obs_form(predicitions, n_samples_outer_expectation, n_samples_inner_expectation)   
                return compute_EIG_obs_from_samples(predictions_in_form, self.sigma_0_sq**(1/2))
            


In [48]:
prior_hyperparameters = {'sigma_0_sq':1,'p_categorical_pr':0,'p_categorical_trt':0 }
bcf = BayesianCausalForest(prior_hyperparameters)

In [49]:
bcf.store_train_data(X=X,Y=Y,T=T)

In [50]:
bcf.posterior_sample_predictions(X,100)

array([[ 0.13422537, -0.14250628, -0.1081355 , ...,  0.07980274,
         0.08245827, -0.06705811],
       [ 0.13422537, -0.14250628, -0.1081355 , ...,  0.07980274,
         0.08245827, -0.09174895],
       [ 0.13422537, -0.14250628, -0.1081355 , ...,  0.07980274,
         0.08245827, -0.09174895],
       ...,
       [ 0.06124566, -0.14250628, -0.1081355 , ...,  0.07980274,
         0.08245827, -0.18068693],
       [ 0.06124566, -0.14250628, -0.1081355 , ...,  0.07980274,
         0.08245827, -0.18068693],
       [ 0.13422537, -0.14250628, -0.1081355 , ...,  0.07980274,
         0.08245827, -0.18068693]])