In [1]:
from pathlib import Path

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pymc_bart as pmb
import torch 
from sklearn.model_selection import train_test_split

In [2]:
coal = np.loadtxt(pm.get_data("coal.csv"))

In [3]:
# discretize data
years = int(coal.max() - coal.min())
bins = years // 4
hist, x_edges = np.histogram(coal, bins=bins)
# compute the location of the centers of the discretized data
x_centers = x_edges[:-1] + (x_edges[1] - x_edges[0]) / 2
# xdata needs to be 2D for BART
x_data = x_centers[:, None]
# express data as the rate number of disaster per year
y_data = hist

In [4]:
n_host_sample = 80 
sigma_error = 1 
d = 10 
A = torch.randn((d,d))
A = 1/(torch.det(A)) * A

T_allocation_host = torch.randn(d)
T_allocation_host = 100/torch.norm(T_allocation_host)*T_allocation_host

mu_nc = torch.randn(d)
mu_nc = 1/torch.norm(mu_nc)*mu_nc

mu_c = torch.randn(d)
mu_c = 1/torch.norm(mu_c)*mu_c

mu = torch.concat([mu_nc,mu_c])


X_host_no_T = (torch.randn((n_host_sample,d)) @ A ) 
T_host = torch.bernoulli(torch.sigmoid(X_host_no_T@ T_allocation_host))
X_host_times_T = (T_host.unsqueeze(dim=0).T * X_host_no_T)
X_host = torch.concat([X_host_no_T,X_host_times_T],dim=1)

Y_host = X_host @ mu
Y_host = (1/Y_host.norm()) * Y_host + sigma_error * torch.randn_like(Y_host)

## Causal Random Forests

In [5]:
from xbcausalforest import XBCF

In [6]:
n_host_sample = 80 
sigma_error = 1 
d = 10 
A = torch.randn((d,d))
A = 1/(torch.det(A)) * A

T_allocation_host = torch.randn(d)
T_allocation_host = 100/torch.norm(T_allocation_host)*T_allocation_host

mu_nc = torch.randn(d)
mu_nc = 1/torch.norm(mu_nc)*mu_nc

mu_c = torch.randn(d)
mu_c = 1/torch.norm(mu_c)*mu_c

mu = torch.concat([mu_nc,mu_c])


X_host_no_T = (torch.randn((n_host_sample,d)) @ A ) 
T_host = torch.bernoulli(torch.sigmoid(X_host_no_T@ T_allocation_host))
X_host_times_T = (T_host.unsqueeze(dim=0).T * X_host_no_T)
X_host = torch.concat([X_host_no_T,X_host_times_T],dim=1)

Y_host = X_host @ mu
Y_host = (1/Y_host.norm()) * Y_host + sigma_error * torch.randn_like(Y_host)

Y = np.array(Y_host,dtype=np.float32)
T = np.array(T_host,dtype=np.int32)
X = np.array(X_host_no_T,dtype=np.float32)

In [7]:
NUM_TREES_PR  = 200
NUM_TREES_TRT = 100

cf = XBCF(
    #model="Normal",
    parallel=True, 
    num_sweeps=50, 
    burnin=15,
    max_depth=250,
    num_trees_pr=NUM_TREES_PR,
    num_trees_trt=NUM_TREES_TRT,
    num_cutpoints=100,
    Nmin=1,
    #mtry_pr=X1.shape[1], # default 0 seems to be 'all'
    #mtry_trt=X.shape[1], 
    tau_pr = 0.6 * np.var(Y)/NUM_TREES_PR, #0.6 * np.var(y) / /NUM_TREES_PR,
    tau_trt = 0.1 * np.var(Y)/NUM_TREES_TRT, #0.1 * np.var(y) / /NUM_TREES_TRT,
    alpha_pr= 0.95, # shrinkage (splitting probability)
    beta_pr= 2, # shrinkage (tree depth)
    alpha_trt= 0.95, # shrinkage for treatment part
    beta_trt= 2,
    p_categorical_pr = 0,
    p_categorical_trt = 0,
    # standardize y and unstandardize for prediction
         )

In [8]:
cf = XBCF(
    # #model="Normal",
    # parallel=True, 
    num_sweeps=500, 
    burnin=15,
    max_depth=250,
    # num_trees_pr=NUM_TREES_PR,
    # num_trees_trt=NUM_TREES_TRT,
    # num_cutpoints=100,
    # Nmin=1,
    # #mtry_pr=X1.shape[1], # default 0 seems to be 'all'
    # #mtry_trt=X.shape[1], 
    # tau_pr = 0.6 * np.var(Y)/NUM_TREES_PR, #0.6 * np.var(y) / /NUM_TREES_PR,
    # tau_trt = 0.1 * np.var(Y)/NUM_TREES_TRT, #0.1 * np.var(y) / /NUM_TREES_TRT,
    # alpha_pr= 0.95, # shrinkage (splitting probability)
    # beta_pr= 2, # shrinkage (tree depth)
    # alpha_trt= 0.95, # shrinkage for treatment part
    # beta_trt= 2,
    p_categorical_pr = 0,
    p_categorical_trt = 0,
    # standardize_target=True, # standardize y and unstandardize for prediction
         )

In [9]:
cf.fit(
    x_t=np.zeros_like(X), # Covariates treatment effect
    x=X, # Covariates outcome (including propensity score)
    y=Y,  # Outcome
    z=T, # Treatment group
)

XBCF(num_sweeps = 500, burnin = 15, max_depth = 250, Nmin = 1, num_cutpoints = 100, no_split_penality = 4.605170185988092, mtry_pr = 10, mtry_trt = 10, p_categorical_pr = 0, p_categorical_trt = 0, num_trees_pr = 30, alpha_pr = 0.95, beta_pr = 1.25, tau_pr = 0.017154474258422852, kap_pr = 16.0, s_pr = 4.0, pr_scale = False, num_trees_trt = 10, alpha_trt = 0.25, beta_trt = 3.0, tau_trt = 0.008577237129211426, kap_trt = 16.0, s_trt = 4.0, trt_scale = False, verbose = False, parallel = False, set_random_seed = False, random_seed = 0, sample_weights_flag = True, a_scaling = True, b_scaling = True)

In [10]:
cf.predict(X,return_mean=False).shape

(80, 500)

In [11]:
def setattrs(_self, **kwargs):
    for k,v in kwargs.items():
        setattr(_self, k, v)

In [12]:
setattrs(cf, num_sweeps=50, burnin=15,)

As can be seen from below, normal predict returns causal predictions and has a predictive version

In [13]:
cf.sigma_draws

In [14]:
cf.predict(X,X1=X,return_mean=False,return_muhat=True)

array([[0.26170915, 0.02497418, 0.24699971, ..., 0.03796825, 0.05030634,
        0.04096697],
       [0.26170915, 0.02893849, 0.24699971, ..., 0.03796825, 0.10935631,
        0.04096697],
       [0.35265137, 0.02497418, 0.24699971, ..., 0.03796825, 0.09756663,
        0.04096697],
       ...,
       [0.37250956, 0.02893849, 0.24699971, ..., 0.03796825, 0.09756663,
        0.04096697],
       [0.35265137, 0.02497418, 0.24699971, ..., 0.03796825, 0.03851666,
        0.04096697],
       [0.24185096, 0.02497418, 0.24699971, ..., 0.03796825, 0.05030634,
        0.04096697]])

In [16]:
from xbart import XBART

In [18]:
xbt = XBART(num_trees=100, num_sweeps=80, burnin=15)
xbt.fit(X,T)

XBART(num_trees = 100, num_sweeps = 80, n_min = 1, num_cutpoints = 100, alpha = 0.95, beta = 1.25, tau = 0.01, burnin = 15, mtry = 10, max_depth_num = 250, kap = 16.0, s = 4.0, verbose = False, parallel = False, seed = 0, model_num = 0, no_split_penality = 4.605170185988092, sample_weights_flag = True, num_classes = 1)

In [20]:
T_pred = xbt.predict(X)

In [27]:
X.shape

(80, 10)

array([[0.40121827],
       [0.39763731],
       [0.76150547],
       [0.46239374],
       [0.32782563],
       [0.61289713],
       [0.68540092],
       [0.32389972],
       [0.32491379],
       [0.21437737],
       [0.20183061],
       [0.72825871],
       [0.4630052 ],
       [0.54614235],
       [0.84809466],
       [0.69020786],
       [0.27827221],
       [0.8299529 ],
       [0.71838772],
       [0.79360174],
       [0.44701289],
       [0.43966144],
       [0.43842493],
       [0.62883467],
       [0.88753249],
       [0.62805918],
       [0.92784613],
       [0.76074891],
       [0.78085155],
       [0.5915168 ],
       [0.25756157],
       [0.20812121],
       [1.03351158],
       [0.7047798 ],
       [0.17563739],
       [0.20022266],
       [0.78148528],
       [0.85210205],
       [0.27341944],
       [0.59632999],
       [0.31063355],
       [0.7298172 ],
       [0.43100984],
       [0.59697205],
       [0.55466014],
       [0.62123816],
       [0.34172467],
       [0.472

In [32]:
np.concatenate([X,T_pred.reshape(-1,1)],axis=1).shape

(80, 11)