Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ export(oneHotEncode)
export(oneHotInitializeAndEncode)
export(orderedCatInitializeAndPreprocess)
export(orderedCatPreprocess)
export(preprocessBartParams)
export(preprocessBcfParams)
export(preprocessPredictionData)
export(preprocessPredictionDataFrame)
export(preprocessPredictionMatrix)
Expand Down
136 changes: 90 additions & 46 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,43 +24,64 @@
#' We do not currently support (but plan to in the near future), test set evaluation for group labels
#' that were not in the training set.
#' @param rfx_basis_test (Optional) Test set basis for "random-slope" regression in additive random effects model.
#' @param cutpoint_grid_size Maximum size of the "grid" of potential cutpoints to consider. Default: 100.
#' @param sigma_leaf_init Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees_mean` if not set here.
#' @param leaf_model Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.
#' @param alpha_mean Prior probability of splitting for a tree of depth 0 in the mean model. Tree split prior combines `alpha_mean` and `beta_mean` via `alpha_mean*(1+node_depth)^-beta_mean`. Default: 0.95.
#' @param beta_mean Exponent that decreases split probabilities for nodes of depth > 0 in the mean model. Tree split prior combines `alpha_mean` and `beta_mean` via `alpha_mean*(1+node_depth)^-beta_mean`. Default: 2.
#' @param min_samples_leaf_mean Minimum allowable size of a leaf, in terms of training samples, in the mean model. Default: 5.
#' @param max_depth_mean Maximum depth of any tree in the ensemble in the mean model. Default: 10. Can be overriden with ``-1`` which does not enforce any depth limits on trees.
#' @param alpha_variance Prior probability of splitting for a tree of depth 0 in the variance model. Tree split prior combines `alpha_variance` and `beta_variance` via `alpha_variance*(1+node_depth)^-beta_variance`. Default: 0.95.
#' @param beta_variance Exponent that decreases split probabilities for nodes of depth > 0 in the variance model. Tree split prior combines `alpha_variance` and `beta_variance` via `alpha_variance*(1+node_depth)^-beta_variance` .Default: 2.
#' @param min_samples_leaf_variance Minimum allowable size of a leaf, in terms of training samples, in the variance model. Default: 5.
#' @param max_depth_variance Maximum depth of any tree in the ensemble in the variance model. Default: 10. Can be overriden with ``-1`` which does not enforce any depth limits on trees.
#' @param a_global Shape parameter in the `IG(a_global, b_global)` global error variance model. Default: 0.
#' @param b_global Scale parameter in the `IG(a_global, b_global)` global error variance model. Default: 0.
#' @param a_leaf Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: 3.
#' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees_mean` if not set here.
#' @param a_forest Shape parameter in the `IG(a_forest, b_forest)` conditional error variance model (which is only sampled if `num_trees_variance > 0`). Calibrated internally as `num_trees_variance / 1.5^2 + 0.5` if not set.
#' @param b_forest Scale parameter in the `IG(a_forest, b_forest)` conditional error variance model (which is only sampled if `num_trees_variance > 0`). Calibrated internally as `num_trees_variance / 1.5^2` if not set.
#' @param q Quantile used to calibrated `lambda` as in Sparapani et al (2021). Default: 0.9.
#' @param sigma2_init Starting value of global error variance parameter. Calibrated internally as `pct_var_sigma2_init*var((y-mean(y))/sd(y))` if not set.
#' @param variance_forest_init Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as `log(pct_var_variance_forest_init*var((y-mean(y))/sd(y)))/num_trees_variance` if not set.
#' @param pct_var_sigma2_init Percentage of standardized outcome variance used to initialize global error variance parameter. Default: 1. Superseded by `sigma2_init`.
#' @param pct_var_variance_forest_init Percentage of standardized outcome variance used to initialize global error variance parameter. Default: 1. Superseded by `variance_forest_init`.
#' @param variance_scale Variance after the data have been scaled. Default: 1.
#' @param variable_weights_mean Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here.
#' @param variable_weights_variance Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here.
#' @param num_trees_mean Number of trees in the ensemble for the conditional mean model. Default: 200. If `num_trees_mean = 0`, the conditional mean will not be modeled using a forest and the function will only proceed if `num_trees_variance > 0`.
#' @param num_trees_variance Number of trees in the ensemble for the conditional variance model. Default: 0. Variance is only modeled using a tree / forest if `num_trees_variance > 0`.
#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.
#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100.
#' @param sample_sigma_global Whether or not to update the `sigma^2` global error variance parameter based on `IG(a_global, b_global)`. Default: T.
#' @param sample_sigma_leaf Whether or not to update the `tau` leaf scale variance parameter based on `IG(a_leaf, b_leaf)`. Cannot (currently) be set to true if `ncol(W_train)>1`. Default: F.
#' @param random_seed Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`.
#' @param keep_burnin Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.
#' @param keep_gfr Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0.
#' @param verbose Whether or not to print progress during the sampling loops. Default: FALSE.
#' @param params The list of model parameters, each of which has a default value.
#'
#' **1. Global Parameters**
#'
#' - `cutpoint_grid_size` Maximum size of the "grid" of potential cutpoints to consider. Default: `100`.
#' - `sigma2_init` Starting value of global error variance parameter. Calibrated internally as `pct_var_sigma2_init*var((y-mean(y))/sd(y))` if not set.
#' - `pct_var_sigma2_init` Percentage of standardized outcome variance used to initialize global error variance parameter. Default: `1`. Superseded by `sigma2_init`.
#' - `variance_scale` Variance after the data have been scaled. Default: `1`.
#' - `a_global` Shape parameter in the `IG(a_global, b_global)` global error variance model. Default: `0`.
#' - `b_global` Scale parameter in the `IG(a_global, b_global)` global error variance model. Default: `0`.
#' - `random_seed` Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`.
#' - `sample_sigma_global` Whether or not to update the `sigma^2` global error variance parameter based on `IG(a_global, b_global)`. Default: `TRUE`.
#' - `keep_burnin` Whether or not "burnin" samples should be included in cached predictions. Default `FALSE`. Ignored if `num_mcmc = 0`.
#' - `keep_gfr` Whether or not "grow-from-root" samples should be included in cached predictions. Default `TRUE`. Ignored if `num_mcmc = 0`.
#' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`.
#'
#' **2. Mean Forest Parameters**
#'
#' - `num_trees_mean` Number of trees in the ensemble for the conditional mean model. Default: `200`. If `num_trees_mean = 0`, the conditional mean will not be modeled using a forest, and the function will only proceed if `num_trees_variance > 0`.
#' - `sample_sigma_leaf` Whether or not to update the `tau` leaf scale variance parameter based on `IG(a_leaf, b_leaf)`. Cannot (currently) be set to true if `ncol(W_train)>1`. Default: `FALSE`.
#'
#' **2.1. Tree Prior Parameters**
#'
#' - `alpha_mean` Prior probability of splitting for a tree of depth 0 in the mean model. Tree split prior combines `alpha_mean` and `beta_mean` via `alpha_mean*(1+node_depth)^-beta_mean`. Default: `0.95`.
#' - `beta_mean` Exponent that decreases split probabilities for nodes of depth > 0 in the mean model. Tree split prior combines `alpha_mean` and `beta_mean` via `alpha_mean*(1+node_depth)^-beta_mean`. Default: `2`.
#' - `min_samples_leaf_mean` Minimum allowable size of a leaf, in terms of training samples, in the mean model. Default: `5`.
#' - `max_depth_mean` Maximum depth of any tree in the ensemble in the mean model. Default: `10`. Can be overridden with ``-1`` which does not enforce any depth limits on trees.
#'
#' **2.2. Leaf Model Parameters**
#'
#' - `variable_weights_mean` Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here.
#' - `sigma_leaf_init` Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees_mean` if not set here.
#' - `a_leaf` Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: `3`.
#' - `b_leaf` Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees_mean` if not set here.
#'
#' **3. Conditional Variance Forest Parameters**
#'
#' - `num_trees_variance` Number of trees in the ensemble for the conditional variance model. Default: `0`. Variance is only modeled using a tree / forest if `num_trees_variance > 0`.
#' - `variance_forest_init` Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as `log(pct_var_variance_forest_init*var((y-mean(y))/sd(y)))/num_trees_variance` if not set.
#' - `pct_var_variance_forest_init` Percentage of standardized outcome variance used to initialize global error variance parameter. Default: `1`. Superseded by `variance_forest_init`.
#'
#' **3.1. Tree Prior Parameters**
#'
#' - `alpha_variance` Prior probability of splitting for a tree of depth 0 in the variance model. Tree split prior combines `alpha_variance` and `beta_variance` via `alpha_variance*(1+node_depth)^-beta_variance`. Default: `0.95`.
#' - `beta_variance` Exponent that decreases split probabilities for nodes of depth > 0 in the variance model. Tree split prior combines `alpha_variance` and `beta_variance` via `alpha_variance*(1+node_depth)^-beta_variance`. Default: `2`.
#' - `min_samples_leaf_variance` Minimum allowable size of a leaf, in terms of training samples, in the variance model. Default: `5`.
#' - `max_depth_variance` Maximum depth of any tree in the ensemble in the variance model. Default: `10`. Can be overridden with ``-1`` which does not enforce any depth limits on trees.
#'
#' **3.2. Leaf Model Parameters**
#'
#' - `variable_weights_variance` Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here.
#' - `sigma_leaf_init` Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees_mean` if not set here.
#' - `a_forest` Shape parameter in the `IG(a_forest, b_forest)` conditional error variance model (which is only sampled if `num_trees_variance > 0`). Calibrated internally as `num_trees_variance / 1.5^2 + 0.5` if not set.
#' - `b_forest` Scale parameter in the `IG(a_forest, b_forest)` conditional error variance model (which is only sampled if `num_trees_variance > 0`). Calibrated internally as `num_trees_variance / 1.5^2` if not set.
#'
#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
#' @export
#'
Expand Down Expand Up @@ -91,19 +112,42 @@
bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
rfx_basis_train = NULL, X_test = NULL, W_test = NULL,
group_ids_test = NULL, rfx_basis_test = NULL,
cutpoint_grid_size = 100, sigma_leaf_init = NULL,
alpha_mean = 0.95, beta_mean = 2.0, min_samples_leaf_mean = 5,
max_depth_mean = 10, alpha_variance = 0.95, beta_variance = 2.0,
min_samples_leaf_variance = 5, max_depth_variance = 10,
a_global = 0, b_global = 0, a_leaf = 3, b_leaf = NULL,
a_forest = NULL, b_forest = NULL, q = 0.9, sigma2_init = NULL,
variance_forest_init = NULL, pct_var_sigma2_init = 1,
pct_var_variance_forest_init = 1, variance_scale = 1,
variable_weights_mean = NULL, variable_weights_variance = NULL,
num_trees_mean = 200, num_trees_variance = 0,
num_gfr = 5, num_burnin = 0, num_mcmc = 100,
sample_sigma_global = T, sample_sigma_leaf = F, random_seed = -1,
keep_burnin = F, keep_gfr = F, verbose = F) {
params = list()) {
# Extract BART parameters
bart_params <- preprocessBartParams(params)
cutpoint_grid_size <- bart_params$cutpoint_grid_size
sigma_leaf_init <- bart_params$sigma_leaf_init
alpha_mean <- bart_params$alpha_mean
beta_mean <- bart_params$beta_mean
min_samples_leaf_mean <- bart_params$min_samples_leaf_mean
max_depth_mean <- bart_params$max_depth_mean
alpha_variance <- bart_params$alpha_variance
beta_variance <- bart_params$beta_variance
min_samples_leaf_variance <- bart_params$min_samples_leaf_variance
max_depth_variance <- bart_params$max_depth_variance
a_global <- bart_params$a_global
b_global <- bart_params$b_global
a_leaf <- bart_params$a_leaf
b_leaf <- bart_params$b_leaf
a_forest <- bart_params$a_forest
b_forest <- bart_params$b_forest
variance_scale <- bart_params$variance_scale
sigma2_init <- bart_params$sigma2_init
variance_forest_init <- bart_params$variance_forest_init
pct_var_sigma2_init <- bart_params$pct_var_sigma2_init
pct_var_variance_forest_init <- bart_params$pct_var_variance_forest_init
variable_weights_mean <- bart_params$variable_weights_mean
variable_weights_variance <- bart_params$variable_weights_variance
num_trees_mean <- bart_params$num_trees_mean
num_trees_variance <- bart_params$num_trees_variance
sample_sigma_global <- bart_params$sample_sigma_global
sample_sigma_leaf <- bart_params$sample_sigma_leaf
random_seed <- bart_params$random_seed
keep_burnin <- bart_params$keep_burnin
keep_gfr <- bart_params$keep_gfr
verbose <- bart_params$verbose

# Determine whether conditional mean, variance, or both will be modeled
if (num_trees_variance > 0) include_variance_forest = T
else include_variance_forest = F
Expand All @@ -121,7 +165,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
}

# Override tau sampling if there is no mean forest
if (!include_mean_forest) sample_tau <- F
if (!include_mean_forest) sample_sigma_leaf <- F

# Variable weight preprocessing (and initialization if necessary)
if (include_mean_forest) {
Expand Down
Loading
Loading