Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ export(rootResetRandomEffectsModel)
export(rootResetRandomEffectsTracker)
export(sampleGlobalErrorVarianceOneIteration)
export(sampleLeafVarianceOneIteration)
export(sample_without_replacement)
export(saveBARTModelToJson)
export(saveBARTModelToJsonFile)
export(saveBARTModelToJsonString)
Expand Down
29 changes: 22 additions & 7 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#' - `keep_every` How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default `1`. Setting `keep_every <- k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples.
#' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`.
#' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`.
#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
#'
#' @param mean_forest_params (Optional) A list of mean forest model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
Expand All @@ -58,7 +59,7 @@
#' - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
#' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
#'
#' @param variance_forest_params (Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
Expand All @@ -73,6 +74,7 @@
#' - `var_forest_prior_scale` Scale parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / leaf_prior_calibration_param^2` if not set.
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
#' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
#'
#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
#' @export
Expand All @@ -98,6 +100,7 @@
#' X_train <- X[train_inds,]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#'
#' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
#' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train = NULL,
Expand All @@ -114,7 +117,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
sigma2_global_shape = 0, sigma2_global_scale = 0,
variable_weights = NULL, random_seed = -1,
keep_burnin = FALSE, keep_gfr = FALSE, keep_every = 1,
num_chains = 1, verbose = FALSE
num_chains = 1, verbose = FALSE,
probit_outcome_model = FALSE
)
general_params_updated <- preprocessParams(
general_params_default, general_params
Expand All @@ -127,7 +131,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
sample_sigma2_leaf = TRUE, sigma2_leaf_init = NULL,
sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL,
keep_vars = NULL, drop_vars = NULL,
probit_outcome_model = FALSE
num_features_subsample = NULL
)
mean_forest_params_updated <- preprocessParams(
mean_forest_params_default, mean_forest_params
Expand All @@ -141,7 +145,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
var_forest_leaf_init = NULL,
var_forest_prior_shape = NULL,
var_forest_prior_scale = NULL,
keep_vars = NULL, drop_vars = NULL
keep_vars = NULL, drop_vars = NULL,
num_features_subsample = NULL
)
variance_forest_params_updated <- preprocessParams(
variance_forest_params_default, variance_forest_params
Expand All @@ -162,6 +167,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
keep_every <- general_params_updated$keep_every
num_chains <- general_params_updated$num_chains
verbose <- general_params_updated$verbose
probit_outcome_model <- general_params_updated$probit_outcome_model

# 2. Mean forest parameters
num_trees_mean <- mean_forest_params_updated$num_trees
Expand All @@ -175,7 +181,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
b_leaf <- mean_forest_params_updated$sigma2_leaf_scale
keep_vars_mean <- mean_forest_params_updated$keep_vars
drop_vars_mean <- mean_forest_params_updated$drop_vars
probit_outcome_model <- mean_forest_params_updated$probit_outcome_model
num_features_subsample_mean <- mean_forest_params_updated$num_features_subsample

# 3. Variance forest parameters
num_trees_variance <- variance_forest_params_updated$num_trees
Expand All @@ -189,6 +195,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
b_forest <- variance_forest_params_updated$var_forest_prior_scale
keep_vars_variance <- variance_forest_params_updated$keep_vars
drop_vars_variance <- variance_forest_params_updated$drop_vars
num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample

# Check if there are enough GFR samples to seed num_chains samplers
if (num_gfr > 0) {
Expand Down Expand Up @@ -373,6 +380,14 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
variable_weights_variance <- variable_weights_variance[original_var_indices]*variable_weights_adj
variable_weights_variance[!(original_var_indices %in% variable_subset_variance)] <- 0
}

# Set num_features_subsample to default, ncol(X_train), if not already set
if (is.null(num_features_subsample_mean)) {
num_features_subsample_mean <- ncol(X_train)
}
if (is.null(num_features_subsample_variance)) {
num_features_subsample_variance <- ncol(X_train)
}

# Convert all input data to matrices if not already converted
if ((is.null(dim(leaf_basis_train))) && (!is.null(leaf_basis_train))) {
Expand Down Expand Up @@ -633,15 +648,15 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
num_observations=nrow(X_train), variable_weights=variable_weights_mean, leaf_dimension=leaf_dimension,
alpha=alpha_mean, beta=beta_mean, min_samples_leaf=min_samples_leaf_mean, max_depth=max_depth_mean,
leaf_model_type=leaf_model_mean_forest, leaf_model_scale=current_leaf_scale,
cutpoint_grid_size=cutpoint_grid_size)
cutpoint_grid_size=cutpoint_grid_size, num_features_subsample=num_features_subsample_mean)
forest_model_mean <- createForestModel(forest_dataset_train, forest_model_config_mean, global_model_config)
}
if (include_variance_forest) {
forest_model_config_variance <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_variance, num_features=ncol(X_train),
num_observations=nrow(X_train), variable_weights=variable_weights_variance, leaf_dimension=1,
alpha=alpha_variance, beta=beta_variance, min_samples_leaf=min_samples_leaf_variance,
max_depth=max_depth_variance, leaf_model_type=leaf_model_variance_forest,
cutpoint_grid_size=cutpoint_grid_size)
cutpoint_grid_size=cutpoint_grid_size, num_features_subsample=num_features_subsample_variance)
forest_model_variance <- createForestModel(forest_dataset_train, forest_model_config_variance, global_model_config)
}

Expand Down
34 changes: 27 additions & 7 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
#' - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
#' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
#'
#' @param treatment_effect_forest_params (Optional) A list of treatment effect forest model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
Expand All @@ -78,6 +79,7 @@
#' - `delta_max` Maximum plausible conditional distributional treatment effect (i.e. P(Y(1) = 1 | X) - P(Y(0) = 1 | X)) when the outcome is binary. Only used when the outcome is specified as a probit model in `general_params`. Must be > 0 and < 1. Default: `0.9`. Ignored if `sigma2_leaf_init` is set directly, as this parameter is used to calibrate `sigma2_leaf_init`.
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
#' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
#'
#' @param variance_forest_params (Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
Expand All @@ -92,6 +94,7 @@
#' - `var_forest_prior_scale` Scale parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / 1.5^2` if not set.
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
#' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
#'
#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
#' @export
Expand Down Expand Up @@ -171,7 +174,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
min_samples_leaf = 5, max_depth = 10,
sample_sigma2_leaf = TRUE, sigma2_leaf_init = NULL,
sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL,
keep_vars = NULL, drop_vars = NULL
keep_vars = NULL, drop_vars = NULL,
num_features_subsample = NULL
)
prognostic_forest_params_updated <- preprocessParams(
prognostic_forest_params_default, prognostic_forest_params
Expand All @@ -183,8 +187,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
min_samples_leaf = 5, max_depth = 5,
sample_sigma2_leaf = FALSE, sigma2_leaf_init = NULL,
sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL,
keep_vars = NULL, drop_vars = NULL,
delta_max = 0.9
keep_vars = NULL, drop_vars = NULL, delta_max = 0.9,
num_features_subsample = NULL
)
treatment_effect_forest_params_updated <- preprocessParams(
treatment_effect_forest_params_default, treatment_effect_forest_params
Expand All @@ -198,7 +202,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
variance_forest_init = NULL,
var_forest_prior_shape = NULL,
var_forest_prior_scale = NULL,
keep_vars = NULL, drop_vars = NULL
keep_vars = NULL, drop_vars = NULL,
num_features_subsample = NULL
)
variance_forest_params_updated <- preprocessParams(
variance_forest_params_default, variance_forest_params
Expand Down Expand Up @@ -238,6 +243,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
b_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_scale
keep_vars_mu <- prognostic_forest_params_updated$keep_vars
drop_vars_mu <- prognostic_forest_params_updated$drop_vars
num_features_subsample_mu <- prognostic_forest_params_updated$num_features_subsample

# 3. Tau forest parameters
num_trees_tau <- treatment_effect_forest_params_updated$num_trees
Expand All @@ -252,6 +258,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
keep_vars_tau <- treatment_effect_forest_params_updated$keep_vars
drop_vars_tau <- treatment_effect_forest_params_updated$drop_vars
delta_max <- treatment_effect_forest_params_updated$delta_max
num_features_subsample_tau <- treatment_effect_forest_params_updated$num_features_subsample

# 4. Variance forest parameters
num_trees_variance <- variance_forest_params_updated$num_trees
Expand All @@ -265,6 +272,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
b_forest <- variance_forest_params_updated$var_forest_prior_scale
keep_vars_variance <- variance_forest_params_updated$keep_vars
drop_vars_variance <- variance_forest_params_updated$drop_vars
num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample

# Check if there are enough GFR samples to seed num_chains samplers
if (num_gfr > 0) {
Expand Down Expand Up @@ -477,6 +485,17 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
X_test_raw <- X_test
if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata)

# Set num_features_subsample to default, ncol(X_train), if not already set
if (is.null(num_features_subsample_mu)) {
num_features_subsample_mu <- ncol(X_train)
}
if (is.null(num_features_subsample_tau)) {
num_features_subsample_tau <- ncol(X_train)
}
if (is.null(num_features_subsample_variance)) {
num_features_subsample_variance <- ncol(X_train)
}

# Convert all input data to matrices if not already converted
if ((is.null(dim(Z_train))) && (!is.null(Z_train))) {
Z_train <- as.matrix(as.numeric(Z_train))
Expand Down Expand Up @@ -899,20 +918,21 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
num_observations=nrow(X_train), variable_weights=variable_weights_mu, leaf_dimension=leaf_dimension_mu_forest,
alpha=alpha_mu, beta=beta_mu, min_samples_leaf=min_samples_leaf_mu, max_depth=max_depth_mu,
leaf_model_type=leaf_model_mu_forest, leaf_model_scale=current_leaf_scale_mu,
cutpoint_grid_size=cutpoint_grid_size)
cutpoint_grid_size=cutpoint_grid_size, num_features_subsample = num_features_subsample_mu)
forest_model_config_tau <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_tau, num_features=ncol(X_train),
num_observations=nrow(X_train), variable_weights=variable_weights_tau, leaf_dimension=leaf_dimension_tau_forest,
alpha=alpha_tau, beta=beta_tau, min_samples_leaf=min_samples_leaf_tau, max_depth=max_depth_tau,
leaf_model_type=leaf_model_tau_forest, leaf_model_scale=current_leaf_scale_tau,
cutpoint_grid_size=cutpoint_grid_size)
cutpoint_grid_size=cutpoint_grid_size, num_features_subsample = num_features_subsample_tau)
forest_model_mu <- createForestModel(forest_dataset_train, forest_model_config_mu, global_model_config)
forest_model_tau <- createForestModel(forest_dataset_train, forest_model_config_tau, global_model_config)
if (include_variance_forest) {
forest_model_config_variance <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_variance, num_features=ncol(X_train),
num_observations=nrow(X_train), variable_weights=variable_weights_variance,
leaf_dimension=leaf_dimension_variance_forest, alpha=alpha_variance, beta=beta_variance,
min_samples_leaf=min_samples_leaf_variance, max_depth=max_depth_variance,
leaf_model_type=leaf_model_variance_forest, cutpoint_grid_size=cutpoint_grid_size)
leaf_model_type=leaf_model_variance_forest, cutpoint_grid_size=cutpoint_grid_size,
num_features_subsample=num_features_subsample_variance)
forest_model_variance <- createForestModel(forest_dataset_train, forest_model_config_variance, global_model_config)
}

Expand Down
Loading
Loading