Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Description: Stochastic tree ensembles (XBART and BART) for supervised learning
License: MIT + file LICENSE
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
LinkingTo:
cpp11, BH
Suggests:
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export(orderedCatInitializeAndPreprocess)
export(orderedCatPreprocess)
export(preprocessBartParams)
export(preprocessBcfParams)
export(preprocessParams)
export(preprocessPredictionData)
export(preprocessPredictionDataFrame)
export(preprocessPredictionMatrix)
Expand Down
230 changes: 119 additions & 111 deletions R/bart.R

Large diffs are not rendered by default.

300 changes: 169 additions & 131 deletions R/bcf.R

Large diffs are not rendered by default.

120 changes: 96 additions & 24 deletions R/utils.R
Original file line number Diff line number Diff line change
@@ -1,35 +1,78 @@
#' Preprocess a parameter list, overriding defaults with any provided parameters.
#'
#' @param default_params List of parameters with default values set.
#' @param user_params (Optional) User-supplied overrides to `default_params`.
#'
#' @return Parameter list with defaults overriden by values supplied in `user_params`
#' @export
preprocessParams <- function(default_params, user_params = NULL) {
# Override defaults from general_params
if (!is.null(user_params)) {
for (key in names(user_params)) {
if (key %in% names(default_params)) {
val <- user_params[[key]]
if (!is.null(val)) default_params[[key]] <- val
}
}
}

# Return result
return(default_params)
}

#' Preprocess BART parameter list. Override defaults with any provided parameters.
#'
#' @param params Parameter list
#' @param general_params List of any non-forest-specific parameters
#' @param mean_forest_params List of any mean forest parameters
#' @param variance_forest_params List of any variance forest parameters
#'
#' @return Parameter list with defaults overriden by values supplied in `params`
#' @return Parameter list with defaults overriden by values supplied in parameter lists
#' @export
preprocessBartParams <- function(params) {
preprocessBartParams <- function(general_params, mean_forest_params, variance_forest_params) {
# Default parameter values
processed_params <- list(
cutpoint_grid_size = 100, sigma_leaf_init = NULL,
cutpoint_grid_size = 100,
alpha_mean = 0.95, beta_mean = 2.0,
min_samples_leaf_mean = 5, max_depth_mean = 10,
variable_weights_mean = NULL, num_trees_mean = 200,
alpha_variance = 0.95, beta_variance = 2.0,
min_samples_leaf_variance = 5, max_depth_variance = 10,
a_global = 0, b_global = 0, a_leaf = 3, b_leaf = NULL,
a_forest = NULL, b_forest = NULL, variance_scale = 1,
sigma2_init = NULL, variance_forest_init = NULL,
pct_var_sigma2_init = 1, pct_var_variance_forest_init = 1,
variable_weights_mean = NULL, variable_weights_variance = NULL,
num_trees_mean = 200, num_trees_variance = 0,
sample_sigma_global = T, sample_sigma_leaf = F,
variable_weights_variance = NULL, num_trees_variance = 0,
sample_sigma2_global = T, sigma2_global_init = NULL,
sigma2_global_shape = 0, sigma2_global_scale = 0,
sample_sigma2_leaf = T, sigma2_leaf_init = NULL,
sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL,
var_forest_prior_shape = NULL, var_forest_prior_scale = NULL,
variance_forest_init = NULL,
sample_sigma_global = T, sample_sigma2_leaf_mean = F,
random_seed = -1, keep_burnin = F, keep_gfr = F, keep_every = 1,
num_chains = 1, standardize = T, verbose = F
)

# Override defaults
for (key in names(params)) {
if (!key %in% names(processed_params)) {
stop("Variable ", key, " is not a valid BART model parameter")
# Override defaults from general_params
for (key in names(general_params)) {
if (key %in% names(processed_params)) {
val <- general_params[[key]]
if (!is.null(val)) processed_params[[key]] <- val
}
}

# Override defaults from mean_forest_params
for (key in names(mean_forest_params)) {
modified_key <- paste0(key, "_mean")
if (modified_key %in% names(processed_params)) {
val <- general_params[[key]]
if (!is.null(val)) processed_params[[modified_key]] <- val
}
}

# Override defaults from variance_forest_params
for (key in names(variance_forest_params)) {
modified_key <- paste0(key, "_variance")
if (modified_key %in% names(processed_params)) {
val <- general_params[[key]]
if (!is.null(val)) processed_params[[modified_key]] <- val
}
val <- params[[key]]
if (!is.null(val)) processed_params[[key]] <- val
}

# Return result
Expand All @@ -38,9 +81,12 @@ preprocessBartParams <- function(params) {

#' Preprocess BCF parameter list. Override defaults with any provided parameters.
#'
#' @param params Parameter list
#' @param general_params List of any non-forest-specific parameters
#' @param mu_forest_params List of any mu forest parameters
#' @param tau_forest_params List of any tau forest parameters
#' @param variance_forest_params List of any variance forest parameters
#'
#' @return Parameter list with defaults overriden by values supplied in `params`
#' @return Parameter list with defaults overriden by values supplied in parameter lists
#' @export
preprocessBcfParams <- function(params) {
# Default parameter values
Expand All @@ -57,19 +103,45 @@ preprocessBcfParams <- function(params) {
keep_vars_tau = NULL, drop_vars_tau = NULL, keep_vars_variance = NULL,
drop_vars_variance = NULL, num_trees_mu = 250, num_trees_tau = 50,
num_trees_variance = 0, num_gfr = 5, num_burnin = 0, num_mcmc = 100,
sample_sigma_global = T, sample_sigma_leaf_mu = T, sample_sigma_leaf_tau = F,
sample_sigma_global = T, sample_sigma2_leaf_mu = T, sample_sigma2_leaf_tau = F,
propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5, b_1 = 0.5,
rfx_prior_var = NULL, random_seed = -1, keep_burnin = F, keep_gfr = F,
keep_every = 1, num_chains = 1, standardize = T, verbose = F
)

# Override defaults
for (key in names(params)) {
if (!key %in% names(processed_params)) {
stop("Variable ", key, " is not a valid BART model parameter")
if (key %in% names(processed_params)) {
val <- params[[key]]
if (!is.null(val)) processed_params[[key]] <- val
}
}

# Override defaults from mu_forest_params
for (key in names(mu_forest_params)) {
modified_key <- paste0(key, "_mu")
if (modified_key %in% names(processed_params)) {
val <- general_params[[key]]
if (!is.null(val)) processed_params[[modified_key]] <- val
}
}

# Override defaults from tau_forest_params
for (key in names(tau_forest_params)) {
modified_key <- paste0(key, "_tau")
if (modified_key %in% names(processed_params)) {
val <- general_params[[key]]
if (!is.null(val)) processed_params[[modified_key]] <- val
}
}

# Override defaults from variance_forest_params
for (key in names(variance_forest_params)) {
modified_key <- paste0(key, "_variance")
if (modified_key %in% names(processed_params)) {
val <- general_params[[key]]
if (!is.null(val)) processed_params[[modified_key]] <- val
}
val <- params[[key]]
if (!is.null(val)) processed_params[[key]] <- val
}

# Return result
Expand Down
3 changes: 2 additions & 1 deletion demo/notebooks/causal_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@
"outputs": [],
"source": [
"bcf_model = BCFModel()\n",
"bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100, params={\"keep_every\": 5})"
"general_params = {\"keep_every\": 5}\n",
"bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100, general_params=general_params)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions demo/notebooks/supervised_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@
"outputs": [],
"source": [
"bart_model = BARTModel()\n",
"param_dict = {\"num_chains\": 3}\n",
"bart_model.sample(X_train=X_train, y_train=y_train, basis_train=basis_train, X_test=X_test, basis_test=basis_test, num_gfr=10, num_mcmc=100, params=param_dict)"
"general_params = {\"num_chains\": 3}\n",
"bart_model.sample(X_train=X_train, y_train=y_train, basis_train=basis_train, X_test=X_test, basis_test=basis_test, num_gfr=10, num_mcmc=100, general_params=general_params)"
]
},
{
Expand Down
Loading
Loading