# Causal Inference Demo Notebook

Load necessary libraries

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from stochtree import BCFModel
from sklearn.model_selection import train_test_split

Generate sample data

In [None]:
# RNG
random_seed = 101
rng = np.random.default_rng(random_seed)

# Generate covariates and basis
n = 1000
p_X = 5
X = rng.uniform(0, 1, (n, p_X))
pi_X = 0.25 + 0.5*X[:,0]
Z = rng.binomial(1, pi_X, n).astype(float)

# Define the outcome mean functions (prognostic and treatment effects)
mu_X = pi_X*5
# tau_X = np.sin(X[:,1]*2*np.pi)
tau_X = X[:,1]*2

# Generate outcome
epsilon = rng.normal(0, 1, n)
y = mu_X + tau_X*Z + epsilon

Test-train split

In [None]:
sample_inds = np.arange(n)
train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)
X_train = X[train_inds,:]
X_test = X[test_inds,:]
Z_train = Z[train_inds]
Z_test = Z[test_inds]
y_train = y[train_inds]
y_test = y[test_inds]
pi_train = pi_X[train_inds]
pi_test = pi_X[test_inds]
mu_train = mu_X[train_inds]
mu_test = mu_X[test_inds]
tau_train = tau_X[train_inds]
tau_test = tau_X[test_inds]

Run BCF

In [None]:
bcf_model = BCFModel()
bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test)

Inspect the MCMC (BART) samples

In [None]:
forest_preds_y_mcmc = bcf_model.y_hat_test[:,bcf_model.num_gfr:]
forest_pred_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True)
forest_pred_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y_test,1), forest_pred_avg_mcmc), axis = 1), columns=["True tau", "Average estimated tau"])
sns.scatterplot(data=forest_pred_df_mcmc, x="True tau", y="Average estimated tau")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

In [None]:
forest_preds_tau_mcmc = bcf_model.tau_hat_test[:,bcf_model.num_gfr:]
forest_pred_avg_mcmc = np.squeeze(forest_preds_tau_mcmc).mean(axis = 1, keepdims = True)
forest_pred_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(tau_test,1), forest_pred_avg_mcmc), axis = 1), columns=["True tau", "Average estimated tau"])
sns.scatterplot(data=forest_pred_df_mcmc, x="True tau", y="Average estimated tau")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

In [None]:
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples - bcf_model.num_gfr),axis=1), np.expand_dims(bcf_model.global_var_samples[bcf_model.num_gfr:],axis=1)), axis = 1), columns=["Sample", "Sigma"])
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
plt.show()

In [None]:
b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples - bcf_model.num_gfr),axis=1), np.expand_dims(bcf_model.b0_samples[bcf_model.num_gfr:],axis=1), np.expand_dims(bcf_model.b1_samples[bcf_model.num_gfr:],axis=1)), axis = 1), columns=["Sample", "Beta_0", "Beta_1"])
sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_0")
sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_1")
plt.show()