diff --git a/R/config.R b/R/config.R index dc156f5e..a5982cac 100644 --- a/R/config.R +++ b/R/config.R @@ -17,6 +17,9 @@ ForestModelConfig <- R6::R6Class( #' @field feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) feature_types = NULL, + #' @field sweep_update_indices Vector of trees to update in a sweep + sweep_update_indices = NULL, + #' @field num_trees Number of trees in the forest being sampled num_trees = NULL, @@ -62,6 +65,7 @@ ForestModelConfig <- R6::R6Class( #' Create a new ForestModelConfig object. #' #' @param feature_types Vector of integer-coded feature types (where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + #' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep #' @param num_trees Number of trees in the forest being sampled #' @param num_features Number of features in training dataset #' @param num_observations Number of observations in training dataset @@ -78,7 +82,7 @@ ForestModelConfig <- R6::R6Class( #' @param cutpoint_grid_size Number of unique cutpoints to consider (default: `100`) #' #' @return A new ForestModelConfig object. - initialize = function(feature_types = NULL, num_trees = NULL, num_features = NULL, + initialize = function(feature_types = NULL, sweep_update_indices = NULL, num_trees = NULL, num_features = NULL, num_observations = NULL, variable_weights = NULL, leaf_dimension = 1, alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = -1, leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1.0, @@ -101,6 +105,10 @@ ForestModelConfig <- R6::R6Class( if (is.null(num_trees)) { stop("num_trees must be provided") } + if (!is.null(sweep_update_indices)) { + stopifnot(min(sweep_update_indices) >= 0) + stopifnot(max(sweep_update_indices) < num_trees) + } if (is.null(num_observations)) { stop("num_observations must be provided") } @@ -111,6 +119,7 @@ ForestModelConfig <- R6::R6Class( stop("`variable_weights` must have `num_features` total elements") } self$feature_types <- feature_types + self$sweep_update_indices <- sweep_update_indices self$variable_weights <- variable_weights self$num_trees <- num_trees self$num_features <- num_features @@ -158,6 +167,17 @@ ForestModelConfig <- R6::R6Class( self$feature_types <- feature_types }, + #' @description + #' Update sweep update indices + #' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep + update_sweep_indices = function(sweep_update_indices) { + if (!is.null(sweep_update_indices)) { + stopifnot(min(sweep_update_indices) >= 0) + stopifnot(max(sweep_update_indices) < self$num_trees) + } + self$sweep_update_indices <- sweep_update_indices + }, + #' @description #' Update variable weights #' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset @@ -242,6 +262,13 @@ ForestModelConfig <- R6::R6Class( return(self$feature_types) }, + #' @description + #' Query sweep update indices for this ForestModelConfig object + #' @returns Vector of (0-indexed) indices of trees to update in a sweep + get_sweep_indices = function() { + return(self$sweep_update_indices) + }, + #' @description #' Query variable weights for this ForestModelConfig object #' @returns Vector specifying sampling probability for all p covariates in ForestDataset @@ -382,6 +409,7 @@ GlobalModelConfig <- R6::R6Class( #' Create a forest model config object #' #' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) +#' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep #' @param num_trees Number of trees in the forest being sampled #' @param num_features Number of features in training dataset #' @param num_observations Number of observations in training dataset @@ -401,13 +429,13 @@ GlobalModelConfig <- R6::R6Class( #' #' @examples #' config <- createForestModelConfig(num_trees = 10, num_features = 5, num_observations = 100) -createForestModelConfig <- function(feature_types = NULL, num_trees = NULL, num_features = NULL, +createForestModelConfig <- function(feature_types = NULL, sweep_update_indices = NULL, num_trees = NULL, num_features = NULL, num_observations = NULL, variable_weights = NULL, leaf_dimension = 1, alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = -1, leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1.0, variance_forest_scale = 1.0, cutpoint_grid_size = 100){ return(invisible(( - ForestModelConfig$new(feature_types, num_trees, num_features, num_observations, + ForestModelConfig$new(feature_types, sweep_update_indices, num_trees, num_features, num_observations, variable_weights, leaf_dimension, alpha, beta, min_samples_leaf, max_depth, leaf_model_type, leaf_model_scale, variance_forest_shape, variance_forest_scale, cutpoint_grid_size) diff --git a/R/cpp11.R b/R/cpp11.R index c2af5982..2260cad6 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -580,12 +580,12 @@ compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums) .Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums) } -sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) { - invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest)) +sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) { + invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest)) } -sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) { - invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest)) +sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) { + invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest)) } sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) { diff --git a/R/model.R b/R/model.R index aa33cd7d..dfa88fc9 100644 --- a/R/model.R +++ b/R/model.R @@ -70,13 +70,15 @@ ForestModel <- R6::R6Class( #' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `TRUE`. #' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `TRUE`. sample_one_iteration = function(forest_dataset, residual, forest_samples, active_forest, - rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = TRUE) { + rng, forest_model_config, global_model_config, + keep_forest = TRUE, gfr = TRUE) { if (active_forest$is_empty()) { stop("`active_forest` has not yet been initialized, which is necessary to run the sampler. Please set constant values for `active_forest`'s leaves using either the `set_root_leaves` or `prepare_for_sampler` methods.") } # Unpack parameters from model config object feature_types <- forest_model_config$feature_types + sweep_update_indices <- forest_model_config$sweep_update_indices leaf_model_int <- forest_model_config$leaf_model_type leaf_model_scale <- forest_model_config$leaf_model_scale variable_weights <- forest_model_config$variable_weights @@ -85,6 +87,12 @@ ForestModel <- R6::R6Class( global_scale <- global_model_config$global_error_variance cutpoint_grid_size <- forest_model_config$cutpoint_grid_size + # Default to empty integer vector if sweep_update_indices is NULL + if (is.null(sweep_update_indices)) { + # sweep_update_indices <- integer(0) + sweep_update_indices <- 0:(forest_model_config$num_trees - 1) + } + # Detect changes to tree prior if (forest_model_config$alpha != get_alpha_tree_prior_cpp(self$tree_prior_ptr)) { update_alpha_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$alpha) @@ -104,14 +112,14 @@ ForestModel <- R6::R6Class( sample_gfr_one_iteration_cpp( forest_dataset$data_ptr, residual$data_ptr, forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr, - self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale, + self$tree_prior_ptr, rng$rng_ptr, sweep_update_indices, feature_types, cutpoint_grid_size, leaf_model_scale, variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest ) } else { sample_mcmc_one_iteration_cpp( forest_dataset$data_ptr, residual$data_ptr, forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr, - self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale, + self$tree_prior_ptr, rng$rng_ptr, sweep_update_indices, feature_types, cutpoint_grid_size, leaf_model_scale, variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest ) } diff --git a/debug/api_debug.cpp b/debug/api_debug.cpp index 39a06ad0..f31e8739 100644 --- a/debug/api_debug.cpp +++ b/debug/api_debug.cpp @@ -669,6 +669,10 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia // Prepare the samplers LeafModelVariant leaf_model = leafModelFactory(model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest); + // Initialize vector of sweep update indices + std::vector sweep_indices(num_trees); + std::iota(sweep_indices.begin(), sweep_indices.end(), 0); + // Run the GFR sampler if (num_gfr > 0) { for (int i = 0; i < num_gfr; i++) { @@ -683,13 +687,13 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia // Sample tree ensemble if (model_type == ModelType::kConstantLeafGaussian) { - GFRSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, true); + GFRSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true); } else if (model_type == ModelType::kUnivariateRegressionLeafGaussian) { - GFRSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, true); + GFRSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true); } else if (model_type == ModelType::kMultivariateRegressionLeafGaussian) { - GFRSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, true, omega_cols); + GFRSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true, omega_cols); } else if (model_type == ModelType::kLogLinearVariance) { - GFRSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, false); + GFRSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, false); } if (rfx_included) { @@ -720,13 +724,13 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia // Sample tree ensemble if (model_type == ModelType::kConstantLeafGaussian) { - MCMCSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, true); + MCMCSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true, true, true); } else if (model_type == ModelType::kUnivariateRegressionLeafGaussian) { - MCMCSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, true); + MCMCSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true, true, true); } else if (model_type == ModelType::kMultivariateRegressionLeafGaussian) { - MCMCSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, true, omega_cols); + MCMCSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true, true, true, omega_cols); } else if (model_type == ModelType::kLogLinearVariance) { - MCMCSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, false); + MCMCSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true, true, false); } if (rfx_included) { diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index a47660ea..5c234d23 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -751,6 +751,7 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore * \param tree_prior Configuration for tree prior (i.e. max depth, min samples in a leaf, depth-defined split probability). * \param gen Random number generator for sampler. * \param variable_weights Vector of selection weights for each variable in `dataset`. + * \param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep. * \param global_variance Current value of (possibly stochastic) global error variance parameter. * \param feature_types Enum-coded vector of feature types (see \ref FeatureType) for each feature in `dataset`. * \param cutpoint_grid_size Maximum size of a grid of potential cutpoints (the grow-from-root algorithm evaluates a series of potential cutpoints for each feature and this parameter "thins" the cutpoint candidates for numeric variables). @@ -763,12 +764,11 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore template static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - double global_variance, std::vector& feature_types, int cutpoint_grid_size, + std::vector& sweep_update_indices, double global_variance, std::vector& feature_types, int cutpoint_grid_size, bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { - // Run the GFR algorithm for each tree int num_trees = forests.NumTrees(); - for (int i = 0; i < num_trees; i++) { + for (const int& i : sweep_update_indices) { // Adjust any model state needed to run a tree sampler // For models that involve Bayesian backfitting, this amounts to adding tree i's // predictions back to the residual (thus, training a model on the "partial residual") @@ -1062,6 +1062,7 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For * \param tree_prior Configuration for tree prior (i.e. max depth, min samples in a leaf, depth-defined split probability). * \param gen Random number generator for sampler. * \param variable_weights Vector of selection weights for each variable in `dataset`. + * \param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep. * \param global_variance Current value of (possibly stochastic) global error variance parameter. * \param keep_forest Whether or not `active_forest` should be retained in `forests`. * \param pre_initialized Whether or not `active_forest` has already been initialized (note: this parameter will be refactored out soon). @@ -1072,10 +1073,10 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For template static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + std::vector& sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Run the MCMC algorithm for each tree int num_trees = forests.NumTrees(); - for (int i = 0; i < num_trees; i++) { + for (const int& i : sweep_update_indices) { // Adjust any model state needed to run a tree sampler // For models that involve Bayesian backfitting, this amounts to adding tree i's // predictions back to the residual (thus, training a model on the "partial residual") diff --git a/man/ForestModelConfig.Rd b/man/ForestModelConfig.Rd index e75a6cd8..5fe5c755 100644 --- a/man/ForestModelConfig.Rd +++ b/man/ForestModelConfig.Rd @@ -7,6 +7,8 @@ for a forest model in the "low-level" stochtree interface} \value{ Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) +Vector of (0-indexed) indices of trees to update in a sweep + Vector specifying sampling probability for all p covariates in ForestDataset Number of trees in a forest @@ -46,6 +48,8 @@ forest model they wish to run. \describe{ \item{\code{feature_types}}{Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)} +\item{\code{sweep_update_indices}}{Vector of trees to update in a sweep} + \item{\code{num_trees}}{Number of trees in the forest being sampled} \item{\code{num_features}}{Number of features in training dataset} @@ -82,6 +86,7 @@ Create a new ForestModelConfig object.} \itemize{ \item \href{#method-ForestModelConfig-new}{\code{ForestModelConfig$new()}} \item \href{#method-ForestModelConfig-update_feature_types}{\code{ForestModelConfig$update_feature_types()}} +\item \href{#method-ForestModelConfig-update_sweep_indices}{\code{ForestModelConfig$update_sweep_indices()}} \item \href{#method-ForestModelConfig-update_variable_weights}{\code{ForestModelConfig$update_variable_weights()}} \item \href{#method-ForestModelConfig-update_alpha}{\code{ForestModelConfig$update_alpha()}} \item \href{#method-ForestModelConfig-update_beta}{\code{ForestModelConfig$update_beta()}} @@ -92,6 +97,7 @@ Create a new ForestModelConfig object.} \item \href{#method-ForestModelConfig-update_variance_forest_scale}{\code{ForestModelConfig$update_variance_forest_scale()}} \item \href{#method-ForestModelConfig-update_cutpoint_grid_size}{\code{ForestModelConfig$update_cutpoint_grid_size()}} \item \href{#method-ForestModelConfig-get_feature_types}{\code{ForestModelConfig$get_feature_types()}} +\item \href{#method-ForestModelConfig-get_sweep_indices}{\code{ForestModelConfig$get_sweep_indices()}} \item \href{#method-ForestModelConfig-get_variable_weights}{\code{ForestModelConfig$get_variable_weights()}} \item \href{#method-ForestModelConfig-get_num_trees}{\code{ForestModelConfig$get_num_trees()}} \item \href{#method-ForestModelConfig-get_num_features}{\code{ForestModelConfig$get_num_features()}} @@ -114,6 +120,7 @@ Create a new ForestModelConfig object.} \subsection{Usage}{ \if{html}{\out{
}}\preformatted{ForestModelConfig$new( feature_types = NULL, + sweep_update_indices = NULL, num_trees = NULL, num_features = NULL, num_observations = NULL, @@ -136,6 +143,8 @@ Create a new ForestModelConfig object.} \describe{ \item{\code{feature_types}}{Vector of integer-coded feature types (where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)} +\item{\code{sweep_update_indices}}{Vector of (0-indexed) indices of trees to update in a sweep} + \item{\code{num_trees}}{Number of trees in the forest being sampled} \item{\code{num_features}}{Number of features in training dataset} @@ -188,6 +197,23 @@ Update feature types } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_sweep_indices}{}}} +\subsection{Method \code{update_sweep_indices()}}{ +Update sweep update indices +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_sweep_indices(sweep_update_indices)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{sweep_update_indices}}{Vector of (0-indexed) indices of trees to update in a sweep} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ForestModelConfig-update_variable_weights}{}}} \subsection{Method \code{update_variable_weights()}}{ @@ -349,6 +375,16 @@ Query feature types for this ForestModelConfig object \if{html}{\out{
}}\preformatted{ForestModelConfig$get_feature_types()}\if{html}{\out{
}} } +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_sweep_indices}{}}} +\subsection{Method \code{get_sweep_indices()}}{ +Query sweep update indices for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_sweep_indices()}\if{html}{\out{
}} +} + } \if{html}{\out{
}} \if{html}{\out{}} diff --git a/man/createForestModelConfig.Rd b/man/createForestModelConfig.Rd index 90de767c..235552aa 100644 --- a/man/createForestModelConfig.Rd +++ b/man/createForestModelConfig.Rd @@ -6,6 +6,7 @@ \usage{ createForestModelConfig( feature_types = NULL, + sweep_update_indices = NULL, num_trees = NULL, num_features = NULL, num_observations = NULL, @@ -25,6 +26,8 @@ createForestModelConfig( \arguments{ \item{feature_types}{Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)} +\item{sweep_update_indices}{Vector of (0-indexed) indices of trees to update in a sweep} + \item{num_trees}{Number of trees in the forest being sampled} \item{num_features}{Number of features in training dataset} diff --git a/src/cpp11.cpp b/src/cpp11.cpp index b047e080..d32eba1f 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -1076,18 +1076,18 @@ extern "C" SEXP _stochtree_compute_leaf_indices_cpp(SEXP forest_container, SEXP END_CPP11 } // sampler.cpp -void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest); -extern "C" SEXP _stochtree_sample_gfr_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest) { +void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers sweep_indices, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest); +extern "C" SEXP _stochtree_sample_gfr_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP sweep_indices, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest) { BEGIN_CPP11 - sample_gfr_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest)); + sample_gfr_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(sweep_indices), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest)); return R_NilValue; END_CPP11 } // sampler.cpp -void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest); -extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest) { +void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers sweep_indices, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest); +extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP sweep_indices, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest) { BEGIN_CPP11 - sample_mcmc_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest)); + sample_mcmc_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(sweep_indices), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest)); return R_NilValue; END_CPP11 } @@ -1669,8 +1669,8 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, {"_stochtree_root_reset_active_forest_cpp", (DL_FUNC) &_stochtree_root_reset_active_forest_cpp, 1}, {"_stochtree_root_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_root_reset_rfx_tracker_cpp, 4}, - {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 16}, - {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 16}, + {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 17}, + {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 17}, {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 5}, {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 4}, {"_stochtree_set_leaf_value_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_value_active_forest_cpp, 2}, diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 471fedba..916bffa9 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -1026,7 +1026,7 @@ class ForestSamplerCpp { } void SampleOneIteration(ForestContainerCpp& forest_samples, ForestCpp& forest, ForestDatasetCpp& dataset, ResidualCpp& residual, RngCpp& rng, - py::array_t feature_types, int cutpoint_grid_size, py::array_t leaf_model_scale_input, + py::array_t feature_types, py::array_t sweep_update_indices, int cutpoint_grid_size, py::array_t leaf_model_scale_input, py::array_t variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest = true, bool gfr = true) { // Refactoring completely out of the Python interface. @@ -1038,6 +1038,15 @@ class ForestSamplerCpp { for (int i = 0; i < feature_types.size(); i++) { feature_types_[i] = static_cast(feature_types.at(i)); } + + // Unpack sweep indices + std::vector sweep_update_indices_; + if (sweep_update_indices.size() > 0) { + sweep_update_indices_.resize(sweep_update_indices.size()); + for (int i = 0; i < sweep_update_indices.size(); i++) { + sweep_update_indices_[i] = sweep_update_indices.at(i); + } + } // Convert leaf model type to enum StochTree::ModelType model_type; @@ -1081,23 +1090,23 @@ class ForestSamplerCpp { std::mt19937* rng_ptr = rng.GetRng(); if (gfr) { if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_basis); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false); } } else { if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, keep_forest, pre_initialized, true); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, keep_forest, pre_initialized, true); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, keep_forest, pre_initialized, true, num_basis); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, keep_forest, pre_initialized, false); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, false); } } } diff --git a/src/sampler.cpp b/src/sampler.cpp index 8890237a..514ae006 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -19,6 +19,7 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, + cpp11::integers sweep_indices, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, @@ -36,6 +37,15 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer(feature_types[i]); } + // Unpack sweep indices + std::vector sweep_indices_(sweep_indices.size()); + // if (sweep_indices.size() > 0) { + // sweep_indices_.resize(sweep_indices.size()); + for (int i = 0; i < sweep_indices.size(); i++) { + sweep_indices_[i] = sweep_indices[i]; + } + // } + // Convert leaf model type to enum StochTree::ModelType model_type; if (leaf_model_int == 0) model_type = StochTree::ModelType::kConstantLeafGaussian; @@ -73,13 +83,13 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_basis); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false); } } @@ -91,6 +101,7 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, + cpp11::integers sweep_indices, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, @@ -108,6 +119,15 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer(feature_types[i]); } + // Unpack sweep indices + std::vector sweep_indices_; + if (sweep_indices.size() > 0) { + sweep_indices_.resize(sweep_indices.size()); + for (int i = 0; i < sweep_indices.size(); i++) { + sweep_indices_[i] = sweep_indices[i]; + } + } + // Convert leaf model type to enum StochTree::ModelType model_type; if (leaf_model_int == 0) model_type = StochTree::ModelType::kConstantLeafGaussian; @@ -145,13 +165,13 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, keep_forest, pre_initialized, true); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, keep_forest, pre_initialized, true); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, keep_forest, pre_initialized, true, num_basis); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, keep_forest, pre_initialized, false); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false); } } diff --git a/stochtree/config.py b/stochtree/config.py index bbeb639a..6cf57cc5 100644 --- a/stochtree/config.py +++ b/stochtree/config.py @@ -29,6 +29,8 @@ class ForestModelConfig: Number of observations in training dataset feature_types : np.array or list, optional Vector of integer-coded feature types (where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + sweep_update_indices : np.array or list, optional + Vector of (0-indexed) indices of trees to update in a sweep variable_weights : np.array or list, optional Vector specifying sampling probability for all p covariates in ForestDataset leaf_dimension : int, optional @@ -59,6 +61,7 @@ def __init__( num_features=None, num_observations=None, feature_types=None, + sweep_update_indices=None, variable_weights=None, leaf_dimension=1, alpha=0.95, @@ -135,6 +138,17 @@ def __init__( raise ValueError( "`leaf_model_scale` must be a scalar value or a 2d numpy array with matching dimensions" ) + if sweep_update_indices is not None: + sweep_update_indices = _standardize_array_to_np(sweep_update_indices) + if np.min(sweep_update_indices) < 0: + raise ValueError( + "sweep_update_indices must be a list / np.array of indices >= 0 and < num_trees", + ) + if np.max(sweep_update_indices) >= num_trees: + raise ValueError( + "sweep_update_indices must be a list / np.array of indices >= 0 and < num_trees", + ) + self.sweep_update_indices = sweep_update_indices # Set internal config values self.num_trees = num_trees @@ -157,7 +171,7 @@ def update_feature_types(self, feature_types) -> None: Parameters ---------- - feature_types : list of np.ndarray + feature_types : list or np.ndarray Vector of integer-coded feature types (where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) Returns @@ -169,6 +183,31 @@ def update_feature_types(self, feature_types) -> None: raise ValueError("`feature_types` must have `num_features` total elements") self.feature_types = feature_types + def update_sweep_indices(self, sweep_update_indices) -> None: + """ + Update feature types + + Parameters + ---------- + sweep_update_indices : list or np.ndarray + Vector of (0-indexed) indices of trees to update in a sweep + + Returns + ------- + self + """ + if sweep_update_indices is not None: + sweep_update_indices = _standardize_array_to_np(sweep_update_indices) + if np.min(sweep_update_indices) < 0: + raise ValueError( + "sweep_update_indices must be a list / np.array of indices >= 0 and < num_trees", + ) + if np.max(sweep_update_indices) >= self.num_trees: + raise ValueError( + "sweep_update_indices must be a list / np.array of indices >= 0 and < num_trees", + ) + self.sweep_update_indices = sweep_update_indices + def update_variable_weights( self, variable_weights: Union[list, np.ndarray] ) -> None: @@ -344,6 +383,17 @@ def get_feature_types(self) -> np.ndarray: """ return self.feature_types + def get_sweep_update_indices(self) -> Union[np.ndarray,None]: + """ + Query vector of (0-indexed) indices of trees to update in a sweep + + Returns + ------- + sweep_update_indices : np.ndarray or None + Vector of (0-indexed) indices of trees to update in a sweep, or `None` + """ + return self.feature_types + def get_variable_weights(self) -> np.ndarray: """ Query variable weights diff --git a/stochtree/sampler.py b/stochtree/sampler.py index ff6a371c..3586c24a 100644 --- a/stochtree/sampler.py +++ b/stochtree/sampler.py @@ -148,6 +148,11 @@ def sample_one_iteration( ) if self.forest_sampler_cpp.GetMaxDepth() != forest_config.get_max_depth(): self.forest_sampler_cpp.SetMaxDepth(forest_config.get_max_depth()) + + # Unpack sweep update indices (initializing empty numpy array if None) + sweep_update_indices = forest_config.get_sweep_update_indices() + if sweep_update_indices is None: + sweep_update_indices = np.arange(forest_config.get_num_trees(), dtype=int) # Run the sampler self.forest_sampler_cpp.SampleOneIteration( @@ -157,6 +162,7 @@ def sample_one_iteration( residual.residual_cpp, rng.rng_cpp, forest_config.get_feature_types(), + sweep_update_indices, forest_config.get_cutpoint_grid_size(), forest_config.get_leaf_model_scale(), forest_config.get_variable_weights(), diff --git a/tools/debug/restricted_sweep.R b/tools/debug/restricted_sweep.R new file mode 100644 index 00000000..6ec85c5b --- /dev/null +++ b/tools/debug/restricted_sweep.R @@ -0,0 +1,167 @@ +# Load library +library(stochtree) + +# Simulate a simple partitioned linear model +n <- 500 +p_X <- 10 +p_W <- 1 +X <- matrix(runif(n*p_X), ncol = p_X) +W <- matrix(runif(n*p_W), ncol = p_W) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-3*W[,1]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-1*W[,1]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (1*W[,1]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (3*W[,1]) +) +y <- f_XW + rnorm(n, 0, 1) + +# Standardize outcome +y_bar <- mean(y) +y_std <- sd(y) +resid <- (y-y_bar)/y_std + +## Sampling + +# Set some parameters that inform the forest and variance parameter samplers +alpha <- 0.9 +beta <- 1.25 +min_samples_leaf <- 1 +max_depth <- 10 +num_trees <- 100 +cutpoint_grid_size <- 100 +global_variance_init <- 1. +current_sigma2 <- global_variance_init +tau_init <- 1/num_trees +leaf_prior_scale <- as.matrix(ifelse(p_W >= 1, diag(tau_init, p_W), diag(tau_init, 1))) +nu <- 4 +lambda <- 0.5 +a_leaf <- 2. +b_leaf <- 0.5 +leaf_regression <- T +feature_types <- as.integer(rep(0, p_X)) # 0 = numeric +var_weights <- rep(1/p_X, p_X) + +# Initialize R-level access to the C++ classes needed to sample our model + +# Data +if (leaf_regression) { + forest_dataset <- createForestDataset(X, W) + outcome_model_type <- 1 + leaf_dimension <- p_W +} else { + forest_dataset <- createForestDataset(X) + outcome_model_type <- 0 + leaf_dimension <- 1 +} +outcome <- createOutcome(resid) + +# Random number generator (std::mt19937) +rng <- createCppRNG() + +# Sampling data structures +forest_model_config <- createForestModelConfig( + feature_types = feature_types, sweep_update_indices = 0:9, num_trees = num_trees, num_features = p_X, + num_observations = n, variable_weights = var_weights, leaf_dimension = leaf_dimension, + alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, max_depth = max_depth, + leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale, + cutpoint_grid_size = cutpoint_grid_size +) +global_model_config <- createGlobalModelConfig(global_error_variance = global_variance_init) +forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) + +# "Active forest" (which gets updated by the sample) and +# container of forest samples (which is written to when +# a sample is not discarded due to burn-in / thinning) +if (leaf_regression) { + forest_samples <- createForestSamples(num_trees, 1, F) + active_forest <- createForest(num_trees, 1, F) +} else { + forest_samples <- createForestSamples(num_trees, 1, T) + active_forest <- createForest(num_trees, 1, T) +} + +# Initialize the leaves of each tree in the forest +active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, outcome_model_type, mean(resid)) +active_forest$adjust_residual(forest_dataset, outcome, forest_model, ifelse(outcome_model_type==1, T, F), F) + +#Prepare to run the sampler +num_warmstart <- 10 +num_mcmc <- 100 +num_samples <- num_warmstart + num_mcmc +global_var_samples <- c(global_variance_init, rep(0, num_samples)) +leaf_scale_samples <- c(tau_init, rep(0, num_samples)) + +# Run the grow-from-root sampler to "warm-start" BART +for (i in 1:num_warmstart) { + # Sample forest + forest_model$sample_one_iteration( + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, global_model_config, keep_forest = T, gfr = T + ) + + # Sample global variance parameter + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome, forest_dataset, rng, nu, lambda + ) + global_var_samples[i+1] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) + + # Sample leaf node variance parameter and update `leaf_prior_scale` + leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( + active_forest, rng, a_leaf, b_leaf + ) + leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] + forest_model_config$update_leaf_model_scale(leaf_prior_scale) +} + +# Check the predictions of each tree +forest_samples$predict_raw_single_tree(forest_dataset, num_warmstart-1, 0) +forest_samples$predict_raw_single_tree(forest_dataset, num_warmstart-1, 1) +forest_samples$predict_raw_single_tree(forest_dataset, num_warmstart-1, 9) +forest_samples$predict_raw_single_tree(forest_dataset, num_warmstart-1, 10) +forest_samples$predict_raw_single_tree(forest_dataset, num_warmstart-1, num_trees-1) + +# Update sweep indices +# forest_model_config$update_sweep_indices(0:(num_trees-1)) +forest_model_config$update_sweep_indices(NULL) + +# Pick up from the last GFR forest (and associated global variance / leaf +# scale parameters) with an MCMC sampler +for (i in (num_warmstart+1):num_samples) { + # Sample forest + forest_model$sample_one_iteration( + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, global_model_config, keep_forest = T, gfr = F + ) + + # Sample global variance parameter + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome, forest_dataset, rng, nu, lambda + ) + global_var_samples[i+1] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) + + # Sample leaf node variance parameter and update `leaf_prior_scale` + leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( + active_forest, rng, a_leaf, b_leaf + ) + leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] + forest_model_config$update_leaf_model_scale(leaf_prior_scale) +} + +# Check the predictions of each tree +forest_samples$predict_raw_single_tree(forest_dataset, num_samples-1, 0) +forest_samples$predict_raw_single_tree(forest_dataset, num_samples-1, 1) +forest_samples$predict_raw_single_tree(forest_dataset, num_samples-1, 9) +forest_samples$predict_raw_single_tree(forest_dataset, num_samples-1, 10) +forest_samples$predict_raw_single_tree(forest_dataset, num_samples-1, num_trees-1) + +# Forest predictions +preds <- forest_samples$predict(forest_dataset)*y_std + y_bar + +# Global error variance +sigma_samples <- sqrt(global_var_samples)*y_std + +# Plot samples +# plot(rowMeans(preds), y); abline(0,1,col="red",lty=3,lwd=3) +# plot(sigma_samples); abline(h = 1,col="blue",lty=3,lwd=3) diff --git a/tools/perf/bart_microbenchmark.R b/tools/perf/bart_microbenchmark.R index 21e5e171..36b36bb9 100644 --- a/tools/perf/bart_microbenchmark.R +++ b/tools/perf/bart_microbenchmark.R @@ -2,14 +2,14 @@ library(microbenchmark) library(stochtree) # Generate data needed to train BART model -n <- 1000 -p <- 5 +n <- 10000 +p <- 20 X <- matrix(runif(n*p), ncol = p) f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 y <- f_XW + rnorm(n, 0, noise_sd) @@ -24,6 +24,7 @@ y_test <- y[test_inds] y_train <- y[train_inds] # Run microbenchmark -microbenchmark( - bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 10, num_mcmc = 1000) +bench_results <- microbenchmark( + bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 10, num_mcmc = 100), + times = 10 ) diff --git a/tools/perf/custom_loop_microbenchmark.R b/tools/perf/custom_loop_microbenchmark.R new file mode 100644 index 00000000..a2312a4f --- /dev/null +++ b/tools/perf/custom_loop_microbenchmark.R @@ -0,0 +1,160 @@ +# Load libraries +library(microbenchmark) +library(stochtree) + +# Simulate a simple partitioned linear model +n <- 10000 +p_X <- 20 +p_W <- 1 +X <- matrix(runif(n*p_X), ncol = p_X) +W <- matrix(runif(n*p_W), ncol = p_W) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-3*W[,1]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-1*W[,1]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (1*W[,1]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (3*W[,1]) +) +y <- f_XW + rnorm(n, 0, 1) + +# Standardize outcome +y_bar <- mean(y) +y_std <- sd(y) +resid <- (y-y_bar)/y_std + +## Sampling + +# Set some parameters that inform the forest and variance parameter samplers +alpha <- 0.9 +beta <- 1.25 +min_samples_leaf <- 1 +max_depth <- 10 +num_trees <- 100 +cutpoint_grid_size <- 100 +global_variance_init <- 1. +current_sigma2 <- global_variance_init +tau_init <- 1/num_trees +leaf_prior_scale <- as.matrix(ifelse(p_W >= 1, diag(tau_init, p_W), diag(tau_init, 1))) +nu <- 4 +lambda <- 0.5 +a_leaf <- 2. +b_leaf <- 0.5 +leaf_regression <- T +feature_types <- as.integer(rep(0, p_X)) # 0 = numeric +var_weights <- rep(1/p_X, p_X) + +# Initialize R-level access to the C++ classes needed to sample our model + +# Data +if (leaf_regression) { + forest_dataset <- createForestDataset(X, W) + outcome_model_type <- 1 + leaf_dimension <- p_W +} else { + forest_dataset <- createForestDataset(X) + outcome_model_type <- 0 + leaf_dimension <- 1 +} +outcome <- createOutcome(resid) + +# Random number generator (std::mt19937) +rng <- createCppRNG() + +# Sampling data structures +forest_model_config <- createForestModelConfig( + feature_types = feature_types, sweep_update_indices = 0:9, num_trees = num_trees, num_features = p_X, + num_observations = n, variable_weights = var_weights, leaf_dimension = leaf_dimension, + alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, max_depth = max_depth, + leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale, + cutpoint_grid_size = cutpoint_grid_size +) +global_model_config <- createGlobalModelConfig(global_error_variance = global_variance_init) +forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) + +# "Active forest" (which gets updated by the sample) and +# container of forest samples (which is written to when +# a sample is not discarded due to burn-in / thinning) +if (leaf_regression) { + forest_samples <- createForestSamples(num_trees, 1, F) + active_forest <- createForest(num_trees, 1, F) +} else { + forest_samples <- createForestSamples(num_trees, 1, T) + active_forest <- createForest(num_trees, 1, T) +} + +# Initialize the leaves of each tree in the forest +active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, outcome_model_type, mean(resid)) +active_forest$adjust_residual(forest_dataset, outcome, forest_model, ifelse(outcome_model_type==1, T, F), F) + +#Prepare to run the sampler +num_warmstart <- 10 +num_mcmc <- 100 +num_samples <- num_warmstart + num_mcmc +global_var_samples <- c(global_variance_init, rep(0, num_samples)) +leaf_scale_samples <- c(tau_init, rep(0, num_samples)) + +bench_results <- microbenchmark( +{ + + # Run the grow-from-root sampler to "warm-start" BART + for (i in 1:num_warmstart) { + # Sample forest + forest_model$sample_one_iteration( + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, global_model_config, keep_forest = T, gfr = T + ) + + # Sample global variance parameter + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome, forest_dataset, rng, nu, lambda + ) + global_var_samples[i+1] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) + + # Sample leaf node variance parameter and update `leaf_prior_scale` + leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( + active_forest, rng, a_leaf, b_leaf + ) + leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] + forest_model_config$update_leaf_model_scale(leaf_prior_scale) + } + + # Update sweep indices + # forest_model_config$update_sweep_indices(0:(num_trees-1)) + forest_model_config$update_sweep_indices(NULL) + + # Pick up from the last GFR forest (and associated global variance / leaf + # scale parameters) with an MCMC sampler + for (i in (num_warmstart+1):num_samples) { + # Sample forest + forest_model$sample_one_iteration( + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, global_model_config, keep_forest = T, gfr = F + ) + + # Sample global variance parameter + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome, forest_dataset, rng, nu, lambda + ) + global_var_samples[i+1] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) + + # Sample leaf node variance parameter and update `leaf_prior_scale` + leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( + active_forest, rng, a_leaf, b_leaf + ) + leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] + forest_model_config$update_leaf_model_scale(leaf_prior_scale) + } +}, +times = 5 +) + +# Forest predictions +preds <- forest_samples$predict(forest_dataset)*y_std + y_bar + +# Global error variance +sigma_samples <- sqrt(global_var_samples)*y_std + +# Plot samples +# plot(rowMeans(preds), y); abline(0,1,col="red",lty=3,lwd=3) +# plot(sigma_samples); abline(h = 1,col="blue",lty=3,lwd=3)