# Demo of the `StochTree` Prototype Interface

While the functions `bart()` and `bcf()` provide simple and performant 
interfaces for supervised learning / causal inference, `stochtree` also 
offers access to many of the "low-level" data structures that are typically 
implemented in C++.
This low-level interface is not designed for performance or even
simplicity --- rather the intent is to provide a "prototype" interface
to the C++ code that doesn't require modifying any C++.

To illustrate when such a prototype interface might be useful, consider
the classic BART algorithm:

**INPUT**: $y$, $X$, $\tau$, $\nu$, $\lambda$, $\alpha$, $\beta$

**OUTPUT**: $m$ samples of a decision forest with $k$ trees and global variance parameter $\sigma^2$

Initialize $\sigma^2$ via a default or a data-dependent calibration exercise

Initialize "forest 0" with $k$ trees with a single root node, referring to tree $j$'s prediction vector as $f_{0,j}$

Compute residual as $r = y - \sum_{j=1}^k f_{0,j}$

**FOR** $i$ **IN** $\left\{1,\dots,m\right\}$:

    Initialize forest $i$ from forest $i-1$
    
    **FOR** $j$ **IN** $\left\{1,\dots,k\right\}$:
        
        Add predictions for tree $j$ to residual: $r = r + f_{i,j}$ 
        
        Update tree $j$ via Metropolis-Hastings with $r$ and $X$ as data and tree priors depending on ($\tau$, $\sigma^2$, $\alpha$, $\beta$)

        Sample leaf node parameters for tree $j$ via Gibbs (leaf node prior is $N\left(0,\tau\right)$)
        
        Subtract (updated) predictions for tree $j$ from residual: $r = r - f_{i,j}$

        Sample $\sigma^2$ via Gibbs (prior is $IG(\nu/2,\nu\lambda/2)$)

While the algorithm itself is conceptually simple, much of the core 
computation is carried out in low-level languages such as C or C++ 
because of the tree data structure. As a result, any changes to this 
algorithm, such as supporting heteroskedasticity (@pratola2020heteroscedastic), 
categorical outcomes (@murray2021log) or causal effect estimation (@hahn2020bayesian) 
require modifying low-level code. 

The prototype interface exposes the core components of the 
loop above at the R level, thus making it possible to interchange 
C++ computation for steps like "update tree $j$ via Metropolis-Hastings" 
with R computation for a custom variance model, other user-specified additive 
mean model components, and so on.

## Scenario 1: Supervised Learning

Load necessary libraries

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from stochtree import Dataset, Residual, RNG, ForestSampler, ForestContainer, GlobalVarianceModel, LeafVarianceModel

Generate sample data

In [None]:
# RNG
random_seed = 1234
rng = np.random.default_rng(random_seed)

# Generate covariates and basis
n = 1000
p_X = 10
p_W = 1
X = rng.uniform(0, 1, (n, p_X))
W = rng.uniform(0, 1, (n, p_W))

# Define the outcome mean function
def outcome_mean(X, W):
    return np.where(
        (X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * W[:,0], 
        np.where(
            (X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * W[:,0], 
            np.where(
                (X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * W[:,0], 
                7.5 * W[:,0]
            )
        )
    )

# Generate outcome
epsilon = rng.normal(0, 1, n)
y = outcome_mean(X, W) + epsilon

# Standardize outcome
y_bar = np.mean(y)
y_std = np.std(y)
resid = (y-y_bar)/y_std

Set some sampling parameters

In [None]:
alpha = 0.9
beta = 1.25
min_samples_leaf = 1
num_trees = 100
cutpoint_grid_size = 100
global_variance_init = 1.
tau_init = 0.5
leaf_prior_scale = np.array([[tau_init]], order='C')
nu = 4.
lamb = 0.5
a_leaf = 2.
b_leaf = 0.5
leaf_regression = True
feature_types = np.repeat(0, p_X).astype(int) # 0 = numeric
var_weights = np.repeat(1/p_X, p_X)

Convert data from numpy to `StochTree` representation

In [None]:
# Dataset (covariates and basis)
dataset = Dataset()
dataset.add_covariates(X)
dataset.add_basis(W)

# Residual
residual = Residual(resid)

Initialize tracking and sampling classes

In [None]:
forest_container = ForestContainer(num_trees, W.shape[1], False)
forest_sampler = ForestSampler(dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf)
cpp_rng = RNG(random_seed)
global_var_model = GlobalVarianceModel()
leaf_var_model = LeafVarianceModel()

Prepare to run the sampler

In [None]:
num_warmstart = 10
num_mcmc = 100
num_samples = num_warmstart + num_mcmc
global_var_samples = np.concatenate((np.array([global_variance_init]), np.repeat(0, num_samples)))
leaf_scale_samples = np.concatenate((np.array([tau_init]), np.repeat(0, num_samples)))

Run the "grow-from-root" (XBART) sampler

In [None]:
for i in range(num_warmstart):
  forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, global_var_samples[i], 1, True, False)
  global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, nu, lamb)
  leaf_scale_samples[i+1] = leaf_var_model.sample_one_iteration(forest_container, cpp_rng, a_leaf, b_leaf, i)
  leaf_prior_scale[0,0] = leaf_scale_samples[i+1]

Run the MCMC (BART) sampler, initialized at the last XBART sample

In [None]:
for i in range(num_warmstart, num_samples):
  forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, global_var_samples[i], 1, False, False)
  global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, nu, lamb)
  leaf_scale_samples[i+1] = leaf_var_model.sample_one_iteration(forest_container, cpp_rng, a_leaf, b_leaf, i)
  leaf_prior_scale[0,0] = leaf_scale_samples[i+1]

Extract mean function and error variance posterior samples

In [None]:
# Forest predictions
forest_preds = forest_container.predict(dataset)*y_std + y_bar
forest_preds_gfr = forest_preds[:,:num_warmstart]
forest_preds_mcmc = forest_preds[:,num_warmstart:num_samples]

# Global error variance
sigma_samples = np.sqrt(global_var_samples)*y_std
sigma_samples_gfr = sigma_samples[:num_warmstart]
sigma_samples_mcmc = sigma_samples[num_warmstart:num_samples]

Inspect the GFR (XBART) samples

In [None]:
forest_pred_avg_gfr = forest_preds_gfr.mean(axis = 1, keepdims = True)
forest_pred_df_gfr = pd.DataFrame(np.concatenate((np.expand_dims(y, axis=1), forest_pred_avg_gfr), axis = 1), columns=["True y", "Average predicted y"])
sns.scatterplot(data=forest_pred_df_gfr, x="True y", y="Average predicted y")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

In [None]:
sigma_df_gfr = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(num_warmstart),axis=1), np.expand_dims(sigma_samples_gfr,axis=1)), axis = 1), columns=["Sample", "Sigma"])
sns.scatterplot(data=sigma_df_gfr, x="Sample", y="Sigma")
plt.show()

Inspect the MCMC (BART) samples

In [None]:
forest_pred_avg_mcmc = forest_preds_mcmc.mean(axis = 1, keepdims = True)
forest_pred_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y, axis=1), forest_pred_avg_mcmc), axis = 1), columns=["True y", "Average predicted y"])
sns.scatterplot(data=forest_pred_df_mcmc, x="True y", y="Average predicted y")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

In [None]:
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(num_samples - num_warmstart),axis=1), np.expand_dims(sigma_samples_mcmc,axis=1)), axis = 1), columns=["Sample", "Sigma"])
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
plt.show()

## Scenario 2: Causal Inference

Generate sample data

In [None]:
# RNG
random_seed = 101
rng = np.random.default_rng(random_seed)

# Generate covariates and basis
n = 1000
p_X = 5
X = rng.uniform(0, 1, (n, p_X))
pi_X = 0.25 + 0.5*X[:,0]
Z = rng.binomial(1, pi_X, n).astype(float)

# Define the outcome mean functions (prognostic and treatment effects)
mu_X = pi_X*5
# tau_X = np.sin(X[:,1]*2*np.pi)
tau_X = X[:,1]*2

# Generate outcome
epsilon = rng.normal(0, 1, n)
y = mu_X + tau_X*Z + epsilon

# Standardize outcome
y_bar = np.mean(y)
y_std = np.std(y)
resid = (y-y_bar)/y_std

Set some sampling parameters

In [None]:
# Prognostic forest parameters
alpha_mu = 0.95
beta_mu = 2.0
min_samples_leaf_mu = 1
num_trees_mu = 200
cutpoint_grid_size_mu = 100
tau_init_mu = 1/200
leaf_prior_scale_mu = np.array([[tau_init_mu]], order='C')
a_leaf_mu = 3.
b_leaf_mu = 1/200
leaf_regression_mu = False
feature_types_mu = np.repeat(0, p_X).astype(int) # 0 = numeric
var_weights_mu = np.repeat(1/(p_X + 1), p_X + 1)

# Treatment forest parameters
alpha_tau = 0.25
beta_tau = 3.
min_samples_leaf_tau = 1
num_trees_tau = 50
cutpoint_grid_size_tau = 100
tau_init_tau = 1/50
leaf_prior_scale_tau = np.array([[tau_init_tau]], order='C')
a_leaf_tau = 3.
b_leaf_tau = 1/50
leaf_regression_tau = True
feature_types_tau = np.repeat(0, p_X).astype(int) # 0 = numeric
var_weights_tau = np.repeat(1/p_X, p_X)

# Global parameters
nu = 2.
lamb = 0.5
global_variance_init = 1.

Convert data from numpy to `StochTree` representation

In [None]:
# Prognostic Forest Dataset (covariates)
dataset_mu = Dataset()
dataset_mu.add_covariates(np.c_[X,pi_X])

# Treatment Forest Dataset (covariates and treatment variable)
dataset_tau = Dataset()
dataset_tau.add_covariates(X)
dataset_tau.add_basis(Z)

# Residual
residual = Residual(resid)

Initialize tracking and sampling classes

In [None]:
# Prognostic forest sampling classes
forest_container_mu = ForestContainer(num_trees_mu, 1, True)
forest_sampler_mu = ForestSampler(dataset_mu, feature_types_mu, num_trees_mu, n, alpha_mu, beta_mu, min_samples_leaf_mu)
leaf_var_model_mu = LeafVarianceModel()

# Treatment forest sampling classes
forest_container_tau = ForestContainer(num_trees_tau, 1 if np.ndim(Z) == 1 else Z.shape[1], False)
forest_sampler_tau = ForestSampler(dataset_tau, feature_types_tau, num_trees_tau, n, alpha_tau, beta_tau, min_samples_leaf_tau)
leaf_var_model_tau = LeafVarianceModel()

# Global classes
cpp_rng = RNG(random_seed)
global_var_model = GlobalVarianceModel()

Prepare to run the sampler

In [None]:
num_warmstart = 10
num_mcmc = 500
num_samples = num_warmstart + num_mcmc
global_var_samples = np.concatenate((np.array([global_variance_init]), np.repeat(0, num_samples)))
leaf_scale_samples_mu = np.concatenate((np.array([tau_init_mu]), np.repeat(0, num_samples)))
leaf_scale_samples_tau = np.concatenate((np.array([tau_init_tau]), np.repeat(0, num_samples)))
leaf_prior_scale_mu = np.array([[tau_init_mu]])
leaf_prior_scale_tau = np.array([[tau_init_tau]])
b_0_init = -0.5
b_1_init = 0.5
b_0_samples = np.concatenate((np.array([b_0_init]), np.repeat(0, num_samples)))
b_1_samples = np.concatenate((np.array([b_1_init]), np.repeat(0, num_samples)))
tau_basis = (1-Z)*b_0_init + Z*b_1_init
dataset_tau.update_basis(tau_basis)

Run the "grow-from-root" (XBART) sampler

In [None]:
for i in range(num_warmstart):
  # Sample the prognostic forest
  forest_sampler_mu.sample_one_iteration(forest_container_mu, dataset_mu, residual, cpp_rng, feature_types_mu, cutpoint_grid_size_mu, leaf_prior_scale_mu, var_weights_mu, global_var_samples[i], 0, True, False)
  leaf_scale_samples_mu[i+1] = leaf_var_model_mu.sample_one_iteration(forest_container_mu, cpp_rng, a_leaf_mu, b_leaf_mu, i)
  leaf_prior_scale_mu[0,0] = leaf_scale_samples_mu[i+1]
  mu_x = forest_container_mu.predict_raw_single_forest(dataset_mu, i)

  # Sample the treatment effect forest
  forest_sampler_tau.sample_one_iteration(forest_container_tau, dataset_tau, residual, cpp_rng, feature_types_tau, cutpoint_grid_size_tau, leaf_prior_scale_tau, var_weights_tau, global_var_samples[i], 1, True, False)
  # leaf_scale_samples_tau[i+1] = leaf_var_model_tau.sample_one_iteration(forest_container_tau, cpp_rng, a_leaf_tau, b_leaf_tau, i)
  # leaf_prior_scale_tau[0,0] = leaf_scale_samples_tau[i+1]
  tau_x = np.squeeze(forest_container_tau.predict_raw_single_forest(dataset_tau, i))
  s_tt0 = np.sum(tau_x*tau_x*(Z==0))
  s_tt1 = np.sum(tau_x*tau_x*(Z==1))
  partial_resid_mu = resid - np.squeeze(mu_x)
  s_ty0 = np.sum(tau_x*partial_resid_mu*(Z==0))
  s_ty1 = np.sum(tau_x*partial_resid_mu*(Z==1))
  b_0_samples[i+1] = rng.normal(loc = (s_ty0/(s_tt0 + 2*global_var_samples[i])), scale = np.sqrt(global_var_samples[i]/(s_tt0 + 2*global_var_samples[i])), size = 1)
  b_1_samples[i+1] = rng.normal(loc = (s_ty1/(s_tt1 + 2*global_var_samples[i])), scale = np.sqrt(global_var_samples[i]/(s_tt1 + 2*global_var_samples[i])), size = 1)
  tau_basis = (1-Z)*b_0_samples[i+1] + Z*b_1_samples[i+1]
  dataset_tau.update_basis(tau_basis)
  
  # Sample global variance
  global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, nu, lamb)

Run the MCMC (BART) sampler, initialized at the last XBART sample

In [None]:
for i in range(num_warmstart, num_samples):
  # Sample the prognostic forest
  forest_sampler_mu.sample_one_iteration(forest_container_mu, dataset_mu, residual, cpp_rng, feature_types_mu, cutpoint_grid_size_mu, leaf_prior_scale_mu, var_weights_mu, global_var_samples[i], 0, False, False)
  leaf_scale_samples_mu[i+1] = leaf_var_model_mu.sample_one_iteration(forest_container_mu, cpp_rng, a_leaf_mu, b_leaf_mu, i)
  leaf_prior_scale_mu[0,0] = leaf_scale_samples_mu[i+1]
  mu_x = forest_container_mu.predict_raw_single_forest(dataset_mu, i)

  # Sample the treatment effect forest
  forest_sampler_tau.sample_one_iteration(forest_container_tau, dataset_tau, residual, cpp_rng, feature_types_tau, cutpoint_grid_size_tau, leaf_prior_scale_tau, var_weights_tau, global_var_samples[i], 1, False, False)
  # leaf_scale_samples_tau[i+1] = leaf_var_model_tau.sample_one_iteration(forest_container_tau, cpp_rng, a_leaf_tau, b_leaf_tau, i)
  # leaf_prior_scale_tau[0,0] = leaf_scale_samples_tau[i+1]
  tau_x = np.squeeze(forest_container_tau.predict_raw_single_forest(dataset_tau, i))
  s_tt0 = np.sum(tau_x*tau_x*(Z==0))
  s_tt1 = np.sum(tau_x*tau_x*(Z==1))
  partial_resid_mu = resid - np.squeeze(mu_x)
  s_ty0 = np.sum(tau_x*partial_resid_mu*(Z==0))
  s_ty1 = np.sum(tau_x*partial_resid_mu*(Z==1))
  b_0_samples[i+1] = rng.normal(loc = (s_ty0/(s_tt0 + 2*global_var_samples[i])), scale = np.sqrt(global_var_samples[i]/(s_tt0 + 2*global_var_samples[i])), size = 1)
  b_1_samples[i+1] = rng.normal(loc = (s_ty1/(s_tt1 + 2*global_var_samples[i])), scale = np.sqrt(global_var_samples[i]/(s_tt1 + 2*global_var_samples[i])), size = 1)
  tau_basis = (1-Z)*b_0_samples[i+1] + Z*b_1_samples[i+1]
  dataset_tau.update_basis(tau_basis)
  
  # Sample global variance
  global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, nu, lamb)

In [None]:
forest_container_tau.predict_raw(dataset_tau)

Extract mean function and error variance posterior samples

In [None]:
# Forest predictions
forest_preds_mu = forest_container_mu.predict(dataset_mu)*y_std + y_bar
forest_preds_mu_gfr = forest_preds_mu[:,:num_warmstart]
forest_preds_mu_mcmc = forest_preds_mu[:,num_warmstart:num_samples]
treatment_coding_samples = (b_1_samples[1:] - b_0_samples[1:])
forest_preds_tau = (forest_container_tau.predict_raw(dataset_tau)*y_std*np.expand_dims(treatment_coding_samples, axis=(0,2)))
forest_preds_tau_gfr = forest_preds_tau[:,:num_warmstart]
forest_preds_tau_mcmc = forest_preds_tau[:,num_warmstart:num_samples]

# Global error variance
sigma_samples = np.sqrt(global_var_samples)*y_std
sigma_samples_gfr = sigma_samples[:num_warmstart]
sigma_samples_mcmc = sigma_samples[num_warmstart:num_samples]

# Adaptive coding parameters
b_1_samples_gfr = b_1_samples[1:(num_warmstart+1)]*y_std
b_0_samples_gfr = b_0_samples[1:(num_warmstart+1)]*y_std
b_1_samples_mcmc = b_1_samples[(num_warmstart+1):]*y_std
b_0_samples_mcmc = b_0_samples[(num_warmstart+1):]*y_std

Inspect the GFR (XBART) samples

In [None]:
forest_preds_tau_avg_gfr = np.squeeze(forest_preds_tau_gfr).mean(axis = 1, keepdims = True)
forest_pred_tau_df_gfr = pd.DataFrame(np.concatenate((np.expand_dims(tau_X,1), forest_preds_tau_avg_gfr), axis = 1), columns=["True tau", "Average estimated tau"])
sns.scatterplot(data=forest_pred_tau_df_gfr, x="True tau", y="Average estimated tau")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

In [None]:
forest_pred_avg_gfr = np.squeeze(forest_preds_mu_gfr).mean(axis = 1, keepdims = True)
forest_pred_df_gfr = pd.DataFrame(np.concatenate((np.expand_dims(tau_X,1), forest_pred_avg_gfr), axis = 1), columns=["True mu", "Average estimated mu"])
sns.scatterplot(data=forest_pred_df_gfr, x="True mu", y="Average estimated mu")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

In [None]:
sigma_df_gfr = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(num_warmstart),axis=1), np.expand_dims(sigma_samples_gfr,axis=1)), axis = 1), columns=["Sample", "Sigma"])
sns.scatterplot(data=sigma_df_gfr, x="Sample", y="Sigma")
plt.show()

In [None]:
b_df_gfr = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(num_warmstart),axis=1), np.expand_dims(b_0_samples_gfr,axis=1), np.expand_dims(b_1_samples_gfr,axis=1)), axis = 1), columns=["Sample", "Beta_0", "Beta_1"])
sns.scatterplot(data=b_df_gfr, x="Sample", y="Beta_0")
sns.scatterplot(data=b_df_gfr, x="Sample", y="Beta_1")
plt.show()

Inspect the MCMC (BART) samples

In [None]:
forest_pred_avg_mcmc = np.squeeze(forest_preds_tau_mcmc).mean(axis = 1, keepdims = True)
forest_pred_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(tau_X,1), forest_pred_avg_mcmc), axis = 1), columns=["True tau", "Average estimated tau"])
sns.scatterplot(data=forest_pred_df_mcmc, x="True tau", y="Average estimated tau")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

In [None]:
forest_pred_avg_mcmc = np.squeeze(forest_preds_mu_mcmc).mean(axis = 1, keepdims = True)
forest_pred_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(tau_X,1), forest_pred_avg_mcmc), axis = 1), columns=["True mu", "Average estimated mu"])
sns.scatterplot(data=forest_pred_df_mcmc, x="True mu", y="Average estimated mu")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

In [None]:
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(num_samples - num_warmstart),axis=1), np.expand_dims(sigma_samples_mcmc,axis=1)), axis = 1), columns=["Sample", "Sigma"])
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
plt.show()

In [None]:
b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(num_samples - num_warmstart),axis=1), np.expand_dims(b_0_samples_mcmc,axis=1), np.expand_dims(b_1_samples_mcmc,axis=1)), axis = 1), columns=["Sample", "Beta_0", "Beta_1"])
sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_0")
sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_1")
plt.show()