# Supervised Learning Demo Notebook

Load necessary libraries

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from stochtree import Dataset, Residual, RNG, ForestSampler, ForestContainer, GlobalVarianceModel, LeafVarianceModel

Generate sample data

In [None]:
# RNG
random_seed = 1234
rng = np.random.default_rng(random_seed)

# Generate covariates and basis
n = 1000
p_X = 10
p_W = 1
X = rng.uniform(0, 1, (n, p_X))
W = rng.uniform(0, 1, (n, p_W))

# Define the outcome mean function
def outcome_mean(X, W):
    return np.where(
        (X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * W[:,0], 
        np.where(
            (X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * W[:,0], 
            np.where(
                (X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * W[:,0], 
                7.5 * W[:,0]
            )
        )
    )

# Generate outcome
epsilon = rng.normal(0, 1, n)
y = outcome_mean(X, W) + epsilon

# Standardize outcome
y_bar = np.mean(y)
y_std = np.std(y)
resid = (y-y_bar)/y_std

Set some sampling parameters

In [None]:
alpha = 0.9
beta = 1.25
min_samples_leaf = 1
num_trees = 100
cutpoint_grid_size = 100
global_variance_init = 1.
tau_init = 0.5
leaf_prior_scale = np.array([[tau_init]], order='C')
nu = 4.
lamb = 0.5
a_leaf = 2.
b_leaf = 0.5
leaf_regression = True
feature_types = np.repeat(0, p_X).astype(int) # 0 = numeric
var_weights = np.repeat(1/p_X, p_X)

Convert data from numpy to `StochTree` representation

In [None]:
# Dataset (covariates and basis)
dataset = Dataset()
dataset.add_covariates(X)
dataset.add_basis(W)

# Residual
residual = Residual(resid)

Initialize tracking and sampling classes

In [None]:
forest_container = ForestContainer(num_trees, W.shape[1], False)
forest_sampler = ForestSampler(dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf)
cpp_rng = RNG(random_seed)
global_var_model = GlobalVarianceModel()
leaf_var_model = LeafVarianceModel()

Prepare to run the sampler

In [None]:
num_warmstart = 10
num_mcmc = 100
num_samples = num_warmstart + num_mcmc
global_var_samples = np.concatenate((np.array([global_variance_init]), np.repeat(0, num_samples)))
leaf_scale_samples = np.concatenate((np.array([tau_init]), np.repeat(0, num_samples)))

Run the "grow-from-root" (XBART) sampler

In [None]:
for i in range(num_warmstart):
  forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, global_var_samples[i], 1, True)
  global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, nu, lamb)
  leaf_scale_samples[i+1] = leaf_var_model.sample_one_iteration(forest_container, cpp_rng, a_leaf, b_leaf, i)
  leaf_prior_scale[0,0] = leaf_scale_samples[i+1]

Run the MCMC (BART) sampler, initialized at the last XBART sample

In [None]:
for i in range(num_warmstart, num_samples):
  forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, global_var_samples[i], 1, False)
  global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, nu, lamb)
  leaf_scale_samples[i+1] = leaf_var_model.sample_one_iteration(forest_container, cpp_rng, a_leaf, b_leaf, i)
  leaf_prior_scale[0,0] = leaf_scale_samples[i+1]

Extract mean function and error variance posterior samples

In [None]:
# Forest predictions
forest_preds = forest_container.predict(dataset)*y_std + y_bar
forest_preds_gfr = forest_preds[:,:num_warmstart]
forest_preds_mcmc = forest_preds[:,num_warmstart:num_samples]

# Global error variance
sigma_samples = np.sqrt(global_var_samples)*y_std
sigma_samples_gfr = sigma_samples[:num_warmstart]
sigma_samples_mcmc = sigma_samples[num_warmstart:num_samples]

Inspect the GFR (XBART) samples

In [None]:
forest_pred_avg_gfr = forest_preds_gfr.mean(axis = 1, keepdims = True)
forest_pred_df_gfr = pd.DataFrame(np.concatenate((np.expand_dims(y, axis=1), forest_pred_avg_gfr), axis = 1), columns=["True y", "Average predicted y"])
sns.scatterplot(data=forest_pred_df_gfr, x="True y", y="Average predicted y")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

In [None]:
sigma_df_gfr = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(num_warmstart),axis=1), np.expand_dims(sigma_samples_gfr,axis=1)), axis = 1), columns=["Sample", "Sigma"])
sns.scatterplot(data=sigma_df_gfr, x="Sample", y="Sigma")
plt.show()

Inspect the MCMC (BART) samples

In [None]:
forest_pred_avg_mcmc = forest_preds_mcmc.mean(axis = 1, keepdims = True)
forest_pred_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y, axis=1), forest_pred_avg_mcmc), axis = 1), columns=["True y", "Average predicted y"])
sns.scatterplot(data=forest_pred_df_mcmc, x="True y", y="Average predicted y")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

In [None]:
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(num_samples - num_warmstart),axis=1), np.expand_dims(sigma_samples_mcmc,axis=1)), axis = 1), columns=["Sample", "Sigma"])
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
plt.show()