# Two Step BART Modeling

Here we consider models of the form

$$y = X\beta + f(X, X\beta) + \epsilon$$

where $f(X)$ is a stochastic tree model, $\beta$ is a vector of linear regression coefficients, and $\epsilon \sim \mathcal{N}(0,\sigma^2)$

The model is fit in two stages:

1. $\beta$ is estimated by $y = X\beta + \nu$ linear regression
2. $f$ is sampled via $y - X\beta = f(X, X\beta) + \epsilon$, where $\epsilon \sim \mathcal{N}(0,\sigma^2)$

To begin to investigate this, we load the necessary libraries

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split

from stochtree import BARTModel

## Demo 1: Supervised Learning

Consider a nonlinear single-index model

Generate sample data

In [None]:
# RNG
rng = np.random.default_rng()

# Generate covariates and basis
n = 5300
p_X = 5
X = rng.uniform(-10./p_X, 10./p_X, (n, p_X))


# Define the single index basis
def single_index_basis(X, beta = None):
    if beta is None:
        _, p_x = X.shape
        rng = np.random.default_rng()
        beta = 2.0*p_x*rng.dirichlet(alpha = np.ones(p_x, dtype=float))
    return np.squeeze(np.matmul(X, beta))


# Define the outcome mean function
def outcome_mean(basis):
    return 0.1 * basis + np.sin(2.0 - basis)/(1.0 + np.abs(basis - 2.0))


# Generate outcome
epsilon = rng.normal(0, 1, n)
beta = 2.0*p_X*rng.dirichlet(alpha = np.ones(p_X, dtype=float))
Xb = single_index_basis(X, beta)
f_x = outcome_mean(Xb)
snr = 3
sig = np.std(f_x)/snr
y = f_x + epsilon*sig

Test-train split

In [None]:
sample_inds = np.arange(n)
train_inds, test_inds = train_test_split(sample_inds, test_size=300)
X_train = X[train_inds, :]
X_test = X[test_inds, :]
Xb_train = Xb[train_inds]
Xb_test = Xb[test_inds]
y_train = y[train_inds]
y_test = y[test_inds]
f_x_train = f_x[train_inds]
f_x_test = f_x[test_inds]

Run single-step BART

In [None]:
single_step_bart_model = BARTModel()
mean_params = {"num_trees": 500, "sample_sigma2_leaf": False}
single_step_bart_model.sample(
    X_train=X_train,
    y_train=y_train,
    X_test=X_test,
    num_gfr=5,
    num_mcmc=200,
    mean_forest_params=mean_params,
)

In [None]:
plt.scatter(x=np.mean(single_step_bart_model.y_hat_test,axis=1), y=y_test, label="Outcome")
plt.scatter(x=np.mean(single_step_bart_model.y_hat_test,axis=1), y=f_x_test, label="Mean function")
plt.xlabel("BART prediction")
plt.ylabel("Actual")
plt.legend()
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()

Run two-step lm + BART

In [None]:
linear_model = LinearRegression()
linear_model.fit(X=X_train, y=y_train)
yhat_lm_train = np.squeeze(linear_model.predict(X=X_train))
yhat_lm_test = np.squeeze(linear_model.predict(X=X_test))

In [None]:
two_step_bart_model = BARTModel()
mean_params = {"num_trees": 50, "max_depth": 15}
two_step_bart_covariate_train = np.c_[X_train, yhat_lm_train]
two_step_bart_covariate_test = np.c_[X_test, yhat_lm_test]
two_step_bart_model.sample(
    X_train=two_step_bart_covariate_train,
    y_train=y_train - yhat_lm_train,
    X_test=two_step_bart_covariate_test,
    num_gfr=10,
    num_mcmc=200,
    mean_forest_params=mean_params,
)

Inspect the MCMC (BART) samples

In [None]:
single_step_preds_y_mcmc = single_step_bart_model.y_hat_test
single_step_avg_mcmc = np.squeeze(single_step_preds_y_mcmc).mean(axis=1, keepdims=True)
two_step_preds_y_mcmc = two_step_bart_model.y_hat_test
two_step_avg_mcmc = np.squeeze(two_step_preds_y_mcmc).mean(axis=1, keepdims=True) + np.expand_dims(yhat_lm_test, 1)
plt.scatter(x=Xb_test, y=single_step_avg_mcmc, label="Classic BART")
plt.scatter(x=Xb_test, y=two_step_avg_mcmc, label="Linear Augmented BART")
plt.xlabel("Index variable")
plt.ylabel("Model predictions")
plt.legend()
plt.show()

In [None]:
plt.scatter(x=single_step_avg_mcmc, y=y_test, label="Classic BART")
plt.scatter(x=two_step_avg_mcmc, y=y_test, label="Linear Augmented BART")
plt.xlabel("Model predictions")
plt.ylabel("True outcome")
plt.legend()
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()

In [None]:
plt.scatter(x=single_step_avg_mcmc, y=f_x_test, label="Classic BART")
plt.scatter(x=two_step_avg_mcmc, y=f_x_test, label="Linear Augmented BART")
plt.xlabel("Model predictions")
plt.ylabel("True mean function")
plt.legend()
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()

Compute the root mean squared difference between $\hat{f}(X)$ and $f(X)$ on the test set

In [None]:
print(f"Single step root mean squared estimation error: {np.sqrt(np.mean(np.power(f_x_test - np.squeeze(single_step_avg_mcmc), 2)))}\nTwo step root mean squared estimation error: {np.sqrt(np.mean(np.power(f_x_test - np.squeeze(two_step_avg_mcmc), 2)))}")

Compute the root mean squared difference between $\hat{f}(X)$ and $y$ on the test set

In [None]:
print(f"Single step root mean squared prediction error: {np.sqrt(np.mean(np.power(y_test - np.squeeze(single_step_avg_mcmc), 2)))}\nTwo step root mean squared prediction error: {np.sqrt(np.mean(np.power(y_test - np.squeeze(two_step_avg_mcmc), 2)))}")