From fd9cc63a5358bc0786113e9ed226273e99f0a0f9 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 1 Oct 2024 08:45:23 -0400 Subject: [PATCH] Added functionality to update partial residual from R (for example, when sampling an additive linear model term in addition to a forest) --- R/cpp11.R | 8 + R/data.R | 34 ++++ include/stochtree/data.h | 3 + man/Outcome.Rd | 42 +++++ man/bart.Rd | 2 +- man/bcf.Rd | 2 +- src/R_data.cpp | 26 ++++ src/cpp11.cpp | 18 +++ src/data.cpp | 24 +++ tools/debug/additive_lm.R | 173 +++++++++++++++++++++ vignettes/CustomSamplingRoutine.Rmd | 233 +++++++++++++++++++++++++++- 11 files changed, 562 insertions(+), 3 deletions(-) create mode 100644 tools/debug/additive_lm.R diff --git a/R/cpp11.R b/R/cpp11.R index c718ca1c..dd0b5986 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -44,6 +44,14 @@ create_column_vector_cpp <- function(outcome) { .Call(`_stochtree_create_column_vector_cpp`, outcome) } +add_to_column_vector_cpp <- function(outcome, update_vector) { + invisible(.Call(`_stochtree_add_to_column_vector_cpp`, outcome, update_vector)) +} + +subtract_from_column_vector_cpp <- function(outcome, update_vector) { + invisible(.Call(`_stochtree_subtract_from_column_vector_cpp`, outcome, update_vector)) +} + get_residual_cpp <- function(vector_ptr) { .Call(`_stochtree_get_residual_cpp`, vector_ptr) } diff --git a/R/data.R b/R/data.R index 0d640626..d7c7206d 100644 --- a/R/data.R +++ b/R/data.R @@ -106,6 +106,40 @@ Outcome <- R6::R6Class( #' @return R vector containing (copy of) the values in `Outcome` object get_data = function() { return(get_residual_cpp(self$data_ptr)) + }, + + #' @description + #' Update the current state of the outcome (i.e. partial residual) data by adding the values of `update_vector` + #' @param update_vector Vector to be added to outcome + #' @return NULL + add_vector = function(update_vector) { + if (!is.numeric(update_vector)) { + stop("update_vector must be a numeric vector or 2d matrix") + } else { + dim_vec <- dim(update_vector) + if (!is.null(dim_vec)) { + if (length(dim_vec) > 2) stop("if update_vector is provided as a matrix, it must be 2d") + update_vector <- as.numeric(update_vector) + } + } + add_to_column_vector_cpp(self$data_ptr, update_vector) + }, + + #' @description + #' Update the current state of the outcome (i.e. partial residual) data by subtracting the values of `update_vector` + #' @param update_vector Vector to be subtracted from outcome + #' @return NULL + subtract_vector = function(update_vector) { + if (!is.numeric(update_vector)) { + stop("update_vector must be a numeric vector or 2d matrix") + } else { + dim_vec <- dim(update_vector) + if (!is.null(dim_vec)) { + if (length(dim_vec) > 2) stop("if update_vector is provided as a matrix, it must be 2d") + update_vector <- as.numeric(update_vector) + } + } + subtract_from_column_vector_cpp(self$data_ptr, update_vector) } ) ) diff --git a/include/stochtree/data.h b/include/stochtree/data.h index 1d16e4fd..b2f76578 100644 --- a/include/stochtree/data.h +++ b/include/stochtree/data.h @@ -35,10 +35,13 @@ class ColumnVector { double GetElement(data_size_t row_num) {return data_(row_num);} void SetElement(data_size_t row_num, double value) {data_(row_num) = value;} void LoadData(double* data_ptr, data_size_t num_row); + void AddToData(double* data_ptr, data_size_t num_row); + void SubtractFromData(double* data_ptr, data_size_t num_row); inline data_size_t NumRows() {return data_.size();} inline Eigen::VectorXd& GetData() {return data_;} private: Eigen::VectorXd data_; + void UpdateData(double* data_ptr, data_size_t num_row, std::function op); }; class ForestDataset { diff --git a/man/Outcome.Rd b/man/Outcome.Rd index 5edfc73c..37a5d922 100644 --- a/man/Outcome.Rd +++ b/man/Outcome.Rd @@ -23,6 +23,8 @@ of the outcome minus the predictions of every other model term \itemize{ \item \href{#method-Outcome-new}{\code{Outcome$new()}} \item \href{#method-Outcome-get_data}{\code{Outcome$get_data()}} +\item \href{#method-Outcome-add_vector}{\code{Outcome$add_vector()}} +\item \href{#method-Outcome-subtract_vector}{\code{Outcome$subtract_vector()}} } } \if{html}{\out{
}} @@ -58,4 +60,44 @@ Extract raw data in R from the underlying C++ object R vector containing (copy of) the values in \code{Outcome} object } } +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Outcome-add_vector}{}}} +\subsection{Method \code{add_vector()}}{ +Update the current state of the outcome (i.e. partial residual) data by adding the values of \code{update_vector} +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Outcome$add_vector(update_vector)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{update_vector}}{Vector to be added to outcome} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +NULL +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Outcome-subtract_vector}{}}} +\subsection{Method \code{subtract_vector()}}{ +Update the current state of the outcome (i.e. partial residual) data by subtracting the values of \code{update_vector} +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Outcome$subtract_vector(update_vector)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{update_vector}}{Vector to be subtracted from outcome} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +NULL +} +} } diff --git a/man/bart.Rd b/man/bart.Rd index eb4dfe59..74e8858c 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -113,7 +113,7 @@ that were not in the training set.} \item{num_mcmc}{Number of "retained" iterations of the MCMC sampler. Default: 100.} -\item{sample_sigma}{Whether or not to update the \code{sigma^2} global error variance parameter based on \code{IG(nu, nu*lambda)}. Default: T.} +\item{sample_sigma}{Whether or not to update the \code{sigma^2} global error variance parameter based on \code{IG(a_globa, b_global)}. Default: T.} \item{sample_tau}{Whether or not to update the \code{tau} leaf scale variance parameter based on \code{IG(a_leaf, b_leaf)}. Cannot (currently) be set to true if \code{ncol(W_train)>1}. Default: T.} diff --git a/man/bcf.Rd b/man/bcf.Rd index af2fbf33..0108b46b 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -153,7 +153,7 @@ that were not in the training set.} \item{num_mcmc}{Number of "retained" iterations of the MCMC sampler. Default: 100.} -\item{sample_sigma_global}{Whether or not to update the \code{sigma^2} global error variance parameter based on \code{IG(nu, nu*lambda)}. Default: T.} +\item{sample_sigma_global}{Whether or not to update the \code{sigma^2} global error variance parameter based on \code{IG(a_global, b_global)}. Default: T.} \item{sample_sigma_leaf_mu}{Whether or not to update the \code{sigma_leaf_mu} leaf scale variance parameter in the prognostic forest based on \code{IG(a_leaf_mu, b_leaf_mu)}. Default: T.} diff --git a/src/R_data.cpp b/src/R_data.cpp index 74e95ee9..53b5c748 100644 --- a/src/R_data.cpp +++ b/src/R_data.cpp @@ -110,6 +110,32 @@ cpp11::external_pointer create_column_vector_cpp(cpp11: return cpp11::external_pointer(vector_ptr_.release()); } +[[cpp11::register]] +void add_to_column_vector_cpp(cpp11::external_pointer outcome, cpp11::doubles update_vector) { + // Unpack pointers to data and dimensions + StochTree::data_size_t n = update_vector.size(); + double* update_data_ptr = REAL(PROTECT(update_vector)); + + // Add to the outcome data using the C++ API + outcome->AddToData(update_data_ptr, n); + + // Unprotect pointers to R data + UNPROTECT(1); +} + +[[cpp11::register]] +void subtract_from_column_vector_cpp(cpp11::external_pointer outcome, cpp11::doubles update_vector) { + // Unpack pointers to data and dimensions + StochTree::data_size_t n = update_vector.size(); + double* update_data_ptr = REAL(PROTECT(update_vector)); + + // Add to the outcome data using the C++ API + outcome->SubtractFromData(update_data_ptr, n); + + // Unprotect pointers to R data + UNPROTECT(1); +} + [[cpp11::register]] cpp11::writable::doubles get_residual_cpp(cpp11::external_pointer vector_ptr) { // Initialize output vector diff --git a/src/cpp11.cpp b/src/cpp11.cpp index d20046bb..db8d8a94 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -87,6 +87,22 @@ extern "C" SEXP _stochtree_create_column_vector_cpp(SEXP outcome) { END_CPP11 } // R_data.cpp +void add_to_column_vector_cpp(cpp11::external_pointer outcome, cpp11::doubles update_vector); +extern "C" SEXP _stochtree_add_to_column_vector_cpp(SEXP outcome, SEXP update_vector) { + BEGIN_CPP11 + add_to_column_vector_cpp(cpp11::as_cpp>>(outcome), cpp11::as_cpp>(update_vector)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp +void subtract_from_column_vector_cpp(cpp11::external_pointer outcome, cpp11::doubles update_vector); +extern "C" SEXP _stochtree_subtract_from_column_vector_cpp(SEXP outcome, SEXP update_vector) { + BEGIN_CPP11 + subtract_from_column_vector_cpp(cpp11::as_cpp>>(outcome), cpp11::as_cpp>(update_vector)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp cpp11::writable::doubles get_residual_cpp(cpp11::external_pointer vector_ptr); extern "C" SEXP _stochtree_get_residual_cpp(SEXP vector_ptr) { BEGIN_CPP11 @@ -881,6 +897,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_add_sample_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_forest_container_cpp, 1}, {"_stochtree_add_sample_value_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_value_forest_container_cpp, 2}, {"_stochtree_add_sample_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_vector_forest_container_cpp, 2}, + {"_stochtree_add_to_column_vector_cpp", (DL_FUNC) &_stochtree_add_to_column_vector_cpp, 2}, {"_stochtree_adjust_residual_forest_container_cpp", (DL_FUNC) &_stochtree_adjust_residual_forest_container_cpp, 7}, {"_stochtree_all_roots_forest_container_cpp", (DL_FUNC) &_stochtree_all_roots_forest_container_cpp, 2}, {"_stochtree_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_average_max_depth_forest_container_cpp, 1}, @@ -992,6 +1009,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 5}, {"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2}, {"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2}, + {"_stochtree_subtract_from_column_vector_cpp", (DL_FUNC) &_stochtree_subtract_from_column_vector_cpp, 2}, {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4}, {"_stochtree_update_residual_forest_container_cpp", (DL_FUNC) &_stochtree_update_residual_forest_container_cpp, 5}, {NULL, NULL, 0} diff --git a/src/data.cpp b/src/data.cpp index ea667bce..2d0a2612 100644 --- a/src/data.cpp +++ b/src/data.cpp @@ -43,6 +43,30 @@ void ColumnVector::LoadData(double* data_ptr, data_size_t num_row) { } } +void ColumnVector::AddToData(double* data_ptr, data_size_t num_row) { + data_size_t num_existing_rows = NumRows(); + CHECK_EQ(num_row, num_existing_rows); + // std::function op = std::plus(); + UpdateData(data_ptr, num_row, std::plus()); +} + +void ColumnVector::SubtractFromData(double* data_ptr, data_size_t num_row) { + data_size_t num_existing_rows = NumRows(); + CHECK_EQ(num_row, num_existing_rows); + // std::function op = std::minus(); + UpdateData(data_ptr, num_row, std::minus()); +} + +void ColumnVector::UpdateData(double* data_ptr, data_size_t num_row, std::function op) { + double ptr_val; + double updated_val; + for (data_size_t i = 0; i < num_row; ++i) { + ptr_val = static_cast(*(data_ptr + i)); + updated_val = op(data_(i), ptr_val); + data_(i) = updated_val; + } +} + void LoadData(double* data_ptr, int num_row, int num_col, bool is_row_major, Eigen::MatrixXd& data_matrix) { data_matrix.resize(num_row, num_col); diff --git a/tools/debug/additive_lm.R b/tools/debug/additive_lm.R new file mode 100644 index 00000000..c0033d66 --- /dev/null +++ b/tools/debug/additive_lm.R @@ -0,0 +1,173 @@ +# Load library +library(stochtree) + +# Generate the data +n <- 500 +p_X <- 10 +p_W <- 1 +X <- matrix(runif(n*p_X), ncol = p_X) +W <- matrix(runif(n*p_W), ncol = p_W) +beta_W <- c(5) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-3) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-1) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (1) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (3) +) +lm_term <- W %*% beta_W +y <- lm_term + f_XW + rnorm(n, 0, 1) + +# Standardize outcome +y_bar <- mean(y) +y_std <- sd(y) +resid <- (y-y_bar)/y_std + +# Set sampler parameters +alpha_bart <- 0.9 +beta_bart <- 1.25 +min_samples_leaf <- 1 +max_depth <- 10 +num_trees <- 100 +cutpoint_grid_size = 100 +global_variance_init = 1. +tau_init = 0.5 +leaf_prior_scale = matrix(c(tau_init), ncol = 1) +nu <- 4 +lambda <- 0.5 +a_leaf <- 2. +b_leaf <- 0.5 +leaf_regression <- F +feature_types <- as.integer(rep(0, p_X)) # 0 = numeric +var_weights <- rep(1/p_X, p_X) +beta_tau <- 20 + +# Initialize C++ objects +# Data +if (leaf_regression) { + forest_dataset <- createForestDataset(X, W) + outcome_model_type <- 1 +} else { + forest_dataset <- createForestDataset(X) + outcome_model_type <- 0 +} +outcome <- createOutcome(resid) + +# Random number generator (std::mt19937) +rng <- createRNG() + +# Sampling data structures +forest_model <- createForestModel(forest_dataset, feature_types, + num_trees, n, alpha_bart, beta_bart, + min_samples_leaf, max_depth) + +# Container of forest samples +if (leaf_regression) { + forest_samples <- createForestContainer(num_trees, 1, F) +} else { + forest_samples <- createForestContainer(num_trees, 1, T) +} + +# Sampler preparation +num_warmstart <- 20 +num_mcmc <- 100 +num_samples <- num_warmstart + num_mcmc +beta_init <- 0 +global_var_samples <- c(global_variance_init, rep(0, num_samples)) +leaf_scale_samples <- c(tau_init, rep(0, num_samples)) +beta_hat_samples <- c(beta_init, rep(0, num_samples)) + +# GFR loop +for (i in 1:num_warmstart) { + # Initialize vectors needed for posterior sampling + if (i == 1) { + beta_hat <- beta_init + yhat_forest <- rep(0, n) + partial_res <- resid - yhat_forest + } else { + yhat_forest <- forest_samples$predict_raw_single_forest(forest_dataset, (i-1)-1) + partial_res <- resid - yhat_forest + outcome$add_vector(W %*% beta_hat) + } + # Sample beta from bayesian linear model with gaussian prior + sigma2 <- global_var_samples[i] + beta_posterior_mean <- sum(partial_res*W[,1])/(sigma2 + sum(W[,1]*W[,1])) + beta_posterior_var <- (sigma2*beta_tau)/(sigma2 + sum(W[,1]*W[,1])) + beta_hat <- rnorm(1, beta_posterior_mean, sqrt(beta_posterior_var)) + beta_hat_samples[i+1] <- beta_hat + # Update partial residual before sampling forest + outcome$subtract_vector(W %*% beta_hat) + + # Sample forest + forest_model$sample_one_iteration( + forest_dataset, outcome, forest_samples, rng, feature_types, + outcome_model_type, leaf_prior_scale, var_weights, + sigma2, cutpoint_grid_size, gfr = T + ) + + # Sample global variance parameter + global_var_samples[i+1] <- sample_sigma2_one_iteration( + outcome, rng, nu, lambda + ) +} + +# MCMC Loop +for (i in (num_warmstart+1):num_samples) { + # Initialize vectors needed for posterior sampling + if (i == 1) { + beta_hat <- beta_init + yhat_forest <- rep(0, n) + partial_res <- resid - yhat_forest + } else { + yhat_forest <- forest_samples$predict_raw_single_forest(forest_dataset, (i-1)-1) + partial_res <- resid - yhat_forest + outcome$add_vector(W %*% beta_hat) + } + # Sample beta from bayesian linear model with gaussian prior + sigma2 <- global_var_samples[i] + beta_posterior_mean <- sum(partial_res*W[,1])/(sigma2 + sum(W[,1]*W[,1])) + beta_posterior_var <- (sigma2*beta_tau)/(sigma2 + sum(W[,1]*W[,1])) + beta_hat <- rnorm(1, beta_posterior_mean, sqrt(beta_posterior_var)) + beta_hat_samples[i+1] <- beta_hat + # Update partial residual before sampling forest + outcome$subtract_vector(W %*% beta_hat) + + # Sample forest + forest_model$sample_one_iteration( + forest_dataset, outcome, forest_samples, rng, feature_types, + outcome_model_type, leaf_prior_scale, var_weights, + global_var_samples[i], cutpoint_grid_size, gfr = F + ) + + # Sample global variance parameter + global_var_samples[i+1] <- sample_sigma2_one_iteration( + outcome, rng, nu, lambda + ) +} + +# Extract samples +# Linear model predictions +lm_preds <- (sapply(1:num_samples, function(x) W[,1]*beta_hat_samples[x+1]))*y_std + +# Forest predictions +forest_preds <- forest_samples$predict(forest_dataset)*y_std + y_bar + +# Overall predictions +preds <- forest_preds + lm_preds + +# Global error variance +sigma_samples <- sqrt(global_var_samples)*y_std + +# Inspect results +# GFR +plot(sigma_samples[1:num_warmstart], ylab="sigma") +plot(beta_hat_samples[1:num_warmstart]*y_std, ylab="beta") +plot(rowMeans(preds[,1:num_warmstart]), y, pch=16, + cex=0.75, xlab = "pred", ylab = "actual") +abline(0,1,col="red",lty=2,lwd=2.5) + +# MCMC +plot(sigma_samples[(num_warmstart+1):num_samples], ylab="sigma") +plot(beta_hat_samples[(num_warmstart+1):num_samples]*y_std, ylab="beta") +plot(rowMeans(preds[,(num_warmstart+1):num_samples]), y, pch=16, + cex=0.75, xlab = "pred", ylab = "actual") +abline(0,1,col="red",lty=2,lwd=2.5) diff --git a/vignettes/CustomSamplingRoutine.Rmd b/vignettes/CustomSamplingRoutine.Rmd index 5d399e5a..2d5f48b8 100644 --- a/vignettes/CustomSamplingRoutine.Rmd +++ b/vignettes/CustomSamplingRoutine.Rmd @@ -690,7 +690,238 @@ plot(rowMeans(forest_preds[,(num_warmstart+1):num_samples]), y, pch=16, abline(0,1,col="red",lty=2,lwd=2.5) ``` -# Demo 4: Causal Inference +# Demo 4: Supervised Learning with Additive Linear Model + +Instead of group random effects, here we show how to combine a stochastic forest +with an additive linear regression term. The model corresponds broadly to + +\begin{equation*} +\begin{aligned} +y &= W\beta + f(X) + \epsilon\\ +f(X) &\sim \text{BART}(c,d)\\ +\beta &\sim \mathcal{N}(0, \tau)\\ +\epsilon &\sim \mathcal{N}\left(0,\sigma^2\right)\\ +\sigma^2 &\sim \text{IG}(a,b) +\end{aligned} +\end{equation*} + +## Simulation + +Simulate a partitioned linear model with a simple additive group random effect structure + +```{r} +# Generate the data +n <- 500 +p_X <- 10 +p_W <- 1 +X <- matrix(runif(n*p_X), ncol = p_X) +W <- matrix(runif(n*p_W), ncol = p_W) +beta_W <- c(5) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-3) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-1) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (1) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (3) +) +lm_term <- W %*% beta_W +y <- lm_term + f_XW + rnorm(n, 0, 1) + +# Standardize outcome +y_bar <- mean(y) +y_std <- sd(y) +resid <- (y-y_bar)/y_std +``` + +## Sampling + +Set some parameters that inform the forest and variance parameter samplers + +```{r} +alpha_bart <- 0.9 +beta_bart <- 1.25 +min_samples_leaf <- 1 +max_depth <- 10 +num_trees <- 100 +cutpoint_grid_size = 100 +global_variance_init = 1. +tau_init = 0.5 +leaf_prior_scale = matrix(c(tau_init), ncol = 1) +nu <- 4 +lambda <- 0.5 +a_leaf <- 2. +b_leaf <- 0.5 +leaf_regression <- F +feature_types <- as.integer(rep(0, p_X)) # 0 = numeric +var_weights <- rep(1/p_X, p_X) +beta_tau <- 20 +``` + +Initialize R-level access to the C++ classes needed to sample our model + +```{r} +# Data +if (leaf_regression) { + forest_dataset <- createForestDataset(X, W) + outcome_model_type <- 1 +} else { + forest_dataset <- createForestDataset(X) + outcome_model_type <- 0 +} +outcome <- createOutcome(resid) + +# Random number generator (std::mt19937) +rng <- createRNG() + +# Sampling data structures +forest_model <- createForestModel(forest_dataset, feature_types, + num_trees, n, alpha_bart, beta_bart, + min_samples_leaf, max_depth) + +# Container of forest samples +if (leaf_regression) { + forest_samples <- createForestContainer(num_trees, 1, F) +} else { + forest_samples <- createForestContainer(num_trees, 1, T) +} +``` + +Prepare to run the sampler + +```{r} +num_warmstart <- 20 +num_mcmc <- 100 +num_samples <- num_warmstart + num_mcmc +beta_init <- 0 +global_var_samples <- c(global_variance_init, rep(0, num_samples)) +leaf_scale_samples <- c(tau_init, rep(0, num_samples)) +beta_samples <- c(beta_init, rep(0, num_samples)) +``` + +Run the grow-from-root sampler to "warm-start" BART + +```{r} +for (i in 1:num_warmstart) { + # Initialize vectors needed for posterior sampling + if (i == 1) { + beta_hat <- beta_init + yhat_forest <- rep(0, n) + partial_res <- resid - yhat_forest + } else { + yhat_forest <- forest_samples$predict_raw_single_forest(forest_dataset, (i-1)-1) + partial_res <- resid - yhat_forest + outcome$add_vector(W %*% beta_hat) + } + # Sample beta from bayesian linear model with gaussian prior + sigma2 <- global_var_samples[i] + beta_posterior_mean <- sum(partial_res*W[,1])/(sigma2 + sum(W[,1]*W[,1])) + beta_posterior_var <- (sigma2*beta_tau)/(sigma2 + sum(W[,1]*W[,1])) + beta_hat <- rnorm(1, beta_posterior_mean, sqrt(beta_posterior_var)) + beta_samples[i+1] <- beta_hat + # Update partial residual before sampling forest + outcome$subtract_vector(W %*% beta_hat) + + # Sample forest + forest_model$sample_one_iteration( + forest_dataset, outcome, forest_samples, rng, feature_types, + outcome_model_type, leaf_prior_scale, var_weights, + sigma2, cutpoint_grid_size, gfr = T + ) + + # Sample global variance parameter + global_var_samples[i+1] <- sample_sigma2_one_iteration( + outcome, rng, nu, lambda + ) +} +``` + +Pick up from the last GFR forest (and associated global variance / leaf +scale parameters) with an MCMC sampler + +```{r} +for (i in (num_warmstart+1):num_samples) { + # Initialize vectors needed for posterior sampling + if (i == 1) { + beta_hat <- beta_init + yhat_forest <- rep(0, n) + partial_res <- resid - yhat_forest + } else { + yhat_forest <- forest_samples$predict_raw_single_forest(forest_dataset, (i-1)-1) + partial_res <- resid - yhat_forest + outcome$add_vector(W %*% beta_hat) + } + # Sample beta from bayesian linear model with gaussian prior + sigma2 <- global_var_samples[i] + beta_posterior_mean <- sum(partial_res*W[,1])/(sigma2 + sum(W[,1]*W[,1])) + beta_posterior_var <- (sigma2*beta_tau)/(sigma2 + sum(W[,1]*W[,1])) + beta_hat <- rnorm(1, beta_posterior_mean, sqrt(beta_posterior_var)) + beta_samples[i+1] <- beta_hat + # Update partial residual before sampling forest + outcome$subtract_vector(W %*% beta_hat) + + # Sample forest + forest_model$sample_one_iteration( + forest_dataset, outcome, forest_samples, rng, feature_types, + outcome_model_type, leaf_prior_scale, var_weights, + global_var_samples[i], cutpoint_grid_size, gfr = F + ) + + # Sample global variance parameter + global_var_samples[i+1] <- sample_sigma2_one_iteration( + outcome, rng, nu, lambda + ) +} +``` + +Predict and rescale samples + +```{r} +# Linear model predictions +lm_preds <- (sapply(1:num_samples, function(x) W[,1]*beta_samples[x+1]))*y_std + +# Forest predictions +forest_preds <- forest_samples$predict(forest_dataset)*y_std + y_bar + +# Overall predictions +preds <- forest_preds + lm_preds + +# Global error variance +sigma_samples <- sqrt(global_var_samples)*y_std + +# Regression parameter +beta_samples <- beta_samples*y_std +``` + +## Results + +Inspect the initial samples obtained via grow-from-root and an additive random effects model + +```{r} +plot(sigma_samples[1:num_warmstart], ylab="sigma") +plot(beta_samples[1:num_warmstart], ylab="beta") +plot(rowMeans(preds[,1:num_warmstart]), y, pch=16, + cex=0.75, xlab = "pred", ylab = "actual") +abline(0,1,col="red",lty=2,lwd=2.5) +``` + +Inspect the BART samples obtained after "warm-starting" plus an additive random effects model + +```{r} +plot(sigma_samples[(num_warmstart+1):num_samples], ylab="sigma") +plot(beta_samples[(num_warmstart+1):num_samples], ylab="beta") +plot(rowMeans(preds[,(num_warmstart+1):num_samples]), y, pch=16, + cex=0.75, xlab = "pred", ylab = "actual") +abline(0,1,col="red",lty=2,lwd=2.5) +``` + +Now inspect the samples from the BART forest alone (without considering the additive linear model predictions) + +```{r} +plot(rowMeans(forest_preds[,(num_warmstart+1):num_samples]), y, pch=16, + cex=0.75, xlab = "pred", ylab = "actual") +abline(0,1,col="red",lty=2,lwd=2.5) +``` + +# Demo 5: Causal Inference Here we show how to implement the Bayesian Causal Forest (BCF) model of @hahn2020bayesian using `stochtree`'s prototype API, including demoing a non-trivial sampling step