Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions R/config.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ ForestModelConfig <- R6::R6Class(
#' @field feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
feature_types = NULL,

#' @field sweep_update_indices Vector of trees to update in a sweep
sweep_update_indices = NULL,

#' @field num_trees Number of trees in the forest being sampled
num_trees = NULL,

Expand Down Expand Up @@ -62,6 +65,7 @@ ForestModelConfig <- R6::R6Class(
#' Create a new ForestModelConfig object.
#'
#' @param feature_types Vector of integer-coded feature types (where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
#' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep
#' @param num_trees Number of trees in the forest being sampled
#' @param num_features Number of features in training dataset
#' @param num_observations Number of observations in training dataset
Expand All @@ -78,7 +82,7 @@ ForestModelConfig <- R6::R6Class(
#' @param cutpoint_grid_size Number of unique cutpoints to consider (default: `100`)
#'
#' @return A new ForestModelConfig object.
initialize = function(feature_types = NULL, num_trees = NULL, num_features = NULL,
initialize = function(feature_types = NULL, sweep_update_indices = NULL, num_trees = NULL, num_features = NULL,
num_observations = NULL, variable_weights = NULL, leaf_dimension = 1,
alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = -1,
leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1.0,
Expand All @@ -101,6 +105,10 @@ ForestModelConfig <- R6::R6Class(
if (is.null(num_trees)) {
stop("num_trees must be provided")
}
if (!is.null(sweep_update_indices)) {
stopifnot(min(sweep_update_indices) >= 0)
stopifnot(max(sweep_update_indices) < num_trees)
}
if (is.null(num_observations)) {
stop("num_observations must be provided")
}
Expand All @@ -111,6 +119,7 @@ ForestModelConfig <- R6::R6Class(
stop("`variable_weights` must have `num_features` total elements")
}
self$feature_types <- feature_types
self$sweep_update_indices <- sweep_update_indices
self$variable_weights <- variable_weights
self$num_trees <- num_trees
self$num_features <- num_features
Expand Down Expand Up @@ -158,6 +167,17 @@ ForestModelConfig <- R6::R6Class(
self$feature_types <- feature_types
},

#' @description
#' Update sweep update indices
#' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep
update_sweep_indices = function(sweep_update_indices) {
if (!is.null(sweep_update_indices)) {
stopifnot(min(sweep_update_indices) >= 0)
stopifnot(max(sweep_update_indices) < self$num_trees)
}
self$sweep_update_indices <- sweep_update_indices
},

#' @description
#' Update variable weights
#' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset
Expand Down Expand Up @@ -242,6 +262,13 @@ ForestModelConfig <- R6::R6Class(
return(self$feature_types)
},

#' @description
#' Query sweep update indices for this ForestModelConfig object
#' @returns Vector of (0-indexed) indices of trees to update in a sweep
get_sweep_indices = function() {
return(self$sweep_update_indices)
},

#' @description
#' Query variable weights for this ForestModelConfig object
#' @returns Vector specifying sampling probability for all p covariates in ForestDataset
Expand Down Expand Up @@ -382,6 +409,7 @@ GlobalModelConfig <- R6::R6Class(
#' Create a forest model config object
#'
#' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
#' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep
#' @param num_trees Number of trees in the forest being sampled
#' @param num_features Number of features in training dataset
#' @param num_observations Number of observations in training dataset
Expand All @@ -401,13 +429,13 @@ GlobalModelConfig <- R6::R6Class(
#'
#' @examples
#' config <- createForestModelConfig(num_trees = 10, num_features = 5, num_observations = 100)
createForestModelConfig <- function(feature_types = NULL, num_trees = NULL, num_features = NULL,
createForestModelConfig <- function(feature_types = NULL, sweep_update_indices = NULL, num_trees = NULL, num_features = NULL,
num_observations = NULL, variable_weights = NULL, leaf_dimension = 1,
alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = -1,
leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1.0,
variance_forest_scale = 1.0, cutpoint_grid_size = 100){
return(invisible((
ForestModelConfig$new(feature_types, num_trees, num_features, num_observations,
ForestModelConfig$new(feature_types, sweep_update_indices, num_trees, num_features, num_observations,
variable_weights, leaf_dimension, alpha, beta, min_samples_leaf,
max_depth, leaf_model_type, leaf_model_scale, variance_forest_shape,
variance_forest_scale, cutpoint_grid_size)
Expand Down
8 changes: 4 additions & 4 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -580,12 +580,12 @@ compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums)
.Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums)
}

sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) {
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest))
sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) {
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest))
}

sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) {
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest))
sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) {
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest))
}

sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) {
Expand Down
14 changes: 11 additions & 3 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,15 @@ ForestModel <- R6::R6Class(
#' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `TRUE`.
#' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `TRUE`.
sample_one_iteration = function(forest_dataset, residual, forest_samples, active_forest,
rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = TRUE) {
rng, forest_model_config, global_model_config,
keep_forest = TRUE, gfr = TRUE) {
if (active_forest$is_empty()) {
stop("`active_forest` has not yet been initialized, which is necessary to run the sampler. Please set constant values for `active_forest`'s leaves using either the `set_root_leaves` or `prepare_for_sampler` methods.")
}

# Unpack parameters from model config object
feature_types <- forest_model_config$feature_types
sweep_update_indices <- forest_model_config$sweep_update_indices
leaf_model_int <- forest_model_config$leaf_model_type
leaf_model_scale <- forest_model_config$leaf_model_scale
variable_weights <- forest_model_config$variable_weights
Expand All @@ -85,6 +87,12 @@ ForestModel <- R6::R6Class(
global_scale <- global_model_config$global_error_variance
cutpoint_grid_size <- forest_model_config$cutpoint_grid_size

# Default to empty integer vector if sweep_update_indices is NULL
if (is.null(sweep_update_indices)) {
# sweep_update_indices <- integer(0)
sweep_update_indices <- 0:(forest_model_config$num_trees - 1)
}

# Detect changes to tree prior
if (forest_model_config$alpha != get_alpha_tree_prior_cpp(self$tree_prior_ptr)) {
update_alpha_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$alpha)
Expand All @@ -104,14 +112,14 @@ ForestModel <- R6::R6Class(
sample_gfr_one_iteration_cpp(
forest_dataset$data_ptr, residual$data_ptr,
forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr,
self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale,
self$tree_prior_ptr, rng$rng_ptr, sweep_update_indices, feature_types, cutpoint_grid_size, leaf_model_scale,
variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest
)
} else {
sample_mcmc_one_iteration_cpp(
forest_dataset$data_ptr, residual$data_ptr,
forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr,
self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale,
self$tree_prior_ptr, rng$rng_ptr, sweep_update_indices, feature_types, cutpoint_grid_size, leaf_model_scale,
variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest
)
}
Expand Down
20 changes: 12 additions & 8 deletions debug/api_debug.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,10 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia
// Prepare the samplers
LeafModelVariant leaf_model = leafModelFactory(model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest);

// Initialize vector of sweep update indices
std::vector<int> sweep_indices(num_trees);
std::iota(sweep_indices.begin(), sweep_indices.end(), 0);

// Run the GFR sampler
if (num_gfr > 0) {
for (int i = 0; i < num_gfr; i++) {
Expand All @@ -683,13 +687,13 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia

// Sample tree ensemble
if (model_type == ModelType::kConstantLeafGaussian) {
GFRSampleOneIter<GaussianConstantLeafModel, GaussianConstantSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianConstantLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, true);
GFRSampleOneIter<GaussianConstantLeafModel, GaussianConstantSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianConstantLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true);
} else if (model_type == ModelType::kUnivariateRegressionLeafGaussian) {
GFRSampleOneIter<GaussianUnivariateRegressionLeafModel, GaussianUnivariateRegressionSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianUnivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, true);
GFRSampleOneIter<GaussianUnivariateRegressionLeafModel, GaussianUnivariateRegressionSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianUnivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true);
} else if (model_type == ModelType::kMultivariateRegressionLeafGaussian) {
GFRSampleOneIter<GaussianMultivariateRegressionLeafModel, GaussianMultivariateRegressionSuffStat, int>(active_forest, tracker, forest_samples, std::get<GaussianMultivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, true, omega_cols);
GFRSampleOneIter<GaussianMultivariateRegressionLeafModel, GaussianMultivariateRegressionSuffStat, int>(active_forest, tracker, forest_samples, std::get<GaussianMultivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true, omega_cols);
} else if (model_type == ModelType::kLogLinearVariance) {
GFRSampleOneIter<LogLinearVarianceLeafModel, LogLinearVarianceSuffStat>(active_forest, tracker, forest_samples, std::get<LogLinearVarianceLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, false);
GFRSampleOneIter<LogLinearVarianceLeafModel, LogLinearVarianceSuffStat>(active_forest, tracker, forest_samples, std::get<LogLinearVarianceLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, false);
}

if (rfx_included) {
Expand Down Expand Up @@ -720,13 +724,13 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia

// Sample tree ensemble
if (model_type == ModelType::kConstantLeafGaussian) {
MCMCSampleOneIter<GaussianConstantLeafModel, GaussianConstantSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianConstantLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, true);
MCMCSampleOneIter<GaussianConstantLeafModel, GaussianConstantSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianConstantLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true, true, true);
} else if (model_type == ModelType::kUnivariateRegressionLeafGaussian) {
MCMCSampleOneIter<GaussianUnivariateRegressionLeafModel, GaussianUnivariateRegressionSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianUnivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, true);
MCMCSampleOneIter<GaussianUnivariateRegressionLeafModel, GaussianUnivariateRegressionSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianUnivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true, true, true);
} else if (model_type == ModelType::kMultivariateRegressionLeafGaussian) {
MCMCSampleOneIter<GaussianMultivariateRegressionLeafModel, GaussianMultivariateRegressionSuffStat, int>(active_forest, tracker, forest_samples, std::get<GaussianMultivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, true, omega_cols);
MCMCSampleOneIter<GaussianMultivariateRegressionLeafModel, GaussianMultivariateRegressionSuffStat, int>(active_forest, tracker, forest_samples, std::get<GaussianMultivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true, true, true, omega_cols);
} else if (model_type == ModelType::kLogLinearVariance) {
MCMCSampleOneIter<LogLinearVarianceLeafModel, LogLinearVarianceSuffStat>(active_forest, tracker, forest_samples, std::get<LogLinearVarianceLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, false);
MCMCSampleOneIter<LogLinearVarianceLeafModel, LogLinearVarianceSuffStat>(active_forest, tracker, forest_samples, std::get<LogLinearVarianceLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true, true, false);
}

if (rfx_included) {
Expand Down
Loading
Loading