diff --git a/.github/workflows/r-test.yml b/.github/workflows/r-test.yml index 5e3cd3a5..7260ac0e 100644 --- a/.github/workflows/r-test.yml +++ b/.github/workflows/r-test.yml @@ -34,9 +34,13 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: any::testthat, any::decor + extra-packages: any::testthat, any::decor, any::rcmdcheck + needs: check - - name: Run unit tests + - name: Create a CRAN-ready version of the R package run: | - Rscript cran-bootstrap.R - Rscript -e 'testthat::test_local("stochtree_cran")' \ No newline at end of file + Rscript cran-bootstrap.R 0 + + - uses: r-lib/actions/check-r-package@v2 + with: + working-directory: 'stochtree_cran' diff --git a/DESCRIPTION b/DESCRIPTION index 3aa06675..169aa5a3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: stochtree Title: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference -Version: 0.0.0.9000 +Version: 0.0.1 Authors@R: c( person("Drew", "Herren", email = "drewherrenopensource@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-4109-6611")), @@ -17,6 +17,7 @@ RoxygenNote: 7.3.2 LinkingTo: cpp11, BH Suggests: + testthat (>= 3.0.0), doParallel, foreach, ggplot2, @@ -26,12 +27,11 @@ Suggests: MASS, mvtnorm, rmarkdown, - testthat (>= 3.0.0), tgp VignetteBuilder: knitr SystemRequirements: C++17 Imports: R6, stats -URL: https://stochastictree.github.io/stochtree-r/ +URL: https://stochtree.ai Config/testthat/edition: 3 diff --git a/cpp_docs/Doxyfile b/Doxyfile similarity index 99% rename from cpp_docs/Doxyfile rename to Doxyfile index 366412c6..06c38d94 100644 --- a/cpp_docs/Doxyfile +++ b/Doxyfile @@ -548,7 +548,7 @@ EXTRACT_PACKAGE = NO # included in the documentation. # The default value is: NO. -EXTRACT_STATIC = NO +EXTRACT_STATIC = YES # If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined # locally in source files will be included in the documentation. If set to NO, @@ -588,7 +588,7 @@ RESOLVE_UNNAMED_PARAMS = YES # section is generated. This option has no effect if EXTRACT_ALL is enabled. # The default value is: NO. -HIDE_UNDOC_MEMBERS = NO +HIDE_UNDOC_MEMBERS = YES # If the HIDE_UNDOC_CLASSES tag is set to YES, Doxygen will hide all # undocumented classes that are normally visible in the class hierarchy. If set @@ -687,7 +687,7 @@ INLINE_INFO = YES # name. If set to NO, the members will appear in declaration order. # The default value is: YES. -SORT_MEMBER_DOCS = YES +SORT_MEMBER_DOCS = NO # If the SORT_BRIEF_DOCS tag is set to YES then Doxygen will sort the brief # descriptions of file, namespace and class members alphabetically by member @@ -740,7 +740,7 @@ STRICT_PROTO_MATCHING = NO # list. This list is created by putting \todo commands in the documentation. # The default value is: YES. -GENERATE_TODOLIST = YES +GENERATE_TODOLIST = NO # The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test # list. This list is created by putting \test commands in the documentation. @@ -965,7 +965,7 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = +INPUT = include/stochtree # This tag can be used to specify the character encoding of the source files # that Doxygen parses. Internally Doxygen uses the UTF-8 encoding. Doxygen uses @@ -1081,7 +1081,10 @@ EXCLUDE_PATTERNS = */test/* \ # wildcard * is used, a substring. Examples: ANamespace, AClass, # ANamespace::AClass, ANamespace::*Test -EXCLUDE_SYMBOLS = StochTree::CommonC +EXCLUDE_SYMBOLS = StochTree::CommonC \ + StochTree::CategorySampleTracker \ + StochTree::ExtractMultipleFeaturesFromMemory \ + StochTree::ExtractSingleFeatureFromMemory # The EXAMPLE_PATH tag can be used to specify one or more files or directories # that contain example code fragments that are included (see the \include @@ -1805,7 +1808,7 @@ FORMULA_MACROFILE = # The default value is: NO. # This tag requires that the tag GENERATE_HTML is set to YES. -USE_MATHJAX = NO +USE_MATHJAX = YES # With MATHJAX_VERSION it is possible to specify the MathJax version to be used. # Note that the different versions of MathJax have different requirements with diff --git a/NAMESPACE b/NAMESPACE index aa71c7fc..7c746a36 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -49,8 +49,6 @@ export(oneHotEncode) export(oneHotInitializeAndEncode) export(orderedCatInitializeAndPreprocess) export(orderedCatPreprocess) -export(preprocessBartParams) -export(preprocessBcfParams) export(preprocessParams) export(preprocessPredictionData) export(preprocessPredictionDataFrame) diff --git a/R/bart.R b/R/bart.R index 2f7e8333..851e36cf 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1015,9 +1015,9 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' bart_model <- bart(X_train = X_train, y_train = y_train, -#' group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, -#' X_test = X_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, +#' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, +#' group_ids_train = group_ids_train, group_ids_test = group_ids_test, +#' rfx_basis_train = rfx_basis_train, rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100) #' rfx_samples <- getRandomEffectSamples(bart_model) getRandomEffectSamples.bartmodel <- function(object, ...){ @@ -1180,7 +1180,7 @@ convertBARTStateToJson <- function(param_list, mean_forest = NULL, variance_fore jsonobj$add_forest(mean_forest) } if (param_list$include_variance_forest) { - jsonobj$add_forest(object$variance_forests) + jsonobj$add_forest(variance_forest) } # Add sampled parameters diff --git a/R/bcf.R b/R/bcf.R index 87f4359c..897debc5 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -1395,36 +1395,28 @@ predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL, group_ids_test = NU rfx_basis_test <- matrix(rep(1, nrow(X_test)), ncol = 1) } - # Add propensities to any covariate set - if (bcf$model_params$propensity_covariate == "both") { - X_test_mu <- cbind(X_test, pi_test) - X_test_tau <- cbind(X_test, pi_test) - } else if (bcf$model_params$propensity_covariate == "mu") { - X_test_mu <- cbind(X_test, pi_test) - X_test_tau <- X_test - } else if (bcf$model_params$propensity_covariate == "tau") { - X_test_mu <- X_test - X_test_tau <- cbind(X_test, pi_test) + # Add propensities to covariate set if necessary + if (bcf$model_params$propensity_covariate != "none") { + X_test_combined <- cbind(X_test, pi_test) } # Create prediction datasets - prediction_dataset_mu <- createForestDataset(X_test_mu) - prediction_dataset_tau <- createForestDataset(X_test_tau, Z_test) + forest_dataset_pred <- createForestDataset(X_test_combined, Z_test) # Compute forest predictions num_samples <- bcf$model_params$num_samples y_std <- bcf$model_params$outcome_scale y_bar <- bcf$model_params$outcome_mean initial_sigma2 <- bcf$model_params$initial_sigma2 - mu_hat_test <- bcf$forests_mu$predict(prediction_dataset_mu)*y_std + y_bar + mu_hat_test <- bcf$forests_mu$predict(forest_dataset_pred)*y_std + y_bar if (bcf$model_params$adaptive_coding) { - tau_hat_test_raw <- bcf$forests_tau$predict_raw(prediction_dataset_tau) + tau_hat_test_raw <- bcf$forests_tau$predict_raw(forest_dataset_pred) tau_hat_test <- t(t(tau_hat_test_raw) * (bcf$b_1_samples - bcf$b_0_samples))*y_std } else { - tau_hat_test <- bcf$forests_tau$predict_raw(prediction_dataset_tau)*y_std + tau_hat_test <- bcf$forests_tau$predict_raw(forest_dataset_pred)*y_std } if (bcf$model_params$include_variance_forest) { - s_x_raw <- bcf$variance_forests$predict(prediction_dataset) + s_x_raw <- bcf$variance_forests$predict(forest_dataset_pred) } # Compute rfx predictions (if needed) @@ -1520,14 +1512,16 @@ predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL, group_ids_test = NU #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +#' mu_params <- list(sample_sigma_leaf = TRUE) +#' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' pi_train = pi_train, group_ids_train = group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, -#' params = bcf_params) +#' mu_forest_params = mu_params, +#' tau_forest_params = tau_params) #' rfx_samples <- getRandomEffectSamples(bcf_model) getRandomEffectSamples.bcf <- function(object, ...){ result = list() @@ -1607,14 +1601,16 @@ getRandomEffectSamples.bcf <- function(object, ...){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +#' mu_params <- list(sample_sigma_leaf = TRUE) +#' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' pi_train = pi_train, group_ids_train = group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, -#' params = bcf_params) +#' mu_forest_params = mu_params, +#' tau_forest_params = tau_params) #' # bcf_json <- convertBCFModelToJson(bcf_model) convertBCFModelToJson <- function(object){ jsonobj <- createCppJson() @@ -1749,14 +1745,16 @@ convertBCFModelToJson <- function(object){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +#' mu_params <- list(sample_sigma_leaf = TRUE) +#' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' pi_train = pi_train, group_ids_train = group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, -#' params = bcf_params) +#' mu_forest_params = mu_params, +#' tau_forest_params = tau_params) #' # saveBCFModelToJsonFile(bcf_model, "test.json") saveBCFModelToJsonFile <- function(object, filename){ # Convert to Json @@ -1823,14 +1821,16 @@ saveBCFModelToJsonFile <- function(object, filename){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +#' mu_params <- list(sample_sigma_leaf = TRUE) +#' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' pi_train = pi_train, group_ids_train = group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, -#' params = bcf_params) +#' mu_forest_params = mu_params, +#' tau_forest_params = tau_params) #' # saveBCFModelToJsonString(bcf_model) saveBCFModelToJsonString <- function(object){ # Convert to Json @@ -1899,14 +1899,16 @@ saveBCFModelToJsonString <- function(object){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +#' mu_params <- list(sample_sigma_leaf = TRUE) +#' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' pi_train = pi_train, group_ids_train = group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, -#' params = bcf_params) +#' mu_forest_params = mu_params, +#' tau_forest_params = tau_params) #' # bcf_json <- convertBCFModelToJson(bcf_model) #' # bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) createBCFModelFromJson <- function(json_object){ @@ -2045,14 +2047,16 @@ createBCFModelFromJson <- function(json_object){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +#' mu_params <- list(sample_sigma_leaf = TRUE) +#' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' pi_train = pi_train, group_ids_train = group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, -#' params = bcf_params) +#' mu_forest_params = mu_params, +#' tau_forest_params = tau_params) #' # saveBCFModelToJsonFile(bcf_model, "test.json") #' # bcf_model_roundtrip <- createBCFModelFromJsonFile("test.json") createBCFModelFromJsonFile <- function(json_filename){ diff --git a/R/cpp11.R b/R/cpp11.R index 3cc45075..bf6345b9 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -580,6 +580,22 @@ tree_prior_cpp <- function(alpha, beta, min_samples_leaf, max_depth) { .Call(`_stochtree_tree_prior_cpp`, alpha, beta, min_samples_leaf, max_depth) } +update_alpha_tree_prior_cpp <- function(tree_prior_ptr, alpha) { + invisible(.Call(`_stochtree_update_alpha_tree_prior_cpp`, tree_prior_ptr, alpha)) +} + +update_beta_tree_prior_cpp <- function(tree_prior_ptr, beta) { + invisible(.Call(`_stochtree_update_beta_tree_prior_cpp`, tree_prior_ptr, beta)) +} + +update_min_samples_leaf_tree_prior_cpp <- function(tree_prior_ptr, min_samples_leaf) { + invisible(.Call(`_stochtree_update_min_samples_leaf_tree_prior_cpp`, tree_prior_ptr, min_samples_leaf)) +} + +update_max_depth_tree_prior_cpp <- function(tree_prior_ptr, max_depth) { + invisible(.Call(`_stochtree_update_max_depth_tree_prior_cpp`, tree_prior_ptr, max_depth)) +} + forest_tracker_cpp <- function(data, feature_types, num_trees, n) { .Call(`_stochtree_forest_tracker_cpp`, data, feature_types, num_trees, n) } diff --git a/R/kernel.R b/R/kernel.R index deaaab6f..fe21f8f1 100644 --- a/R/kernel.R +++ b/R/kernel.R @@ -159,7 +159,7 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU } # Preprocess forest indices - num_forests <- forest_container$num_samples() + num_forests <- model_object$model_params$num_samples if (is.null(forest_inds)) { forest_inds <- as.integer(1:num_forests) } else { diff --git a/R/model.R b/R/model.R index 90ddbb3c..8e34d32b 100644 --- a/R/model.R +++ b/R/model.R @@ -128,6 +128,38 @@ ForestModel <- R6::R6Class( #' @return NULL propagate_residual_update = function(residual) { propagate_trees_column_vector_cpp(self$tracker_ptr, residual$data_ptr) + }, + + #' @description + #' Update alpha in the tree prior + #' @param alpha New value of alpha to be used + #' @return NULL + update_alpha = function(alpha) { + update_alpha_tree_prior_cpp(self$tree_prior_ptr, alpha) + }, + + #' @description + #' Update beta in the tree prior + #' @param beta New value of beta to be used + #' @return NULL + update_beta = function(beta) { + update_beta_tree_prior_cpp(self$tree_prior_ptr, beta) + }, + + #' @description + #' Update min_samples_leaf in the tree prior + #' @param min_samples_leaf New value of min_samples_leaf to be used + #' @return NULL + update_min_samples_leaf = function(min_samples_leaf) { + update_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr, min_samples_leaf) + }, + + #' @description + #' Update max_depth in the tree prior + #' @param max_depth New value of max_depth to be used + #' @return NULL + update_max_depth = function(max_depth) { + update_max_depth_tree_prior_cpp(self$tree_prior_ptr, max_depth) } ) ) diff --git a/R/serialization.R b/R/serialization.R index 4c6ec0cb..24205f9e 100644 --- a/R/serialization.R +++ b/R/serialization.R @@ -413,7 +413,7 @@ loadRandomEffectSamplesCombinedJsonString <- function(json_string_list, json_rfx json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) invisible(output <- RandomEffectSamples$new()) - for (i in 1:length(json_object_list)) { + for (i in 1:length(json_string_list)) { json_string <- json_string_list[[i]] if (i == 1) { output$load_from_json_string(json_string, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) diff --git a/R/utils.R b/R/utils.R index 66c9b702..a1fc12a8 100644 --- a/R/utils.R +++ b/R/utils.R @@ -20,134 +20,6 @@ preprocessParams <- function(default_params, user_params = NULL) { return(default_params) } -#' Preprocess BART parameter list. Override defaults with any provided parameters. -#' -#' @param general_params List of any non-forest-specific parameters -#' @param mean_forest_params List of any mean forest parameters -#' @param variance_forest_params List of any variance forest parameters -#' -#' @return Parameter list with defaults overriden by values supplied in parameter lists -#' @export -preprocessBartParams <- function(general_params, mean_forest_params, variance_forest_params) { - # Default parameter values - processed_params <- list( - cutpoint_grid_size = 100, - alpha_mean = 0.95, beta_mean = 2.0, - min_samples_leaf_mean = 5, max_depth_mean = 10, - variable_weights_mean = NULL, num_trees_mean = 200, - alpha_variance = 0.95, beta_variance = 2.0, - min_samples_leaf_variance = 5, max_depth_variance = 10, - variable_weights_variance = NULL, num_trees_variance = 0, - sample_sigma2_global = T, sigma2_global_init = NULL, - sigma2_global_shape = 0, sigma2_global_scale = 0, - sample_sigma2_leaf = T, sigma2_leaf_init = NULL, - sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL, - var_forest_prior_shape = NULL, var_forest_prior_scale = NULL, - variance_forest_init = NULL, - sample_sigma_global = T, sample_sigma2_leaf_mean = F, - random_seed = -1, keep_burnin = F, keep_gfr = F, keep_every = 1, - num_chains = 1, standardize = T, verbose = F - ) - - # Override defaults from general_params - for (key in names(general_params)) { - if (key %in% names(processed_params)) { - val <- general_params[[key]] - if (!is.null(val)) processed_params[[key]] <- val - } - } - - # Override defaults from mean_forest_params - for (key in names(mean_forest_params)) { - modified_key <- paste0(key, "_mean") - if (modified_key %in% names(processed_params)) { - val <- general_params[[key]] - if (!is.null(val)) processed_params[[modified_key]] <- val - } - } - - # Override defaults from variance_forest_params - for (key in names(variance_forest_params)) { - modified_key <- paste0(key, "_variance") - if (modified_key %in% names(processed_params)) { - val <- general_params[[key]] - if (!is.null(val)) processed_params[[modified_key]] <- val - } - } - - # Return result - return(processed_params) -} - -#' Preprocess BCF parameter list. Override defaults with any provided parameters. -#' -#' @param general_params List of any non-forest-specific parameters -#' @param mu_forest_params List of any mu forest parameters -#' @param tau_forest_params List of any tau forest parameters -#' @param variance_forest_params List of any variance forest parameters -#' -#' @return Parameter list with defaults overriden by values supplied in parameter lists -#' @export -preprocessBcfParams <- function(params) { - # Default parameter values - processed_params <- list( - cutpoint_grid_size = 100, sigma_leaf_mu = NULL, sigma_leaf_tau = NULL, - alpha_mu = 0.95, alpha_tau = 0.25, alpha_variance = 0.95, - beta_mu = 2.0, beta_tau = 3.0, beta_variance = 2.0, - min_samples_leaf_mu = 5, min_samples_leaf_tau = 5, min_samples_leaf_variance = 5, - max_depth_mu = 10, max_depth_tau = 5, max_depth_variance = 10, - a_global = 0, b_global = 0, a_leaf_mu = 3, a_leaf_tau = 3, b_leaf_mu = NULL, - b_leaf_tau = NULL, a_forest = NULL, b_forest = NULL, sigma2_init = NULL, - variance_forest_init = NULL, pct_var_sigma2_init = 1, pct_var_variance_forest_init = 1, - variable_weights = NULL, keep_vars_mu = NULL, drop_vars_mu = NULL, - keep_vars_tau = NULL, drop_vars_tau = NULL, keep_vars_variance = NULL, - drop_vars_variance = NULL, num_trees_mu = 250, num_trees_tau = 50, - num_trees_variance = 0, num_gfr = 5, num_burnin = 0, num_mcmc = 100, - sample_sigma_global = T, sample_sigma2_leaf_mu = T, sample_sigma2_leaf_tau = F, - propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5, b_1 = 0.5, - rfx_prior_var = NULL, random_seed = -1, keep_burnin = F, keep_gfr = F, - keep_every = 1, num_chains = 1, standardize = T, verbose = F - ) - - # Override defaults - for (key in names(params)) { - if (key %in% names(processed_params)) { - val <- params[[key]] - if (!is.null(val)) processed_params[[key]] <- val - } - } - - # Override defaults from mu_forest_params - for (key in names(mu_forest_params)) { - modified_key <- paste0(key, "_mu") - if (modified_key %in% names(processed_params)) { - val <- general_params[[key]] - if (!is.null(val)) processed_params[[modified_key]] <- val - } - } - - # Override defaults from tau_forest_params - for (key in names(tau_forest_params)) { - modified_key <- paste0(key, "_tau") - if (modified_key %in% names(processed_params)) { - val <- general_params[[key]] - if (!is.null(val)) processed_params[[modified_key]] <- val - } - } - - # Override defaults from variance_forest_params - for (key in names(variance_forest_params)) { - modified_key <- paste0(key, "_variance") - if (modified_key %in% names(processed_params)) { - val <- general_params[[key]] - if (!is.null(val)) processed_params[[modified_key]] <- val - } - } - - # Return result - return(processed_params) -} - #' Preprocess covariates. DataFrames will be preprocessed based on their column #' types. Matrices will be passed through assuming all columns are numeric. #' @@ -774,7 +646,8 @@ oneHotEncode <- function(x_input, unique_levels) { #' @export #' #' @examples -#' x <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") +#' x <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", +#' "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") #' preprocess_list <- orderedCatInitializeAndPreprocess(x) #' x_preprocessed <- preprocess_list$x_preprocessed orderedCatInitializeAndPreprocess <- function(x_input) { @@ -807,7 +680,8 @@ orderedCatInitializeAndPreprocess <- function(x_input) { #' @export #' #' @examples -#' x_levels <- c("1. Strongly disagree", "2. Disagree", "3. Neither agree nor disagree", +#' x_levels <- c("1. Strongly disagree", "2. Disagree", +#' "3. Neither agree nor disagree", #' "4. Agree", "5. Strongly agree") #' x <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", #' "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") diff --git a/_pkgdown.yml b/_pkgdown.yml index 195c0bb9..ff88143e 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -88,8 +88,7 @@ reference: - CppRNG - createRNG - calibrate_inverse_gamma_error_variance - - preprocessBartParams - - preprocessBcfParams + - preprocessParams - computeMaxLeafIndex - computeForestLeafIndices - computeForestLeafVariances diff --git a/cpp_docs/Makefile b/cpp_docs/Makefile deleted file mode 100644 index d0c3cbf1..00000000 --- a/cpp_docs/Makefile +++ /dev/null @@ -1,20 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build -SOURCEDIR = source -BUILDDIR = build - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/cpp_docs/README.md b/cpp_docs/README.md deleted file mode 100644 index 5017af8c..00000000 --- a/cpp_docs/README.md +++ /dev/null @@ -1,21 +0,0 @@ -# C++ API Documentation - -## Building Documentation Locally - -We are still working out the best way to deploy the C++ documentation online alongside the R and Python documentation. -In the meantime, to build the C++ documentation locally, first ensure that you have [doxygen](https://www.doxygen.nl/index.html) installed. -On MacOS, this can be [done via homebrew](https://formulae.brew.sh/formula/doxygen) (i.e. `brew install doxygen`). -Next, you will need both the [Sphinx](https://www.sphinx-doc.org/en/master/) and [breathe](https://breathe.readthedocs.io/en/latest/dot_graphs.html) python packages - -Now, navigate to the python package's main directory (i.e. `cd [path/to/stochtree]`), build the C++ documentation via `doxygen` and then run `sphinx-build` as below - -``` -pip install --upgrade pip -pip install -r cpp_docs/requirements.txt -doxygen cpp_docs/Doxyfile -sphinx-build -M html cpp_docs/ cpp_docs/build/ -``` - -## Documentation Style - -Module (class, function, etc...) documentation follows the format prescribed by [doxygen](https://www.doxygen.nl/manual/docblocks.html) for C++ code. diff --git a/cpp_docs/conf.py b/cpp_docs/conf.py deleted file mode 100644 index 369f5010..00000000 --- a/cpp_docs/conf.py +++ /dev/null @@ -1,40 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -# -# For the full list of built-in configuration values, see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# import os -# import sys -# sys.path.insert(0, os.path.abspath('../..')) -from pathlib import Path -CPP_DOC_PATH = Path(__file__).absolute().parent - -# -- Project information ----------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information - -project = 'stochtree' -copyright = '2024, Drew Herren' -author = 'Drew Herren' -release = '0.0.1' - -# -- General configuration --------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration - -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'breathe' -] - -templates_path = ['_templates'] -exclude_patterns = [] - -# Breathe Configuration -breathe_projects = {"StochTree": str(CPP_DOC_PATH / "doxyoutput" / "xml")} -breathe_default_project = "StochTree" - -# -- Options for HTML output ------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output - -html_theme = 'furo' -html_static_path = ['_static'] diff --git a/cpp_docs/dataset.rst b/cpp_docs/dataset.rst deleted file mode 100644 index 30179954..00000000 --- a/cpp_docs/dataset.rst +++ /dev/null @@ -1,42 +0,0 @@ -Dataset API -=========== - -Forest Dataset --------------- - -The ``ForestDataset`` class is a wrapper around data needed to sample one or more tree ensembles. -Its core elements are - -* **Covariates**: Features / variables used to partition the forests. Stored internally as a (column-major) ``Eigen::MatrixXd``. -* **Basis**: *[Optional]* basis vector used to define a "leaf regression" --- a partitioned linear model where covariates define the partitions and basis defines the regression variables. - Also stored internally as a (column-major) ``Eigen::MatrixXd``. -* **Sample Weights**: *[Optional]* case weights for every observation in a training dataset. These may be heteroskedastic variance parameters or simply survey / case weights. - Stored internally as an ``Eigen::VectorXd``. - -.. doxygenclass:: StochTree::ForestDataset - :project: StochTree - :members: - -Random Effects Dataset ----------------------- - -The ``RandomEffectsDataset`` class is a wrapper around data needed to sample one or more tree ensembles. -Its core elements are - -* **Basis**: Vector of variables that have group-specific random coefficients. In the simplest additive group random effects model, this is a constant intercept of all ones. - Stored internally as a (column-major) ``Eigen::MatrixXd``. -* **Group Indices**: Integer-valued indices of group membership. In a model with three groups, these indices would typically be 0, 1, and 2 (remapped from perhaps more descriptive labels in R or Python). - Stored internally as an ``std::vector`` of integers. -* **Sample Weights**: *[Optional]* case weights for every observation in a training dataset. These may be heteroskedastic variance parameters or simply survey / case weights. - Stored internally as an ``Eigen::VectorXd``. - -.. doxygenclass:: StochTree::RandomEffectsDataset - :project: StochTree - :members: - -Other Classes and Types ------------------------ - -.. doxygenenum:: StochTree::FeatureType - :project: StochTree - \ No newline at end of file diff --git a/cpp_docs/index.rst b/cpp_docs/index.rst deleted file mode 100644 index 10396d88..00000000 --- a/cpp_docs/index.rst +++ /dev/null @@ -1,10 +0,0 @@ -StochTree C++ API and Implementations -===================================== - -This page documents the data structures and interfaces that constitute the ``stochtree`` C++ core. -It may be useful to researchers building novel tree algorithms or users seeking a deeper understanding of the algorithms implemented in ``stochtree``. - -.. toctree:: - dataset - tracking - tree diff --git a/cpp_docs/make.bat b/cpp_docs/make.bat deleted file mode 100644 index dc1312ab..00000000 --- a/cpp_docs/make.bat +++ /dev/null @@ -1,35 +0,0 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set SOURCEDIR=source -set BUILDDIR=build - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.https://www.sphinx-doc.org/ - exit /b 1 -) - -if "%1" == "" goto help - -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% -goto end - -:help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% - -:end -popd diff --git a/cpp_docs/requirements.txt b/cpp_docs/requirements.txt deleted file mode 100644 index 44bfd94d..00000000 --- a/cpp_docs/requirements.txt +++ /dev/null @@ -1,39 +0,0 @@ -alabaster==0.7.13 -Babel==2.15.0 -beautifulsoup4==4.12.3 -breathe==4.35.0 -certifi==2024.2.2 -charset-normalizer==3.3.2 -docutils==0.20.1 -furo==2024.5.6 -idna==3.7 -imagesize==1.4.1 -importlib_metadata==7.1.0 -Jinja2==3.1.4 -joblib==1.4.2 -MarkupSafe==2.1.5 -numpy==1.24.4 -packaging==24.0 -pandas==2.0.3 -pybind11==2.12.0 -Pygments==2.18.0 -python-dateutil==2.9.0.post0 -pytz==2024.1 -requests==2.32.2 -scikit-learn==1.3.2 -scipy==1.10.1 -six==1.16.0 -snowballstemmer==2.2.0 -soupsieve==2.5 -Sphinx==7.1.2 -sphinx-basic-ng==1.0.0b2 -sphinxcontrib-applehelp==1.0.4 -sphinxcontrib-devhelp==1.0.2 -sphinxcontrib-htmlhelp==2.0.1 -sphinxcontrib-jsmath==1.0.1 -sphinxcontrib-qthelp==1.0.3 -sphinxcontrib-serializinghtml==1.1.5 -threadpoolctl==3.5.0 -tzdata==2024.1 -urllib3==2.2.1 -zipp==3.18.2 diff --git a/cpp_docs/tracking.rst b/cpp_docs/tracking.rst deleted file mode 100644 index 57f22f99..00000000 --- a/cpp_docs/tracking.rst +++ /dev/null @@ -1,36 +0,0 @@ -Forest Sampling Tracker API -=========================== - -A truly minimalist tree ensemble library only needs - -* A representation of a decision tree -* A container for grouping / storing ensembles of trees -* In-memory access to / representation of training data -* Routines / functions to construct the trees - -Most algorithms for optimizing or sampling tree ensembles frequently perform the following operations - -* Determine which leaf a training observation falls into for a decision tree (to compute its prediction and update the residual / outcome) -* Evaluate potential split candidates for a leaf of a decision - -With only the "minimalist" tools above, these two tasks proceed largely as follows - -* For every observation in the dataset, traverse the tree (runtime depends on the tree topology but in a fully balanced tree with :math:`k` nodes, this has time complexity :math:`O(\log (k))`). -* For every observation in the dataset, determine whether an observation falls into a given node and whether or not a proposed decision rule would be true - -These operations both perform unnecessary computation which can be avoided with some additional real-time tracking. Essentially, we want - -1. A mapping from dataset row index to leaf node id for every tree in an ensemble (so that we can skip the tree traversal during prediction) -2. A mapping from leaf node id to dataset row indices every tree in an ensemble (so that we can skip the full pass through the training data at split evaluation) - -.. 1. For every observation in a dataset, which leaf node of each tree does the sample fall into? -.. 2. For every leaf in a tree, which training set observations fall into that node? - -Forest Tracker --------------- - -The ``ForestTracker`` class is a wrapper around several implementations of the mappings discussed above. - -.. doxygenclass:: StochTree::ForestTracker - :project: StochTree - :members: diff --git a/cpp_docs/tree.rst b/cpp_docs/tree.rst deleted file mode 100644 index 3fe13dba..00000000 --- a/cpp_docs/tree.rst +++ /dev/null @@ -1,20 +0,0 @@ -Decision Tree API -================= - -Tree ----- - -The fundamental building block of the C++ tree interface is the ``Tree`` class. - -.. doxygenclass:: StochTree::Tree - :project: StochTree - :members: - -Tree Split ----------- - -Numeric and categorical splits are represented by a ``TreeSplit`` class. - -.. doxygenclass:: StochTree::TreeSplit - :project: StochTree - :members: diff --git a/cran-bootstrap.R b/cran-bootstrap.R index 295abc7c..4615ee27 100644 --- a/cran-bootstrap.R +++ b/cran-bootstrap.R @@ -9,6 +9,32 @@ # https://github.com/microsoft/LightGBM/blob/master/build-cran-package.sh, # which is MIT licensed with the following copyright: # Copyright (c) Microsoft Corporation +# +# Includes one command line argument: +# include_vignettes : 1 to include the vignettes folder in the R package subfolder +# 0 to exclude vignettes +# +# Run this script from the command line via +# +# Explicitly include vignettes +# ---------------------------- +# Rscript cran-bootstrap.R 1 +# +# Explicitly exclude vignettes +# ---------------------------- +# Rscript cran-bootstrap.R 0 +# +# Exclude vignettes by default +# ---------------------------- +# Rscript cran-bootstrap.R + +# Unpack command line arguments +args <- commandArgs(trailingOnly = T) +if (length(args) > 0){ + include_vignettes <- as.logical(as.integer(args[1])) +} else{ + include_vignettes <- F +} # Create the stochtree_cran folder cran_dir <- "stochtree_cran" @@ -27,9 +53,13 @@ pkg_core_files <- c( list.files("man", recursive = TRUE, full.names = TRUE), "NAMESPACE", list.files("R", recursive = TRUE, full.names = TRUE), - r_src_files, - list.files("vignettes", pattern = ".(Rmd|bib)$", recursive = TRUE, full.names = TRUE) + r_src_files ) +if (include_vignettes) { + pkg_core_files <- c( + pkg_core_files, list.files("vignettes", pattern = ".(Rmd|bib)$", recursive = TRUE, full.names = TRUE) + ) +} pkg_core_files_dst <- file.path(cran_dir, pkg_core_files) # Handle tests separately (move from test/R/ folder to tests/ folder) test_files_src <- list.files("test/R", recursive = TRUE, full.names = TRUE) @@ -69,6 +99,16 @@ makevars_lines <- readLines(cran_makevars) makevars_lines[grep("^(PKG_CPPFLAGS)", makevars_lines)] <- "PKG_CPPFLAGS= -I$(PKGROOT)/src/include $(STOCHTREE_CPPFLAGS)" writeLines(makevars_lines, cran_makevars) +# Remove vignette deps from DESCRIPTION if no vignettes +if (!include_vignettes) { + cran_description <- file.path(cran_dir, "DESCRIPTION") + description_lines <- readLines(cran_description) + suggestion_begin <- grep("Suggests:", description_lines) + 2 + suggestion_end <- grep("VignetteBuilder:", description_lines) + description_lines <- description_lines[-(suggestion_begin:suggestion_end)] + writeLines(description_lines, cran_description) +} + # Copy fast_double_parser header to an include/ subdirectory of src/ header_folders <- c("nlohmann", "stochtree") header_files_to_vendor_src <- c() diff --git a/include/stochtree/category_tracker.h b/include/stochtree/category_tracker.h index 04ee0b0a..e5817419 100644 --- a/include/stochtree/category_tracker.h +++ b/include/stochtree/category_tracker.h @@ -144,20 +144,11 @@ class CategorySampleTracker { /*! \brief Data indices for a given node */ std::vector& NodeIndices(int category_id) { int32_t id = category_id_map_[category_id]; - // std::vector::iterator start = indices_.begin() + category_begin_[id]; - // std::vector::iterator end = indices_.begin() + category_begin_[id] + category_length_[id]; - // std::vector output(start, end); - // return output; return node_index_vector_[id]; } /*! \brief Data indices for a given node */ std::vector& NodeIndicesInternalIndex(int internal_category_id) { -// int32_t id = category_id_map_[category_id]; - // std::vector::iterator start = indices_.begin() + category_begin_[id]; - // std::vector::iterator end = indices_.begin() + category_begin_[id] + category_length_[id]; - // std::vector output(start, end); - // return output; return node_index_vector_[internal_category_id]; } diff --git a/include/stochtree/container.h b/include/stochtree/container.h index 13fa098c..e5a3b8fe 100644 --- a/include/stochtree/container.h +++ b/include/stochtree/container.h @@ -20,19 +20,94 @@ namespace StochTree { +/*! + * \brief Container of `TreeEnsemble` forest objects. This is the primary (in-memory) storage interface for multiple + * "samples" of a decision tree ensemble in `stochtree`. + * \ingroup forest_group + */ class ForestContainer { public: + /*! + * \brief Construct a new ForestContainer object. + * + * \param num_trees Number of trees in each forest. + * \param output_dimension Dimension of the leaf node parameter in each tree of each forest. + * \param is_leaf_constant Whether or not the leaves of each tree are treated as "constant." If true, then predicting from an ensemble is simply a matter or determining which leaf node an observation falls into. If false, prediction will multiply a leaf node's parameter(s) for a given observation by a basis vector. + * \param is_exponentiated Whether or not the leaves of each tree are stored in log scale. If true, leaf predictions are exponentiated before their prediction is returned. + */ ForestContainer(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false); + /*! + * \brief Construct a new ForestContainer object. + * + * \param num_samples Initial size of a container of forest samples. + * \param num_trees Number of trees in each forest. + * \param output_dimension Dimension of the leaf node parameter in each tree of each forest. + * \param is_leaf_constant Whether or not the leaves of each tree are treated as "constant." If true, then predicting from an ensemble is simply a matter or determining which leaf node an observation falls into. If false, prediction will multiply a leaf node's parameter(s) for a given observation by a basis vector. + * \param is_exponentiated Whether or not the leaves of each tree are stored in log scale. If true, leaf predictions are exponentiated before their prediction is returned. + */ ForestContainer(int num_samples, int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false); ~ForestContainer() {} - + /*! + * \brief Remove a forest from a container of forest samples and delete the corresponding object, freeing its memory. + * + * \param sample_num Index of forest to be deleted. + */ void DeleteSample(int sample_num); + /*! + * \brief Add a new forest to the container by copying `forest`. + * + * \param forest Forest to be copied and added to the container of retained forest samples. + */ void AddSample(TreeEnsemble& forest); + /*! + * \brief Initialize a "root" forest of univariate trees as the first element of the container, setting all root node values in every tree to `leaf_value`. + * + * \param leaf_value Value to assign to the root node of every tree. + */ void InitializeRoot(double leaf_value); + /*! + * \brief Initialize a "root" forest of multivariate trees as the first element of the container, setting all root node values in every tree to `leaf_vector`. + * + * \param leaf_value Vector of values to assign to the root node of every tree. + */ void InitializeRoot(std::vector& leaf_vector); + /*! + * \brief Pre-allocate space for `num_samples` additional forests in the container. + * + * \param num_samples Number of (default-constructed) forests to allocated space for in the container. + */ void AddSamples(int num_samples); + /*! + * \brief Copy the forest stored at `previous_sample_id` to the forest stored at `new_sample_id`. + * + * \param new_sample_id Index of the new forest to be copied from an earlier sample. + * \param previous_sample_id Index of the previous forest to copy to `new_sample_id`. + */ void CopyFromPreviousSample(int new_sample_id, int previous_sample_id); + /*! + * \brief Predict from every forest in the container on every observation in the provided dataset. + * The resulting vector is "column-major", where every forest in a container defines the columns of a + * prediction matrix and every observation in the provided dataset defines the rows. The (`i`,`j`) element + * of this prediction matrix can be read from the `j * num_rows + i` element of the returned `std::vector`, + * where `num_rows` is equal to the number of observations in `dataset` (i.e. `dataset.NumObservations()`). + * + * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights. + * \return std::vector Vector of predictions for every forest in the container and every observation in `dataset`. + */ std::vector Predict(ForestDataset& dataset); + /*! + * \brief Predict from every forest in the container on every observation in the provided dataset. + * The resulting vector stores a possibly three-dimensional array, where the dimensions are arranged as follows + * + * 1. Dimension of the leaf node's raw values (1 for GaussianConstantLeafModel, GaussianUnivariateRegressionLeafModel, and LogLinearVarianceLeafModel, >1 for GaussianMultivariateRegressionLeafModel) + * 2. Observations in the provided dataset. + * 3. Forest samples in the container. + * + * If the leaf nodes have univariate values, then the "first dimension" is 1 and the resulting array has the exact same layout as in \ref Predict. + * + * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights. + * \return std::vector Vector of predictions for every forest in the container and every observation in `dataset`. + */ std::vector PredictRaw(ForestDataset& dataset); std::vector PredictRaw(ForestDataset& dataset, int forest_num); std::vector PredictRawSingleTree(ForestDataset& dataset, int forest_num, int tree_num); diff --git a/include/stochtree/data.h b/include/stochtree/data.h index fa2d9494..c3bdb077 100644 --- a/include/stochtree/data.h +++ b/include/stochtree/data.h @@ -13,7 +13,25 @@ namespace StochTree { -/*! \brief Extract local features from memory */ +/*! + * \defgroup data_group Dataset API + * + * \brief Functions for loading, using, and modifying training / test data for forest samplers. + * + * \{ + */ + +/*! + * \brief Extract multiple features from the raw data loaded from a file into an `Eigen::MatrixXd`. + * Lightly modified from LightGBM's datasetloader interface to support `stochtree`'s use cases. + * \internal + * + * \param text_data Vector of data reads as string from a file. + * \param parser Pointer to a parser object (i.e. `CSVParser`). + * \param column_indices Integer labels of columns to be extracted from `text_data` into `data`. + * \param data Eigen matrix into which `text_data` will be parsed and unpacked. + * \param num_rows Number of observations in the data being loaded. + */ static inline void ExtractMultipleFeaturesFromMemory(std::vector* text_data, const Parser* parser, std::vector& column_indices, Eigen::MatrixXd& data, data_size_t num_rows) { @@ -45,7 +63,17 @@ static inline void ExtractMultipleFeaturesFromMemory(std::vector* t text_data->clear(); } -/*! \brief Extract local features from memory */ +/*! +* \brief Extract a single feature from the raw data loaded from a file into an `Eigen::VectorXd`. + * Lightly modified from LightGBM's datasetloader interface to support `stochtree`'s use cases. + * \internal + * + * \param text_data Vector of data reads as string from a file. + * \param parser Pointer to a parser object (i.e. `CSVParser`). + * \param column_index Integer labels of columns to be extracted from `text_data` into `data`. + * \param data Eigen vector into which `text_data` will be parsed and unpacked. + * \param num_rows Number of observations in the data being loaded. + */ static inline void ExtractSingleFeatureFromMemory(std::vector* text_data, const Parser* parser, int32_t column_index, Eigen::VectorXd& data, data_size_t num_rows) { std::vector> oneline_features; @@ -99,42 +127,148 @@ static inline std::vector Str2FeatureVec(const char* parameters) { return feature_vec; } +/*! + * \brief Internal wrapper around `Eigen::MatrixXd` interface for multidimensional floating point data. + */ class ColumnMatrix { public: ColumnMatrix() {} + /*! + * \brief Construct a new `ColumnMatrix` object from in-memory data buffer. + * + * \param data_ptr Pointer to first element of a contiguous array of data storing a matrix. + * \param num_row Number of rows in the matrix. + * \param num_col Number of columns / covariates in the matrix. + * \param is_row_major Whether or not the data in `data_ptr` are organized in a row-major or column-major fashion. + */ ColumnMatrix(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major); + /*! + * \brief Construct a new ColumnMatrix object from CSV file + * + * \param filename Name of the file (including any necessary path prefixes). + * \param column_index_string Comma-delimited string listing columns to extract into covariates matrix. + * \param header Whether or not the file contains a header of column names / non-data. + * \param precise_float_parser Whether floating point numbers in the CSV should be parsed precisely. + */ ColumnMatrix(std::string filename, std::string column_index_string, bool header = true, bool precise_float_parser = false); ~ColumnMatrix() {} + /*! + * \brief Returns the value stored at (`row`, `col`) in the object's internal `Eigen::MatrixXd`. + * + * \param row Row number to query in the matrix + * \param col Column number to query in the matrix + */ double GetElement(data_size_t row_num, int32_t col_num) {return data_(row_num, col_num);} + /*! + * \brief Update an observation in the object's internal `Eigen::MatrixXd` to a new value. + * + * \param row Row number to be overwritten. + * \param col Column number to be overwritten. + * \param value New value to write in (`row`, `col`) in the object's internal `Eigen::MatrixXd`. + */ void SetElement(data_size_t row_num, int32_t col_num, double value) {data_(row_num, col_num) = value;} + /*! + * \brief Update the data in a `ColumnMatrix` object from an in-memory data buffer. This will erase the existing matrix. + * + * \param data_ptr Pointer to first element of a contiguous array of data storing a matrix. + * \param num_row Number of rows in the matrix. + * \param num_col Number of columns / covariates in the matrix. + * \param is_row_major Whether or not the data in `data_ptr` are organized in a row-major or column-major fashion. + */ void LoadData(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major); + /*! \brief Number of rows in the object's internal `Eigen::MatrixXd`. */ inline data_size_t NumRows() {return data_.rows();} + /*! \brief Number of columns in the object's internal `Eigen::MatrixXd`. */ inline int NumCols() {return data_.cols();} + /*! \brief Return a reference to the object's internal `Eigen::MatrixXd`, for interfaces that require a raw matrix. */ inline Eigen::MatrixXd& GetData() {return data_;} private: Eigen::MatrixXd data_; }; +/*! + * \brief Internal wrapper around `Eigen::VectorXd` interface for univariate floating point data. + * The (frequently updated) full / partial residual used in sampling forests is stored internally + * as a `ColumnVector` by the sampling functions (see \ref sampling_group). + */ class ColumnVector { public: ColumnVector() {} + /*! + * \brief Construct a new `ColumnVector` object from in-memory data buffer. + * + * \param data_ptr Pointer to first element of a contiguous array of data storing a vector. + * \param num_row Number of rows / elements in the vector. + */ ColumnVector(double* data_ptr, data_size_t num_row); + /*! + * \brief Construct a new ColumnMatrix object from CSV file + * + * \param filename Name of the file (including any necessary path prefixes). + * \param column_index Integer index of the column in `filename` to be unpacked as a vector. + * \param header Whether or not the file contains a header of column names / non-data. + * \param precise_float_parser Whether floating point numbers in the CSV should be parsed precisely. + */ ColumnVector(std::string filename, int32_t column_index, bool header = true, bool precise_float_parser = false); ~ColumnVector() {} - double GetElement(data_size_t row_num) {return data_(row_num);} - void SetElement(data_size_t row_num, double value) {data_(row_num) = value;} + /*! + * \brief Returns the value stored at position `row` in the object's internal `Eigen::VectorXd`. + * + * \param row Row number to query in the vector + */ + double GetElement(data_size_t row) {return data_(row);} + /*! + * \brief Returns the value stored at position `row` in the object's internal `Eigen::VectorXd`. + * + * \param row Row number to query in the vector + * \param value New value to write to element `row` of the object's internal `Eigen::VectorXd`. + */ + void SetElement(data_size_t row, double value) {data_(row) = value;} + /*! + * \brief Update the data in a `ColumnVector` object from an in-memory data buffer. This will erase the existing vector. + * + * \param data_ptr Pointer to first element of a contiguous array of data storing a vector. + * \param num_row Number of rows / elements in the vector. + */ void LoadData(double* data_ptr, data_size_t num_row); + /*! + * \brief Update the data in a `ColumnVector` object from an in-memory data buffer, by adding each value obtained + * in `data_ptr` to the existing values in the object's internal `Eigen::VectorXd`. + * + * \param data_ptr Pointer to first element of a contiguous array of data storing a vector. + * \param num_row Number of rows / elements in the vector. + */ void AddToData(double* data_ptr, data_size_t num_row); + /*! + * \brief Update the data in a `ColumnVector` object from an in-memory data buffer, by subtracting each value obtained + * in `data_ptr` from the existing values in the object's internal `Eigen::VectorXd`. + * + * \param data_ptr Pointer to first element of a contiguous array of data storing a vector. + * \param num_row Number of rows / elements in the vector. + */ void SubtractFromData(double* data_ptr, data_size_t num_row); + /*! + * \brief Update the data in a `ColumnVector` object from an in-memory data buffer, by substituting each value obtained + * in `data_ptr` for the existing values in the object's internal `Eigen::VectorXd`. + * + * \param data_ptr Pointer to first element of a contiguous array of data storing a vector. + * \param num_row Number of rows / elements in the vector. + */ void OverwriteData(double* data_ptr, data_size_t num_row); + /*! \brief Number of rows in the object's internal `Eigen::VectorXd`. */ inline data_size_t NumRows() {return data_.size();} + /*! \brief Return a reference to the object's internal `Eigen::VectorXd`, for interfaces that require a raw vector. */ inline Eigen::VectorXd& GetData() {return data_;} private: Eigen::VectorXd data_; void UpdateData(double* data_ptr, data_size_t num_row, std::function op); }; -/*! \brief API for loading and accessing data used to sample tree ensembles */ +/*! + * \brief API for loading and accessing data used to sample tree ensembles + * The covariates / bases / weights used in sampling forests are stored internally + * as a `ForestDataset` by the sampling functions (see \ref sampling_group). + */ class ForestDataset { public: /*! \brief Default constructor. No data is loaded at construction time. */ @@ -303,6 +437,27 @@ class ForestDataset { var_weights_.SetElement(i, temp_value); } } + /*! + * \brief Update an observation in the internal covariate matrix to a new value + * + * \param row Row number to be overwritten in the covariate matrix + * \param col Column number to be overwritten in the covariate matrix + * \param new_value New covariate value + */ + void SetCovariateValue(data_size_t row_id, int col, double new_value) { + covariates_.SetElement(row_id, col, new_value); + } + /*! + * \brief Update an observation in the internal basis matrix to a new value + * + * \param row Row number to be overwritten in the basis matrix + * \param col Column number to be overwritten in the basis matrix + * \param new_value New basis value + */ + void SetBasisValue(data_size_t row_id, int col, double new_value) { + CHECK(has_basis_); + basis_.SetElement(row_id, col, new_value); + } /*! * \brief Update an observation in the internal variance weight vector to a new value * @@ -419,6 +574,8 @@ class RandomEffectsDataset { bool has_group_labels_{false}; }; +/*! \} */ // end of data_group + } // namespace StochTree #endif // STOCHTREE_DATA_H_ diff --git a/include/stochtree/ensemble.h b/include/stochtree/ensemble.h index 5f4330d3..8ec37fbd 100644 --- a/include/stochtree/ensemble.h +++ b/include/stochtree/ensemble.h @@ -24,8 +24,26 @@ using json = nlohmann::json; namespace StochTree { +/*! + * \defgroup forest_group Forest API + * + * \brief Classes / functions for creating and modifying forests (i.e. ensembles of trees). + * + * \{ + */ + +/*! \brief Class storing a "forest," or an ensemble of decision trees. + */ class TreeEnsemble { public: + /*! + * \brief Initialize a new TreeEnsemble + * + * \param num_trees Number of trees in a forest + * \param output_dimension Dimension of the leaf node parameter + * \param is_leaf_constant Whether or not the leaves of each tree are treated as "constant." If true, then predicting from an ensemble is simply a matter or determining which leaf node an observation falls into. If false, prediction will multiply a leaf node's parameter(s) for a given observation by a basis vector. + * \param is_exponentiated Whether or not the leaves of each tree are stored in log scale. If true, leaf predictions are exponentiated before their prediction is returned. + */ TreeEnsemble(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false) { // Initialize trees in the ensemble trees_ = std::vector>(num_trees); @@ -39,6 +57,12 @@ class TreeEnsemble { is_leaf_constant_ = is_leaf_constant; is_exponentiated_ = is_exponentiated; } + + /*! + * \brief Initialize an ensemble based on the state of an existing ensemble + * + * \param ensemble `TreeEnsemble` used to initialize the current ensemble + */ TreeEnsemble(TreeEnsemble& ensemble) { // Unpack ensemble configurations num_trees_ = ensemble.num_trees_; @@ -56,31 +80,64 @@ class TreeEnsemble { this->CloneFromExistingTree(j, tree); } } + ~TreeEnsemble() {} + /*! + * \brief Return a pointer to a tree in the forest + * + * \param i Index (0-based) of a tree to be queried + * \return Tree* + */ inline Tree* GetTree(int i) { return trees_[i].get(); } + /*! + * \brief Reset a `TreeEnsemble` to all single-node "root" trees + */ inline void ResetRoot() { for (int i = 0; i < num_trees_; i++) { ResetInitTree(i); } } + /*! + * \brief Reset a single tree in an ensemble + * \todo Consider refactoring this and `ResetInitTree` + * + * \param i Index (0-based) of the tree to be reset + */ inline void ResetTree(int i) { trees_[i].reset(new Tree()); } + /*! + * \brief Reset a single tree in an ensemble + * \todo Consider refactoring this and `ResetTree` + * + * \param i Index (0-based) of the tree to be reset + */ inline void ResetInitTree(int i) { trees_[i].reset(new Tree()); trees_[i]->Init(output_dimension_, is_exponentiated_); } + /*! + * \brief Clone a single tree in an ensemble from an existing tree, overwriting current tree + * + * \param i Index of the tree to be overwritten + * \param tree Pointer to tree used to clone tree `i` + */ inline void CloneFromExistingTree(int i, Tree* tree) { return trees_[i]->CloneFromTree(tree); } + /*! + * \brief Reset an ensemble to clone another ensemble + * + * \param ensemble Reference to an existing `TreeEnsemble` + */ inline void ReconstituteFromForest(TreeEnsemble& ensemble) { // Delete old tree pointers trees_.clear(); @@ -462,6 +519,8 @@ class TreeEnsemble { bool is_exponentiated_; }; +/*! \} */ // end of forest_group + } // namespace StochTree #endif // STOCHTREE_ENSEMBLE_H_ diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index e0f0e50f..5711116d 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -23,6 +23,331 @@ namespace StochTree { +/*! + * \defgroup leaf_model_group Leaf Model API + * + * \brief Classes / functions for implementing leaf models. + * + * Stochastic tree algorithms are all essentially hierarchical + * models with an adaptive group structure defined by an ensemble + * of decision trees. Each novel model is governed by + * + * - A `LeafModel` class, defining the integrated likelihood and posterior, conditional on a particular tree structure + * - A `SuffStat` class that tracks and accumulates sufficient statistics necessary for a `LeafModel` + * + * To provide a thorough overview of this interface (and, importantly, how to extend it), we must introduce some mathematical notation. + * Any forest-based regression model involves an outcome, which we'll call \f$y\f$, and features (or "covariates"), which we'll call \f$X\f$. + * Our goal is to predict \f$y\f$ as a function of \f$X\f$, which we'll call \f$f(X)\f$. + * + * NOTE: if we have a more complicated, but still additive, model, such as \f$y = X\beta + f(X)\f$, then we can just model + * \f$y - X\beta = f(X)\f$, treating the residual \f$y - X\beta\f$ as the outcome data, and we are back to the general setting above. + * + * Now, since \f$f(X)\f$ is an additive tree ensemble, we can think of it as the sum of \f$b\f$ separate decision tree functions, + * where \f$b\f$ is the number of trees in an ensemble, so that + * + * \f[ + * f(X) = f_1(X) + \dots + f_b(X) + * \f] + * + * and each decision tree function \f$f_j\f$ has the property that features \f$X\f$ are used to determine which leaf node an observation + * falls into, and then the parameters attached to that leaf node are used to compute \f$f_j(X)\f$. The exact mechanics of this process + * are model-dependent, so now we introduce the "leaf node" models that `stochtree` supports. + * + * \section gaussian_constant_leaf_model Gaussian Constant Leaf Model + * + * The most standard and common tree ensemble is a sum of "constant leaf" trees, in which a leaf node's parameter uniquely determines the prediction + * for all observations that fall into that leaf. For example, if leaf 2 for a tree is reached by the conditions that \f$X_1 < 0.4 \; \& \; X_2 > 0.6\f$, then + * every observation whose first feature is less than 0.4 and whose second feature is greater than 0.6 will receive the same prediction. Mathematically, + * for an observation \f$i\f$ this looks like + * + * \f[ + * f_j(X_i) = \sum_{\ell \in L} \mathbb{1}(X_i \in \ell) \mu_{\ell} + * \f] + * + * where \f$L\f$ denotes the indices of every leaf node, \f$\mu_{\ell}\f$ is the parameter attached to leaf node \f$\ell\f$, and \f$\mathbb{1}(X \in \ell)\f$ + * checks whether \f$X_i\f$ falls into leaf node \f$\ell\f$. + * + * The way that we make such a model "stochastic" is by attaching to the leaf node parameters \f$\mu_{\ell}\f$ a "prior" distribution. + * This leaf model corresponds to the "classic" BART model of Chipman et al (2010) + * as well as its "XBART" extension (He and Hahn (2023)). + * We assign each leaf node parameter a prior + * + * \f[ + * \mu \sim \mathcal{N}\left(0, \tau\right) + * \f] + * + * Assuming a homoskedastic Gaussian outcome likelihood (i.e. \f$y_i \sim \mathcal{N}\left(f(X_i),\sigma^2\right)\f$), + * the log marginal likelihood in this model, for the outcome data in node \f$\ell\f$ of tree \f$j\f$ is given by + * + * \f[ + * L(y) = -\frac{n_{\ell}}{2}\log(2\pi) - n_{\ell}\log(\sigma) + \frac{1}{2} \log\left(\frac{\sigma^2}{n_{\ell} \tau + \sigma^2}\right) - \frac{s_{yy,\ell}}{2\sigma^2} + \frac{\tau s_{y,\ell}^2}{2\sigma^2(n_{\ell} \tau + \sigma^2)} + * \f] + * + * where + * + * \f[ + * n_{\ell} = \sum_{i : X_i \in \ell} 1 + * \f] + * + * \f[ + * s_{y,\ell} = \sum_{i : X_i \in \ell} r_i + * \f] + * + * \f[ + * s_{yy,\ell} = \sum_{i : X_i \in \ell} r_i^2 + * \f] + * + * \f[ + * r_i = y_i - \sum_{k \neq j} f_k(X_i) + * \f] + * + * In words, this model depends on the data for a given leaf node only through three sufficient statistics, \f$n_{\ell}\f$, \f$s_{y,\ell}\f$, and \f$s_{yy,\ell}\f$, + * and it only depends on the other trees in the ensemble through the "partial residual" \f$r_i\f$. The posterior distribution for + * node \f$\ell\f$'s leaf parameter is similarly defined as: + * + * \f[ + * \mu_{\ell} \mid - \sim \mathcal{N}\left(\frac{\tau s_{y,\ell}}{n_{\ell} \tau + \sigma^2}, \frac{\tau \sigma^2}{n_{\ell} \tau + \sigma^2}\right) + * \f] + * + * Now, consider the possibility that each observation carries a unique weight \f$w_i\f$. These could be "case weights" in a survey context or + * individual-level variances ("heteroskedasticity"). These case weights transform the outcome distribution (and associated likelihood) to + * + * \f[ + * y_i \mid - \sim \mathcal{N}\left(\mu(X_i), \frac{\sigma^2}{w_i}\right) + * \f] + * + * This gives a modified log marginal likelihood of + * + * \f[ + * L(y) = -\frac{n_{\ell}}{2}\log(2\pi) - \frac{1}{2} \sum_{i : X_i \in \ell} \log\left(\frac{\sigma^2}{w_i}\right) + \frac{1}{2} \log\left(\frac{\sigma^2}{s_{w,\ell} \tau + \sigma^2}\right) - \frac{s_{wyy,\ell}}{2\sigma^2} + \frac{\tau s_{wy,\ell}^2}{2\sigma^2(s_{w,\ell} \tau + \sigma^2)} + * \f] + * + * where + * + * \f[ + * s_{w,\ell} = \sum_{i : X_i \in \ell} w_i + * \f] + * + * \f[ + * s_{wy,\ell} = \sum_{i : X_i \in \ell} w_i r_i + * \f] + * + * \f[ + * s_{wyy,\ell} = \sum_{i : X_i \in \ell} w_i r_i^2 + * \f] + * + * Finally, note that when we consider splitting leaf \f$\ell\f$ into new left and right leaves, or pruning two nodes into a single leaf node, + * we compute the log marginal likelihood of the combined data and the log marginal likelihoods of the left and right leaves and compare these three values. + * + * The terms \f$\frac{n_{\ell}}{2}\log(2\pi)\f$, \f$\sum_{i : X_i \in \ell} \log\left(\frac{\sigma^2}{w_i}\right)\f$, and \f$\frac{s_{wyy,\ell}}{2\sigma^2}\f$ + * are such that their left and right node values will always sum to the respective value in the combined log marginal likelihood, so they can be ignored + * when evaluating splits or prunes and thus the reduced log marginal likelihood is + * + * \f[ + * L(y) \propto \frac{1}{2} \log\left(\frac{\sigma^2}{s_{w,\ell} \tau + \sigma^2}\right) + \frac{\tau s_{wy,\ell}^2}{2\sigma^2(n_{\ell} \tau + \sigma^2)} + * \f] + * + * So the \ref StochTree::GaussianConstantSuffStat "GaussianConstantSuffStat" class tracks a generalized version of these three statistics + * (which allows for each observation to have a weight \f$w_i \neq 1\f$): + * + * - \f$n_{\ell}\f$: `data_size_t n` + * - \f$s_{w,\ell}\f$: `double sum_w` + * - \f$s_{wy,\ell}\f$: `double sum_yw` + * + * And these values are used by the \ref StochTree::GaussianConstantLeafModel "GaussianConstantLeafModel" class in the + * \ref StochTree::GaussianConstantLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood", + * \ref StochTree::GaussianConstantLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood", + * \ref StochTree::GaussianConstantLeafModel::PosteriorParameterMean "PosteriorParameterMean", and + * \ref StochTree::GaussianConstantLeafModel::PosteriorParameterVariance "PosteriorParameterVariance" methods. + * To give one example, below is the implementation of \ref StochTree::GaussianConstantLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood": + * + * \code{.cpp} + * double left_log_ml = ( + * -0.5*std::log(1 + tau_*(left_stat.sum_w/global_variance)) + ((tau_*left_stat.sum_yw*left_stat.sum_yw)/(2.0*global_variance*(tau_*left_stat.sum_w + global_variance))) + * ); + * + * double right_log_ml = ( + * -0.5*std::log(1 + tau_*(right_stat.sum_w/global_variance)) + ((tau_*right_stat.sum_yw*right_stat.sum_yw)/(2.0*global_variance*(tau_*right_stat.sum_w + global_variance))) + * ); + * + * return left_log_ml + right_log_ml; + * \endcode + * + * \section gaussian_multivariate_regression_leaf_model Gaussian Multivariate Regression Leaf Model + * + * In this model, the tree defines a "partitioned linear model" in which leaf node parameters define regression weights + * that are multiplied by a "basis" \f$\Omega\f$ to determine the prediction for an observation. + * + * \f[ + * f_j(X_i) = \sum_{\ell \in L} \mathbb{1}(X_i \in \ell) \Omega_i \vec{\beta_{\ell}} + * \f] + * + * and we assign \f$\beta_{\ell}\f$ a prior of + * + * \f[ + * \vec{\beta_{\ell}} \sim \mathcal{N}\left(\vec{\beta_0}, \Sigma_0\right) + * \f] + * + * where \f$\vec{\beta_0}\f$ is typically a vector of zeros. The outcome likelihood is still + * + * \f[ + * y_i \sim \mathcal{N}\left(f(X_i), \sigma^2\right) + * \f] + * + * This gives a reduced log integrated likelihood of + * + * \f[ + * L(y) \propto - \frac{1}{2} \log\left(\textrm{det}\left(I_p + \frac{\Sigma_0\Omega'\Omega}{\sigma^2}\right)\right) + \frac{1}{2}\frac{y'\Omega}{\sigma^2}\left(\Sigma_0^{-1} + \frac{\Omega'\Omega}{\sigma^2}\right)^{-1}\frac{\Omega'y}{\sigma^2} + * \f] + * + * where \f$\Omega\f$ is a matrix of bases for every observation in leaf \f$\ell\f$ and \f$p\f$ is the dimension of \f$\Omega\f$. The posterior for \f$\vec{\beta_{\ell}}\f$ is + * + * \f[ + * \vec{\beta_{\ell}} \sim \mathcal{N}\left(\left(\Sigma_0^{-1} + \frac{\Omega'\Omega}{\sigma^2}\right)^{-1}\left(\frac{\Omega'y}{\sigma^2}\right),\left(\Sigma_0^{-1} + \frac{\Omega'\Omega}{\sigma^2}\right)^{-1}\right) + * \f] + * + * This is an extension of the single-tree model of Chipman et al (2002), with: + * + * - Support for using a separate basis for leaf model than the partitioning (i.e. tree) model (i.e. \f$X \neq \Omega\f$) + * - Support for multiple trees and sampling via grow-from-root (GFR) or MCMC + * + * We can also enable heteroskedasticity by defining a (diagonal) covariance matrix for the outcome likelihood + * + * \f[ + * \Sigma_y = \text{diag}\left(\sigma^2 / w_1,\sigma^2 / w_2,\dots,\sigma^2 / w_n\right) + * \f] + * + * This updates the reduced log integrated likelihood to + * + * \f[ + * L(y) \propto - \frac{1}{2} \log\left(\textrm{det}\left(I_p + \Sigma_{0}\Omega'\Sigma_y^{-1}\Omega\right)\right) + \frac{1}{2}y'\Sigma_{y}^{-1}\Omega\left(\Sigma_{0}^{-1} + \Omega'\Sigma_{y}^{-1}\Omega\right)^{-1}\Omega'\Sigma_{y}^{-1}y + * \f] + * + * and a posterior for \f$\vec{\beta_{\ell}}\f$ of + * + * \f[ + * \vec{\beta_{\ell}} \sim \mathcal{N}\left(\left(\Sigma_{0}^{-1} + \Omega'\Sigma_{y}^{-1}\Omega\right)^{-1}\left(\Omega'\Sigma_{y}^{-1}y\right),\left(\Sigma_{0}^{-1} + \Omega'\Sigma_{y}^{-1}\Omega\right)^{-1}\right) + * \f] + * + * \section gaussian_univariate_regression_leaf_model Gaussian Univariate Regression Leaf Model + * + * This specializes the Gaussian Multivariate Regression Leaf Model for a univariate leaf basis, which allows for several computational speedups (replacing generalized matrix operations with simple summation or sum-product operations). + * We simplify \f$\Omega\f$ to \f$\omega\f$, a univariate basis for every observation, so that \f$\Omega'\Omega = \sum_{i:i \in \ell}\omega_i^2\f$ and \f$\Omega'y = \sum_{i:i \in \ell}\omega_ir_i\f$. Similarly, the prior for the leaf + * parameter becomes univariate normal as in \ref gaussian_constant_leaf_model: + * + * \f[ + * \beta \sim \mathcal{N}\left(0, \tau\right) + * \f] + * + * Allowing for case / variance weights $w_i$ as above, we derive a reduced log marginal likelihood of + * + * \f[ + * L(y) \propto \frac{1}{2} \log\left(\frac{\sigma^2}{s_{wxx,\ell} \tau + \sigma^2}\right) + \frac{\tau s_{wyx,\ell}^2}{2\sigma^2(s_{wxx,\ell} \tau + \sigma^2)} + * \f] + * + * where + * + * \f[ + * s_{wxx,\ell} = \sum_{i : X_i \in \ell} w_i \omega_i \omega_i + * \f] + * + * \f[ + * s_{wyx,\ell} = \sum_{i : X_i \in \ell} w_i r_i \omega_i + * \f] + * + * and a posterior of + * + * \f[ + * \beta_{\ell} \mid - \sim \mathcal{N}\left(\frac{\tau s_{wyx,\ell}}{s_{wxx,\ell} \tau + \sigma^2}, \frac{\tau \sigma^2}{s_{wxx,\ell} \tau + \sigma^2}\right) + * \f] + * + * \section inverse_gamma_leaf_model Inverse Gamma Leaf Model + * + * Each of the above models is a variation on a theme: a conjugate, partitioned Gaussian leaf model. + * The inverse gamma leaf model allows for forest-based heteroskedasticity modeling using an inverse gamma prior on the exponentiated leaf parameter, as discussed in Murray (2021) + * Define a variance function based on an ensemble of \f$b\f$ trees as + * + * \f[ + * \sigma^2(X) = \exp\left(s_1(X) + \dots + s_b(X)\right) + * \f] + * + * where each tree function \f$s_j(X)\f$ is defined as + * + * \f[ + * s_j(X_i) = \sum_{\ell \in L} \mathbb{1}(X_i \in \ell) \lambda_{\ell} + * \f] + * + * We reparameterize \f$\lambda_{\ell} = \log(\mu_{\ell})\f$ and we place an inverse gamma prior on \f$\mu_{\ell}\f$ + * + * \f[ + * \mu_{\ell} \sim \text{IG}\left(a, b\right) + * \f] + * + * As noted in Murray (2021), this model no longer enables the "Bayesian backfitting" simplification + * of conjugated Gaussian leaf models, in which sampling updates for a given tree only depend on other trees in the ensemble via their imprint on the partial residual + * \f$r_i = y_i - \sum_{k \neq j} \mu_k(X_i)\f$. + * However, this model is part of a broader class of models with convenient "blocked MCMC" sampling updates (another important example being multinomial classification). + * + * Under an outcome model + * + * \f[ + * y \sim \mathcal{N}\left(f(X), \sigma_0^2 \sigma^2(X)\right) + * \f] + * + * updates to \f$\mu_{\ell}\f$ for a given tree \f$j\f$ are based on a reduced log marginal likelihood of + * + * \f[ + * L(y) \propto a \log (b) - \log \Gamma (a) + \log \Gamma \left(a + \frac{n_{\ell}}{2}\right) - \left(a + \frac{n_{\ell}}{2}\right) \left(b + \frac{s_{\sigma,\ell}}{2\sigma_0^2}\right) + * \f] + * + * where + * + * \f[ + * n_{\ell} = \sum_{i : X_i \in \ell} 1 + * \f] + * + * \f[ + * s_{\sigma,\ell} = \sum_{i: i \in \ell} \frac{(y_i - f(X_i))^2}{\prod_{k \neq j} s_k(X_i)} + * \f] + * + * and a posterior of + * + * \f[ + * \mu_{\ell} \mid - \sim \text{IG}\left( a + \frac{n_{\ell}}{2} , b + \frac{s_{\sigma,\ell}}{2\sigma_0^2} \right) + * \f] + * + * Thus, as above, we implement a sufficient statistic class (\ref StochTree::LogLinearVarianceSuffStat "LogLinearVarianceSuffStat"), which tracks + * + * - \f$n_{\ell}\f$: `data_size_t n` + * - \f$s_{\sigma,\ell}\f$: `double weighted_sum_ei` + * + * And these values are used by the \ref StochTree::LogLinearVarianceLeafModel "LogLinearVarianceLeafModel" class in the + * \ref StochTree::LogLinearVarianceLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood", + * \ref StochTree::LogLinearVarianceLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood", + * \ref StochTree::LogLinearVarianceLeafModel::PosteriorParameterShape "PosteriorParameterShape", and + * \ref StochTree::LogLinearVarianceLeafModel::PosteriorParameterScale "PosteriorParameterScale" methods. + * To give one example, below is the implementation of \ref StochTree::LogLinearVarianceLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood": + * + * \code{.cpp} + * double prior_terms = a_ * std::log(b_) - boost::math::lgamma(a_); + * double a_term = a_ + 0.5 * suff_stat.n; + * double b_term = b_ + ((0.5 * suff_stat.weighted_sum_ei) / global_variance); + * double log_b_term = std::log(b_term); + * double lgamma_a_term = boost::math::lgamma(a_term); + * double resid_term = a_term * log_b_term; + * double log_ml = prior_terms + lgamma_a_term - resid_term; + * return log_ml; + * \endcode + * + * \{ + */ + +/*! \brief Leaf models for the forest sampler: + * 1. `kConstantLeafGaussian`: Every leaf node has a zero-centered univariate normal prior and every leaf is constant. + * 2. `kUnivariateRegressionLeafGaussian`: Every leaf node has a zero-centered univariate normal prior and every leaf is a linear model, multiplying the leaf parameter by a (fixed) basis. + * 3. `kMultivariateRegressionLeafGaussian`: Every leaf node has a multivariate normal prior, centered around the zero vector, and every leaf is a linear model, matrix-multiplying the leaf parameters by a (fixed) basis vector. + * 4. `kLogLinearVariance`: Every leaf node has a inverse gamma prior and every leaf is constant. + */ enum ModelType { kConstantLeafGaussian, kUnivariateRegressionLeafGaussian, @@ -36,11 +361,23 @@ class GaussianConstantSuffStat { data_size_t n; double sum_w; double sum_yw; + /*! + * \brief Construct a new GaussianConstantSuffStat object, setting all sufficient statistics to zero + */ GaussianConstantSuffStat() { n = 0; sum_w = 0.0; sum_yw = 0.0; } + /*! + * \brief Accumulate data from observation `row_idx` into the sufficient statistics + * + * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights + * \param outcome Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` + * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state + * \param row_idx Index of the training data observation from which the sufficient statistics should be updated + * \param tree_idx Index of the tree being updated in the course of this sufficient statistic update + */ void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) { n += 1; if (dataset.HasVarWeights()) { @@ -51,27 +388,55 @@ class GaussianConstantSuffStat { sum_yw += outcome(row_idx, 0); } } + /*! + * \brief Reset all of the sufficient statistics to zero + */ void ResetSuffStat() { n = 0; sum_w = 0.0; sum_yw = 0.0; } + /*! + * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` + * + * \param lhs First sufficient statistic ("left hand side") + * \param rhs Second sufficient statistic ("right hand side") + */ void AddSuffStat(GaussianConstantSuffStat& lhs, GaussianConstantSuffStat& rhs) { n = lhs.n + rhs.n; sum_w = lhs.sum_w + rhs.sum_w; sum_yw = lhs.sum_yw + rhs.sum_yw; } + /*! + * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` + * + * \param lhs First sufficient statistic ("left hand side") + * \param rhs Second sufficient statistic ("right hand side") + */ void SubtractSuffStat(GaussianConstantSuffStat& lhs, GaussianConstantSuffStat& rhs) { n = lhs.n - rhs.n; sum_w = lhs.sum_w - rhs.sum_w; sum_yw = lhs.sum_yw - rhs.sum_yw; } + /*! + * \brief Check whether accumulated sample size, `n`, is greater than some threshold + * + * \param threshold Value used to compute `n > threshold` + */ bool SampleGreaterThan(data_size_t threshold) { return n > threshold; } + /*! + * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold + * + * \param threshold Value used to compute `n >= threshold` + */ bool SampleGreaterThanEqual(data_size_t threshold) { return n >= threshold; } + /*! + * \brief Return the sample size accumulated by a sufficient stat object + */ data_size_t SampleSize() { return n; } @@ -80,15 +445,64 @@ class GaussianConstantSuffStat { /*! \brief Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model */ class GaussianConstantLeafModel { public: + /*! + * \brief Construct a new GaussianConstantLeafModel object + * + * \param tau Leaf node prior scale parameter + */ GaussianConstantLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();} ~GaussianConstantLeafModel() {} + /*! + * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. + * + * \param left_stat Sufficient statistics of the left node formed by the proposed split + * \param right_stat Sufficient statistics of the right node formed by the proposed split + * \param global_variance Global error variance parameter + */ double SplitLogMarginalLikelihood(GaussianConstantSuffStat& left_stat, GaussianConstantSuffStat& right_stat, double global_variance); + /*! + * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. + * + * \param suff_stat Sufficient statistics of the node being evaluated + * \param global_variance Global error variance parameter + */ double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat& suff_stat, double global_variance); + /*! + * \brief Leaf node posterior mean. + * + * \param suff_stat Sufficient statistics of the node being evaluated + * \param global_variance Global error variance parameter + */ double PosteriorParameterMean(GaussianConstantSuffStat& suff_stat, double global_variance); + /*! + * \brief Leaf node posterior variance. + * + * \param suff_stat Sufficient statistics of the node being evaluated + * \param global_variance Global error variance parameter + */ double PosteriorParameterVariance(GaussianConstantSuffStat& suff_stat, double global_variance); + /*! + * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters + * + * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights + * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state + * \param residual Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` + * \param tree Tree to be updated + * \param tree_num Integer index of tree to be updated + * \param global_variance Value of the global error variance parameter + * \param gen C++ random number generator + */ void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen); void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value); + /*! + * \brief Set a new value for the leaf node scale parameter + * + * \param tau Leaf node prior scale parameter + */ void SetScale(double tau) {tau_ = tau;} + /*! + * \brief Whether this model requires a basis vector for posterior inference and prediction + */ inline bool RequiresBasis() {return false;} private: double tau_; @@ -101,11 +515,23 @@ class GaussianUnivariateRegressionSuffStat { data_size_t n; double sum_xxw; double sum_yxw; + /*! + * \brief Construct a new GaussianUnivariateRegressionSuffStat object, setting all sufficient statistics to zero + */ GaussianUnivariateRegressionSuffStat() { n = 0; sum_xxw = 0.0; sum_yxw = 0.0; } + /*! + * \brief Accumulate data from observation `row_idx` into the sufficient statistics + * + * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights + * \param outcome Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` + * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state + * \param row_idx Index of the training data observation from which the sufficient statistics should be updated + * \param tree_idx Index of the tree being updated in the course of this sufficient statistic update + */ void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) { n += 1; if (dataset.HasVarWeights()) { @@ -116,27 +542,55 @@ class GaussianUnivariateRegressionSuffStat { sum_yxw += outcome(row_idx, 0)*dataset.BasisValue(row_idx, 0); } } + /*! + * \brief Reset all of the sufficient statistics to zero + */ void ResetSuffStat() { n = 0; sum_xxw = 0.0; sum_yxw = 0.0; } + /*! + * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` + * + * \param lhs First sufficient statistic ("left hand side") + * \param rhs Second sufficient statistic ("right hand side") + */ void AddSuffStat(GaussianUnivariateRegressionSuffStat& lhs, GaussianUnivariateRegressionSuffStat& rhs) { n = lhs.n + rhs.n; sum_xxw = lhs.sum_xxw + rhs.sum_xxw; sum_yxw = lhs.sum_yxw + rhs.sum_yxw; } + /*! + * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` + * + * \param lhs First sufficient statistic ("left hand side") + * \param rhs Second sufficient statistic ("right hand side") + */ void SubtractSuffStat(GaussianUnivariateRegressionSuffStat& lhs, GaussianUnivariateRegressionSuffStat& rhs) { n = lhs.n - rhs.n; sum_xxw = lhs.sum_xxw - rhs.sum_xxw; sum_yxw = lhs.sum_yxw - rhs.sum_yxw; } + /*! + * \brief Check whether accumulated sample size, `n`, is greater than some threshold + * + * \param threshold Value used to compute `n > threshold` + */ bool SampleGreaterThan(data_size_t threshold) { return n > threshold; } + /*! + * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold + * + * \param threshold Value used to compute `n >= threshold` + */ bool SampleGreaterThanEqual(data_size_t threshold) { return n >= threshold; } + /*! + * \brief Return the sample size accumulated by a sufficient stat object + */ data_size_t SampleSize() { return n; } @@ -147,10 +601,46 @@ class GaussianUnivariateRegressionLeafModel { public: GaussianUnivariateRegressionLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();} ~GaussianUnivariateRegressionLeafModel() {} + /*! + * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. + * + * \param left_stat Sufficient statistics of the left node formed by the proposed split + * \param right_stat Sufficient statistics of the right node formed by the proposed split + * \param global_variance Global error variance parameter + */ double SplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& left_stat, GaussianUnivariateRegressionSuffStat& right_stat, double global_variance); + /*! + * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. + * + * \param suff_stat Sufficient statistics of the node being evaluated + * \param global_variance Global error variance parameter + */ double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance); + /*! + * \brief Leaf node posterior mean. + * + * \param suff_stat Sufficient statistics of the node being evaluated + * \param global_variance Global error variance parameter + */ double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance); + /*! + * \brief Leaf node posterior variance. + * + * \param suff_stat Sufficient statistics of the node being evaluated + * \param global_variance Global error variance parameter + */ double PosteriorParameterVariance(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance); + /*! + * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters + * + * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights + * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state + * \param residual Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` + * \param tree Tree to be updated + * \param tree_num Integer index of tree to be updated + * \param global_variance Value of the global error variance parameter + * \param gen C++ random number generator + */ void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen); void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value); void SetScale(double tau) {tau_ = tau;} @@ -167,12 +657,26 @@ class GaussianMultivariateRegressionSuffStat { int p; Eigen::MatrixXd XtWX; Eigen::MatrixXd ytWX; + /*! + * \brief Construct a new GaussianMultivariateRegressionSuffStat object + * + * \param basis_dim Size of the basis vector that defines the leaf regression + */ GaussianMultivariateRegressionSuffStat(int basis_dim) { n = 0; XtWX = Eigen::MatrixXd::Zero(basis_dim, basis_dim); ytWX = Eigen::MatrixXd::Zero(1, basis_dim); p = basis_dim; } + /*! + * \brief Accumulate data from observation `row_idx` into the sufficient statistics + * + * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights + * \param outcome Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` + * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state + * \param row_idx Index of the training data observation from which the sufficient statistics should be updated + * \param tree_idx Index of the tree being updated in the course of this sufficient statistic update + */ void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) { n += 1; if (dataset.HasVarWeights()) { @@ -183,27 +687,55 @@ class GaussianMultivariateRegressionSuffStat { ytWX += (outcome(row_idx, 0)*(dataset.GetBasis()(row_idx, Eigen::all))); } } + /*! + * \brief Reset all of the sufficient statistics to zero + */ void ResetSuffStat() { n = 0; XtWX = Eigen::MatrixXd::Zero(p, p); ytWX = Eigen::MatrixXd::Zero(1, p); } + /*! + * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` + * + * \param lhs First sufficient statistic ("left hand side") + * \param rhs Second sufficient statistic ("right hand side") + */ void AddSuffStat(GaussianMultivariateRegressionSuffStat& lhs, GaussianMultivariateRegressionSuffStat& rhs) { n = lhs.n + rhs.n; XtWX = lhs.XtWX + rhs.XtWX; ytWX = lhs.ytWX + rhs.ytWX; } + /*! + * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` + * + * \param lhs First sufficient statistic ("left hand side") + * \param rhs Second sufficient statistic ("right hand side") + */ void SubtractSuffStat(GaussianMultivariateRegressionSuffStat& lhs, GaussianMultivariateRegressionSuffStat& rhs) { n = lhs.n - rhs.n; XtWX = lhs.XtWX - rhs.XtWX; ytWX = lhs.ytWX - rhs.ytWX; } + /*! + * \brief Check whether accumulated sample size, `n`, is greater than some threshold + * + * \param threshold Value used to compute `n > threshold` + */ bool SampleGreaterThan(data_size_t threshold) { return n > threshold; } + /*! + * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold + * + * \param threshold Value used to compute `n >= threshold` + */ bool SampleGreaterThanEqual(data_size_t threshold) { return n >= threshold; } + /*! + * \brief Return the sample size accumulated by a sufficient stat object + */ data_size_t SampleSize() { return n; } @@ -212,12 +744,53 @@ class GaussianMultivariateRegressionSuffStat { /*! \brief Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model */ class GaussianMultivariateRegressionLeafModel { public: + /*! + * \brief Construct a new GaussianMultivariateRegressionLeafModel object + * + * \param Sigma_0 Prior covariance, must have the same number of rows and columns as dimensions of the basis vector for the multivariate regression problem + */ GaussianMultivariateRegressionLeafModel(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0; multivariate_normal_sampler_ = MultivariateNormalSampler();} ~GaussianMultivariateRegressionLeafModel() {} + /*! + * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. + * + * \param left_stat Sufficient statistics of the left node formed by the proposed split + * \param right_stat Sufficient statistics of the right node formed by the proposed split + * \param global_variance Global error variance parameter + */ double SplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& left_stat, GaussianMultivariateRegressionSuffStat& right_stat, double global_variance); + /*! + * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. + * + * \param suff_stat Sufficient statistics of the node being evaluated + * \param global_variance Global error variance parameter + */ double NoSplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance); + /*! + * \brief Leaf node posterior mean. + * + * \param suff_stat Sufficient statistics of the node being evaluated + * \param global_variance Global error variance parameter + */ Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance); + /*! + * \brief Leaf node posterior variance. + * + * \param suff_stat Sufficient statistics of the node being evaluated + * \param global_variance Global error variance parameter + */ Eigen::MatrixXd PosteriorParameterVariance(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance); + /*! + * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters + * + * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights + * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state + * \param residual Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` + * \param tree Tree to be updated + * \param tree_num Integer index of tree to be updated + * \param global_variance Value of the global error variance parameter + * \param gen C++ random number generator + */ void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen); void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value); void SetScale(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0;} @@ -236,28 +809,65 @@ class LogLinearVarianceSuffStat { n = 0; weighted_sum_ei = 0.0; } + /*! + * \brief Accumulate data from observation `row_idx` into the sufficient statistics + * + * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights + * \param outcome Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` + * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state + * \param row_idx Index of the training data observation from which the sufficient statistics should be updated + * \param tree_idx Index of the tree being updated in the course of this sufficient statistic update + */ void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) { n += 1; weighted_sum_ei += std::exp(std::log(outcome(row_idx)*outcome(row_idx)) - tracker.GetSamplePrediction(row_idx) + tracker.GetTreeSamplePrediction(row_idx, tree_idx)); } + /*! + * \brief Reset all of the sufficient statistics to zero + */ void ResetSuffStat() { n = 0; weighted_sum_ei = 0.0; } + /*! + * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` + * + * \param lhs First sufficient statistic ("left hand side") + * \param rhs Second sufficient statistic ("right hand side") + */ void AddSuffStat(LogLinearVarianceSuffStat& lhs, LogLinearVarianceSuffStat& rhs) { n = lhs.n + rhs.n; weighted_sum_ei = lhs.weighted_sum_ei + rhs.weighted_sum_ei; } + /*! + * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` + * + * \param lhs First sufficient statistic ("left hand side") + * \param rhs Second sufficient statistic ("right hand side") + */ void SubtractSuffStat(LogLinearVarianceSuffStat& lhs, LogLinearVarianceSuffStat& rhs) { n = lhs.n - rhs.n; weighted_sum_ei = lhs.weighted_sum_ei - rhs.weighted_sum_ei; } + /*! + * \brief Check whether accumulated sample size, `n`, is greater than some threshold + * + * \param threshold Value used to compute `n > threshold` + */ bool SampleGreaterThan(data_size_t threshold) { return n > threshold; } + /*! + * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold + * + * \param threshold Value used to compute `n >= threshold` + */ bool SampleGreaterThanEqual(data_size_t threshold) { return n >= threshold; } + /*! + * \brief Return the sample size accumulated by a sufficient stat object + */ data_size_t SampleSize() { return n; } @@ -268,11 +878,47 @@ class LogLinearVarianceLeafModel { public: LogLinearVarianceLeafModel(double a, double b) {a_ = a; b_ = b; gamma_sampler_ = GammaSampler();} ~LogLinearVarianceLeafModel() {} + /*! + * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. + * + * \param left_stat Sufficient statistics of the left node formed by the proposed split + * \param right_stat Sufficient statistics of the right node formed by the proposed split + * \param global_variance Global error variance parameter + */ double SplitLogMarginalLikelihood(LogLinearVarianceSuffStat& left_stat, LogLinearVarianceSuffStat& right_stat, double global_variance); + /*! + * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. + * + * \param suff_stat Sufficient statistics of the node being evaluated + * \param global_variance Global error variance parameter + */ double NoSplitLogMarginalLikelihood(LogLinearVarianceSuffStat& suff_stat, double global_variance); double SuffStatLogMarginalLikelihood(LogLinearVarianceSuffStat& suff_stat, double global_variance); + /*! + * \brief Leaf node posterior shape parameter. + * + * \param suff_stat Sufficient statistics of the node being evaluated + * \param global_variance Global error variance parameter + */ double PosteriorParameterShape(LogLinearVarianceSuffStat& suff_stat, double global_variance); + /*! + * \brief Leaf node posterior scale parameter. + * + * \param suff_stat Sufficient statistics of the node being evaluated + * \param global_variance Global error variance parameter + */ double PosteriorParameterScale(LogLinearVarianceSuffStat& suff_stat, double global_variance); + /*! + * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters + * + * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights + * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state + * \param residual Data object containing the "full" residual net of all the model's mean terms + * \param tree Tree to be updated + * \param tree_num Integer index of tree to be updated + * \param global_variance Value of the global error variance parameter + * \param gen C++ random number generator + */ void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen); void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value); void SetPriorShape(double a) {a_ = a;} @@ -284,11 +930,27 @@ class LogLinearVarianceLeafModel { GammaSampler gamma_sampler_; }; +/*! + * \brief Unifying layer for disparate sufficient statistic class types + * + * Joins together GaussianConstantSuffStat, GaussianUnivariateRegressionSuffStat, + * GaussianMultivariateRegressionSuffStat, and LogLinearVarianceSuffStat + * as a combined "variant" type. See the std::variant documentation + * for more detail. + */ using SuffStatVariant = std::variant; +/*! + * \brief Unifying layer for disparate leaf model class types + * + * Joins together GaussianConstantLeafModel, GaussianUnivariateRegressionLeafModel, + * GaussianMultivariateRegressionLeafModel, and LogLinearVarianceLeafModel + * as a combined "variant" type. See the std::variant documentation + * for more detail. + */ using LeafModelVariant = std::variant(); @@ -316,6 +984,15 @@ static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_di } } +/*! + * \brief Factory function that creates a new `LeafModel` object for the specified model type + * + * \param model_type Enumeration storing the model type + * \param tau Value of the leaf node prior scale parameter, only used if `model_type = kConstantLeafGaussian` or `model_type = kUnivariateRegressionLeafGaussian` + * \param Sigma0 Value of the leaf node prior covariance matrix, only used if `model_type = kMultivariateRegressionLeafGaussian` + * \param a Value of the leaf node inverse gamma prior shape parameter, only used if `model_type = kLogLinearVariance` + * \param b Value of the leaf node inverse gamma prior scale parameter, only used if `model_type = kLogLinearVariance` + */ static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b) { if (model_type == kConstantLeafGaussian) { return createLeafModel(tau); @@ -323,10 +1000,8 @@ static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau return createLeafModel(tau); } else if (model_type == kMultivariateRegressionLeafGaussian) { return createLeafModel(Sigma0); - } else if (model_type == kLogLinearVariance) { - return createLeafModel(a, b); } else { - Log::Fatal("Incompatible model type provided to leaf model factory"); + return createLeafModel(a, b); } } @@ -423,6 +1098,8 @@ static inline void AccumulateCutpointBinSuffStat(SuffStatType& left_suff_stat, F } } +/*! \} */ // end of leaf_model_group + } // namespace StochTree #endif // STOCHTREE_LEAF_MODEL_H_ diff --git a/include/stochtree/mainpage.h b/include/stochtree/mainpage.h new file mode 100644 index 00000000..71da0945 --- /dev/null +++ b/include/stochtree/mainpage.h @@ -0,0 +1,118 @@ +#ifndef STOCHTREE_MAINPAGE_H_ +#define STOCHTREE_MAINPAGE_H_ + +/*! + * \mainpage stochtree C++ Documentation + * + * \section getting-started Getting Started + * + * `stochtree` can be built and run as a standalone C++ program directly from source using `cmake`: + * + * \subsection cloning-repo Cloning the Repository + * + * To clone the repository, you must have git installed, which you can do following these instructions. + * + * Once git is available at the command line, navigate to the folder that will store this project (in bash / zsh, this is done by running `cd` followed by the path to the directory). + * Then, clone the `stochtree` repo as a subfolder by running + * \code{.sh} + * git clone --recursive https://github.com/StochasticTree/stochtree.git + * \endcode + * + * NOTE: this project incorporates several dependencies as git submodules, + * which is why the `--recursive` flag is necessary (some systems may perform a recursive clone without this flag, but + * `--recursive` ensures this behavior on all platforms). If you have already cloned the repo without the `--recursive` flag, + * you can retrieve the submodules recursively by running `git submodule update --init --recursive` in the main repo directory. + * + * \section key-components Key Components + * + * The stochtree C++ core consists of thousands of lines of C++ code, but it can organized and understood through several components (see [topics](topics.html) for more detail): + * + * - Trees: the most important "primitive" of decision tree algorithms is the \ref tree_group "decision tree itself", which in stochtree is defined by a \ref StochTree::Tree "Tree" class as well as a series of static helper functions for prediction. + * - Forest: individual trees are combined into a \ref forest_group "forest", or ensemble, which in stochtree is defined by the \ref StochTree::TreeEnsemble "TreeEnsemble" class and a container of forests is defined by the \ref StochTree::ForestContainer "ForestContainer" class. + * - Dataset: data can be loaded from a variety of sources into a `stochtree` \ref data_group "data layer". + * - Leaf Model: `stochtree`'s data structures are generalized to support a wide range of models, which are defined via specialized classes in the \ref leaf_model_group "leaf model layer". + * - Sampler: helper functions that sample forests from training data comprise the \ref sampling_group "sampling layer" of `stochtree`. + * + * \section extending-stochtree Extending `stochtree` + * + * \subsection custom-leaf-models Custom Leaf Models + * + * The \ref leaf_model_group "leaf model documentation" details the key components of new decision tree models: + * custom `LeafModel` and `SuffStat` classes that implement a model's log marginal likelihood and posterior computations. + * + * Adding a new leaf model will consist largely of implementing new versions of each of these classes which track the + * API of the existing classes. Once these classes exist, they need to be reflected in several places. + * + * Suppose, for the sake of illustration, that the newest custom leaf model is a multinomial logit model. + * + * First, add an entry to the \ref StochTree::ModelType "ModelType" enumeration for this new model type + * + * \code{.cpp} + * enum ModelType { + * kConstantLeafGaussian, + * kUnivariateRegressionLeafGaussian, + * kMultivariateRegressionLeafGaussian, + * kLogLinearVariance, + * kMultinomialLogit, + * }; + * \endcode + * + * Next, add entries to the `std::variants` that bundle related `SuffStat` and `LeafModel` classes + * + * \code{.cpp} + * using SuffStatVariant = std::variant; + * \endcode + * + * \code{.cpp} + * using LeafModelVariant = std::variant; + * \endcode + * + * Finally, update the \ref StochTree::suffStatFactory "suffStatFactory" and \ref StochTree::leafModelFactory "leafModelFactory" functions to add a logic branch registering these new objects + * + * \code{.cpp} + * static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_dim = 0) { + * if (model_type == kConstantLeafGaussian) { + * return createSuffStat(); + * } else if (model_type == kUnivariateRegressionLeafGaussian) { + * return createSuffStat(); + * } else if (model_type == kMultivariateRegressionLeafGaussian) { + * return createSuffStat(basis_dim); + * } else if (model_type == kLogLinearVariance) { + * return createSuffStat(); + * } else if (model_type == kMultinomialLogit) { + * return createSuffStat(); + * } else { + * Log::Fatal("Incompatible model type provided to suff stat factory"); + * } + * } + * \endcode + * + * \code{.cpp} + * static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b) { + * if (model_type == kConstantLeafGaussian) { + * return createLeafModel(tau); + * } else if (model_type == kUnivariateRegressionLeafGaussian) { + * return createLeafModel(tau); + * } else if (model_type == kMultivariateRegressionLeafGaussian) { + * return createLeafModel(Sigma0); + * } else if (model_type == kLogLinearVariance) { + * return createLeafModel(a, b); + * } else if (model_type == kMultinomialLogit) { + * return createLeafModel(); + * } else { + * Log::Fatal("Incompatible model type provided to leaf model factory"); + * } + * } + * \endcode + * + */ + +#endif // STOCHTREE_MAINPAGE_H_ diff --git a/include/stochtree/tree.h b/include/stochtree/tree.h index b252f9cb..5a820d66 100644 --- a/include/stochtree/tree.h +++ b/include/stochtree/tree.h @@ -55,7 +55,17 @@ enum FeatureSplitType { /*! \brief Forward declaration of TreeSplit class */ class TreeSplit; -/*! \brief API for constructing decision trees (splitting, pruning, setting parameter values) */ +/*! + * \defgroup tree_group Tree API + * + * \brief Classes / functions for creating and modifying decision trees. + * + * \section tree_design Design + * + * \{ + */ + +/*! \brief Decision tree data structure */ class Tree { public: static constexpr std::int32_t kInvalidNodeId{-1}; @@ -961,6 +971,8 @@ class TreeSplit { std::vector split_categories_; }; +/*! \} */ // end of tree_group + } // namespace StochTree #endif // STOCHTREE_TREE_H_ diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index d8db03b9..a47660ea 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -22,6 +22,31 @@ namespace StochTree { +/*! + * \defgroup sampling_group Forest Sampler API + * + * \brief Functions for sampling from a forest. The core interfce of these functions, + * as used by the R, Python, and standalone C++ program, is defined by + * \ref MCMCSampleOneIter, which runs one iteration of the MCMC sampler for a + * given forest, and \ref GFRSampleOneIter, which runs one iteration of the + * grow-from-root (GFR) algorithm for a given forest. All other functions are + * essentially helpers used in a sampling function, which are documented here + * to make extending the C++ codebase more straightforward. + * + * \{ + */ + +/*! + * \brief Computer the range of available split values for a continuous variable, given the current structure of a tree. + * + * \param tracker Tracking data structures that speed up sampler operations. + * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights. + * \param tree_num Index of the tree for which a split is proposed. + * \param leaf_split Index of the leaf in `tree_num` for which a split is proposed. + * \param feature_split Index of the feature that we will query the available range. + * \param var_min Current minimum feature value (called by refence and modified by this function). + * \param var_max Current maximum feature value (called by refence and modified by this function). + */ static inline void VarSplitRange(ForestTracker& tracker, ForestDataset& dataset, int tree_num, int leaf_split, int feature_split, double& var_min, double& var_max) { var_min = std::numeric_limits::max(); var_max = std::numeric_limits::min(); @@ -41,6 +66,17 @@ static inline void VarSplitRange(ForestTracker& tracker, ForestDataset& dataset, } } +/*! + * \brief Determines whether a proposed split creates two leaf nodes with constant values for every feature (thus ensuring that the tree cannot split further). + * + * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights. + * \param tracker Tracking data structures that speed up sampler operations. + * \param split Proposed split of tree `tree_num` at node `leaf_split`. + * \param tree_num Index of the tree for which a split is proposed. + * \param leaf_split Index of the leaf in `tree_num` for which a split is proposed. + * \param feature_split Index of the feature to which `split` will be applied + * \return `true` if `split` creates two nodes with constant values for every feature in `dataset`, `false` otherwise. + */ static inline bool NodesNonConstantAfterSplit(ForestDataset& dataset, ForestTracker& tracker, TreeSplit& split, int tree_num, int leaf_split, int feature_split) { int p = dataset.GetCovariates().cols(); data_size_t idx; @@ -409,39 +445,6 @@ static inline std::tuple EvaluateExist return std::tuple(split_log_ml, no_split_log_ml, left_n, right_n); } -// template -// static inline void ModelInitialization(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, -// ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, -// std::mt19937& gen, std::vector& variable_weights, double global_variance, -// bool pre_initialized, bool backfitting, int prev_num_samples, bool var_trees = false) { -// if ((prev_num_samples == 0) && (!pre_initialized)) { -// // Add new forest to the container -// forests.AddSamples(1); - -// // Set initial value for each leaf in the forest -// double leaf_value; -// if (var_trees) { -// leaf_value = std::log(ComputeVarianceOutcome(residual)) / static_cast(forests.NumTrees()); -// } else { -// leaf_value = ComputeMeanOutcome(residual) / static_cast(forests.NumTrees()); -// } -// TreeEnsemble* ensemble = forests.GetEnsemble(0); -// leaf_model.SetEnsembleRootPredictedValue(dataset, ensemble, leaf_value); -// tracker.AssignAllSamplesToConstantPrediction(leaf_value); -// } else if (prev_num_samples > 0) { -// // Add new forest to the container -// forests.AddSamples(1); - -// // NOTE: only doing this for the simplicity of the partial residual step -// // We could alternatively "reach back" to the tree predictions from a previous -// // sample (whenever there is more than one sample). This is cleaner / quicker -// // to implement during this refactor. -// forests.CopyFromPreviousSample(prev_num_samples, prev_num_samples - 1); -// } else { -// forests.IncrementSampleCount(); -// } -// } - template static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { @@ -729,25 +732,41 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore } } +/*! + * Runs one iteration of the "grow-from-root" (GFR) sampler for a tree ensemble model, which consists of two steps for every tree in a forest: + * 1. Grow a tree by recursively sampling cutpoint via the GFR algorithm + * 2. Sampling leaf node parameters, conditional on an updated tree, via a Gibbs sampler + * + * \tparam LeafModel Leaf model type (i.e. `GaussianConstantLeafModel`, `GaussianUnivariateRegressionLeafModel`, etc...) + * \tparam LeafSuffStat Leaf sufficient statistic type (i.e. `GaussianConstantSuffStat`, `GaussianUnivariateRegressionSuffStat`, etc...) + * \tparam LeafSuffStatConstructorArgs Type of constructor arguments used to initialize `LeafSuffStat` class. For `GaussianMultivariateRegressionSuffStat`, + * this is `int`, while each of the other three sufficient statistic classes do not take a constructor argument. + * \param active_forest Current state of an ensemble from the sampler's perspective. This is managed through an "active forest" class, as distinct from a "forest container" class + * of stored ensemble samples because we often wish to update model state without saving the result (e.g. during burn-in or thinning of an MCMC sampler). + * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state. + * \param forests Container of "stored" forests. + * \param leaf_model Leaf model object -- type is determined by template argument `LeafModel`. + * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights. + * \param residual Data object containing residual used in training. The state of `residual` is updated by this function (the prior predictions of `active_forest` are added to the residual and the updated predictions from `active_forest` are subtracted back out). + * \param tree_prior Configuration for tree prior (i.e. max depth, min samples in a leaf, depth-defined split probability). + * \param gen Random number generator for sampler. + * \param variable_weights Vector of selection weights for each variable in `dataset`. + * \param global_variance Current value of (possibly stochastic) global error variance parameter. + * \param feature_types Enum-coded vector of feature types (see \ref FeatureType) for each feature in `dataset`. + * \param cutpoint_grid_size Maximum size of a grid of potential cutpoints (the grow-from-root algorithm evaluates a series of potential cutpoints for each feature and this parameter "thins" the cutpoint candidates for numeric variables). + * \param keep_forest Whether or not `active_forest` should be retained in `forests`. + * \param pre_initialized Whether or not `active_forest` has already been initialized (note: this parameter will be refactored out soon). + * \param backfitting Whether or not the sampler uses "backfitting" (wherein the sampler for a given tree only depends on the other trees via + * their effect on the residual) or the more general "blocked MCMC" (wherein the state of other trees must be more explicitly considered). + * \param leaf_suff_stat_args Any arguments which must be supplied to initialize a `LeafSuffStat` object. + */ template static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, double global_variance, std::vector& feature_types, int cutpoint_grid_size, bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { - // // Previous number of samples - // int prev_num_samples = forests.NumSamples(); - - // // Handle any "initialization" of a model (trees, ForestTracker, etc...) if this is the first sample and - // // the model was not pre-initialized - // bool var_trees; - // if (std::is_same_v) var_trees = true; - // else var_trees = false; - // ModelInitialization(tracker, forests, leaf_model, dataset, residual, tree_prior, gen, - // variable_weights, global_variance, pre_initialized, backfitting, - // prev_num_samples, var_trees); // Run the GFR algorithm for each tree - // TreeEnsemble* ensemble = forests.GetEnsemble(prev_num_samples); int num_trees = forests.NumTrees(); for (int i = 0; i < num_trees; i++) { // Adjust any model state needed to run a tree sampler @@ -813,8 +832,6 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM // Select a split variable at random int p = dataset.GetCovariates().cols(); CHECK_EQ(variable_weights.size(), p); - // std::vector var_weights(p); - // std::fill(var_weights.begin(), var_weights.end(), 1.0/p); std::discrete_distribution<> var_dist(variable_weights.begin(), variable_weights.end()); int var_chosen = var_dist(gen); @@ -1026,31 +1043,43 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For } } +/*! + * \brief Runs one iteration of the MCMC sampler for a tree ensemble model, which consists of two steps for every tree in a forest: + * 1. Sampling "birth-death" tree modifications via the Metropolis-Hastings algorithm + * 2. Sampling leaf node parameters, conditional on a (possibly-updated) tree, via a Gibbs sampler + * + * \tparam LeafModel Leaf model type (i.e. `GaussianConstantLeafModel`, `GaussianUnivariateRegressionLeafModel`, etc...) + * \tparam LeafSuffStat Leaf sufficient statistic type (i.e. `GaussianConstantSuffStat`, `GaussianUnivariateRegressionSuffStat`, etc...) + * \tparam LeafSuffStatConstructorArgs Type of constructor arguments used to initialize `LeafSuffStat` class. For `GaussianMultivariateRegressionSuffStat`, + * this is `int`, while each of the other three sufficient statistic classes do not take a constructor argument. + * \param active_forest Current state of an ensemble from the sampler's perspective. This is managed through an "active forest" class, as distinct from a "forest container" class + * of stored ensemble samples because we often wish to update model state without saving the result (e.g. during burn-in or thinning of an MCMC sampler). + * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state. + * \param forests Container of "stored" forests. + * \param leaf_model Leaf model object -- type is determined by template argument `LeafModel`. + * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights. + * \param residual Data object containing residual used in training. The state of `residual` is updated by this function (the prior predictions of `active_forest` are added to the residual and the updated predictions from `active_forest` are subtracted back out). + * \param tree_prior Configuration for tree prior (i.e. max depth, min samples in a leaf, depth-defined split probability). + * \param gen Random number generator for sampler. + * \param variable_weights Vector of selection weights for each variable in `dataset`. + * \param global_variance Current value of (possibly stochastic) global error variance parameter. + * \param keep_forest Whether or not `active_forest` should be retained in `forests`. + * \param pre_initialized Whether or not `active_forest` has already been initialized (note: this parameter will be refactored out soon). + * \param backfitting Whether or not the sampler uses "backfitting" (wherein the sampler for a given tree only depends on the other trees via + * their effect on the residual) or the more general "blocked MCMC" (wherein the state of other trees must be more explicitly considered). + * \param leaf_suff_stat_args Any arguments which must be supplied to initialize a `LeafSuffStat` object. + */ template static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { - // // Previous number of samples - // int prev_num_samples = forests.NumSamples(); - - // // Handle any "initialization" of a model (trees, ForestTracker, etc...) if this is the first sample and - // // the model was not pre-initialized - // bool var_trees; - // if (std::is_same_v) var_trees = true; - // else var_trees = false; - // ModelInitialization(tracker, forests, leaf_model, dataset, residual, tree_prior, gen, - // variable_weights, global_variance, pre_initialized, backfitting, - // prev_num_samples, var_trees); - // Run the MCMC algorithm for each tree - // TreeEnsemble* ensemble = forests.GetEnsemble(prev_num_samples); int num_trees = forests.NumTrees(); for (int i = 0; i < num_trees; i++) { // Adjust any model state needed to run a tree sampler // For models that involve Bayesian backfitting, this amounts to adding tree i's // predictions back to the residual (thus, training a model on the "partial residual") // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object - // Tree* tree = ensemble->GetTree(i); Tree* tree = active_forest.GetTree(i); AdjustStateBeforeTreeSampling(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i); @@ -1077,6 +1106,8 @@ static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& } } +/*! \} */ // end of sampling_group + } // namespace StochTree #endif // STOCHTREE_TREE_SAMPLER_H_ diff --git a/man/ForestModel.Rd b/man/ForestModel.Rd index 1d710ac1..f00d7566 100644 --- a/man/ForestModel.Rd +++ b/man/ForestModel.Rd @@ -24,6 +24,10 @@ trees, and exposes functionality to run a forest sampler \item \href{#method-ForestModel-sample_one_iteration}{\code{ForestModel$sample_one_iteration()}} \item \href{#method-ForestModel-propagate_basis_update}{\code{ForestModel$propagate_basis_update()}} \item \href{#method-ForestModel-propagate_residual_update}{\code{ForestModel$propagate_residual_update()}} +\item \href{#method-ForestModel-update_alpha}{\code{ForestModel$update_alpha()}} +\item \href{#method-ForestModel-update_beta}{\code{ForestModel$update_beta()}} +\item \href{#method-ForestModel-update_min_samples_leaf}{\code{ForestModel$update_min_samples_leaf()}} +\item \href{#method-ForestModel-update_max_depth}{\code{ForestModel$update_max_depth()}} } } \if{html}{\out{
}} @@ -182,4 +186,84 @@ This function is run after the \code{Outcome} class's \code{update_data} method, NULL } } +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModel-update_alpha}{}}} +\subsection{Method \code{update_alpha()}}{ +Update alpha in the tree prior +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModel$update_alpha(alpha)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{alpha}}{New value of alpha to be used} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +NULL +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModel-update_beta}{}}} +\subsection{Method \code{update_beta()}}{ +Update beta in the tree prior +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModel$update_beta(beta)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{beta}}{New value of beta to be used} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +NULL +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModel-update_min_samples_leaf}{}}} +\subsection{Method \code{update_min_samples_leaf()}}{ +Update min_samples_leaf in the tree prior +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModel$update_min_samples_leaf(min_samples_leaf)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{min_samples_leaf}}{New value of min_samples_leaf to be used} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +NULL +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModel-update_max_depth}{}}} +\subsection{Method \code{update_max_depth()}}{ +Update max_depth in the tree prior +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModel$update_max_depth(max_depth)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{max_depth}}{New value of max_depth to be used} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +NULL +} +} } diff --git a/man/bart.Rd b/man/bart.Rd index 1e8b8dd5..379c5c6c 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -73,7 +73,7 @@ that were not in the training set.} \item \code{cutpoint_grid_size} Maximum size of the "grid" of potential cutpoints to consider in the GFR algorithm. Default: \code{100}. \item \code{standardize} Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: \code{TRUE}. \item \code{sample_sigma2_global} Whether or not to update the \code{sigma^2} global error variance parameter based on \code{IG(sigma2_global_shape, sigma2_global_scale)}. Default: \code{TRUE}. -\item \code{sigma2_init} Starting value of global error variance parameter. Calibrated internally as \code{1.0*var((y_train-mean(y_train))/sd(y_train))} if not set. +\item \code{sigma2_global_init} Starting value of global error variance parameter. Calibrated internally as \code{1.0*var(y_train)}, where \code{y_train} is the possibly standardized outcome, if not set. \item \code{sigma2_global_shape} Shape parameter in the \code{IG(sigma2_global_shape, sigma2_global_scale)} global error variance model. Default: \code{0}. \item \code{sigma2_global_scale} Scale parameter in the \code{IG(sigma2_global_shape, sigma2_global_scale)} global error variance model. Default: \code{0}. \item \code{random_seed} Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to \code{std::random_device}. @@ -106,7 +106,7 @@ that were not in the training set.} \item \code{min_samples_leaf} Minimum allowable size of a leaf, in terms of training samples, in the variance model. Default: \code{5}. \item \code{max_depth} Maximum depth of any tree in the ensemble in the variance model. Default: \code{10}. Can be overridden with \code{-1} which does not enforce any depth limits on trees. \item \code{variable_weights} Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. -\item \code{init_root_val} Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as \code{log(0.6*var((y_train-mean(y_train))/sd(y_train)))/num_trees} if not set. +\item \code{var_forest_leaf_init} Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as \code{log(0.6*var(y_train))/num_trees}, where \code{y_train} is the possibly standardized outcome, if not set. \item \code{var_forest_prior_shape} Shape parameter in the \code{IG(var_forest_prior_shape, var_forest_prior_scale)} conditional error variance model (which is only sampled if \code{num_trees > 0}). Calibrated internally as \code{num_trees / 1.5^2 + 0.5} if not set. \item \code{var_forest_prior_scale} Scale parameter in the \code{IG(var_forest_prior_shape, var_forest_prior_scale)} conditional error variance model (which is only sampled if \code{num_trees > 0}). Calibrated internally as \code{num_trees / 1.5^2} if not set. }} diff --git a/man/bcf.Rd b/man/bcf.Rd index 653ea440..1aa2c76e 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -75,7 +75,7 @@ that were not in the training set.} \item \code{cutpoint_grid_size} Maximum size of the "grid" of potential cutpoints to consider in the GFR algorithm. Default: \code{100}. \item \code{standardize} Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: \code{TRUE}. \item \code{sample_sigma2_global} Whether or not to update the \code{sigma^2} global error variance parameter based on \code{IG(sigma2_global_shape, sigma2_global_scale)}. Default: \code{TRUE}. -\item \code{sigma2_init} Starting value of global error variance parameter. Calibrated internally as \code{1.0*var((y_train-mean(y_train))/sd(y_train))} if not set. +\item \code{sigma2_global_init} Starting value of global error variance parameter. Calibrated internally as \code{1.0*var((y_train-mean(y_train))/sd(y_train))} if not set. \item \code{sigma2_global_shape} Shape parameter in the \code{IG(sigma2_global_shape, sigma2_global_scale)} global error variance model. Default: \code{0}. \item \code{sigma2_global_scale} Scale parameter in the \code{IG(sigma2_global_shape, sigma2_global_scale)} global error variance model. Default: \code{0}. \item \code{variable_weights} Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to \code{1/ncol(X_train)}. A workaround if you wish to provide a custom weight for the propensity score is to include it as a column in \code{X_train} and then set \code{propensity_covariate} to \code{'none'} adjust \code{keep_vars_mu}, \code{keep_vars_tau} and \code{keep_vars_variance} accordingly. @@ -94,12 +94,12 @@ that were not in the training set.} \item{mu_forest_params}{(Optional) A list of prognostic forest model parameters, each of which has a default value processed internally, so this argument list is optional. \itemize{ -\item \code{num_trees} Number of trees in the ensemble for the conditional mean model. Default: \code{200}. If \code{num_trees = 0}, the conditional mean will not be modeled using a forest, and the function will only proceed if \code{num_trees > 0} for the variance forest. -\item \code{alpha} Prior probability of splitting for a tree of depth 0 in the mean model. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{0.95}. -\item \code{beta} Exponent that decreases split probabilities for nodes of depth > 0 in the mean model. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{2}. -\item \code{min_samples_leaf} Minimum allowable size of a leaf, in terms of training samples, in the mean model. Default: \code{5}. -\item \code{max_depth} Maximum depth of any tree in the ensemble in the mean model. Default: \code{10}. Can be overridden with \code{-1} which does not enforce any depth limits on trees. -\item \code{variable_weights} Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. +\item \code{num_trees} Number of trees in the ensemble for the prognostic forest. Default: \code{250}. Must be a positive integer. +\item \code{alpha} Prior probability of splitting for a tree of depth 0 in the prognostic forest. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{0.95}. +\item \code{beta} Exponent that decreases split probabilities for nodes of depth > 0 in the prognostic forest. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{2}. +\item \code{min_samples_leaf} Minimum allowable size of a leaf, in terms of training samples, in the prognostic forest. Default: \code{5}. +\item \code{max_depth} Maximum depth of any tree in the ensemble in the prognostic forest. Default: \code{10}. Can be overridden with \code{-1} which does not enforce any depth limits on trees. +\item \code{variable_weights} Numeric weights reflecting the relative probability of splitting on each variable in the prognostic forest. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. \item \code{sample_sigma2_leaf} Whether or not to update the leaf scale variance parameter based on \code{IG(sigma2_leaf_shape, sigma2_leaf_scale)}. Cannot (currently) be set to true if \code{ncol(W_train)>1}. Default: \code{FALSE}. \item \code{sigma2_leaf_init} Starting value of leaf node scale parameter. Calibrated internally as \code{1/num_trees} if not set here. \item \code{sigma2_leaf_shape} Shape parameter in the \code{IG(sigma2_leaf_shape, sigma2_leaf_scale)} leaf node parameter variance model. Default: \code{3}. @@ -110,12 +110,12 @@ that were not in the training set.} \item{tau_forest_params}{(Optional) A list of treatment effect forest model parameters, each of which has a default value processed internally, so this argument list is optional. \itemize{ -\item \code{num_trees} Number of trees in the ensemble for the conditional mean model. Default: \code{200}. If \code{num_trees = 0}, the conditional mean will not be modeled using a forest, and the function will only proceed if \code{num_trees > 0} for the variance forest. -\item \code{alpha} Prior probability of splitting for a tree of depth 0 in the mean model. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{0.95}. -\item \code{beta} Exponent that decreases split probabilities for nodes of depth > 0 in the mean model. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{2}. -\item \code{min_samples_leaf} Minimum allowable size of a leaf, in terms of training samples, in the mean model. Default: \code{5}. -\item \code{max_depth} Maximum depth of any tree in the ensemble in the mean model. Default: \code{10}. Can be overridden with \code{-1} which does not enforce any depth limits on trees. -\item \code{variable_weights} Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. +\item \code{num_trees} Number of trees in the ensemble for the treatment effect forest. Default: \code{50}. Must be a positive integer. +\item \code{alpha} Prior probability of splitting for a tree of depth 0 in the treatment effect forest. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{0.25}. +\item \code{beta} Exponent that decreases split probabilities for nodes of depth > 0 in the treatment effect forest. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{3}. +\item \code{min_samples_leaf} Minimum allowable size of a leaf, in terms of training samples, in the treatment effect forest. Default: \code{5}. +\item \code{max_depth} Maximum depth of any tree in the ensemble in the treatment effect forest. Default: \code{5}. Can be overridden with \code{-1} which does not enforce any depth limits on trees. +\item \code{variable_weights} Numeric weights reflecting the relative probability of splitting on each variable in the treatment effect forest. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. \item \code{sample_sigma2_leaf} Whether or not to update the leaf scale variance parameter based on \code{IG(sigma2_leaf_shape, sigma2_leaf_scale)}. Cannot (currently) be set to true if \code{ncol(W_train)>1}. Default: \code{FALSE}. \item \code{sigma2_leaf_init} Starting value of leaf node scale parameter. Calibrated internally as \code{1/num_trees} if not set here. \item \code{sigma2_leaf_shape} Shape parameter in the \code{IG(sigma2_leaf_shape, sigma2_leaf_scale)} leaf node parameter variance model. Default: \code{3}. @@ -131,8 +131,7 @@ that were not in the training set.} \item \code{beta} Exponent that decreases split probabilities for nodes of depth > 0 in the variance model. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{2}. \item \code{min_samples_leaf} Minimum allowable size of a leaf, in terms of training samples, in the variance model. Default: \code{5}. \item \code{max_depth} Maximum depth of any tree in the ensemble in the variance model. Default: \code{10}. Can be overridden with \code{-1} which does not enforce any depth limits on trees. -\item \code{variable_weights} Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. -\item \code{init_root_val} Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as \code{log(0.6*var((y_train-mean(y_train))/sd(y_train)))/num_trees} if not set. +\item \code{variance_forest_init} Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as \code{log(0.6*var((y_train-mean(y_train))/sd(y_train)))/num_trees} if not set. \item \code{var_forest_prior_shape} Shape parameter in the \code{IG(var_forest_prior_shape, var_forest_prior_scale)} conditional error variance model (which is only sampled if \code{num_trees > 0}). Calibrated internally as \code{num_trees / 1.5^2 + 0.5} if not set. \item \code{var_forest_prior_scale} Scale parameter in the \code{IG(var_forest_prior_shape, var_forest_prior_scale)} conditional error variance model (which is only sampled if \code{num_trees > 0}). Calibrated internally as \code{num_trees / 1.5^2} if not set. \item \code{keep_vars} Vector of variable names or column indices denoting variables that should be included in the forest. Default: \code{NULL}. diff --git a/man/convertBCFModelToJson.Rd b/man/convertBCFModelToJson.Rd index 29d1051c..71a7ff73 100644 --- a/man/convertBCFModelToJson.Rd +++ b/man/convertBCFModelToJson.Rd @@ -66,13 +66,15 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +mu_params <- list(sample_sigma_leaf = TRUE) +tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, - params = bcf_params) + mu_forest_params = mu_params, + tau_forest_params = tau_params) # bcf_json <- convertBCFModelToJson(bcf_model) } diff --git a/man/createBCFModelFromJson.Rd b/man/createBCFModelFromJson.Rd index e605b9c9..254dbe74 100644 --- a/man/createBCFModelFromJson.Rd +++ b/man/createBCFModelFromJson.Rd @@ -68,14 +68,16 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +mu_params <- list(sample_sigma_leaf = TRUE) +tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, - params = bcf_params) + mu_forest_params = mu_params, + tau_forest_params = tau_params) # bcf_json <- convertBCFModelToJson(bcf_model) # bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) } diff --git a/man/createBCFModelFromJsonFile.Rd b/man/createBCFModelFromJsonFile.Rd index ec0075bf..64b81b16 100644 --- a/man/createBCFModelFromJsonFile.Rd +++ b/man/createBCFModelFromJsonFile.Rd @@ -68,14 +68,16 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +mu_params <- list(sample_sigma_leaf = TRUE) +tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, - params = bcf_params) + mu_forest_params = mu_params, + tau_forest_params = tau_params) # saveBCFModelToJsonFile(bcf_model, "test.json") # bcf_model_roundtrip <- createBCFModelFromJsonFile("test.json") } diff --git a/man/getRandomEffectSamples.bartmodel.Rd b/man/getRandomEffectSamples.bartmodel.Rd index 72419b54..14d3084d 100644 --- a/man/getRandomEffectSamples.bartmodel.Rd +++ b/man/getRandomEffectSamples.bartmodel.Rd @@ -51,9 +51,9 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, - group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, - X_test = X_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, +bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + group_ids_train = group_ids_train, group_ids_test = group_ids_test, + rfx_basis_train = rfx_basis_train, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100) rfx_samples <- getRandomEffectSamples(bart_model) } diff --git a/man/getRandomEffectSamples.bcf.Rd b/man/getRandomEffectSamples.bcf.Rd index 4e38df8e..ddfb0300 100644 --- a/man/getRandomEffectSamples.bcf.Rd +++ b/man/getRandomEffectSamples.bcf.Rd @@ -70,13 +70,15 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +mu_params <- list(sample_sigma_leaf = TRUE) +tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, - params = bcf_params) + mu_forest_params = mu_params, + tau_forest_params = tau_params) rfx_samples <- getRandomEffectSamples(bcf_model) } diff --git a/man/orderedCatInitializeAndPreprocess.Rd b/man/orderedCatInitializeAndPreprocess.Rd index efd38e3b..7996808c 100644 --- a/man/orderedCatInitializeAndPreprocess.Rd +++ b/man/orderedCatInitializeAndPreprocess.Rd @@ -24,7 +24,8 @@ ordered levels to integers if necessary, and storing the unique levels of a variable. } \examples{ -x <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") +x <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", + "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") preprocess_list <- orderedCatInitializeAndPreprocess(x) x_preprocessed <- preprocess_list$x_preprocessed } diff --git a/man/orderedCatPreprocess.Rd b/man/orderedCatPreprocess.Rd index 4c5bf54c..12011e6a 100644 --- a/man/orderedCatPreprocess.Rd +++ b/man/orderedCatPreprocess.Rd @@ -28,7 +28,8 @@ ordered levels to integers if necessary, and storing the unique levels of a variable. } \examples{ -x_levels <- c("1. Strongly disagree", "2. Disagree", "3. Neither agree nor disagree", +x_levels <- c("1. Strongly disagree", "2. Disagree", + "3. Neither agree nor disagree", "4. Agree", "5. Strongly agree") x <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") diff --git a/man/preprocessBartParams.Rd b/man/preprocessBartParams.Rd deleted file mode 100644 index 31d29c0b..00000000 --- a/man/preprocessBartParams.Rd +++ /dev/null @@ -1,25 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/utils.R -\name{preprocessBartParams} -\alias{preprocessBartParams} -\title{Preprocess BART parameter list. Override defaults with any provided parameters.} -\usage{ -preprocessBartParams( - general_params, - mean_forest_params, - variance_forest_params -) -} -\arguments{ -\item{general_params}{List of any non-forest-specific parameters} - -\item{mean_forest_params}{List of any mean forest parameters} - -\item{variance_forest_params}{List of any variance forest parameters} -} -\value{ -Parameter list with defaults overriden by values supplied in parameter lists -} -\description{ -Preprocess BART parameter list. Override defaults with any provided parameters. -} diff --git a/man/preprocessBcfParams.Rd b/man/preprocessBcfParams.Rd deleted file mode 100644 index 84b15f74..00000000 --- a/man/preprocessBcfParams.Rd +++ /dev/null @@ -1,23 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/utils.R -\name{preprocessBcfParams} -\alias{preprocessBcfParams} -\title{Preprocess BCF parameter list. Override defaults with any provided parameters.} -\usage{ -preprocessBcfParams(params) -} -\arguments{ -\item{general_params}{List of any non-forest-specific parameters} - -\item{mu_forest_params}{List of any mu forest parameters} - -\item{tau_forest_params}{List of any tau forest parameters} - -\item{variance_forest_params}{List of any variance forest parameters} -} -\value{ -Parameter list with defaults overriden by values supplied in parameter lists -} -\description{ -Preprocess BCF parameter list. Override defaults with any provided parameters. -} diff --git a/man/preprocessParams.Rd b/man/preprocessParams.Rd index 93a808c0..9e1732d4 100644 --- a/man/preprocessParams.Rd +++ b/man/preprocessParams.Rd @@ -4,12 +4,12 @@ \alias{preprocessParams} \title{Preprocess a parameter list, overriding defaults with any provided parameters.} \usage{ -preprocessParams(user_params, default_params) +preprocessParams(default_params, user_params = NULL) } \arguments{ -\item{user_params}{User-supplied overrides to \code{default_params}.} - \item{default_params}{List of parameters with default values set.} + +\item{user_params}{(Optional) User-supplied overrides to \code{default_params}.} } \value{ Parameter list with defaults overriden by values supplied in \code{user_params} diff --git a/man/saveBCFModelToJsonFile.Rd b/man/saveBCFModelToJsonFile.Rd index 5c6bb6c0..f7685c48 100644 --- a/man/saveBCFModelToJsonFile.Rd +++ b/man/saveBCFModelToJsonFile.Rd @@ -65,13 +65,15 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +mu_params <- list(sample_sigma_leaf = TRUE) +tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, - params = bcf_params) + mu_forest_params = mu_params, + tau_forest_params = tau_params) # saveBCFModelToJsonFile(bcf_model, "test.json") } diff --git a/man/saveBCFModelToJsonString.Rd b/man/saveBCFModelToJsonString.Rd index 63c0d298..be405f5f 100644 --- a/man/saveBCFModelToJsonString.Rd +++ b/man/saveBCFModelToJsonString.Rd @@ -66,13 +66,15 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +mu_params <- list(sample_sigma_leaf = TRUE) +tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, - params = bcf_params) + mu_forest_params = mu_params, + tau_forest_params = tau_params) # saveBCFModelToJsonString(bcf_model) } diff --git a/man/stochtree-package.Rd b/man/stochtree-package.Rd index 4496129a..942f0e4a 100644 --- a/man/stochtree-package.Rd +++ b/man/stochtree-package.Rd @@ -11,7 +11,7 @@ Stochastic tree ensembles (XBART and BART) for supervised learning and causal in \seealso{ Useful links: \itemize{ - \item \url{https://stochastictree.github.io/stochtree-r/} + \item \url{https://stochtree.ai} } } diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 9952659f..0091dffd 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -1072,6 +1072,38 @@ extern "C" SEXP _stochtree_tree_prior_cpp(SEXP alpha, SEXP beta, SEXP min_sample END_CPP11 } // sampler.cpp +void update_alpha_tree_prior_cpp(cpp11::external_pointer tree_prior_ptr, double alpha); +extern "C" SEXP _stochtree_update_alpha_tree_prior_cpp(SEXP tree_prior_ptr, SEXP alpha) { + BEGIN_CPP11 + update_alpha_tree_prior_cpp(cpp11::as_cpp>>(tree_prior_ptr), cpp11::as_cpp>(alpha)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +void update_beta_tree_prior_cpp(cpp11::external_pointer tree_prior_ptr, double beta); +extern "C" SEXP _stochtree_update_beta_tree_prior_cpp(SEXP tree_prior_ptr, SEXP beta) { + BEGIN_CPP11 + update_beta_tree_prior_cpp(cpp11::as_cpp>>(tree_prior_ptr), cpp11::as_cpp>(beta)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +void update_min_samples_leaf_tree_prior_cpp(cpp11::external_pointer tree_prior_ptr, int min_samples_leaf); +extern "C" SEXP _stochtree_update_min_samples_leaf_tree_prior_cpp(SEXP tree_prior_ptr, SEXP min_samples_leaf) { + BEGIN_CPP11 + update_min_samples_leaf_tree_prior_cpp(cpp11::as_cpp>>(tree_prior_ptr), cpp11::as_cpp>(min_samples_leaf)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +void update_max_depth_tree_prior_cpp(cpp11::external_pointer tree_prior_ptr, int max_depth); +extern "C" SEXP _stochtree_update_max_depth_tree_prior_cpp(SEXP tree_prior_ptr, SEXP max_depth) { + BEGIN_CPP11 + update_max_depth_tree_prior_cpp(cpp11::as_cpp>>(tree_prior_ptr), cpp11::as_cpp>(max_depth)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp cpp11::external_pointer forest_tracker_cpp(cpp11::external_pointer data, cpp11::integers feature_types, int num_trees, StochTree::data_size_t n); extern "C" SEXP _stochtree_forest_tracker_cpp(SEXP data, SEXP feature_types, SEXP num_trees, SEXP n) { BEGIN_CPP11 @@ -1497,6 +1529,10 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_subtract_from_column_vector_cpp", (DL_FUNC) &_stochtree_subtract_from_column_vector_cpp, 2}, {"_stochtree_sum_leaves_squared_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_sum_leaves_squared_ensemble_forest_container_cpp, 2}, {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4}, + {"_stochtree_update_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_update_alpha_tree_prior_cpp, 2}, + {"_stochtree_update_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_update_beta_tree_prior_cpp, 2}, + {"_stochtree_update_max_depth_tree_prior_cpp", (DL_FUNC) &_stochtree_update_max_depth_tree_prior_cpp, 2}, + {"_stochtree_update_min_samples_leaf_tree_prior_cpp", (DL_FUNC) &_stochtree_update_min_samples_leaf_tree_prior_cpp, 2}, {NULL, NULL, 0} }; } diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 87279397..f8cf32b7 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -1109,6 +1109,22 @@ class ForestSamplerCpp { StochTree::UpdateResidualNewOutcome(*tracker_, *residual_ptr); } + void UpdateAlpha(double alpha) { + split_prior_->SetAlpha(alpha); + } + + void UpdateBeta(double beta) { + split_prior_->SetBeta(beta); + } + + void UpdateMinSamplesLeaf(int min_samples_leaf) { + split_prior_->SetMinSamplesLeaf(min_samples_leaf); + } + + void UpdateMaxDepth(int max_depth) { + split_prior_->SetMaxDepth(max_depth); + } + private: std::unique_ptr tracker_; std::unique_ptr split_prior_; @@ -1585,7 +1601,11 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("SampleOneIteration", &ForestSamplerCpp::SampleOneIteration) .def("InitializeForestModel", &ForestSamplerCpp::InitializeForestModel) .def("PropagateBasisUpdate", &ForestSamplerCpp::PropagateBasisUpdate) - .def("PropagateResidualUpdate", &ForestSamplerCpp::PropagateResidualUpdate); + .def("PropagateResidualUpdate", &ForestSamplerCpp::PropagateResidualUpdate) + .def("UpdateAlpha", &ForestSamplerCpp::UpdateAlpha) + .def("UpdateBeta", &ForestSamplerCpp::UpdateBeta) + .def("UpdateMinSamplesLeaf", &ForestSamplerCpp::UpdateMinSamplesLeaf) + .def("UpdateMaxDepth", &ForestSamplerCpp::UpdateMaxDepth); py::class_(m, "GlobalVarianceModelCpp") .def(py::init<>()) diff --git a/src/sampler.cpp b/src/sampler.cpp index f6e0f3c6..5b5d8afb 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -197,6 +197,30 @@ cpp11::external_pointer tree_prior_cpp(double alpha, doubl return cpp11::external_pointer(prior_ptr_.release()); } +[[cpp11::register]] +void update_alpha_tree_prior_cpp(cpp11::external_pointer tree_prior_ptr, double alpha) { + // Update alpha + tree_prior_ptr->SetAlpha(alpha); +} + +[[cpp11::register]] +void update_beta_tree_prior_cpp(cpp11::external_pointer tree_prior_ptr, double beta) { + // Update beta + tree_prior_ptr->SetBeta(beta); +} + +[[cpp11::register]] +void update_min_samples_leaf_tree_prior_cpp(cpp11::external_pointer tree_prior_ptr, int min_samples_leaf) { + // Update min_samples_leaf + tree_prior_ptr->SetMinSamplesLeaf(min_samples_leaf); +} + +[[cpp11::register]] +void update_max_depth_tree_prior_cpp(cpp11::external_pointer tree_prior_ptr, int max_depth) { + // Update max_depth + tree_prior_ptr->SetMaxDepth(max_depth); +} + [[cpp11::register]] cpp11::external_pointer forest_tracker_cpp(cpp11::external_pointer data, cpp11::integers feature_types, int num_trees, StochTree::data_size_t n) { // Convert vector of integers to std::vector of enum FeatureType diff --git a/stochtree/__init__.py b/stochtree/__init__.py index 94e93250..95b49ae3 100644 --- a/stochtree/__init__.py +++ b/stochtree/__init__.py @@ -8,6 +8,19 @@ from .serialization import JSONSerializer from .utils import NotSampledError -__all__ = ['BARTModel', 'BCFModel', 'Dataset', 'Residual', 'ForestContainer', 'Forest', - 'CovariateTransformer', 'RNG', 'ForestSampler', 'GlobalVarianceModel', - 'LeafVarianceModel', 'JSONSerializer', 'NotSampledError', 'calibrate_global_error_variance'] \ No newline at end of file +__all__ = [ + 'BARTModel', + 'BCFModel', + 'Dataset', + 'Residual', + 'ForestContainer', + 'Forest', + 'CovariateTransformer', + 'RNG', + 'ForestSampler', + 'GlobalVarianceModel', + 'LeafVarianceModel', + 'JSONSerializer', + 'NotSampledError', + 'calibrate_global_error_variance' +] \ No newline at end of file diff --git a/stochtree/bart.py b/stochtree/bart.py index 0dd5f340..808f1f31 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -5,7 +5,7 @@ from math import log import numpy as np import pandas as pd -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Union from .data import Dataset, Residual from .forest import ForestContainer, Forest from .preprocessing import CovariateTransformer, _preprocess_params @@ -14,17 +14,44 @@ from .utils import NotSampledError class BARTModel: - """Class that handles sampling, storage, and serialization of stochastic forest models like BART, XBART, and Warm-Start BART - """ + r""" + Class that handles sampling, storage, and serialization of stochastic forest models for supervised learning. + The class takes its name from Bayesian Additive Regression Trees, an MCMC sampler originally developed in + Chipman, George, McCulloch (2010), but supports several sampling algorithms: + + - MCMC: The "classic" sampler defined in Chipman, George, McCulloch (2010). In order to run the MCMC sampler, set `num_gfr = 0` (explained below) and then define a sampler according to several parameters: + - `num_burnin`: the number of iterations to run before "retaining" samples for further analysis. These "burned in" samples are helpful for allowing a sampler to converge before retaining samples. + - `num_chains`: the number of independent sequences of MCMC samples to generate (typically referred to in the literature as "chains") + - `num_mcmc`: the number of "retained" samples of the posterior distribution + - `keep_every`: after a sampler has "burned in", we will run the sampler for `keep_every` * `num_mcmc` iterations, retaining one of each `keep_every` iteration in a chain. + - GFR (Grow-From-Root): A fast, greedy approximation of the BART MCMC sampling algorithm introduced in He and Hahn (2021). GFR sampler iterations are governed by the `num_gfr` parameter, and there are two primary ways to use this sampler: + - Standalone: setting `num_gfr > 0` and both `num_burnin = 0` and `num_mcmc = 0` will only run and retain GFR samples of the posterior. This is typically referred to as "XBART" (accelerated BART). + - Initializer for MCMC: setting `num_gfr > 0` and `num_mcmc > 0` will use ensembles from the GFR algorithm to initialize `num_chains` independent MCMC BART samplers, which are run for `num_mcmc` iterations. This is typically referred to as "warm start BART". + + In addition to enabling multiple samplers, we support a broad set of models. First, note that the original BART model of Chipman, George, McCulloch (2010) is + + \begin{equation*} + \begin{aligned} + y &= f(X) + \epsilon\\ + f(X) &\sim \text{BART}(\cdot)\\ + \epsilon &\sim N(0, \sigma^2)\\ + \sigma^2 &\sim IG(\nu, \nu\lambda) + \end{aligned} + \end{equation*} + In words, there is a nonparametric mean function governed by a tree ensemble with a BART prior and an additive (mean-zero) Gaussian error + term, whose variance is parameterized with an inverse gamma prior. + + The `BARTModel` class supports the following extensions of this model: + + - Leaf Regression: Rather than letting `f(X)` define a standard decision tree ensemble, in which each tree uses `X` to partition the data and then serve up constant predictions, we allow for models `f(X,Z)` in which `X` and `Z` together define a partitioned linear model (`X` partitions the data and `Z` serves as the basis for regression models). This model can be run by specifying `basis_train` in the `sample` method. + - Heteroskedasticity: Rather than define $\epsilon$ parameterically, we can let a forest $\sigma^2(X)$ model a conditional error variance function. This can be done by setting `num_trees_variance > 0` in the `params` dictionary passed to the `sample` method. + """ def __init__(self) -> None: # Internal flag for whether the sample() method has been run self.sampled = False self.rng = np.random.default_rng() - def is_sampled(self) -> bool: - return self.sampled - def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = None, X_test: np.array = None, basis_test: np.array = None, num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, general_params: Optional[Dict[str, Any]] = None, mean_forest_params: Optional[Dict[str, Any]] = None, variance_forest_params: Optional[Dict[str, Any]] = None) -> None: @@ -33,64 +60,64 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N Parameters ---------- - X_train : :obj:`np.array` + X_train : np.array Training set covariates on which trees may be partitioned. - y_train : :obj:`np.array` + y_train : np.array Training set outcome. - basis_train : :obj:`np.array`, optional + basis_train : np.array, optional Optional training set basis vector used to define a regression to be run in the leaves of each tree. - X_test : :obj:`np.array`, optional + X_test : np.array, optional Optional test set covariates. - basis_test : :obj:`np.array`, optional + basis_test : np.array, optional Optional test set basis vector used to define a regression to be run in the leaves of each tree. Must be included / omitted consistently (i.e. if basis_train is provided, then basis_test must be provided alongside X_test). - num_gfr : :obj:`int`, optional - Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Defaults to ``5``. - num_burnin : :obj:`int`, optional - Number of "burn-in" iterations of the MCMC sampler. Defaults to ``0``. Ignored if ``num_gfr > 0``. - num_mcmc : :obj:`int`, optional - Number of "retained" iterations of the MCMC sampler. Defaults to ``100``. If this is set to 0, GFR (XBART) samples will be retained. - general_params : :obj:`dict`, optional + num_gfr : int, optional + Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Defaults to `5`. + num_burnin : int, optional + Number of "burn-in" iterations of the MCMC sampler. Defaults to `0`. Ignored if `num_gfr > 0`. + num_mcmc : int, optional + Number of "retained" iterations of the MCMC sampler. Defaults to `100`. If this is set to 0, GFR (XBART) samples will be retained. + general_params : dict, optional Dictionary of general model parameters, each of which has a default value processed internally, so this argument is optional. - * ``cutpoint_grid_size`` (``int``): Maximum number of cutpoints to consider for each feature. Defaults to ``100``. - * ``standardize`` (``bool``): Whether or not to standardize the outcome (and store the offset / scale in the model object). Defaults to ``True``. - * ``sample_sigma2_global`` (``bool``): Whether or not to update the ``sigma^2`` global error variance parameter based on ``IG(sigma2_global_shape, sigma2_global_scale)``. Defaults to ``True``. - * ``sigma2_init`` (``float``): Starting value of global variance parameter. Set internally to the outcome variance (standardized if `standardize = True`) if not set here. - * ``sigma2_global_shape`` (``float``): Shape parameter in the ``IG(sigma2_global_shape, b_glsigma2_global_scaleobal)`` global error variance model. Defaults to ``0``. - * ``sigma2_global_scale`` (``float``): Scale parameter in the ``IG(sigma2_global_shape, b_glsigma2_global_scaleobal)`` global error variance model. Defaults to ``0``. - * ``random_seed`` (``int``): Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to ``std::random_device``. - * ``keep_burnin`` (``bool``): Whether or not "burnin" samples should be included in predictions. Defaults to ``False``. Ignored if ``num_mcmc == 0``. - * ``keep_gfr`` (``bool``): Whether or not "warm-start" / grow-from-root samples should be included in predictions. Defaults to ``False``. Ignored if ``num_mcmc == 0``. - * ``keep_every`` (``int``): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to ``1``. Setting ``keep_every = k`` for some ``k > 1`` will "thin" the MCMC samples by retaining every ``k``-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. - * ``num_chains`` (``int``): How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Defaults to `1`. + * `cutpoint_grid_size` (`int`): Maximum number of cutpoints to consider for each feature. Defaults to `100`. + * `standardize` (`bool`): Whether or not to standardize the outcome (and store the offset / scale in the model object). Defaults to `True`. + * `sample_sigma2_global` (`bool`): Whether or not to update the `sigma^2` global error variance parameter based on `IG(sigma2_global_shape, sigma2_global_scale)`. Defaults to `True`. + * `sigma2_init` (`float`): Starting value of global variance parameter. Set internally to the outcome variance (standardized if `standardize = True`) if not set here. + * `sigma2_global_shape` (`float`): Shape parameter in the `IG(sigma2_global_shape, b_glsigma2_global_scaleobal)` global error variance model. Defaults to `0`. + * `sigma2_global_scale` (`float`): Scale parameter in the `IG(sigma2_global_shape, b_glsigma2_global_scaleobal)` global error variance model. Defaults to `0`. + * `random_seed` (`int`): Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`. + * `keep_burnin` (`bool`): Whether or not "burnin" samples should be included in predictions. Defaults to `False`. Ignored if `num_mcmc == 0`. + * `keep_gfr` (`bool`): Whether or not "warm-start" / grow-from-root samples should be included in predictions. Defaults to `False`. Ignored if `num_mcmc == 0`. + * `keep_every` (`int`): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to `1`. Setting `keep_every = k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. + * `num_chains` (`int`): How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Defaults to `1`. - mean_forest_params : :obj:`dict`, optional + mean_forest_params : dict, optional Dictionary of mean forest model parameters, each of which has a default value processed internally, so this argument is optional. - * ``num_trees`` (``int``): Number of trees in the conditional mean model. Defaults to ``200``. If ``num_trees = 0``, the conditional mean will not be modeled using a forest and sampling will only proceed if ``num_trees > 0`` for the variance forest. - * ``alpha`` (``float``): Prior probability of splitting for a tree of depth 0 in the conditional mean model. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``0.95``. - * ``beta`` (``float``): Exponent that decreases split probabilities for nodes of depth > 0 in the conditional mean model. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``2``. - * ``min_samples_leaf`` (``int``): Minimum allowable size of a leaf, in terms of training samples, in the conditional mean model. Defaults to ``5``. - * ``max_depth`` (``int``): Maximum depth of any tree in the ensemble in the conditional mean model. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. - * ``variable_weights`` (``np.array``): Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of ``X_train`` if not provided. - * ``sample_sigma2_leaf`` (``bool``): Whether or not to update the ``tau`` leaf scale variance parameter based on ``IG(sigma2_leaf_shape, sigma2_leaf_scale)``. Cannot (currently) be set to true if ``basis_train`` has more than one column. Defaults to ``False``. - * ``sigma2_leaf_init`` (``float``): Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. - * ``sigma2_leaf_shape`` (``float``): Shape parameter in the ``IG(sigma2_leaf_shape, sigma2_leaf_scale)`` leaf node parameter variance model. Defaults to ``3``. - * ``sigma2_leaf_scale`` (``float``): Scale parameter in the ``IG(sigma2_leaf_shape, sigma2_leaf_scale)`` leaf node parameter variance model. Calibrated internally as ``0.5/num_trees`` if not set here. - - variance_forest_params : :obj:`dict`, optional + * `num_trees` (`int`): Number of trees in the conditional mean model. Defaults to `200`. If `num_trees = 0`, the conditional mean will not be modeled using a forest and sampling will only proceed if `num_trees > 0` for the variance forest. + * `alpha` (`float`): Prior probability of splitting for a tree of depth 0 in the conditional mean model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Defaults to `0.95`. + * `beta` (`float`): Exponent that decreases split probabilities for nodes of depth > 0 in the conditional mean model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Defaults to `2`. + * `min_samples_leaf` (`int`): Minimum allowable size of a leaf, in terms of training samples, in the conditional mean model. Defaults to `5`. + * `max_depth` (`int`): Maximum depth of any tree in the ensemble in the conditional mean model. Defaults to `10`. Can be overriden with `-1` which does not enforce any depth limits on trees. + * `variable_weights` (`np.array`): Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of `X_train` if not provided. + * `sample_sigma2_leaf` (`bool`): Whether or not to update the `tau` leaf scale variance parameter based on `IG(sigma2_leaf_shape, sigma2_leaf_scale)`. Cannot (currently) be set to true if `basis_train` has more than one column. Defaults to `False`. + * `sigma2_leaf_init` (`float`): Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. + * `sigma2_leaf_shape` (`float`): Shape parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Defaults to `3`. + * `sigma2_leaf_scale` (`float`): Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. + + variance_forest_params : dict, optional Dictionary of variance forest model parameters, each of which has a default value processed internally, so this argument is optional. - * ``num_trees`` (``int``): Number of trees in the conditional variance model. Defaults to ``0``. Variance is only modeled using a tree / forest if ``num_trees > 0``. - * ``alpha`` (``float``): Prior probability of splitting for a tree of depth 0 in the conditional variance model. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``0.95``. - * ``beta`` (``float``): Exponent that decreases split probabilities for nodes of depth > 0 in the conditional variance model. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``2``. - * ``min_samples_leaf`` (``int``): Minimum allowable size of a leaf, in terms of training samples, in the conditional variance model. Defaults to ``5``. - * ``max_depth`` (``int``): Maximum depth of any tree in the ensemble in the conditional variance model. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. - * ``variable_weights`` (``np.array``): Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of ``X_train`` if not provided. - * ``var_forest_leaf_init`` (``float``): Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as ``np.log(0.6*np.var(y_train))/num_trees_variance``, where `y_train` is the possibly standardized outcome, if not set. - * ``var_forest_prior_shape`` (``float``): Shape parameter in the [optional] ``IG(var_forest_prior_shape, var_forest_prior_scale)`` conditional error variance forest (which is only sampled if ``num_trees > 0``). Calibrated internally as ``num_trees / 1.5^2 + 0.5`` if not set here. - * ``var_forest_prior_scale`` (``float``): Scale parameter in the [optional] ``IG(var_forest_prior_shape, var_forest_prior_scale)`` conditional error variance forest (which is only sampled if ``num_trees > 0``). Calibrated internally as ``num_trees / 1.5^2`` if not set here. + * `num_trees` (`int`): Number of trees in the conditional variance model. Defaults to `0`. Variance is only modeled using a tree / forest if `num_trees > 0`. + * `alpha` (`float`): Prior probability of splitting for a tree of depth 0 in the conditional variance model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Defaults to `0.95`. + * `beta` (`float`): Exponent that decreases split probabilities for nodes of depth > 0 in the conditional variance model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Defaults to `2`. + * `min_samples_leaf` (`int`): Minimum allowable size of a leaf, in terms of training samples, in the conditional variance model. Defaults to `5`. + * `max_depth` (`int`): Maximum depth of any tree in the ensemble in the conditional variance model. Defaults to `10`. Can be overriden with `-1` which does not enforce any depth limits on trees. + * `variable_weights` (`np.array`): Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of `X_train` if not provided. + * `var_forest_leaf_init` (`float`): Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as `np.log(0.6*np.var(y_train))/num_trees_variance`, where `y_train` is the possibly standardized outcome, if not set. + * `var_forest_prior_shape` (`float`): Shape parameter in the [optional] `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance forest (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / 1.5^2 + 0.5` if not set here. + * `var_forest_prior_scale` (`float`): Scale parameter in the [optional] `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance forest (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / 1.5^2` if not set here. Returns ------- @@ -553,20 +580,24 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N else: self.sigma2_x_test = sigma_x_test_raw*self.sigma2_init*self.y_std*self.y_std - def predict(self, covariates: np.array, basis: np.array = None) -> np.array: - """Return predictions from every forest sampled (either / both of mean and variance) + def predict(self, covariates: np.array, basis: np.array = None) -> Union[np.array, tuple]: + """Return predictions from every forest sampled (either / both of mean and variance). + Return type is either a single array of predictions, if a BART model only includes a + mean or variance term, or a tuple of prediction arrays, if a BART model includes both. Parameters ---------- - covariates : :obj:`np.array` + covariates : np.array Test set covariates. - basis : :obj:`np.array`, optional + basis : np.array, optional Optional test set basis vector, must be provided if the model was trained with a leaf regression basis. Returns ------- - tuple of :obj:`np.array` - Tuple of arrays of predictions corresponding to each forest (mean and variance, depending on whether either / both was included). Each array will contain as many rows as in ``covariates`` and as many columns as retained samples of the algorithm. + mu_x : np.array, optional + Mean forest predictions. + sigma2_x : np.array, optional + Variance forest predictions. """ if not self.is_sampled(): msg = ( @@ -615,15 +646,15 @@ def predict_mean(self, covariates: np.array, basis: np.array = None) -> np.array Parameters ---------- - covariates : :obj:`np.array` + covariates : np.array Test set covariates. - basis : :obj:`np.array`, optional + basis : np.array, optional Optional test set basis vector, must be provided if the model was trained with a leaf regression basis. Returns ------- - tuple of :obj:`np.array` - Tuple of arrays of predictions corresponding to each forest (mean and variance, depending on whether either / both was included). Each array will contain as many rows as in ``covariates`` and as many columns as retained samples of the algorithm. + np.array + Mean forest predictions. """ if not self.is_sampled(): msg = ( @@ -670,8 +701,8 @@ def predict_variance(self, covariates: np.array) -> np.array: Returns ------- - tuple of :obj:`np.array` - Tuple of arrays of predictions corresponding to the variance forest. Each array will contain as many rows as in ``covariates`` and as many columns as retained samples of the algorithm. + np.array + Variance forest predictions. """ if not self.is_sampled(): msg = ( @@ -706,11 +737,11 @@ def predict_variance(self, covariates: np.array) -> np.array: def to_json(self) -> str: """ Converts a sampled BART model to JSON string representation (which can then be saved to a file or - processed using the ``json`` library) + processed using the `json` library) Returns ------- - :obj:`str` + str JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests """ if not self.is_sampled: @@ -759,7 +790,7 @@ def from_json(self, json_string: str) -> None: Parameters ---------- - json_string : :obj:`str` + json_string : str JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests """ # Parse string to a JSON object in C++ @@ -804,3 +835,13 @@ def from_json(self, json_string: str) -> None: # Mark the deserialized model as "sampled" self.sampled = True + + def is_sampled(self) -> bool: + """Whether or not a BART model has been sampled. + + Returns + ------- + bool + `True` if a BART model has been sampled, `False` otherwise + """ + return self.sampled diff --git a/stochtree/bcf.py b/stochtree/bcf.py index fc16e7dc..47b5f19c 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -14,17 +14,53 @@ from .utils import NotSampledError class BCFModel: - """Class that handles sampling, storage, and serialization of causal BART models like BCF, XBCF, and Warm-Start BCF - """ + r""" + Class that handles sampling, storage, and serialization of stochastic forest models for causal effect estimation. + The class takes its name from Bayesian Causal Forests, an MCMC sampler originally developed in + Hahn, Murray, Carvalho (2020), but supports several sampling algorithms: + + * MCMC: The "classic" sampler defined in Hahn, Murray, Carvalho (2020). In order to run the MCMC sampler, + set `num_gfr = 0` (explained below) and then define a sampler according to several parameters: + * `num_burnin`: the number of iterations to run before "retaining" samples for further analysis. These "burned in" samples + are helpful for allowing a sampler to converge before retaining samples. + * `num_chains`: the number of independent sequences of MCMC samples to generate (typically referred to in the literature as "chains") + * `num_mcmc`: the number of "retained" samples of the posterior distribution + * `keep_every`: after a sampler has "burned in", we will run the sampler for `keep_every` * `num_mcmc` iterations, retaining one of each `keep_every` iteration in a chain. + * GFR (Grow-From-Root): A fast, greedy approximation of the BART MCMC sampling algorithm introduced in Krantsevich, He, and Hahn (2023). GFR sampler iterations are + governed by the `num_gfr` parameter, and there are two primary ways to use this sampler: + * Standalone: setting `num_gfr > 0` and both `num_burnin = 0` and `num_mcmc = 0` will only run and retain GFR samples of the posterior. This is typically referred to as "XBART" (accelerated BART). + * Initializer for MCMC: setting `num_gfr > 0` and `num_mcmc > 0` will use ensembles from the GFR algorithm to initialize `num_chains` independent MCMC BART samplers, which are run for `num_mcmc` iterations. + This is typically referred to as "warm start BART". + + In addition to enabling multiple samplers, we support a broad set of models. First, note that the original BCF model of Hahn, Murray, Carvalho (2020) is + + \begin{equation*} + \begin{aligned} + y &= a(X) + b_z(X) + \epsilon\\ + b_z(X) &= (b_1 Z + b_0 (1-Z)) t(X)\\ + b_0, b_1 &\sim N(0, \frac{1}{2})\\\\ + a(X) &\sim \text{BART}()\\ + t(X) &\sim \text{BART}()\\ + \epsilon &\sim N(0, \sigma^2)\\ + \sigma^2 &\sim IG(a, b) + \end{aligned} + \end{equation*} + + for continuous outcome $y$, binary treatment $Z$, and covariates $X$. + In words, there are two nonparametric mean functions -- a "prognostic" function and a "treatment effect" function -- governed by tree ensembles with BART priors and an additive (mean-zero) Gaussian error + term, whose variance is parameterized with an inverse gamma prior. + + The `BCFModel` class supports the following extensions of this model: + + - Continuous Treatment: If $Z$ is continuous rather than binary, we define $b_z(X) = \tau(X, Z) = Z \tau(X)$, where the "leaf model" for the $\tau$ forest is essentially a regression on continuous $Z$. + - Heteroskedasticity: Rather than define $\epsilon$ parameterically, we can let a forest $\sigma^2(X)$ model a conditional error variance function. This can be done by setting `num_trees_variance > 0` in the `params` dictionary passed to the `sample` method. + """ def __init__(self) -> None: # Internal flag for whether the sample() method has been run self.sampled = False self.rng = np.random.default_rng() - def is_sampled(self) -> bool: - return self.sampled - def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_train: np.array, pi_train: np.array = None, X_test: Union[pd.DataFrame, np.array] = None, Z_test: np.array = None, pi_test: np.array = None, num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, general_params: Optional[Dict[str, Any]] = None, @@ -35,91 +71,91 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr Parameters ---------- - X_train : :obj:`np.array` or :obj:`pd.DataFrame` + X_train : np.array or pd.DataFrame Covariates used to split trees in the ensemble. Can be passed as either a matrix or dataframe. - Z_train : :obj:`np.array` + Z_train : np.array Array of (continuous or binary; univariate or multivariate) treatment assignments. - y_train : :obj:`np.array` + y_train : np.array Outcome to be modeled by the ensemble. - pi_train : :obj:`np.array` + pi_train : np.array Optional vector of propensity scores. If not provided, this will be estimated from the data. - X_test : :obj:`np.array`, optional + X_test : np.array, optional Optional test set of covariates used to define "out of sample" evaluation data. - Z_test : :obj:`np.array`, optional + Z_test : np.array, optional Optional test set of (continuous or binary) treatment assignments. - Must be provided if ``X_test`` is provided. - pi_test : :obj:`np.array`, optional - Optional test set vector of propensity scores. If not provided (but ``X_test`` and ``Z_test`` are), this will be estimated from the data. - num_gfr : :obj:`int`, optional - Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Defaults to ``5``. - num_burnin : :obj:`int`, optional - Number of "burn-in" iterations of the MCMC sampler. Defaults to ``0``. Ignored if ``num_gfr > 0``. - num_mcmc : :obj:`int`, optional - Number of "retained" iterations of the MCMC sampler. Defaults to ``100``. If this is set to 0, GFR (XBART) samples will be retained. - general_params : :obj:`dict`, optional + Must be provided if `X_test` is provided. + pi_test : np.array, optional + Optional test set vector of propensity scores. If not provided (but `X_test` and `Z_test` are), this will be estimated from the data. + num_gfr : int, optional + Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Defaults to `5`. + num_burnin : int, optional + Number of "burn-in" iterations of the MCMC sampler. Defaults to `0`. Ignored if `num_gfr > 0`. + num_mcmc : int, optional + Number of "retained" iterations of the MCMC sampler. Defaults to `100`. If this is set to 0, GFR (XBART) samples will be retained. + general_params : dict, optional Dictionary of general model parameters, each of which has a default value processed internally, so this argument is optional. - * ``cutpoint_grid_size`` (``int``): Maximum number of cutpoints to consider for each feature. Defaults to ``100``. - * ``standardize`` (``bool``): Whether or not to standardize the outcome (and store the offset / scale in the model object). Defaults to ``True``. - * ``sample_sigma2_global`` (``bool``): Whether or not to update the ``sigma^2`` global error variance parameter based on ``IG(sigma2_global_shape, sigma2_global_scale)``. Defaults to ``True``. - * ``sigma2_global_init`` (``float``): Starting value of global variance parameter. Set internally to the outcome variance (standardized if `standardize = True`) if not set here. - * ``sigma2_global_shape`` (``float``): Shape parameter in the ``IG(sigma2_global_shape, b_glsigma2_global_scaleobal)`` global error variance model. Defaults to ``0``. - * ``sigma2_global_scale`` (``float``): Scale parameter in the ``IG(sigma2_global_shape, b_glsigma2_global_scaleobal)`` global error variance model. Defaults to ``0``. - * ``variable_weights`` (`np.`array``): Numeric weights reflecting the relative probability of splitting on each variable in each of the forests. Does not need to sum to 1 but cannot be negative. Defaults to ``np.repeat(1/X_train.shape[1], X_train.shape[1])`` if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to ``1/X_train.shape[1]``. A workaround if you wish to provide a custom weight for the propensity score is to include it as a column in ``X_train`` and then set ``propensity_covariate`` to ``'none'`` and adjust ``keep_vars`` accordingly for the mu or tau forests. - * ``propensity_covariate`` (``str``): Whether to include the propensity score as a covariate in either or both of the forests. Enter ``"none"`` for neither, ``"mu"`` for the prognostic forest, ``"tau"`` for the treatment forest, and ``"both"`` for both forests. - If this is not ``"none"`` and a propensity score is not provided, it will be estimated from (``X_train``, ``Z_train``) using ``BARTModel``. Defaults to ``"mu"``. - * ``adaptive_coding`` (``bool``): Whether or not to use an "adaptive coding" scheme in which a binary treatment variable is not coded manually as (0,1) or (-1,1) but learned via - parameters ``b_0`` and ``b_1`` that attach to the outcome model ``[b_0 (1-Z) + b_1 Z] tau(X)``. This is ignored when Z is not binary. Defaults to True. - * ``control_coding_init`` (``float``): Initial value of the "control" group coding parameter. This is ignored when ``Z`` is not binary. Default: ``-0.5``. - * ``treated_coding_init`` (``float``): Initial value of the "treated" group coding parameter. This is ignored when ``Z`` is not binary. Default: ``0.5``. - * ``random_seed`` (``int``): Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to ``std::random_device``. - * ``keep_burnin`` (``bool``): Whether or not "burnin" samples should be included in predictions. Defaults to ``False``. Ignored if ``num_mcmc == 0``. - * ``keep_gfr`` (``bool``): Whether or not "warm-start" / grow-from-root samples should be included in predictions. Defaults to ``False``. Ignored if ``num_mcmc == 0``. - * ``keep_every`` (``int``): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to ``1``. Setting ``keep_every = k`` for some ``k > 1`` will "thin" the MCMC samples by retaining every ``k``-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. - * ``num_chains`` (``int``): How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Defaults to `1`. + * `cutpoint_grid_size` (`int`): Maximum number of cutpoints to consider for each feature. Defaults to `100`. + * `standardize` (`bool`): Whether or not to standardize the outcome (and store the offset / scale in the model object). Defaults to `True`. + * `sample_sigma2_global` (`bool`): Whether or not to update the `sigma^2` global error variance parameter based on `IG(sigma2_global_shape, sigma2_global_scale)`. Defaults to `True`. + * `sigma2_global_init` (`float`): Starting value of global variance parameter. Set internally to the outcome variance (standardized if `standardize = True`) if not set here. + * `sigma2_global_shape` (`float`): Shape parameter in the `IG(sigma2_global_shape, b_glsigma2_global_scaleobal)` global error variance model. Defaults to `0`. + * `sigma2_global_scale` (`float`): Scale parameter in the `IG(sigma2_global_shape, b_glsigma2_global_scaleobal)` global error variance model. Defaults to `0`. + * `variable_weights` (`np.array`): Numeric weights reflecting the relative probability of splitting on each variable in each of the forests. Does not need to sum to 1 but cannot be negative. Defaults to `np.repeat(1/X_train.shape[1], X_train.shape[1])` if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to `1/X_train.shape[1]`. A workaround if you wish to provide a custom weight for the propensity score is to include it as a column in `X_train` and then set `propensity_covariate` to `'none'` and adjust `keep_vars` accordingly for the mu or tau forests. + * `propensity_covariate` (`str`): Whether to include the propensity score as a covariate in either or both of the forests. Enter `"none"` for neither, `"mu"` for the prognostic forest, `"tau"` for the treatment forest, and `"both"` for both forests. + If this is not `"none"` and a propensity score is not provided, it will be estimated from (`X_train`, `Z_train`) using `BARTModel`. Defaults to `"mu"`. + * `adaptive_coding` (`bool`): Whether or not to use an "adaptive coding" scheme in which a binary treatment variable is not coded manually as (0,1) or (-1,1) but learned via + parameters `b_0` and `b_1` that attach to the outcome model `[b_0 (1-Z) + b_1 Z] tau(X)`. This is ignored when Z is not binary. Defaults to True. + * `control_coding_init` (`float`): Initial value of the "control" group coding parameter. This is ignored when `Z` is not binary. Default: `-0.5`. + * `treated_coding_init` (`float`): Initial value of the "treated" group coding parameter. This is ignored when `Z` is not binary. Default: `0.5`. + * `random_seed` (`int`): Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`. + * `keep_burnin` (`bool`): Whether or not "burnin" samples should be included in predictions. Defaults to `False`. Ignored if `num_mcmc == 0`. + * `keep_gfr` (`bool`): Whether or not "warm-start" / grow-from-root samples should be included in predictions. Defaults to `False`. Ignored if `num_mcmc == 0`. + * `keep_every` (`int`): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to `1`. Setting `keep_every = k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. + * `num_chains` (`int`): How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Defaults to `1`. - mu_forest_params : :obj:`dict`, optional + mu_forest_params : dict, optional Dictionary of prognostic forest model parameters, each of which has a default value processed internally, so this argument is optional. - * ``num_trees`` (``int``): Number of trees in the prognostic forest. Defaults to ``250``. Must be a positive integer. - * ``alpha`` (``float``): Prior probability of splitting for a tree of depth 0 in the prognostic forest. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``0.95``. - * ``beta`` (``float``): Exponent that decreases split probabilities for nodes of depth > 0 in the prognostic forest. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``2``. - * ``min_samples_leaf`` (``int``): Minimum allowable size of a leaf, in terms of training samples, in the prognostic forest. Defaults to ``5``. - * ``max_depth`` (``int``): Maximum depth of any tree in the ensemble in the prognostic forest. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. - * ``variable_weights`` (``np.array``): Numeric weights reflecting the relative probability of splitting on each variable in the prognostic forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of ``X_train`` if not provided. - * ``sample_sigma2_leaf`` (``bool``): Whether or not to update the ``tau`` leaf scale variance parameter based on ``IG(sigma2_leaf_shape, sigma2_leaf_scale)``. Cannot (currently) be set to true if ``basis_train`` has more than one column. Defaults to ``False``. - * ``sigma2_leaf_init`` (``float``): Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. - * ``sigma2_leaf_shape`` (``float``): Shape parameter in the ``IG(sigma2_leaf_shape, sigma2_leaf_scale)`` leaf node parameter variance model. Defaults to ``3``. - * ``sigma2_leaf_scale`` (``float``): Scale parameter in the ``IG(sigma2_leaf_shape, sigma2_leaf_scale)`` leaf node parameter variance model. Calibrated internally as ``0.5/num_trees`` if not set here. - * ``keep_vars`` (``list`` or ``np.array``): Vector of variable names or column indices denoting variables that should be included in the prognostic (``mu(X)``) forest. Defaults to ``None``. - * ``drop_vars`` (``list`` or ``np.array``): Vector of variable names or column indices denoting variables that should be excluded from the prognostic (``mu(X)``) forest. Defaults to ``None``. If both ``drop_vars_mu`` and ``keep_vars_mu`` are set, ``drop_vars_mu`` will be ignored. + * `num_trees` (`int`): Number of trees in the prognostic forest. Defaults to `250`. Must be a positive integer. + * `alpha` (`float`): Prior probability of splitting for a tree of depth 0 in the prognostic forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Defaults to `0.95`. + * `beta` (`float`): Exponent that decreases split probabilities for nodes of depth > 0 in the prognostic forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Defaults to `2`. + * `min_samples_leaf` (`int`): Minimum allowable size of a leaf, in terms of training samples, in the prognostic forest. Defaults to `5`. + * `max_depth` (`int`): Maximum depth of any tree in the ensemble in the prognostic forest. Defaults to `10`. Can be overriden with `-1` which does not enforce any depth limits on trees. + * `variable_weights` (`np.array`): Numeric weights reflecting the relative probability of splitting on each variable in the prognostic forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of `X_train` if not provided. + * `sample_sigma2_leaf` (`bool`): Whether or not to update the `tau` leaf scale variance parameter based on `IG(sigma2_leaf_shape, sigma2_leaf_scale)`. Cannot (currently) be set to true if `basis_train` has more than one column. Defaults to `False`. + * `sigma2_leaf_init` (`float`): Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. + * `sigma2_leaf_shape` (`float`): Shape parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Defaults to `3`. + * `sigma2_leaf_scale` (`float`): Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. + * `keep_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be included in the prognostic (`mu(X)`) forest. Defaults to `None`. + * `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the prognostic (`mu(X)`) forest. Defaults to `None`. If both `drop_vars_mu` and `keep_vars_mu` are set, `drop_vars_mu` will be ignored. - tau_forest_params : :obj:`dict`, optional + tau_forest_params : dict, optional Dictionary of treatment effect forest model parameters, each of which has a default value processed internally, so this argument is optional. - * ``num_trees`` (``int``): Number of trees in the treatment effect forest. Defaults to ``50``. Must be a positive integer. - * ``alpha`` (``float``): Prior probability of splitting for a tree of depth 0 in the treatment effect forest. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``0.25``. - * ``beta`` (``float``): Exponent that decreases split probabilities for nodes of depth > 0 in the treatment effect forest. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``3``. - * ``min_samples_leaf`` (``int``): Minimum allowable size of a leaf, in terms of training samples, in the treatment effect forest. Defaults to ``5``. - * ``max_depth`` (``int``): Maximum depth of any tree in the ensemble in the treatment effect forest. Defaults to ``5``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. - * ``variable_weights`` (``np.array``): Numeric weights reflecting the relative probability of splitting on each variable in the treatment effect forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of ``X_train`` if not provided. - * ``sample_sigma2_leaf`` (``bool``): Whether or not to update the ``tau`` leaf scale variance parameter based on ``IG(sigma2_leaf_shape, sigma2_leaf_scale)``. Cannot (currently) be set to true if ``basis_train`` has more than one column. Defaults to ``False``. - * ``sigma2_leaf_init`` (``float``): Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. - * ``sigma2_leaf_shape`` (``float``): Shape parameter in the ``IG(sigma2_leaf_shape, sigma2_leaf_scale)`` leaf node parameter variance model. Defaults to ``3``. - * ``sigma2_leaf_scale`` (``float``): Scale parameter in the ``IG(sigma2_leaf_shape, sigma2_leaf_scale)`` leaf node parameter variance model. Calibrated internally as ``0.5/num_trees`` if not set here. + * `num_trees` (`int`): Number of trees in the treatment effect forest. Defaults to `50`. Must be a positive integer. + * `alpha` (`float`): Prior probability of splitting for a tree of depth 0 in the treatment effect forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Defaults to `0.25`. + * `beta` (`float`): Exponent that decreases split probabilities for nodes of depth > 0 in the treatment effect forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Defaults to `3`. + * `min_samples_leaf` (`int`): Minimum allowable size of a leaf, in terms of training samples, in the treatment effect forest. Defaults to `5`. + * `max_depth` (`int`): Maximum depth of any tree in the ensemble in the treatment effect forest. Defaults to `5`. Can be overriden with `-1` which does not enforce any depth limits on trees. + * `variable_weights` (`np.array`): Numeric weights reflecting the relative probability of splitting on each variable in the treatment effect forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of `X_train` if not provided. + * `sample_sigma2_leaf` (`bool`): Whether or not to update the `tau` leaf scale variance parameter based on `IG(sigma2_leaf_shape, sigma2_leaf_scale)`. Cannot (currently) be set to true if `basis_train` has more than one column. Defaults to `False`. + * `sigma2_leaf_init` (`float`): Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. + * `sigma2_leaf_shape` (`float`): Shape parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Defaults to `3`. + * `sigma2_leaf_scale` (`float`): Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. - variance_forest_params : :obj:`dict`, optional + variance_forest_params : dict, optional Dictionary of variance forest model parameters, each of which has a default value processed internally, so this argument is optional. - * ``num_trees`` (``int``): Number of trees in the conditional variance model. Defaults to ``0``. Variance is only modeled using a tree / forest if ``num_trees > 0``. - * ``alpha`` (``float``): Prior probability of splitting for a tree of depth 0 in the conditional variance model. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``0.95``. - * ``beta`` (``float``): Exponent that decreases split probabilities for nodes of depth > 0 in the conditional variance model. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``2``. - * ``min_samples_leaf`` (``int``): Minimum allowable size of a leaf, in terms of training samples, in the conditional variance model. Defaults to ``5``. - * ``max_depth`` (``int``): Maximum depth of any tree in the ensemble in the conditional variance model. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. - * ``variable_weights`` (``np.array``): Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of ``X_train`` if not provided. - * ``var_forest_leaf_init`` (``float``): Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as ``np.log(0.6*np.var(y_train))/num_trees_variance``, where `y_train` is the possibly standardized outcome, if not set. - * ``var_forest_prior_shape`` (``float``): Shape parameter in the [optional] ``IG(var_forest_prior_shape, var_forest_prior_scale)`` conditional error variance forest (which is only sampled if ``num_trees > 0``). Calibrated internally as ``num_trees / 1.5^2 + 0.5`` if not set here. - * ``var_forest_prior_scale`` (``float``): Scale parameter in the [optional] ``IG(var_forest_prior_shape, var_forest_prior_scale)`` conditional error variance forest (which is only sampled if ``num_trees > 0``). Calibrated internally as ``num_trees / 1.5^2`` if not set here. + * `num_trees` (`int`): Number of trees in the conditional variance model. Defaults to `0`. Variance is only modeled using a tree / forest if `num_trees > 0`. + * `alpha` (`float`): Prior probability of splitting for a tree of depth 0 in the conditional variance model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Defaults to `0.95`. + * `beta` (`float`): Exponent that decreases split probabilities for nodes of depth > 0 in the conditional variance model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Defaults to `2`. + * `min_samples_leaf` (`int`): Minimum allowable size of a leaf, in terms of training samples, in the conditional variance model. Defaults to `5`. + * `max_depth` (`int`): Maximum depth of any tree in the ensemble in the conditional variance model. Defaults to `10`. Can be overriden with `-1` which does not enforce any depth limits on trees. + * `variable_weights` (`np.array`): Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of `X_train` if not provided. + * `var_forest_leaf_init` (`float`): Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as `np.log(0.6*np.var(y_train))/num_trees_variance`, where `y_train` is the possibly standardized outcome, if not set. + * `var_forest_prior_shape` (`float`): Shape parameter in the [optional] `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance forest (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / 1.5^2 + 0.5` if not set here. + * `var_forest_prior_scale` (`float`): Scale parameter in the [optional] `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance forest (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / 1.5^2` if not set here. Returns ------- @@ -1093,13 +1129,13 @@ def predict_tau(self, X: np.array, Z: np.array, propensity: np.array = None) -> Test set covariates. Z : np.array Test set treatment indicators. - propensity : :obj:`np.array`, optional + propensity : np.array, optional Optional test set propensities. Must be provided if propensities were provided when the model was sampled. Returns ------- np.array - Array with as many rows as in ``X`` and as many columns as retained samples of the algorithm. + Array with as many rows as in `X` and as many columns as retained samples of the algorithm. """ if not self.is_sampled(): msg = ( @@ -1166,13 +1202,13 @@ def predict_variance(self, covariates: np.array, propensity: np.array = None) -> ---------- covariates : np.array Test set covariates. - covariates : np.array + propensity : np.array, optional Test set propensity scores. Optional (not currently used in variance forests). Returns ------- - tuple of :obj:`np.array` - Tuple of arrays of predictions corresponding to the variance forest. Each array will contain as many rows as in ``covariates`` and as many columns as retained samples of the algorithm. + np.array + Array of predictions corresponding to the variance forest. Each array will contain as many rows as in `covariates` and as many columns as retained samples of the algorithm. """ if not self.is_sampled(): msg = ( @@ -1221,9 +1257,9 @@ def predict_variance(self, covariates: np.array, propensity: np.array = None) -> return variance_pred - def predict(self, X: np.array, Z: np.array, propensity: np.array = None) -> np.array: + def predict(self, X: np.array, Z: np.array, propensity: np.array = None) -> tuple: """Predict outcome model components (CATE function and prognostic function) as well as overall outcome for every provided observation. - Predicted outcomes are computed as ``yhat = mu_x + Z*tau_x`` where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function. + Predicted outcomes are computed as `yhat = mu_x + Z*tau_x` where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function. Parameters ---------- @@ -1231,16 +1267,20 @@ def predict(self, X: np.array, Z: np.array, propensity: np.array = None) -> np.a Test set covariates. Z : np.array Test set treatment indicators. - propensity : :obj:`np.array`, optional + propensity : `np.array`, optional Optional test set propensities. Must be provided if propensities were provided when the model was sampled. Returns ------- - tuple of np.array - Tuple of arrays with as many rows as in ``X`` and as many columns as retained samples of the algorithm. - The first entry of the tuple contains conditional average treatment effect (CATE) samples, - the second entry contains prognostic effect samples, and the third entry contains outcome prediction samples. - The optional fourth array contains variance forest samples. + tau_x : np.array + Conditional average treatment effect (CATE) samples for every observation provided. + mu_x : np.array + Prognostic effect samples for every observation provided. + yhat_x : np.array + Outcome prediction samples for every observation provided. + sigma2_x : np.array, optional + Variance forest samples for every observation provided. Only returned if the + model includes a heteroskedasticity forest. """ if not self.is_sampled(): msg = ( @@ -1318,11 +1358,11 @@ def predict(self, X: np.array, Z: np.array, propensity: np.array = None) -> np.a def to_json(self) -> str: """ Converts a sampled BART model to JSON string representation (which can then be saved to a file or - processed using the ``json`` library) + processed using the `json` library) Returns ------- - :obj:`str` + str JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests """ if not self.is_sampled: @@ -1377,7 +1417,7 @@ def from_json(self, json_string: str) -> None: Parameters ---------- - json_string : :obj:`str` + json_string : str JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests """ # Parse string to a JSON object in C++ @@ -1426,4 +1466,14 @@ def from_json(self, json_string: str) -> None: # Mark the deserialized model as "sampled" self.sampled = True + + def is_sampled(self) -> bool: + """Whether or not a BCF model has been sampled. + + Returns + ------- + bool + `True` if a BCF model has been sampled, `False` otherwise + """ + return self.sampled diff --git a/stochtree/calibration.py b/stochtree/calibration.py index 3e4d1fbe..7ff6a0f3 100644 --- a/stochtree/calibration.py +++ b/stochtree/calibration.py @@ -8,24 +8,24 @@ def calibrate_global_error_variance(X: np.array, y: np.array, nu: float = 3, q: float = 0.9, standardize: bool = True) -> float: """Calibrates scale parameter of global error variance model as in Chipman et al (2010) by setting a value of lambda, - part of the scale parameter in the ``sigma2 ~ IG(nu/2, (nu*lambda)/2)`` prior. + part of the scale parameter in the `sigma2 ~ IG(nu/2, (nu*lambda)/2)` prior. Parameters ---------- - X : :obj:`np.array` + X : np.array Covariates to be used as split candidates for constructing trees. - y : :obj:`np.array` + y : np.array Outcome to be used as target for constructing trees. - nu : :obj:`float`, optional - Shape parameter in the ``IG(nu, nu*lamb)`` global error variance model. Defaults to ``3``. - q : :obj:`float`, optional - Quantile used to calibrated ``lamb`` as in Sparapani et al (2021). Defaults to ``0.9``. - standardize : :obj:`bool`, optional - Whether or not ``y`` should be standardized before calibration. Defaults to ``True``. + nu : float, optional + Shape parameter in the `IG(nu, nu*lamb)` global error variance model. Defaults to `3`. + q : float, optional + Quantile used to calibrated `lamb` as in Sparapani et al (2021). Defaults to `0.9`. + standardize : bool, optional + Whether or not `y` should be standardized before calibration. Defaults to `True`. Returns ------- - lamb : :obj:`float` + float Part of scale parameter of global error variance model """ # Convert X and y to expected dimensions diff --git a/stochtree/data.py b/stochtree/data.py index 88c70f83..83a2a662 100644 --- a/stochtree/data.py +++ b/stochtree/data.py @@ -1,17 +1,29 @@ -""" -Python classes wrapping C++ data objects -""" import numpy as np from stochtree_cpp import ForestDatasetCpp, ResidualCpp class Dataset: + """ + Wrapper around a C++ class that stores all of the non-outcome data used in `stochtree`. This includes: + + 1. Features used for partitioning (also referred to as "covariates" in many places in these docs). + 2. Basis vectors used to define non-constant leaf models. This is optional but may be included via the `add_basis` method. + 3. Variance weights used to define heteroskedastic or otherwise weighted models. This is optional but may be included via the `add_variance_weights` method. + """ def __init__(self) -> None: - # Initialize a ForestDatasetCpp object + """ + Initialize a `Dataset` object + """ self.dataset_cpp = ForestDatasetCpp() def add_covariates(self, covariates: np.array): """ Add covariates to a dataset + + Parameters + ---------- + covariates : np.array + Numpy array of covariates. If data contain categorical, string, time series, or other columns in a + dataframe, please first preprocess using the `CovariateTransformer`. """ covariates_ = np.expand_dims(covariates, 1) if np.ndim(covariates) == 1 else covariates n, p = covariates_.shape @@ -21,6 +33,11 @@ def add_covariates(self, covariates: np.array): def add_basis(self, basis: np.array): """ Add basis matrix to a dataset + + Parameters + ---------- + basis : np.array + Numpy array of basis vectors. """ basis_ = np.expand_dims(basis, 1) if np.ndim(basis) == 1 else basis n, p = basis_.shape @@ -29,7 +46,13 @@ def add_basis(self, basis: np.array): def update_basis(self, basis: np.array): """ - Update basis matrix in a dataset + Update basis matrix in a dataset. Allows users to build an ensemble whose leaves + regress on bases that are updated throughout the sampler. + + Parameters + ---------- + basis : np.array + Numpy array of basis vectors. """ basis_ = np.expand_dims(basis, 1) if np.ndim(basis) == 1 else basis n, p = basis_.shape @@ -39,25 +62,56 @@ def update_basis(self, basis: np.array): def add_variance_weights(self, variance_weights: np.array): """ Add variance weights to a dataset + + Parameters + ---------- + variance_weights : np.array + Univariate numpy array of variance weights. """ n = variance_weights.size self.dataset_cpp.AddVarianceWeights(variance_weights, n) class Residual: + """ + Wrapper around a C++ class that stores residual data used in `stochtree`. + This object becomes part of the real-time model "state" in that its contents + always contain a full or partial residual, depending on the state of the sampler. + + Typically this object is initialized with the original outcome and then "residualized" + by subtracting out the initial prediction value of every tree in every forest term + (as well as the predictions of any other model term). + """ def __init__(self, residual: np.array) -> None: - # Initialize a ResidualCpp object + """ + Initialize a `Residual` object + + Parameters + ---------- + residual : np.array + Univariate numpy array of residual values. + """ n = residual.size self.residual_cpp = ResidualCpp(residual, n) def get_residual(self) -> np.array: """ Extract the current values of the residual as a numpy array + + Returns + ------- + np.array + Current values of the residual (which may be net of any forest / other model terms) """ return self.residual_cpp.GetResidualArray() def update_data(self, new_vector: np.array) -> None: """ - Update the current state of the outcome (i.e. partial residual) data by replacing each element with the elements of ``new_vector`` + Update the current state of the outcome (i.e. partial residual) data by replacing each element with the elements of `new_vector` + + Parameters + ---------- + new_vector : np.array + Univariate numpy array of new residual values. """ n = new_vector.size self.residual_cpp.UpdateData(new_vector, n) diff --git a/stochtree/forest.py b/stochtree/forest.py index 033a6b41..576dc60f 100644 --- a/stochtree/forest.py +++ b/stochtree/forest.py @@ -2,14 +2,31 @@ Python classes wrapping C++ forest container object """ import numpy as np -from .data import Dataset, Residual -# from .serialization import JSONSerializer +from .data import Dataset from stochtree_cpp import ForestContainerCpp, ForestCpp from typing import Union class ForestContainer: - def __init__(self, num_trees: int, output_dimension: int, leaf_constant: bool, is_exponentiated: bool) -> None: - # Initialize a ForestContainerCpp object + """ + Container that stores sampled (and retained) tree ensembles from BART, BCF or a custom sampler. + + Parameters + ---------- + num_trees : int + Number of trees that each forest should contain + output_dimension : int, optional + Dimension of the leaf node parameters in each tree + leaf_constant : bool, optional + Whether the leaf node model is "constant" (i.e. prediction is simply a + sum of leaf node parameters for every observation in a dataset) or not (i.e. + each leaf node parameter is multiplied by a "basis vector" before being returned + as a prediction). + is_exponentiated : bool, optional + Whether or not the leaf node parameters are stored in log scale (in which case, they + must be exponentiated before being returned as predictions). + """ + def __init__(self, num_trees: int, output_dimension: int = 1, + leaf_constant: bool = True, is_exponentiated: bool = False) -> None: self.forest_container_cpp = ForestContainerCpp(num_trees, output_dimension, leaf_constant, is_exponentiated) self.num_trees = num_trees self.output_dimension = output_dimension @@ -17,11 +34,38 @@ def __init__(self, num_trees: int, output_dimension: int, leaf_constant: bool, i self.is_exponentiated = is_exponentiated def predict(self, dataset: Dataset) -> np.array: - # Predict samples from Dataset + """ + Predict from each forest in the container, using the provided `Dataset` object. + + Parameters + ---------- + dataset : Dataset + Python object wrapping the "dataset" class used by C++ sampling and prediction data structures. + + Returns + ------- + np.array + Numpy array with (`n`, `m`) dimensions, where `n` is the number of observations in `dataset` and `m` + is the number of samples in the forest container. + """ return self.forest_container_cpp.Predict(dataset.dataset_cpp) def predict_raw(self, dataset: Dataset) -> np.array: - # Predict raw leaf values for a specific forest (indexed by forest_num) from Dataset + """ + Predict raw leaf values for a every forest in the container, using the provided `Dataset` object + + Parameters + ---------- + dataset : Dataset + Python object wrapping the "dataset" class used by C++ sampling and prediction data structures. + + Returns + ------- + np.array + Numpy array with (`n`, `k`, `m`) dimensions, where `n` is the number of observations in `dataset`, + `k` is the dimension of the leaf parameter, and `m` is the number of samples in the forest container. + If `k = 1`, then the returned array is simply (`n`, `m`) dimensions. + """ result = self.forest_container_cpp.PredictRaw(dataset.dataset_cpp) if result.ndim == 3: if result.shape[1] == 1: @@ -29,15 +73,60 @@ def predict_raw(self, dataset: Dataset) -> np.array: return result def predict_raw_single_forest(self, dataset: Dataset, forest_num: int) -> np.array: - # Predict raw leaf values for a specific forest (indexed by forest_num) from Dataset + """ + Predict raw leaf values for a specific forest (indexed by `forest_num`), using the provided `Dataset` object + + Parameters + ---------- + dataset : Dataset + Python object wrapping the "dataset" class used by C++ sampling and prediction data structures. + forest_num : int + Index of the forest from which to predict. Forest indices are 0-based. + + Returns + ------- + np.array + Numpy array with (`n`, `k`) dimensions, where `n` is the number of observations in `dataset` and + `k` is the dimension of the leaf parameter. + """ return self.forest_container_cpp.PredictRawSingleForest(dataset.dataset_cpp, forest_num) def predict_raw_single_tree(self, dataset: Dataset, forest_num: int, tree_num: int) -> np.array: - # Predict raw leaf values for a specific tree from specific forest from Dataset + """ + Predict raw leaf values for a specific tree of a specific forest (indexed by `tree_num` and `forest_num` + respectively), using the provided `Dataset` object. + + Parameters + ---------- + dataset : Dataset + Python object wrapping the "dataset" class used by C++ sampling and prediction data structures. + forest_num : int + Index of the forest from which to predict. Forest indices are 0-based. + tree_num : int + Index of the tree which to predict (within forest indexed by `forest_num`). Tree indices are 0-based. + + Returns + ------- + np.array + Numpy array with (`n`, `k`) dimensions, where `n` is the number of observations in `dataset` and + `k` is the dimension of the leaf parameter. + """ return self.forest_container_cpp.PredictRawSingleTree(dataset.dataset_cpp, forest_num, tree_num) def set_root_leaves(self, forest_num: int, leaf_value: Union[float, np.array]) -> None: - # Predict raw leaf values for a specific forest (indexed by forest_num) from Dataset + """ + Set constant (root) leaf node values for every tree in the forest indexed by `forest_num`. + Assumes the forest consists of all root (single-node) trees. + + Parameters + ---------- + forest_num : int + Index of the forest for which we will set root node parameters. + leaf_value : float or np.array + Constant values to which root nodes are to be set. If the trees in forest `forest_num` + are univariate, then `leaf_value` must be a `float`, while if the trees in forest `forest_num` + are multivariate, then `leaf_value` must be a `np.array`. + """ if not isinstance(leaf_value, np.ndarray) and not isinstance(leaf_value, float): raise ValueError("leaf_value must be either a float or np.array") if isinstance(leaf_value, np.ndarray): @@ -49,15 +138,50 @@ def set_root_leaves(self, forest_num: int, leaf_value: Union[float, np.array]) - self.forest_container_cpp.SetRootValue(forest_num, leaf_value) def save_to_json_file(self, json_filename: str) -> None: + """ + Save the forests in the container to a JSON file. + + Parameters + ---------- + json_filename : str + Name of JSON file to which forest container state will be saved. + May contain absolute or relative paths. + """ self.forest_container_cpp.SaveToJsonFile(json_filename) def load_from_json_file(self, json_filename: str) -> None: + """ + Load a forest container from output stored in a JSON file. + + Parameters + ---------- + json_filename : str + Name of JSON file from which forest container state will be restored. + May contain absolute or relative paths. + """ self.forest_container_cpp.LoadFromJsonFile(json_filename) def dump_json_string(self) -> str: + """ + Dump a forest container into an in-memory JSON string (which can be directly serialized or + combined with other JSON strings before serialization). + + Returns + ------- + str + In-memory string containing state of a forest container. + """ return self.forest_container_cpp.DumpJsonString() def load_from_json_string(self, json_string: str) -> None: + """ + Reload a forest container from an in-memory JSON string. + + Parameters + ---------- + json_string : str + In-memory string containing state of a forest container. + """ self.forest_container_cpp.LoadFromJsonString(json_string) def add_sample(self, leaf_value: Union[float, np.array]) -> None: @@ -66,8 +190,8 @@ def add_sample(self, leaf_value: Union[float, np.array]) -> None: Parameters ---------- - leaf_value : :obj:`float` or :obj:`np.array` - Value (or vector of values) to initialize root nodes in tree + leaf_value : float or np.array + Value (or vector of values) to initialize root nodes of every tree in a forest """ if isinstance(leaf_value, np.ndarray): leaf_value = np.squeeze(leaf_value) @@ -82,19 +206,19 @@ def add_numeric_split(self, forest_num: int, tree_num: int, leaf_num: int, featu Parameters ---------- - forest_num : :obj:`int` + forest_num : int Index of the forest which contains the tree to be split - tree_num : :obj:`int` + tree_num : int Index of the tree to be split - leaf_num : :obj:`int` + leaf_num : int Leaf to be split - feature_num : :obj:`int` + feature_num : int Feature that defines the new split - split_threshold : :obj:`float` + split_threshold : float Value that defines the cutoff of the new split - left_leaf_value : :obj:`float` or :obj:`np.array` + left_leaf_value : float or np.array Value (or array of values) to assign to the newly created left node - right_leaf_value : :obj:`float` or :obj:`np.array` + right_leaf_value : float or np.array Value (or array of values) to assign to the newly created right node """ if isinstance(left_leaf_value, np.ndarray): @@ -110,59 +234,72 @@ def get_tree_leaves(self, forest_num: int, tree_num: int) -> np.array: Parameters ---------- - forest_num : :obj:`int` + forest_num : int Index of the forest which contains tree `tree_num` - tree_num : :obj:`float` or :obj:`np.array` + tree_num : float or np.array Index of the tree for which leaf indices will be retrieved + + Returns + ------- + np.array + One-dimensional numpy array, containing the indices of leaf nodes in a given tree. """ return self.forest_container_cpp.GetTreeLeaves(forest_num, tree_num) def get_tree_split_counts(self, forest_num: int, tree_num: int, num_features: int) -> np.array: """ - Retrieve a vector of split counts for every training set variable in a given tree in a given forest + Retrieve a vector of split counts for every training set feature in a given tree in a given forest Parameters ---------- - forest_num : :obj:`int` + forest_num : int Index of the forest which contains tree `tree_num` - tree_num : :obj:`int` + tree_num : int Index of the tree for which split counts will be retrieved - num_features : :obj:`int` + num_features : int Total number of features in the training set + + Returns + ------- + np.array + One-dimensional numpy array with as many elements as in the forest model's training set, + containing the split count for each feature for a given forest and tree. """ return self.forest_container_cpp.GetTreeSplitCounts(forest_num, tree_num, num_features) def get_forest_split_counts(self, forest_num: int, num_features: int) -> np.array: """ - Retrieve a vector of split counts for every training set variable in a given forest + Retrieve a vector of split counts for every training set feature in a given forest Parameters ---------- - forest_num : :obj:`int` + forest_num : int Index of the forest which contains tree `tree_num` - num_features : :obj:`int` + num_features : int Total number of features in the training set Returns ------- - :obj:`np.array` - One-dimensional numpy array, containing the number of splits a variable receives, summed across each tree of a given forest in a ``ForestContainer`` + np.array + One-dimensional numpy array with as many elements as in the forest model's training set, + containing the split count for each feature for a given forest (summed across every tree in the forest). """ return self.forest_container_cpp.GetForestSplitCounts(forest_num, num_features) def get_overall_split_counts(self, num_features: int) -> np.array: """ - Retrieve a vector of split counts for every training set variable in a given forest, aggregated across ensembles and trees + Retrieve a vector of split counts for every training set feature, aggregated across ensembles and trees. Parameters ---------- - num_features : :obj:`int` + num_features : int Total number of features in the training set Returns ------- - :obj:`np.array` - One-dimensional numpy array, containing the number of splits a variable receives, summed across each tree of every forest in a ``ForestContainer`` + np.array + One-dimensional numpy array with as many elements as in the forest model's training set, + containing the split count for each feature summed across every forest of every tree in the container. """ return self.forest_container_cpp.GetOverallSplitCounts(num_features) @@ -172,13 +309,15 @@ def get_granular_split_counts(self, num_features: int) -> np.array: Parameters ---------- - num_features : :obj:`int` + num_features : int Total number of features in the training set Returns ------- - :obj:`np.array` - Three-dimensional numpy array, containing the number of splits a variable receives in each tree of each forest in a ``ForestContainer`` + np.array + Three-dimensional numpy array, containing the number of splits a variable receives in each tree of each forest in a ``ForestContainer``. + Array will have dimensions (`m`,`b`,`p`) where `m` is the number of forests in the container, `b` is the number of trees in each + forest, and `p` is the number of features in the forest model's training dataset. """ return self.forest_container_cpp.GetGranularSplitCounts(num_features) @@ -188,12 +327,12 @@ def num_forest_leaves(self, forest_num: int) -> int: Parameters ---------- - forest_num : :obj:`int` + forest_num : int Index of the forest to be queried Returns ------- - :obj:`int` + int Number of leaves in a given forest in a ``ForestContainer`` """ return self.forest_container_cpp.NumLeavesForest(forest_num) @@ -204,12 +343,12 @@ def sum_leaves_squared(self, forest_num: int) -> float: Parameters ---------- - forest_num : :obj:`int` + forest_num : int Index of the forest to be queried Returns ------- - :obj:`float` + float Sum of squared leaf values in a given forest in a ``ForestContainer`` """ return self.forest_container_cpp.SumLeafSquared(forest_num) @@ -218,12 +357,19 @@ def is_leaf_node(self, forest_num: int, tree_num: int, node_id: int) -> bool: """ Whether or not a given node of a given tree in a given forest in the ``ForestContainer`` is a leaf - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + bool + `True` if node `node_id` in tree `tree_num` of forest `forest_num` is a leaf, `False` otherwise """ return self.forest_container_cpp.IsLeafNode(forest_num, tree_num, node_id) @@ -231,12 +377,19 @@ def is_numeric_split_node(self, forest_num: int, tree_num: int, node_id: int) -> """ Whether or not a given node of a given tree in a given forest in the ``ForestContainer`` is a numeric split node - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + bool + `True` if node `node_id` in tree `tree_num` of forest `forest_num` is a numeric split node, `False` otherwise """ return self.forest_container_cpp.IsNumericSplitNode(forest_num, tree_num, node_id) @@ -244,12 +397,19 @@ def is_categorical_split_node(self, forest_num: int, tree_num: int, node_id: int """ Whether or not a given node of a given tree in a given forest in the ``ForestContainer`` is a categorical split node - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + bool + `True` if node `node_id` in tree `tree_num` of forest `forest_num` is a categorical split node, `False` otherwise """ return self.forest_container_cpp.IsCategoricalSplitNode(forest_num, tree_num, node_id) @@ -257,12 +417,20 @@ def parent_node(self, forest_num: int, tree_num: int, node_id: int) -> int: """ Parent node of given node of a given tree in a given forest in the ``ForestContainer`` - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + int + Index of the parent of node `node_id` in tree `tree_num` of forest `forest_num`. + If `node_id` is a root node, returns `-1`. """ return self.forest_container_cpp.ParentNode(forest_num, tree_num, node_id) @@ -270,12 +438,20 @@ def left_child_node(self, forest_num: int, tree_num: int, node_id: int) -> int: """ Left child node of given node of a given tree in a given forest in the ``ForestContainer`` - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + int + Index of the left child of node `node_id` in tree `tree_num` of forest `forest_num`. + If `node_id` is a leaf, returns `-1`. """ return self.forest_container_cpp.LeftChildNode(forest_num, tree_num, node_id) @@ -283,12 +459,20 @@ def right_child_node(self, forest_num: int, tree_num: int, node_id: int) -> int: """ Right child node of given node of a given tree in a given forest in the ``ForestContainer`` - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + int + Index of the right child of node `node_id` in tree `tree_num` of forest `forest_num`. + If `node_id` is a leaf, returns `-1`. """ return self.forest_container_cpp.RightChildNode(forest_num, tree_num, node_id) @@ -296,12 +480,20 @@ def node_depth(self, forest_num: int, tree_num: int, node_id: int) -> int: """ Depth of given node of a given tree in a given forest in the ``ForestContainer``. - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + int + Depth of node `node_id` in tree `tree_num` of forest `forest_num`. The root node is defined + as "depth zero." """ return self.forest_container_cpp.NodeDepth(forest_num, tree_num, node_id) @@ -310,12 +502,19 @@ def node_split_index(self, forest_num: int, tree_num: int, node_id: int) -> int: Split index of given node of a given tree in a given forest in the ``ForestContainer``. Returns ``-1`` if the node is a leaf. - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + int + Split index of `node_id` in tree `tree_num` of forest `forest_num`. """ if self.is_leaf_node(forest_num, tree_num, node_id): return -1 @@ -327,12 +526,19 @@ def node_split_threshold(self, forest_num: int, tree_num: int, node_id: int) -> Threshold that defines a numeric split for a given node of a given tree in a given forest in the ``ForestContainer``. Returns ``np.Inf`` if the node is a leaf or a categorical split node. - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + float + Threshold that defines a numeric split for node `node_id` in tree `tree_num` of forest `forest_num`. """ if self.is_leaf_node(forest_num, tree_num, node_id) or self.is_categorical_split_node(forest_num, tree_num, node_id): return np.Inf @@ -344,12 +550,19 @@ def node_split_categories(self, forest_num: int, tree_num: int, node_id: int) -> Array of category indices that define a categorical split for a given node of a given tree in a given forest in the ``ForestContainer``. Returns ``np.array([np.Inf])`` if the node is a leaf or a numeric split node. - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + np.array + Array of category indices that define a categorical split for node `node_id` in tree `tree_num` of forest `forest_num`. """ if self.is_leaf_node(forest_num, tree_num, node_id) or self.is_numeric_split_node(forest_num, tree_num, node_id): return np.array([np.Inf]) @@ -358,15 +571,22 @@ def node_split_categories(self, forest_num: int, tree_num: int, node_id: int) -> def node_leaf_values(self, forest_num: int, tree_num: int, node_id: int) -> np.array: """ - Leaf node value(s) for a given node of a given tree in a given forest in the ``ForestContainer``. + Node parameter value(s) for a given node of a given tree in a given forest in the ``ForestContainer``. Values are stale if the node is a split node. - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + np.array + Array of parameter values for node `node_id` in tree `tree_num` of forest `forest_num`. """ return self.forest_container_cpp.NodeLeafValues(forest_num, tree_num, node_id) @@ -374,10 +594,17 @@ def num_nodes(self, forest_num: int, tree_num: int) -> int: """ Number of nodes in a given tree in a given forest in the ``ForestContainer``. - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried + + Returns + ------- + int + Total number of nodes in tree `tree_num` of forest `forest_num`. """ return self.forest_container_cpp.NumNodes(forest_num, tree_num) @@ -385,10 +612,17 @@ def num_leaves(self, forest_num: int, tree_num: int) -> int: """ Number of leaves in a given tree in a given forest in the ``ForestContainer``. - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried + + Returns + ------- + int + Total number of leaves in tree `tree_num` of forest `forest_num`. """ return self.forest_container_cpp.NumLeaves(forest_num, tree_num) @@ -396,10 +630,17 @@ def num_leaf_parents(self, forest_num: int, tree_num: int) -> int: """ Number of leaf parents (split nodes with two leaves as children) in a given tree in a given forest in the ``ForestContainer``. - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried + + Returns + ------- + int + Total number of leaf parents in tree `tree_num` of forest `forest_num`. """ return self.forest_container_cpp.NumLeafParents(forest_num, tree_num) @@ -407,10 +648,17 @@ def num_split_nodes(self, forest_num: int, tree_num: int) -> int: """ Number of split_nodes in a given tree in a given forest in the ``ForestContainer``. - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried + + Returns + ------- + int + Total number of split nodes in tree `tree_num` of forest `forest_num`. """ return self.forest_container_cpp.NumSplitNodes(forest_num, tree_num) @@ -418,10 +666,17 @@ def nodes(self, forest_num: int, tree_num: int) -> np.array: """ Array of node indices in a given tree in a given forest in the ``ForestContainer``. - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried + + Returns + ------- + np.array + Array of indices of nodes in tree `tree_num` of forest `forest_num`. """ return self.forest_container_cpp.Nodes(forest_num, tree_num) @@ -429,10 +684,17 @@ def leaves(self, forest_num: int, tree_num: int) -> np.array: """ Array of leaf indices in a given tree in a given forest in the ``ForestContainer``. - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be queried - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried + + Returns + ------- + np.array + Array of indices of leaf nodes in tree `tree_num` of forest `forest_num`. """ return self.forest_container_cpp.Leaves(forest_num, tree_num) @@ -440,15 +702,39 @@ def delete_sample(self, forest_num: int) -> None: """ Modify the ``ForestContainer`` by removing the forest sample indexed by ``forest_num``. - forest_num : :obj:`int` + Parameters + ---------- + forest_num : int Index of the forest to be removed from the ``ForestContainer`` """ return self.forest_container_cpp.DeleteSample(forest_num) class Forest: - def __init__(self, num_trees: int, output_dimension: int, leaf_constant: bool, is_exponentiated: bool) -> None: - # Initialize a ForestCpp object + """ + In-memory python wrapper around a C++ tree ensemble object + + Parameters + ---------- + num_trees : int + Number of trees that each forest should contain + output_dimension : int, optional + Dimension of the leaf node parameters in each tree + leaf_constant : bool, optional + Whether the leaf node model is "constant" (i.e. prediction is simply a + sum of leaf node parameters for every observation in a dataset) or not (i.e. + each leaf node parameter is multiplied by a "basis vector" before being returned + as a prediction). + is_exponentiated : bool, optional + Whether or not the leaf node parameters are stored in log scale (in which case, they + must be exponentiated before being returned as predictions). + """ + def __init__(self, num_trees: int, output_dimension: int = 1, + leaf_constant: bool = True, is_exponentiated: bool = False) -> None: self.forest_cpp = ForestCpp(num_trees, output_dimension, leaf_constant, is_exponentiated) + self.num_trees = num_trees + self.output_dimension = output_dimension + self.leaf_constant = leaf_constant + self.is_exponentiated = is_exponentiated def reset_root(self) -> None: """ @@ -462,19 +748,45 @@ def reset(self, forest_container: ForestContainer, forest_num: int) -> None: Parameters ---------- - forest_container : :obj:`ForestContainer` + forest_container : `ForestContainer Stochtree object storing tree ensembles - forest_num : :obj:`int` + forest_num : int Index of the ensemble used to reset the ``Forest`` """ self.forest_cpp.Reset(forest_container.forest_container_cpp, forest_num) def predict(self, dataset: Dataset) -> np.array: - # Predict samples from Dataset + """ + Predict from each forest in the container, using the provided `Dataset` object. + + Parameters + ---------- + dataset : Dataset + Python object wrapping the "dataset" class used by C++ sampling and prediction data structures. + + Returns + ------- + np.array + One-dimensional numpy array with length equal to the number of observations in `dataset`. + """ return self.forest_cpp.Predict(dataset.dataset_cpp) def predict_raw(self, dataset: Dataset) -> np.array: - # Predict raw leaf values for a specific forest (indexed by forest_num) from Dataset + """ + Predict raw leaf values for a every forest in the container, using the provided `Dataset` object + + Parameters + ---------- + dataset : Dataset + Python object wrapping the "dataset" class used by C++ sampling and prediction data structures. + + Returns + ------- + np.array + Numpy array with (`n`, `k`) dimensions, where `n` is the number of observations in `dataset` and + `k` is the dimension of the leaf parameter. If `k = 1`, then the returned array is simply one-dimensional + with `n` observations. + """ result = self.forest_cpp.PredictRaw(dataset.dataset_cpp) if result.ndim == 3: if result.shape[1] == 1: @@ -482,7 +794,17 @@ def predict_raw(self, dataset: Dataset) -> np.array: return result def set_root_leaves(self, leaf_value: Union[float, np.array]) -> None: - # Predict raw leaf values for a specific forest (indexed by forest_num) from Dataset + """ + Set constant (root) leaf node values for every tree in the forest. + Assumes the forest consists of all root (single-node) trees. + + Parameters + ---------- + leaf_value : float or np.array + Constant values to which root nodes are to be set. If the trees in forest `forest_num` + are univariate, then `leaf_value` must be a `float`, while if the trees in forest `forest_num` + are multivariate, then `leaf_value` must be a `np.array`. + """ if not isinstance(leaf_value, np.ndarray) and not isinstance(leaf_value, float): raise ValueError("leaf_value must be either a float or np.array") if isinstance(leaf_value, np.ndarray): @@ -500,17 +822,17 @@ def add_numeric_split(self, tree_num: int, leaf_num: int, feature_num: int, spli Parameters ---------- - tree_num : :obj:`int` + tree_num : int Index of the tree to be split - leaf_num : :obj:`int` + leaf_num : int Leaf to be split - feature_num : :obj:`int` + feature_num : int Feature that defines the new split - split_threshold : :obj:`float` + split_threshold : float Value that defines the cutoff of the new split - left_leaf_value : :obj:`float` or :obj:`np.array` + left_leaf_value : float or np.array Value (or array of values) to assign to the newly created left node - right_leaf_value : :obj:`float` or :obj:`np.array` + right_leaf_value : float or np.array Value (or array of values) to assign to the newly created right node """ if isinstance(left_leaf_value, np.ndarray): @@ -526,8 +848,13 @@ def get_tree_leaves(self, tree_num: int) -> np.array: Parameters ---------- - tree_num : :obj:`float` or :obj:`np.array` + tree_num : float or np.array Index of the tree for which leaf indices will be retrieved + + Returns + ------- + np.array + One-dimensional numpy array, containing the indices of leaf nodes in a given tree. """ return self.forest_cpp.GetTreeLeaves(tree_num) @@ -537,10 +864,16 @@ def get_tree_split_counts(self, tree_num: int, num_features: int) -> np.array: Parameters ---------- - tree_num : :obj:`int` + tree_num : int Index of the tree for which split counts will be retrieved - num_features : :obj:`int` + num_features : int Total number of features in the training set + + Returns + ------- + np.array + One-dimensional numpy array with as many elements as in the forest model's training set, + containing the split count for each feature for a given tree of the forest. """ return self.forest_cpp.GetTreeSplitCounts(tree_num, num_features) @@ -550,8 +883,14 @@ def get_overall_split_counts(self, num_features: int) -> np.array: Parameters ---------- - num_features : :obj:`int` + num_features : int Total number of features in the training set + + Returns + ------- + np.array + One-dimensional numpy array with as many elements as in the forest model's training set, + containing the overall split count in the forest for each feature. """ return self.forest_cpp.GetOverallSplitCounts(num_features) @@ -561,8 +900,14 @@ def get_granular_split_counts(self, num_features: int) -> np.array: Parameters ---------- - num_features : :obj:`int` + num_features : int Total number of features in the training set + + Returns + ------- + np.array + One-dimensional numpy array with as many elements as in the forest model's training set, + containing the split count for each feature for a every tree in the forest. """ return self.forest_cpp.GetGranularSplitCounts(num_features) @@ -572,7 +917,7 @@ def num_forest_leaves(self) -> int: Returns ------- - :obj:`int` + int Number of leaves in a forest """ return self.forest_cpp.NumLeavesForest() @@ -583,7 +928,7 @@ def sum_leaves_squared(self) -> float: Returns ------- - :obj:`float` + float Sum of squared leaf values in a forest """ return self.forest_cpp.SumLeafSquared() @@ -592,10 +937,15 @@ def is_leaf_node(self, tree_num: int, node_id: int) -> bool: """ Whether or not a given node of a given tree of a forest is a leaf - tree_num : :obj:`int` + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + bool + `True` if node `node_id` in tree `tree_num` is a leaf, `False` otherwise """ return self.forest_cpp.IsLeafNode(tree_num, node_id) @@ -603,10 +953,17 @@ def is_numeric_split_node(self, tree_num: int, node_id: int) -> bool: """ Whether or not a given node of a given tree of a forest is a numeric split node - tree_num : :obj:`int` + Parameters + ---------- + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + bool + `True` if node `node_id` in tree `tree_num` is a numeric split node, `False` otherwise """ return self.forest_cpp.IsNumericSplitNode(tree_num, node_id) @@ -614,10 +971,17 @@ def is_categorical_split_node(self, tree_num: int, node_id: int) -> bool: """ Whether or not a given node of a given tree of a forest is a categorical split node - tree_num : :obj:`int` + Parameters + ---------- + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + bool + `True` if node `node_id` in tree `tree_num` is a categorical split node, `False` otherwise """ return self.forest_cpp.IsCategoricalSplitNode(tree_num, node_id) @@ -625,10 +989,18 @@ def parent_node(self, tree_num: int, node_id: int) -> int: """ Parent node of given node of a given tree of a forest - tree_num : :obj:`int` + Parameters + ---------- + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + int + Index of the parent of node `node_id` in tree `tree_num`. + If `node_id` is a root node, returns `-1`. """ return self.forest_cpp.ParentNode(tree_num, node_id) @@ -636,10 +1008,18 @@ def left_child_node(self, tree_num: int, node_id: int) -> int: """ Left child node of given node of a given tree of a forest - tree_num : :obj:`int` + Parameters + ---------- + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + int + Index of the left child of node `node_id` in tree `tree_num`. + If `node_id` is a leaf, returns `-1`. """ return self.forest_cpp.LeftChildNode(tree_num, node_id) @@ -647,10 +1027,18 @@ def right_child_node(self, tree_num: int, node_id: int) -> int: """ Right child node of given node of a given tree of a forest - tree_num : :obj:`int` + Parameters + ---------- + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + int + Index of the right child of node `node_id` in tree `tree_num`. + If `node_id` is a leaf, returns `-1`. """ return self.forest_cpp.RightChildNode(tree_num, node_id) @@ -659,10 +1047,17 @@ def node_depth(self, tree_num: int, node_id: int) -> int: Depth of given node of a given tree of a forest Returns ``-1`` if the node is a leaf. - tree_num : :obj:`int` + Parameters + ---------- + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + int + Depth of node `node_id` in tree `tree_num`. The root node is defined as "depth zero." """ return self.forest_cpp.NodeDepth(tree_num, node_id) @@ -671,10 +1066,17 @@ def node_split_index(self, tree_num: int, node_id: int) -> int: Split index of given node of a given tree of a forest. Returns ``-1`` if the node is a leaf. - tree_num : :obj:`int` + Parameters + ---------- + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + int + Split index of `node_id` in tree `tree_num`. """ if self.is_leaf_node(tree_num, node_id): return -1 @@ -686,10 +1088,17 @@ def node_split_threshold(self, tree_num: int, node_id: int) -> float: Threshold that defines a numeric split for a given node of a given tree of a forest. Returns ``np.Inf`` if the node is a leaf or a categorical split node. - tree_num : :obj:`int` + Parameters + ---------- + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + float + Threshold that defines a numeric split for node `node_id` in tree `tree_num`. """ if self.is_leaf_node(tree_num, node_id) or self.is_categorical_split_node(tree_num, node_id): return np.Inf @@ -701,10 +1110,17 @@ def node_split_categories(self, tree_num: int, node_id: int) -> np.array: Array of category indices that define a categorical split for a given node of a given tree of a forest. Returns ``np.array([np.Inf])`` if the node is a leaf or a numeric split node. - tree_num : :obj:`int` + Parameters + ---------- + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + np.array + Array of category indices that define a categorical split for node `node_id` in tree `tree_num`. """ if self.is_leaf_node(tree_num, node_id) or self.is_numeric_split_node(tree_num, node_id): return np.array([np.Inf]) @@ -716,10 +1132,17 @@ def node_leaf_values(self, tree_num: int, node_id: int) -> np.array: Leaf node value(s) for a given node of a given tree of a forest. Values are stale if the node is a split node. - tree_num : :obj:`int` + Parameters + ---------- + tree_num : int Index of the tree to be queried - node_id : :obj:`int` + node_id : int Index of the node to be queried + + Returns + ------- + np.array + Array of parameter values for node `node_id` in tree `tree_num`. """ return self.forest_cpp.NodeLeafValues(tree_num, node_id) @@ -727,8 +1150,15 @@ def num_nodes(self, tree_num: int) -> int: """ Number of nodes in a given tree of a forest - tree_num : :obj:`int` + Parameters + ---------- + tree_num : int Index of the tree to be queried + + Returns + ------- + int + Total number of nodes in tree `tree_num`. """ return self.forest_cpp.NumNodes(tree_num) @@ -736,8 +1166,15 @@ def num_leaves(self, tree_num: int) -> int: """ Number of leaves in a given tree of a forest - tree_num : :obj:`int` + Parameters + ---------- + tree_num : int Index of the tree to be queried + + Returns + ------- + int + Total number of leaves in tree `tree_num`. """ return self.forest_cpp.NumLeaves(tree_num) @@ -745,8 +1182,15 @@ def num_leaf_parents(self, tree_num: int) -> int: """ Number of leaf parents in a given tree of a forest - tree_num : :obj:`int` + Parameters + ---------- + tree_num : int Index of the tree to be queried + + Returns + ------- + int + Total number of leaf parents in tree `tree_num`. """ return self.forest_cpp.NumLeafParents(tree_num) @@ -754,8 +1198,15 @@ def num_split_nodes(self, tree_num: int) -> int: """ Number of split_nodes in a given tree of a forest - tree_num : :obj:`int` + Parameters + ---------- + tree_num : int Index of the tree to be queried + + Returns + ------- + int + Total number of split nodes in tree `tree_num`. """ return self.forest_cpp.NumSplitNodes(tree_num) @@ -763,8 +1214,15 @@ def nodes(self, tree_num: int) -> np.array: """ Array of node indices in a given tree of a forest - tree_num : :obj:`int` + Parameters + ---------- + tree_num : int Index of the tree to be queried + + Returns + ------- + np.array + Array of indices of nodes in tree `tree_num`. """ return self.forest_cpp.Nodes(tree_num) @@ -772,7 +1230,14 @@ def leaves(self, tree_num: int) -> np.array: """ Array of leaf indices in a given tree of a forest - tree_num : :obj:`int` + Parameters + ---------- + tree_num : int Index of the tree to be queried + + Returns + ------- + np.array + Array of indices of leaf nodes in tree `tree_num`. """ return self.forest_cpp.Leaves(tree_num) diff --git a/stochtree/preprocessing.py b/stochtree/preprocessing.py index 7a796198..a586afd8 100644 --- a/stochtree/preprocessing.py +++ b/stochtree/preprocessing.py @@ -5,7 +5,6 @@ """ from typing import Union, Optional, Any, Dict from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder -from sklearn.utils.validation import check_array, column_or_1d import numpy as np import pandas as pd import warnings @@ -131,9 +130,10 @@ def _preprocess_bcf_params(params: Optional[Dict[str, Any]] = None) -> Dict[str, class CovariateTransformer: - """Class that transforms covariates to a format that can be used to define tree splits """ - + Class that transforms covariates to a format that can be used to define tree splits. + Modeled after the [scikit-learn preprocessing classes](https://scikit-learn.org/1.5/modules/preprocessing.html). + """ def __init__(self) -> None: self._is_fitted = False self._ordinal_encoders = [] @@ -346,50 +346,88 @@ def _check_is_fitted(self) -> bool: return self._is_fitted def fit(self, covariates: Union[pd.DataFrame, np.array]) -> None: - """Fits a ``CovariateTransformer`` by unpacking (and storing) data type information on the input (raw) covariates + r"""Fits a `CovariateTransformer` by unpacking (and storing) data type information on the input (raw) covariates and then converting to a numpy array which can be passed to a tree ensemble sampler. - If ``covariates`` is a ``pd.DataFrame``, `column dtypes `_ + If `covariates` is a `pd.DataFrame`, [column dtypes](https://pandas.pydata.org/docs/user_guide/basics.html#basics-dtypes) will be handled as follows: - * ``category``: one-hot encoded if unordered, ordinal encoded if ordered - * ``string``: one-hot encoded - * ``boolean``: passed through as binary integer, treated as ordered categorical by tree samplers - * integer (i.e. ``Int8``, ``Int16``, etc...): passed through as double (**note**: if you have categorical data stored as integers, you should explicitly convert it to categorical in pandas, see this `user guide `_) - * float (i.e. ``Float32``, ``Float64``): passed through as double - * ``object``: currently unsupported, convert object columns to numeric or categorical before passing - * Datetime (i.e. ``datetime64``): currently unsupported, though datetime columns can be converted to numeric features, see `here `_ - * Period (i.e. ``period[]``): currently unsupported, though period columns can be converted to numeric features, see `here `_ - * Interval (i.e. ``interval``, ``Interval[datetime64[ns]]``): currently unsupported, though interval columns can be converted to numeric or categorical features, see `here `_ - * Sparse (i.e. ``Sparse``, ``Sparse[float]``): currently unsupported, convert sparse columns to dense before passing + * `category`: one-hot encoded if unordered, ordinal encoded if ordered + * `string`: one-hot encoded + * `boolean`: passed through as binary integer, treated as ordered categorical by tree samplers + * integer (i.e. `Int8`, `Int16`, etc...): passed through as double (**note**: if you have categorical data stored as integers, you should explicitly convert it to categorical in pandas, see this [user guide](https://pandas.pydata.org/pandas-docs/stable/user_guide/categorical.html)) + * float (i.e. `Float32`, `Float64`): passed through as double + * `object`: currently unsupported, convert object columns to numeric or categorical before passing + * Datetime (i.e. `datetime64`): currently unsupported, though datetime columns can be converted to numeric features, see [here](https://pandas.pydata.org/docs/reference/api/pandas.Timestamp.html#pandas.Timestamp) + * Period (i.e. `period[]`): currently unsupported, though period columns can be converted to numeric features, see [here](https://pandas.pydata.org/docs/reference/api/pandas.Period.html#pandas.Period) + * Interval (i.e. `interval`, `Interval[datetime64[ns]]`): currently unsupported, though interval columns can be converted to numeric or categorical features, see [here](https://pandas.pydata.org/docs/reference/api/pandas.Interval.html#pandas.Interval) + * Sparse (i.e. `Sparse`, `Sparse[float]`): currently unsupported, convert sparse columns to dense before passing Columns with unsupported types will be ignored, with a warning. - If ``covariates`` is a ``np.array``, columns must be numeric and the only preprocessing done by ``CovariateTransformer.fit()`` is to + If `covariates` is a `np.array`, columns must be numeric and the only preprocessing done by `CovariateTransformer.fit()` is to auto-detect binary columns. All other integer-valued columns will be passed through to the tree sampler as (continuous) numeric data. If you would like to treat integer-valued data as categorical, you can either convert your numpy array to a pandas dataframe and - explicitly tag such columns as ordered / unordered categorical, or preprocess manually using ``sklearn.preprocessing.OneHotEncoder`` - and ``sklearn.preprocessing.OrdinalEncoder``. + explicitly tag such columns as ordered / unordered categorical, or preprocess manually using `sklearn.preprocessing.OneHotEncoder` + and `sklearn.preprocessing.OrdinalEncoder`. Parameters ---------- - covariates : :obj:`np.array` or :obj:`pd.DataFrame` + covariates : np.array or pd.DataFrame Covariates to be preprocessed. - - Returns - ------- - self : CovariateTransformer - Fitted CovariateTransformer. """ self._fit(covariates) return self def transform(self, covariates: Union[pd.DataFrame, np.array]) -> np.array: + r"""Run a fitted a `CovariateTransformer` on a new covariate set, + returning a numpy array of covariates preprocessed into a format needed + to sample or predict from a `stochtree` ensemble. + + Parameters + ---------- + covariates : np.array or pd.DataFrame + Covariates to be preprocessed. + + Returns + ------- + np.array + Numpy array of preprocessed covariates, with as many rows as in `covariates` + and as many columns as were created during pre-processing (including one-hot encoding + categorical features). + """ return self._transform(covariates) def fit_transform(self, covariates: Union[pd.DataFrame, np.array]) -> np.array: + r"""Runs the `fit()` and `transform()` methods in sequence. + + Parameters + ---------- + covariates : np.array or pd.DataFrame + Covariates to be preprocessed. + + Returns + ------- + np.array + Numpy array of preprocessed covariates, with as many rows as in `covariates` + and as many columns as were created during pre-processing (including one-hot encoding + categorical features). + """ self._fit(covariates) return self._transform(covariates) def fetch_original_feature_indices(self) -> list: + r"""Map features in a preprocessed covariate set back to the + original set of features provided to a `CovariateTransformer`. + + Returns + ------- + list + List with as many entries as features in the preprocessed results + returned by a fitted `CovariateTransformer`. Each element is a feature + index indicating the feature from which a given preprocessed feature was generated. + If a single categorical feature were one-hot encoded into 5 binary features, + this method would return a list `[0,0,0,0,0]`. If the transformer merely passes + through `k` numeric features, this method would return a list `[0,...,k-1]`. + """ return self._original_feature_indices diff --git a/stochtree/sampler.py b/stochtree/sampler.py index b384f258..a5876099 100644 --- a/stochtree/sampler.py +++ b/stochtree/sampler.py @@ -8,30 +8,63 @@ from typing import Union class RNG: - def __init__(self, random_seed: int) -> None: - # Initialize a ForestDatasetCpp object + """ + Wrapper around the C++ standard library random number generator. + Accepts an optional random seed at initialization for replicability. + + Parameters + ---------- + random_seed : int, optional + Random seed for replicability. If not specified, the default value of `-1` + triggers an initialization of the RNG based on + [std::random_device](https://en.cppreference.com/w/cpp/numeric/random/random_device). + """ + def __init__(self, random_seed: int = -1) -> None: self.rng_cpp = RngCpp(random_seed) class ForestSampler: + """ + Wrapper around many of the core C++ sampling data structures and algorithms. + + Parameters + ---------- + dataset : Dataset + `stochtree` dataset object storing covariates / bases / weights + feature_types : np.array + Array of integer-coded values indicating the column type of each feature in `dataset`. + Integer codes map `0` to "numeric" (continuous), `1` to "ordered categorical, and `2` to + "unordered categorical". + num_trees : int + Number of trees in the forest model that this sampler class will fit. + num_obs : int + Number of observations / "rows" in `dataset`. + alpha : float + Prior probability of splitting for a tree of depth 0 in a forest model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. + beta : float + Exponent that decreases split probabilities for nodes of depth > 0 in a forest model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. + min_samples_leaf : int + Minimum allowable size of a leaf, in terms of training samples, in a forest model. + max_depth : int, optional + Maximum depth of any tree in the ensemble in a forest model. + """ def __init__(self, dataset: Dataset, feature_types: np.array, num_trees: int, num_obs: int, alpha: float, beta: float, min_samples_leaf: int, max_depth: int = -1) -> None: - # Initialize a ForestDatasetCpp object self.forest_sampler_cpp = ForestSamplerCpp(dataset.dataset_cpp, feature_types, num_trees, num_obs, alpha, beta, min_samples_leaf, max_depth) def reconstitute_from_forest(self, forest: Forest, dataset: Dataset, residual: Residual, is_mean_model: bool) -> None: """ - Re-initialize a forest sampler tracking data structures from a specific forest in a ``ForestContainer`` + Re-initialize a forest sampler tracking data structures from a specific forest in a `ForestContainer` Parameters ---------- - dataset : :obj:`Dataset` - Stochtree dataset object storing covariates / bases / weights - residual : :obj:`Residual` - Stochtree object storing continuously updated partial / full residual - forest : :obj:`Forest` - Stochtree object storing tree ensemble - is_mean_model : :obj:`bool` - Indicator of whether the model being updated a conditional mean model (``True``) or a conditional variance model (``False``) + dataset : Dataset + `stochtree` dataset object storing covariates / bases / weights + residual : Residual + `stochtree` object storing continuously updated partial / full residual + forest : Forest + `stochtree` object storing tree ensemble + is_mean_model : bool + Indicator of whether the model being updated a conditional mean model (`True`) or a conditional variance model (`False`) """ self.forest_sampler_cpp.ReconstituteTrackerFromForest(forest.forest_cpp, dataset.dataset_cpp, residual.residual_cpp, is_mean_model) @@ -44,37 +77,37 @@ def sample_one_iteration(self, forest_container: ForestContainer, forest: Forest Parameters ---------- - forest_container : :obj:`ForestContainer` - Stochtree object storing tree ensembles - forest : :obj:`Forest` - Stochtree object storing the "active" forest being sampled - dataset : :obj:`Dataset` - Stochtree dataset object storing covariates / bases / weights - residual : :obj:`Residual` - Stochtree object storing continuously updated partial / full residual - rng : :obj:`RNG` - Stochtree object storing C++ random number generator to be used sampling algorithm - feature_types : :obj:`np.array` + forest_container : ForestContainer + `stochtree` object storing tree ensembles + forest : Forest + `stochtree` object storing the "active" forest being sampled + dataset : Dataset + `stochtree` dataset object storing covariates / bases / weights + residual : Residual + `stochtree` object storing continuously updated partial / full residual + rng : RNG + `stochtree` object storing C++ random number generator to be used sampling algorithm + feature_types : np.array Array of integer-coded feature types (0 = numeric, 1 = ordered categorical, 2 = unordered categorical) - cutpoint_grid_size : :obj:`int` + cutpoint_grid_size : int Maximum size of a grid of available cutpoints (which thins the number of possible splits, particularly useful in the grow-from-root algorithm) - leaf_model_scale_input : :obj:`np.array` + leaf_model_scale_input : np.array Numpy array containing leaf model scale parameter (if the leaf model is univariate, this is essentially a scalar which is used as such in the C++ source, but stored as a numpy array) - variable_weights : :obj:`np.array` + variable_weights : np.array Numpy array containing sampling probabilities for each feature - a_forest : :obj:`float` - Scale parameter for the inverse gamma outcome model for heteroskedasticity forest - b_forest : :obj:`float` - Scale parameter for the inverse gamma outcome model for heteroskedasticity forest - global_variance : :obj:`float` + a_forest : float + Shape parameter for the inverse gamma outcome model for a heteroskedasticity forest + b_forest : float + Scale parameter for the inverse gamma outcome model for a heteroskedasticity forest + global_variance : float Current value of the global error variance parameter - leaf_model_int : :obj:`int` + leaf_model_int : int Integer encoding the leaf model type (0 = constant Gaussian leaf mean model, 1 = univariate Gaussian leaf regression mean model, 2 = multivariate Gaussian leaf regression mean model, 3 = univariate Inverse Gamma constant leaf variance model) - keep_forest : :obj:`bool` - Whether or not the resulting forest should be retained in ``forest_container`` or discarded (due to burnin or thinning for example) - gfr : :obj:`bool` - Whether or not the "grow-from-root" (GFR) sampler is run (if this is ``True`` and ``leaf_model_int=0`` this is equivalent to XBART, if this is ``FALSE`` and ``leaf_model_int=0`` this is equivalent to the original BART) - pre_initialized : :obj:`bool` + keep_forest : bool + Whether or not the resulting forest should be retained in `forest_container` or discarded (due to burnin or thinning for example) + gfr : bool + Whether or not the "grow-from-root" (GFR) sampler is run (if this is `True` and `leaf_model_int=0` this is equivalent to XBART, if this is `FALSE` and `leaf_model_int=0` this is equivalent to the original BART) + pre_initialized : bool Whether or not the forest being sampled has already been initialized """ self.forest_sampler_cpp.SampleOneIteration(forest_container.forest_container_cpp, forest.forest_cpp, dataset.dataset_cpp, residual.residual_cpp, rng.rng_cpp, @@ -87,15 +120,15 @@ def prepare_for_sampler(self, dataset: Dataset, residual: Residual, forest: Fore Parameters ---------- - dataset : :obj:`Dataset` - Stochtree dataset object storing covariates / bases / weights - residual : :obj:`Residual` - Stochtree object storing continuously updated partial / full residual - forest : :obj:`Forest` - Stochtree object storing the "active" forest being sampled - leaf_model : :obj:`int` + dataset : Dataset + `stochtree` dataset object storing covariates / bases / weights + residual : Residual + `stochtree` object storing continuously updated partial / full residual + forest : Forest + `stochtree` object storing the "active" forest being sampled + leaf_model : int Integer encoding the leaf model type - initial_values : :obj:`np.array` + initial_values : np.array Constant root node value(s) at which to initialize forest prediction (internally, it is divided by the number of trees and typically it is 0 for mean models and 1 for variance models). """ self.forest_sampler_cpp.InitializeForestModel(dataset.dataset_cpp, residual.residual_cpp, forest.forest_cpp, leaf_model, initial_values) @@ -109,16 +142,16 @@ def adjust_residual(self, dataset: Dataset, residual: Residual, forest: Forest, Parameters ---------- - dataset : :obj:`Dataset` - Stochtree dataset object storing covariates / bases / weights - residual : :obj:`Residual` - Stochtree object storing continuously updated partial / full residual - forest : :obj:`Forest` - Stochtree object storing the "active" forest being sampled - requires_basis : :obj:`bool` + dataset : Dataset + `stochtree` dataset object storing covariates / bases / weights + residual : Residual + `stochtree` object storing continuously updated partial / full residual + forest : Forest + `stochtree` object storing the "active" forest being sampled + requires_basis : bool Whether or not the forest requires a basis dot product when predicting - add : :obj:`bool` - Whether the predictions of each tree are added (if ``add=True``) or subtracted (``add=False``) from the outcome to form the new residual + add : bool + Whether the predictions of each tree are added (if `add=True`) or subtracted (`add=False`) from the outcome to form the new residual """ forest.forest_cpp.AdjustResidual(dataset.dataset_cpp, residual.residual_cpp, self.forest_sampler_cpp, requires_basis, add) @@ -133,38 +166,121 @@ def propagate_basis_update(self, dataset: Dataset, residual: Residual, forest: F Parameters ---------- - dataset : :obj:`Dataset` + dataset : Dataset Stochtree dataset object storing covariates / bases / weights - residual : :obj:`Residual` + residual : Residual Stochtree object storing continuously updated partial / full residual - forest : :obj:`Forest` + forest : Forest Stochtree object storing the "active" forest being sampled """ self.forest_sampler_cpp.PropagateBasisUpdate(dataset.dataset_cpp, residual.residual_cpp, forest.forest_cpp) - def propagate_residual_update(self, residual: Residual) -> None: - self.forest_sampler_cpp.PropagateResidualUpdate(residual.residual_cpp) + def update_alpha(self, alpha: float) -> None: + """ + Update `alpha` in the tree prior + + Parameters + ---------- + alpha : float + New value of `alpha` to be used + """ + self.forest_sampler_cpp.UpdateAlpha(alpha) + + def update_beta(self, beta: float) -> None: + """ + Update `beta` in the tree prior + + Parameters + ---------- + beta : float + New value of `beta` to be used + """ + self.forest_sampler_cpp.UpdateBeta(beta) + + def update_min_samples_leaf(self, min_samples_leaf: int) -> None: + """ + Update `min_samples_leaf` in the tree prior + + Parameters + ---------- + min_samples_leaf : int + New value of `min_samples_leaf` to be used + """ + self.forest_sampler_cpp.UpdateMinSamplesLeaf(min_samples_leaf) + + def update_max_depth(self, max_depth: int) -> None: + """ + Update `max_depth` in the tree prior + + Parameters + ---------- + max_depth : int + New value of `max_depth` to be used + """ + self.forest_sampler_cpp.UpdateMaxDepth(max_depth) class GlobalVarianceModel: + """ + Wrapper around methods / functions for sampling a "global" error variance model + with [inverse gamma](https://en.wikipedia.org/wiki/Inverse-gamma_distribution) prior. + """ def __init__(self) -> None: - # Initialize a GlobalVarianceModelCpp object self.variance_model_cpp = GlobalVarianceModelCpp() def sample_one_iteration(self, residual: Residual, rng: RNG, a: float, b: float) -> float: """ Sample one iteration of a global error variance parameter + + Parameters + ---------- + residual : Residual + `stochtree` object storing continuously updated partial / full residual + rng : RNG + `stochtree` object storing C++ random number generator to be used sampling algorithm + a : float + Shape parameter for the inverse gamma error variance model + b : float + Scale parameter for the inverse gamma error variance model + + Returns + ------- + float + One draw from a Gibbs sampler for the error variance model, which depends + on the rest of the model only through the "full" residual stored in + a `Residual` object (net of predictions of any mean term such as a forest or + an additive parametric fixed / random effect term). """ return self.variance_model_cpp.SampleOneIteration(residual.residual_cpp, rng.rng_cpp, a, b) class LeafVarianceModel: + """ + Wrapper around methods / functions for sampling a "leaf scale" model for the variance term of a Gaussian + leaf model with [inverse gamma](https://en.wikipedia.org/wiki/Inverse-gamma_distribution) prior. + """ def __init__(self) -> None: - # Initialize a LeafVarianceModelCpp object self.variance_model_cpp = LeafVarianceModelCpp() def sample_one_iteration(self, forest: Forest, rng: RNG, a: float, b: float) -> float: """ - Sample one iteration of a forest leaf model's variance parameter (assuming a location-scale leaf model, most commonly ``N(0, tau)``) + Sample one iteration of a forest leaf model's variance parameter (assuming a location-scale leaf model, most commonly `N(0, tau)`) + + Parameters + ---------- + forest : Forest + `stochtree` object storing the "active" forest being sampled + rng : RNG + `stochtree` object storing C++ random number generator to be used sampling algorithm + a : float + Shape parameter for the inverse gamma leaf scale model + b : float + Scale parameter for the inverse gamma leaf scale model + + Returns + ------- + float + One draw from a Gibbs sampler for the leaf scale model, which depends + on the rest of the model only through its respective forest. """ return self.variance_model_cpp.SampleOneIteration(forest.forest_cpp, rng.rng_cpp, a, b) diff --git a/stochtree/serialization.py b/stochtree/serialization.py index 22ee1176..acbb9e85 100644 --- a/stochtree/serialization.py +++ b/stochtree/serialization.py @@ -7,9 +7,9 @@ from stochtree_cpp import JsonCpp class JSONSerializer: - """Class that handles serialization and deserialization of stochastic forest models """ - + Class that handles serialization and deserialization of stochastic forest models + """ def __init__(self) -> None: self.json_cpp = JsonCpp() self.num_forests = 0 @@ -21,18 +21,18 @@ def return_json_string(self) -> str: Returns ------- - :obj:`str` + str JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests """ return self.json_cpp.DumpJson() def load_from_json_string(self, json_string: str) -> None: """ - Parse in-memory JSON string to ``JsonCpp`` object + Parse in-memory JSON string to `JsonCpp` object Parameters ------- - json_string : :obj:`str` + json_string : str JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests """ self.json_cpp.LoadFromString(json_string) @@ -40,8 +40,10 @@ def load_from_json_string(self, json_string: str) -> None: def add_forest(self, forest_samples: ForestContainer) -> None: """Adds a container of forest samples to a json object - :param forest_samples: Samples of a tree ensemble - :type forest_samples: ForestContainer + Parameters + ---------- + forest_samples : ForestContainer + Samples of a tree ensemble """ forest_label = self.json_cpp.AddForest(forest_samples.forest_container_cpp) self.num_forests += 1 @@ -50,12 +52,14 @@ def add_forest(self, forest_samples: ForestContainer) -> None: def add_scalar(self, field_name: str, field_value: float, subfolder_name: str = None) -> None: """Adds a scalar (numeric) value to a json object - :param field_name: Name of the json field / label under which the numeric value will be stored - :type field_name: str - :param field_value: Numeric value to be stored - :type field_value: float - :param subfolder_name: Name of "subfolder" under which ``field_name`` to be stored in the json hierarchy - :type subfolder_name: str, optional + Parameters + ---------- + field_name : str + Name of the json field / label under which the numeric value will be stored + field_value : float + Numeric value to be stored + subfolder_name : str, optional + Name of "subfolder" under which `field_name` to be stored in the json hierarchy """ if subfolder_name is None: self.json_cpp.AddDouble(field_name, field_value) @@ -65,12 +69,14 @@ def add_scalar(self, field_name: str, field_value: float, subfolder_name: str = def add_boolean(self, field_name: str, field_value: bool, subfolder_name: str = None) -> None: """Adds a scalar (boolean) value to a json object - :param field_name: Name of the json field / label under which the boolean value will be stored - :type field_name: str - :param field_value: Boolean value to be stored - :type field_value: bool - :param subfolder_name: Name of "subfolder" under which ``field_name`` to be stored in the json hierarchy - :type subfolder_name: str, optional + Parameters + ---------- + field_name : str + Name of the json field / label under which the boolean value will be stored + field_value : bool + Boolean value to be stored + subfolder_name : str, optional + Name of "subfolder" under which `field_name` to be stored in the json hierarchy """ if subfolder_name is None: self.json_cpp.AddBool(field_name, field_value) @@ -80,12 +86,14 @@ def add_boolean(self, field_name: str, field_value: bool, subfolder_name: str = def add_string(self, field_name: str, field_value: str, subfolder_name: str = None) -> None: """Adds a string to a json object - :param field_name: Name of the json field / label under which the numeric value will be stored - :type field_name: str - :param field_value: String field to be stored - :type field_value: str - :param subfolder_name: Name of "subfolder" under which ``field_name`` to be stored in the json hierarchy - :type subfolder_name: str, optional + Parameters + ---------- + field_name : str + Name of the json field / label under which the numeric value will be stored + field_value : str + String field to be stored + subfolder_name : str, optional + Name of "subfolder" under which `field_name` to be stored in the json hierarchy """ if subfolder_name is None: self.json_cpp.AddString(field_name, field_value) @@ -95,12 +103,14 @@ def add_string(self, field_name: str, field_value: str, subfolder_name: str = No def add_numeric_vector(self, field_name: str, field_vector: np.array, subfolder_name: str = None) -> None: """Adds a numeric vector (stored as a numpy array) to a json object - :param field_name: Name of the json field / label under which the numeric vector will be stored - :type field_name: str - :param field_vector: Numpy array containing the vector to be stored in json. Should be one-dimensional. - :type field_vector: np.array - :param subfolder_name: Name of "subfolder" under which ``field_name`` to be stored in the json hierarchy - :type subfolder_name: str, optional + Parameters + ---------- + field_name : str + Name of the json field / label under which the numeric vector will be stored + field_vector : np.array + Numpy array containing the vector to be stored in json. Should be one-dimensional. + subfolder_name : str, optional + Name of "subfolder" under which `field_name` to be stored in the json hierarchy """ # Runtime checks if not isinstance(field_vector, np.ndarray): @@ -118,12 +128,14 @@ def add_numeric_vector(self, field_name: str, field_vector: np.array, subfolder_ def add_string_vector(self, field_name: str, field_vector: list, subfolder_name: str = None) -> None: """Adds a list of strings to a json object as an array - :param field_name: Name of the json field / label under which the string list will be stored - :type field_name: str - :param field_vector: Python list of strings containing the array to be stored in json - :type field_vector: list - :param subfolder_name: Name of "subfolder" under which ``field_name`` to be stored in the json hierarchy - :type subfolder_name: str, optional + Parameters + ---------- + field_name : str + Name of the json field / label under which the string list will be stored + field_vector : list + Python list of strings containing the array to be stored in json + subfolder_name : str, optional + Name of "subfolder" under which `field_name` to be stored in the json hierarchy """ # Runtime checks if not isinstance(field_vector, list): @@ -137,10 +149,12 @@ def add_string_vector(self, field_name: str, field_vector: list, subfolder_name: def get_scalar(self, field_name: str, subfolder_name: str = None) -> float: """Retrieves a scalar (numeric) value from a json object - :param field_name: Name of the json field / label under which the numeric value is stored - :type field_name: str - :param subfolder_name: Name of "subfolder" under which ``field_name`` is stored in the json hierarchy - :type subfolder_name: str, optional + Parameters + ---------- + field_name : str + Name of the json field / label under which the numeric value is stored + subfolder_name : str, optional + Name of "subfolder" under which `field_name` is stored in the json hierarchy """ if subfolder_name is None: return self.json_cpp.ExtractDouble(field_name) @@ -150,10 +164,12 @@ def get_scalar(self, field_name: str, subfolder_name: str = None) -> float: def get_boolean(self, field_name: str, subfolder_name: str = None) -> bool: """Retrieves a scalar (boolean) value from a json object - :param field_name: Name of the json field / label under which the boolean value is stored - :type field_name: str - :param subfolder_name: Name of "subfolder" under which ``field_name`` is stored in the json hierarchy - :type subfolder_name: str, optional + Parameters + ---------- + field_name : str + Name of the json field / label under which the boolean value is stored + subfolder_name : str, optional + Name of "subfolder" under which `field_name` is stored in the json hierarchy """ if subfolder_name is None: return self.json_cpp.ExtractBool(field_name) @@ -163,10 +179,12 @@ def get_boolean(self, field_name: str, subfolder_name: str = None) -> bool: def get_string(self, field_name: str, subfolder_name: str = None) -> str: """Retrieve a string to a json object - :param field_name: Name of the json field / label under which the numeric value is stored - :type field_name: str - :param subfolder_name: Name of "subfolder" under which ``field_name`` is stored in the json hierarchy - :type subfolder_name: str, optional + Parameters + ---------- + field_name : str + Name of the json field / label under which the numeric value is stored + subfolder_name : str, optional + Name of "subfolder" under which `field_name` is stored in the json hierarchy """ if subfolder_name is None: return self.json_cpp.ExtractString(field_name) @@ -176,10 +194,12 @@ def get_string(self, field_name: str, subfolder_name: str = None) -> str: def get_numeric_vector(self, field_name: str, subfolder_name: str = None) -> np.array: """Adds a string to a json object - :param field_name: Name of the json field / label under which the numeric vector is stored - :type field_name: str - :param subfolder_name: Name of "subfolder" under which ``field_name`` to be stored in the json hierarchy - :type subfolder_name: str, optional + Parameters + ---------- + field_name : str + Name of the json field / label under which the numeric vector is stored + subfolder_name : str, optional + Name of "subfolder" under which `field_name` to be stored in the json hierarchy """ if subfolder_name is None: return self.json_cpp.ExtractDoubleVector(field_name) @@ -189,18 +209,32 @@ def get_numeric_vector(self, field_name: str, subfolder_name: str = None) -> np. def get_string_vector(self, field_name: str, subfolder_name: str = None) -> list: """Adds a string to a json object - :param field_name: Name of the json field / label under which the string list is stored - :type field_name: str - :param subfolder_name: Name of "subfolder" under which ``field_name`` to be stored in the json hierarchy - :type subfolder_name: str, optional + Parameters + ---------- + field_name : str + Name of the json field / label under which the string list is stored + subfolder_name : str, optional + Name of "subfolder" under which `field_name` to be stored in the json hierarchy """ if subfolder_name is None: return self.json_cpp.ExtractStringVector(field_name) else: return self.json_cpp.ExtractStringVectorSubfolder(subfolder_name, field_name) - def get_forest_container(self, forest_label: str) -> ForestContainer: + def get_forest_container(self, forest_str: str) -> ForestContainer: + """Converts a JSON string for a container of forests to a `ForestContainer` object. + + Parameters + ---------- + forest_str : str + String containing the JSON representation of a `ForestContainer` + + Returns + ------- + ForestContainer + In-memory `ForestContainer` python object, created from JSON string + """ # TODO: read this from JSON result = ForestContainer(0, 1, True, False) - result.forest_container_cpp.LoadFromJson(self.json_cpp, forest_label) + result.forest_container_cpp.LoadFromJson(self.json_cpp, forest_str) return result