diff --git a/DESCRIPTION b/DESCRIPTION index 71157a8d..6dc68ff2 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -15,7 +15,7 @@ Encoding: UTF-8 Roxygen: list(markdown = TRUE) RoxygenNote: 7.3.1 LinkingTo: - cpp11 + cpp11, BH Suggests: doParallel, foreach, diff --git a/NAMESPACE b/NAMESPACE index e590fbe3..365e9dfd 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -17,12 +17,14 @@ export(createBARTModelFromCombinedJsonString) export(createBARTModelFromJson) export(createBARTModelFromJsonFile) export(createBARTModelFromJsonString) +export(createBCFModelFromCombinedJsonString) export(createBCFModelFromJson) export(createBCFModelFromJsonFile) export(createBCFModelFromJsonString) export(createCppJson) export(createCppJsonFile) export(createCppJsonString) +export(createForest) export(createForestContainer) export(createForestCovariates) export(createForestCovariatesFromMetadata) @@ -55,6 +57,13 @@ export(preprocessPredictionMatrix) export(preprocessTrainData) export(preprocessTrainDataFrame) export(preprocessTrainMatrix) +export(resetActiveForest) +export(resetForestModel) +export(resetRandomEffectsModel) +export(resetRandomEffectsTracker) +export(rootResetActiveForest) +export(rootResetRandomEffectsModel) +export(rootResetRandomEffectsTracker) export(sample_sigma2_one_iteration) export(sample_tau_one_iteration) export(saveBARTModelToJsonFile) @@ -62,6 +71,7 @@ export(saveBARTModelToJsonString) export(saveBCFModelToJsonFile) export(saveBCFModelToJsonString) importFrom(R6,R6Class) +importFrom(stats,coef) importFrom(stats,lm) importFrom(stats,model.matrix) importFrom(stats,qgamma) diff --git a/R/bart.R b/R/bart.R index 1e058765..9098b8df 100644 --- a/R/bart.R +++ b/R/bart.R @@ -27,6 +27,8 @@ #' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5. #' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0. #' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100. +#' @param previous_model_json (Optional) JSON string containing a previous BART model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Default: `NULL`. +#' @param warmstart_sample_num (Optional) Sample number from `previous_model_json` that will be used to warmstart this BART sampler. One-indexed (so that the first sample is used for warm-start by setting `warmstart_sample_num = 1`). Default: `NULL`. #' @param params The list of model parameters, each of which has a default value. #' #' **1. Global Parameters** @@ -40,8 +42,10 @@ #' - `random_seed` Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`. #' - `sample_sigma_global` Whether or not to update the `sigma^2` global error variance parameter based on `IG(a_global, b_global)`. Default: `TRUE`. #' - `keep_burnin` Whether or not "burnin" samples should be included in cached predictions. Default `FALSE`. Ignored if `num_mcmc = 0`. -#' - `keep_gfr` Whether or not "grow-from-root" samples should be included in cached predictions. Default `TRUE`. Ignored if `num_mcmc = 0`. +#' - `keep_gfr` Whether or not "grow-from-root" samples should be included in cached predictions. Default `FALSE`. Ignored if `num_mcmc = 0`. #' - `standardize` Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: `TRUE`. +#' - `keep_every` How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default `1`. Setting `keep_every <- k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. +#' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`. #' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`. #' #' **2. Mean Forest Parameters** @@ -114,6 +118,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, rfx_basis_train = NULL, X_test = NULL, W_test = NULL, group_ids_test = NULL, rfx_basis_test = NULL, num_gfr = 5, num_burnin = 0, num_mcmc = 100, + previous_model_json = NULL, warmstart_sample_num = NULL, params = list()) { # Extract BART parameters bart_params <- preprocessBartParams(params) @@ -148,8 +153,56 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, keep_burnin <- bart_params$keep_burnin keep_gfr <- bart_params$keep_gfr standardize <- bart_params$standardize + keep_every <- bart_params$keep_every + num_chains <- bart_params$num_chains verbose <- bart_params$verbose + # Check if there are enough GFR samples to seed num_chains samplers + if (num_gfr > 0) { + if (num_chains > num_gfr) { + stop("num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains") + } + } + + # Override keep_gfr if there are no MCMC samples + if (num_mcmc == 0) keep_gfr <- T + + # Check if previous model JSON is provided and parse it if so + # TODO: check that warmstart_sample_num is <= the number of samples in this previous model + has_prev_model <- !is.null(previous_model_json) + if (has_prev_model) { + previous_bart_model <- createBARTModelFromJsonString(previous_model_json) + previous_y_bar <- previous_bart_model$model_params$outcome_mean + previous_y_scale <- previous_bart_model$model_params$outcome_scale + previous_var_scale <- previous_bart_model$model_params$variance_scale + if (previous_bart_model$model_params$include_mean_forest) { + previous_forest_samples_mean <- previous_bart_model$mean_forests + } else previous_forest_samples_mean <- NULL + if (previous_bart_model$model_params$include_mean_forest) { + previous_forest_samples_variance <- previous_bart_model$variance_forests + } else previous_forest_samples_variance <- NULL + if (previous_bart_model$model_params$sample_sigma_global) { + previous_global_var_samples <- previous_bart_model$sigma2_global_samples*( + previous_var_scale / (previous_y_scale*previous_y_scale) + ) + } else previous_global_var_samples <- NULL + if (previous_bart_model$model_params$sample_sigma_leaf) { + previous_leaf_var_samples <- previous_bart_model$sigma2_leaf_samples + } else previous_leaf_var_samples <- NULL + if (previous_bart_model$model_params$has_rfx) { + previous_rfx_samples <- previous_bart_model$rfx_samples + } else previous_rfx_samples <- NULL + } else { + previous_y_bar <- NULL + previous_y_scale <- NULL + previous_var_scale <- NULL + previous_global_var_samples <- NULL + previous_leaf_var_samples <- NULL + previous_rfx_samples <- NULL + previous_forest_samples_mean <- NULL + previous_forest_samples_variance <- NULL + } + # Determine whether conditional mean, variance, or both will be modeled if (num_trees_variance > 0) include_variance_forest = T else include_variance_forest = F @@ -401,13 +454,16 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, # Container of forest samples if (include_mean_forest) { forest_samples_mean <- createForestContainer(num_trees_mean, output_dimension, is_leaf_constant, FALSE) + active_forest_mean <- createForest(num_trees_mean, output_dimension, is_leaf_constant, FALSE) } if (include_variance_forest) { forest_samples_variance <- createForestContainer(num_trees_variance, 1, TRUE, TRUE) + active_forest_variance <- createForest(num_trees_variance, 1, TRUE, TRUE) } - # Random effects prior parameters + # Random effects initialization if (has_rfx) { + # Prior parameters if (num_rfx_components == 1) { alpha_init <- c(1) } else if (num_rfx_components > 1) { @@ -420,10 +476,8 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, sigma_xi_init <- diag(1,num_rfx_components,num_rfx_components) sigma_xi_shape <- 1 sigma_xi_scale <- 1 - } - - # Random effects data structure and storage container - if (has_rfx) { + + # Random effects data structure and storage container rfx_dataset_train <- createRandomEffectsDataset(group_ids_train, rfx_basis_train) rfx_tracker_train <- createRandomEffectsTracker(group_ids_train) rfx_model <- createRandomEffectsModel(num_rfx_components, num_rfx_groups) @@ -437,26 +491,34 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, } # Container of variance parameter samples - num_samples <- num_gfr + num_burnin + num_mcmc - if (sample_sigma_global) global_var_samples <- rep(0, num_samples) - if (sample_sigma_leaf) leaf_scale_samples <- rep(0, num_samples) + num_actual_mcmc_iter <- num_mcmc * keep_every + num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter + # Delete GFR samples from these containers after the fact if desired + # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc + num_retained_samples <- num_gfr + ifelse(keep_burnin, num_burnin, 0) + num_mcmc * num_chains + if (sample_sigma_global) global_var_samples <- rep(NA, num_retained_samples) + if (sample_sigma_leaf) leaf_scale_samples <- rep(NA, num_retained_samples) + sample_counter <- 0 # Initialize the leaves of each tree in the mean forest if (include_mean_forest) { if (requires_basis) init_values_mean_forest <- rep(0., ncol(W_train)) else init_values_mean_forest <- 0. - forest_samples_mean$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mean, leaf_model_mean_forest, init_values_mean_forest) + active_forest_mean$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mean, leaf_model_mean_forest, init_values_mean_forest) } # Initialize the leaves of each tree in the variance forest if (include_variance_forest) { - forest_samples_variance$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_variance, leaf_model_variance_forest, variance_forest_init) + active_forest_variance$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_variance, leaf_model_variance_forest, variance_forest_init) } # Run GFR (warm start) if specified if (num_gfr > 0){ - gfr_indices = 1:num_gfr for (i in 1:num_gfr) { + # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC + # keep_sample <- ifelse(keep_gfr, T, F) + keep_sample <- T + if (keep_sample) sample_counter <- sample_counter + 1 # Print progress if (verbose) { if ((i %% 10 == 0) || (i == num_gfr)) { @@ -466,83 +528,178 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, if (include_mean_forest) { forest_model_mean$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_mean, rng, feature_types, - leaf_model_mean_forest, current_leaf_scale, variable_weights_mean, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T + forest_dataset_train, outcome_train, forest_samples_mean, active_forest_mean, + rng, feature_types, leaf_model_mean_forest, current_leaf_scale, variable_weights_mean, + a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T ) } if (include_variance_forest) { forest_model_variance$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_variance, rng, feature_types, - leaf_model_variance_forest, current_leaf_scale, variable_weights_variance, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T + forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance, + rng, feature_types, leaf_model_variance_forest, current_leaf_scale, variable_weights_variance, + a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T ) } if (sample_sigma_global) { - global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, a_global, b_global) - current_sigma2 <- global_var_samples[i] + current_sigma2 <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, a_global, b_global) + if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 } if (sample_sigma_leaf) { - leaf_scale_samples[i] <- sample_tau_one_iteration(forest_samples_mean, rng, a_leaf, b_leaf, i-1) - current_leaf_scale <- as.matrix(leaf_scale_samples[i]) + leaf_scale_double <- sample_tau_one_iteration(active_forest_mean, rng, a_leaf, b_leaf) + current_leaf_scale <- as.matrix(leaf_scale_double) + if (keep_sample) leaf_scale_samples[sample_counter] <- leaf_scale_double } if (has_rfx) { - rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, current_sigma2, rng) + rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, keep_sample, current_sigma2, rng) } } } # Run MCMC if (num_burnin + num_mcmc > 0) { - if (num_burnin > 0) { - burnin_indices = (num_gfr+1):(num_gfr+num_burnin) - } - if (num_mcmc > 0) { - mcmc_indices = (num_gfr+num_burnin+1):(num_gfr+num_burnin+num_mcmc) - } - for (i in (num_gfr+1):num_samples) { - # Print progress - if (verbose) { - if (num_burnin > 0) { - if (((i - num_gfr) %% 100 == 0) || ((i - num_gfr) == num_burnin)) { - cat("Sampling", i - num_gfr, "out of", num_burnin, "BART burn-in draws\n") + for (chain_num in 1:num_chains) { + if (num_gfr > 0) { + # Reset state of active_forest and forest_model based on a previous GFR sample + forest_ind <- num_gfr - chain_num + if (include_mean_forest) { + resetActiveForest(active_forest_mean, forest_samples_mean, forest_ind) + resetForestModel(forest_model_mean, active_forest_mean, forest_dataset_train, outcome_train, TRUE) + if (sample_sigma_leaf) { + leaf_scale_double <- leaf_scale_samples[forest_ind + 1] + current_leaf_scale <- as.matrix(leaf_scale_double) + } + } + if (include_variance_forest) { + resetActiveForest(active_forest_variance, forest_samples_variance, forest_ind) + resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE) + } + if (has_rfx) { + resetRandomEffectsModel(rfx_model, rfx_samples, forest_ind, sigma_alpha_init) + resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) + } + if (sample_sigma_global) current_sigma2 <- global_var_samples[forest_ind + 1] + } else if (has_prev_model) { + if (include_mean_forest) { + resetActiveForest(active_forest_mean, previous_forest_samples_mean, warmstart_sample_num - 1) + resetForestModel(forest_model_mean, active_forest_mean, forest_dataset_train, outcome_train, TRUE) + if (sample_sigma_leaf && (!is.null(previous_leaf_var_samples))) { + leaf_scale_double <- previous_leaf_var_samples[warmstart_sample_num] + current_leaf_scale <- as.matrix(leaf_scale_double) + } + } + if (include_variance_forest) { + resetActiveForest(active_forest_variance, previous_forest_samples_variance, warmstart_sample_num - 1) + resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE) + } + # TODO: also initialize from previous RFX samples + # if (has_rfx) { + # rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, + # sigma_xi_init, sigma_xi_shape, sigma_xi_scale) + # rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) + # } + if (sample_sigma_global) { + if (!is.null(previous_global_var_samples)) { + current_sigma2 <- previous_global_var_samples[warmstart_sample_num] + } + } + } else { + if (include_mean_forest) { + rootResetActiveForest(active_forest_mean) + active_forest_mean$set_root_leaves(init_values_mean_forest / num_trees_mean) + resetForestModel(forest_model_mean, active_forest_mean, forest_dataset_train, outcome_train, TRUE) + if (sample_sigma_leaf) { + current_leaf_scale <- as.matrix(sigma_leaf_init) } } - if (num_mcmc > 0) { - if (((i - num_gfr - num_burnin) %% 100 == 0) || (i == num_samples)) { - cat("Sampling", i - num_burnin - num_gfr, "out of", num_mcmc, "BART MCMC draws\n") + if (include_variance_forest) { + rootResetActiveForest(active_forest_variance) + active_forest_variance$set_root_leaves(log(variance_forest_init) / num_trees_variance) + resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE) + } + if (has_rfx) { + rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, + sigma_xi_init, sigma_xi_shape, sigma_xi_scale) + rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) + } + if (sample_sigma_global) current_sigma2 <- sigma2_init + } + for (i in (num_gfr+1):num_samples) { + is_mcmc <- i > (num_gfr + num_burnin) + if (is_mcmc) { + mcmc_counter <- i - (num_gfr + num_burnin) + if (mcmc_counter %% keep_every == 0) keep_sample <- T + else keep_sample <- F + } else { + if (keep_burnin) keep_sample <- T + else keep_sample <- F + } + if (keep_sample) sample_counter <- sample_counter + 1 + # Print progress + if (verbose) { + if (num_burnin > 0) { + if (((i - num_gfr) %% 100 == 0) || ((i - num_gfr) == num_burnin)) { + cat("Sampling", i - num_gfr, "out of", num_burnin, "BART burn-in draws; Chain number ", chain_num, "\n") + } + } + if (num_mcmc > 0) { + if (((i - num_gfr - num_burnin) %% 100 == 0) || (i == num_samples)) { + cat("Sampling", i - num_burnin - num_gfr, "out of", num_mcmc, "BART MCMC draws; Chain number ", chain_num, "\n") + } } } + + if (include_mean_forest) { + forest_model_mean$sample_one_iteration( + forest_dataset_train, outcome_train, forest_samples_mean, active_forest_mean, + rng, feature_types, leaf_model_mean_forest, current_leaf_scale, variable_weights_mean, + a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T + ) + } + if (include_variance_forest) { + forest_model_variance$sample_one_iteration( + forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance, + rng, feature_types, leaf_model_variance_forest, current_leaf_scale, variable_weights_variance, + a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T + ) + } + if (sample_sigma_global) { + current_sigma2 <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, a_global, b_global) + if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 + } + if (sample_sigma_leaf) { + leaf_scale_double <- sample_tau_one_iteration(active_forest_mean, rng, a_leaf, b_leaf) + current_leaf_scale <- as.matrix(leaf_scale_double) + if (keep_sample) leaf_scale_samples[sample_counter] <- leaf_scale_double + } + if (has_rfx) { + rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, keep_sample, current_sigma2, rng) + } } - + } + } + + # Remove GFR samples if they are not to be retained + if ((!keep_gfr) && (num_gfr > 0)) { + for (i in 1:num_gfr) { if (include_mean_forest) { - forest_model_mean$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_mean, rng, feature_types, - leaf_model_mean_forest, current_leaf_scale, variable_weights_mean, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = T - ) + forest_samples_mean$delete_sample(i-1) } if (include_variance_forest) { - forest_model_variance$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_variance, rng, feature_types, - leaf_model_variance_forest, current_leaf_scale, variable_weights_variance, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = T - ) - } - if (sample_sigma_global) { - global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, a_global, b_global) - current_sigma2 <- global_var_samples[i] - } - if (sample_sigma_leaf) { - leaf_scale_samples[i] <- sample_tau_one_iteration(forest_samples_mean, rng, a_leaf, b_leaf, i-1) - current_leaf_scale <- as.matrix(leaf_scale_samples[i]) + forest_samples_variance$delete_sample(i-1) } if (has_rfx) { - rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, current_sigma2, rng) + rfx_samples$delete_sample(i-1) } } + if (sample_sigma_global) { + global_var_samples <- global_var_samples[(num_gfr+1):length(global_var_samples)] + } + if (sample_sigma_leaf) { + leaf_scale_samples <- leaf_scale_samples[(num_gfr+1):length(leaf_scale_samples)] + } + num_retained_samples <- num_retained_samples - num_gfr } - + # Mean forest predictions if (include_mean_forest) { y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train/sqrt(variance_scale) + y_bar_train @@ -564,53 +721,18 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, rfx_preds_test <- rfx_samples$predict(group_ids_test, rfx_basis_test)*y_std_train/sqrt(variance_scale) y_hat_test <- y_hat_test + rfx_preds_test } - - # Compute retention indices - if (num_mcmc > 0) { - keep_indices = mcmc_indices - if (keep_gfr) keep_indices <- c(gfr_indices, keep_indices) - if (keep_burnin) keep_indices <- c(burnin_indices, keep_indices) - } else { - if ((num_gfr > 0) && (num_burnin > 0)) { - # Override keep_gfr = FALSE since there are no MCMC samples - # Don't retain both GFR and burnin samples - keep_indices = gfr_indices - } else if ((num_gfr <= 0) && (num_burnin > 0)) { - # Override keep_burnin = FALSE since there are no MCMC or GFR samples - keep_indices = burnin_indices - } else if ((num_gfr > 0) && (num_burnin <= 0)) { - # Override keep_gfr = FALSE since there are no MCMC samples - keep_indices = gfr_indices - } else { - stop("There are no samples to retain!") - } - } - - # Subset forest and RFX predictions - if (include_mean_forest) { - y_hat_train <- y_hat_train[,keep_indices] - if (has_test) y_hat_test <- y_hat_test[,keep_indices] - } - if (include_variance_forest) { - sigma_x_hat_train <- sigma_x_hat_train[,keep_indices] - if (has_test) sigma_x_hat_test <- sigma_x_hat_test[,keep_indices] - } - if (has_rfx) { - rfx_preds_train <- rfx_preds_train[,keep_indices] - if (has_test) rfx_preds_test <- rfx_preds_test[,keep_indices] - } # Global error variance - if (sample_sigma_global) sigma2_samples <- global_var_samples[keep_indices]*(y_std_train^2)/variance_scale + if (sample_sigma_global) sigma2_samples <- global_var_samples*(y_std_train^2)/variance_scale # Leaf parameter variance - if (sample_sigma_leaf) tau_samples <- leaf_scale_samples[keep_indices] + if (sample_sigma_leaf) tau_samples <- leaf_scale_samples # Rescale variance forest prediction by global sigma2 (sampled or constant) if (include_variance_forest) { if (sample_sigma_global) { - sigma_x_hat_train <- sapply(1:length(keep_indices), function(i) sqrt(sigma_x_hat_train[,i]*sigma2_samples[i])) - if (has_test) sigma_x_hat_test <- sapply(1:length(keep_indices), function(i) sqrt(sigma_x_hat_test[,i]*sigma2_samples[i])) + sigma_x_hat_train <- sapply(1:num_retained_samples, function(i) sqrt(sigma_x_hat_train[,i]*sigma2_samples[i])) + if (has_test) sigma_x_hat_test <- sapply(1:num_retained_samples, function(i) sqrt(sigma_x_hat_test[,i]*sigma2_samples[i])) } else { sigma_x_hat_train <- sqrt(sigma_x_hat_train*sigma2_init)*y_std_train/sqrt(variance_scale) if (has_test) sigma_x_hat_test <- sqrt(sigma_x_hat_test*sigma2_init)*y_std_train/sqrt(variance_scale) @@ -619,6 +741,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, # Return results as a list # TODO: store variance_scale and propagate through predict function + # TODO: refactor out the "num_retained_samples" variable now that we burn-in/thin correctly model_params <- list( "sigma2_init" = sigma2_init, "sigma_leaf_init" = sigma_leaf_init, @@ -637,11 +760,12 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, "requires_basis" = requires_basis, "num_covariates" = ncol(X_train), "num_basis" = ifelse(is.null(W_train),0,ncol(W_train)), - "num_samples" = num_samples, + "num_samples" = num_retained_samples, "num_gfr" = num_gfr, "num_burnin" = num_burnin, "num_mcmc" = num_mcmc, - "num_retained_samples" = length(keep_indices), + "keep_every" = keep_every, + "num_chains" = num_chains, "has_basis" = !is.null(W_train), "has_rfx" = has_rfx, "has_rfx_basis" = has_basis_rfx, @@ -654,8 +778,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, ) result <- list( "model_params" = model_params, - "train_set_metadata" = X_train_metadata, - "keep_indices" = keep_indices + "train_set_metadata" = X_train_metadata ) if (include_mean_forest) { result[["mean_forests"]] = forest_samples_mean @@ -689,7 +812,6 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, return(result) } - #' Predict from a sampled BART model on new data #' #' @param bart Object of type `bart` containing draws of a regression forest and associated sampling outputs. @@ -787,6 +909,7 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL else prediction_dataset <- createForestDataset(X_test) # Compute mean forest predictions + num_samples <- bart$model_params$num_samples variance_scale <- bart$model_params$variance_scale y_std <- bart$model_params$outcome_scale y_bar <- bart$model_params$outcome_mean @@ -805,21 +928,11 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL rfx_predictions <- bart$rfx_samples$predict(group_ids_test, rfx_basis_test)*y_std/sqrt(variance_scale) } - # Restrict predictions to the "retained" samples (if applicable) - keep_indices = bart$keep_indices - if (bart$model_params$include_mean_forest) { - mean_forest_predictions <- mean_forest_predictions[,keep_indices] - } - if (bart$model_params$include_variance_forest) { - s_x_raw <- s_x_raw[,keep_indices] - } - if (bart$model_params$has_rfx) rfx_predictions <- rfx_predictions[,keep_indices] - # Scale variance forest predictions if (bart$model_params$include_variance_forest) { if (bart$model_params$sample_sigma_global) { sigma2_samples <- bart$sigma2_global_samples - variance_forest_predictions <- sapply(1:length(keep_indices), function(i) sqrt(s_x_raw[,i]*sigma2_samples[i])) + variance_forest_predictions <- sapply(1:num_samples, function(i) sqrt(s_x_raw[,i]*sigma2_samples[i])) } else { variance_forest_predictions <- sqrt(s_x_raw*sigma2_init)*y_std/sqrt(variance_scale) } @@ -995,8 +1108,9 @@ convertBARTModelToJson <- function(object){ jsonobj$add_scalar("num_samples", object$model_params$num_samples) jsonobj$add_scalar("num_covariates", object$model_params$num_covariates) jsonobj$add_scalar("num_basis", object$model_params$num_basis) + jsonobj$add_scalar("num_chains", object$model_params$num_chains) + jsonobj$add_scalar("keep_every", object$model_params$keep_every) jsonobj$add_boolean("requires_basis", object$model_params$requires_basis) - jsonobj$add_vector("keep_indices", object$keep_indices) if (object$model_params$sample_sigma_global) { jsonobj$add_vector("sigma2_global_samples", object$sigma2_global_samples, "parameters") } @@ -1013,6 +1127,69 @@ convertBARTModelToJson <- function(object){ return(jsonobj) } +#' Convert in-memory BART model objects (forests, random effects, vectors) to in-memory JSON. +#' This function is primarily a convenience function for serialization / deserialization in a parallel BART sampler. +#' +#' @param param_list List containing high-level model state parameters +#' @param mean_forest Container of conditional mean forest samples (optional). Default: `NULL`. +#' @param variance_forest Container of conditional variance forest samples (optional). Default: `NULL`. +#' @param rfx_samples Container of random effect samples (optional). Default: `NULL`. +#' @param global_variance_samples Vector of global error variance samples (optional). Default: `NULL`. +#' @param local_variance_samples Vector of leaf scale samples (optional). Default: `NULL`. +#' +#' @return Object of type `CppJson` +convertBARTStateToJson <- function(param_list, mean_forest = NULL, variance_forest = NULL, + rfx_samples = NULL, global_variance_samples = NULL, + local_variance_samples = NULL) { + # Initialize JSON object + jsonobj <- createCppJson() + + # Add global parameters + jsonobj$add_scalar("variance_scale", param_list$variance_scale) + jsonobj$add_scalar("outcome_scale", param_list$outcome_scale) + jsonobj$add_scalar("outcome_mean", param_list$outcome_mean) + jsonobj$add_boolean("standardize", param_list$standardize) + jsonobj$add_scalar("sigma2_init", param_list$sigma2_init) + jsonobj$add_boolean("sample_sigma_global", param_list$sample_sigma_global) + jsonobj$add_boolean("sample_sigma_leaf", param_list$sample_sigma_leaf) + jsonobj$add_boolean("include_mean_forest", param_list$include_mean_forest) + jsonobj$add_boolean("include_variance_forest", param_list$include_variance_forest) + jsonobj$add_boolean("has_rfx", param_list$has_rfx) + jsonobj$add_boolean("has_rfx_basis", param_list$has_rfx_basis) + jsonobj$add_scalar("num_rfx_basis", param_list$num_rfx_basis) + jsonobj$add_scalar("num_gfr", param_list$num_gfr) + jsonobj$add_scalar("num_burnin", param_list$num_burnin) + jsonobj$add_scalar("num_mcmc", param_list$num_mcmc) + jsonobj$add_scalar("num_covariates", param_list$num_covariates) + jsonobj$add_scalar("num_basis", param_list$num_basis) + jsonobj$add_scalar("keep_every", param_list$keep_every) + jsonobj$add_boolean("requires_basis", param_list$requires_basis) + + # Add the forests + if (param_list$include_mean_forest) { + jsonobj$add_forest(mean_forest) + } + if (param_list$include_variance_forest) { + jsonobj$add_forest(object$variance_forests) + } + + # Add sampled parameters + if (param_list$sample_sigma_global) { + jsonobj$add_vector("sigma2_global_samples", global_variance_samples, "parameters") + } + if (param_list$sample_sigma_leaf) { + jsonobj$add_vector("sigma2_leaf_samples", local_variance_samples, "parameters") + } + + # Add random effects + if (param_list$has_rfx) { + jsonobj$add_random_effects(rfx_samples) + jsonobj$add_string_vector("rfx_unique_group_ids", param_list$rfx_unique_group_ids) + } + + return(jsonobj) +} + #' Convert the persistent aspects of a BART model to (in-memory) JSON and save to a file #' #' @param object Object of type `bartmodel` containing draws of a BART model and associated sampling outputs. @@ -1154,7 +1331,6 @@ createBARTModelFromJson <- function(json_object){ train_set_metadata[["unordered_unique_levels"]] <- json_object$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]]) } output[["train_set_metadata"]] <- train_set_metadata - output[["keep_indices"]] <- json_object$get_vector("keep_indices") # Unpack model params model_params = list() @@ -1176,6 +1352,8 @@ createBARTModelFromJson <- function(json_object){ model_params[["num_samples"]] <- json_object$get_scalar("num_samples") model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates") model_params[["num_basis"]] <- json_object$get_scalar("num_basis") + model_params[["num_chains"]] <- json_object$get_scalar("num_chains") + model_params[["keep_every"]] <- json_object$get_scalar("keep_every") model_params[["requires_basis"]] <- json_object$get_boolean("requires_basis") output[["model_params"]] <- model_params @@ -1233,7 +1411,7 @@ createBARTModelFromJsonFile <- function(json_filename){ # Load a `CppJson` object from file bart_json <- createCppJsonFile(json_filename) - # Create and return the BCF object + # Create and return the BART object bart_object <- createBARTModelFromJson(bart_json) return(bart_object) @@ -1278,7 +1456,7 @@ createBARTModelFromJsonString <- function(json_string){ # Load a `CppJson` object from string bart_json <- createCppJsonString(json_string) - # Create and return the BCF object + # Create and return the BART object bart_object <- createBARTModelFromJson(bart_json) return(bart_object) @@ -1370,10 +1548,10 @@ createBARTModelFromCombinedJson <- function(json_object_list){ model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates") model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis") + model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") + model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") # Combine values that are sample-specific - keep_index_offset <- 0 - keep_indices <- c() for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { @@ -1381,18 +1559,14 @@ createBARTModelFromCombinedJson <- function(json_object_list){ model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") model_params[["num_samples"]] <- json_object$get_scalar("num_samples") - keep_indices <- c(keep_indices, keep_index_offset + json_object$get_vector("keep_indices")) } else { prev_json <- json_object_list[[i-1]] model_params[["num_gfr"]] <- model_params[["num_gfr"]] + json_object$get_scalar("num_gfr") model_params[["num_burnin"]] <- model_params[["num_burnin"]] + json_object$get_scalar("num_burnin") model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + json_object$get_scalar("num_mcmc") model_params[["num_samples"]] <- model_params[["num_samples"]] + json_object$get_scalar("num_samples") - keep_index_offset <- keep_index_offset + prev_json$get_scalar("num_samples") - keep_indices <- c(keep_indices, keep_index_offset + json_object$get_vector("keep_indices")) } } - output[["keep_indices"]] <- keep_indices output[["model_params"]] <- model_params # Unpack sampled parameters @@ -1503,8 +1677,7 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){ train_set_metadata[["unordered_unique_levels"]] <- json_object_default$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]]) } output[["train_set_metadata"]] <- train_set_metadata - output[["keep_indices"]] <- json_object_default$get_vector("keep_indices") - + # Unpack model params model_params = list() model_params[["variance_scale"]] <- json_object_default$get_scalar("variance_scale") @@ -1521,11 +1694,11 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){ model_params[["num_rfx_basis"]] <- json_object_default$get_scalar("num_rfx_basis") model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates") model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") + model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") + model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis") # Combine values that are sample-specific - keep_index_offset <- 0 - keep_indices <- c() for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { @@ -1533,18 +1706,14 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){ model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") model_params[["num_samples"]] <- json_object$get_scalar("num_samples") - keep_indices <- c(keep_indices, keep_index_offset + json_object$get_vector("keep_indices")) } else { prev_json <- json_object_list[[i-1]] model_params[["num_gfr"]] <- model_params[["num_gfr"]] + json_object$get_scalar("num_gfr") model_params[["num_burnin"]] <- model_params[["num_burnin"]] + json_object$get_scalar("num_burnin") model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + json_object$get_scalar("num_mcmc") model_params[["num_samples"]] <- model_params[["num_samples"]] + json_object$get_scalar("num_samples") - keep_index_offset <- keep_index_offset + prev_json$get_scalar("num_samples") - keep_indices <- c(keep_indices, keep_index_offset + json_object$get_vector("keep_indices")) } } - output[["keep_indices"]] <- keep_indices output[["model_params"]] <- model_params # Unpack sampled parameters diff --git a/R/bcf.R b/R/bcf.R index 4f30e7a0..921b7570 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -24,6 +24,8 @@ #' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5. #' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0. #' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100. +#' @param previous_model_json (Optional) JSON string containing a previous BCF model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Default: `NULL`. +#' @param warmstart_sample_num (Optional) Sample number from `previous_model_json` that will be used to warmstart this BCF sampler. One-indexed (so that the first sample is used for warm-start by setting `warmstart_sample_num = 1`). Default: `NULL`. #' @param params The list of model parameters, each of which has a default value. #' #' **1. Global Parameters** @@ -42,6 +44,8 @@ #' - `keep_burnin` Whether or not "burnin" samples should be included in cached predictions. Default `FALSE`. Ignored if `num_mcmc = 0`. #' - `keep_gfr` Whether or not "grow-from-root" samples should be included in cached predictions. Default `FALSE`. Ignored if `num_mcmc = 0`. #' - `standardize` Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: `TRUE`. +#' - `keep_every` How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default `1`. Setting `keep_every <- k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. +#' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`. #' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`. #' - `sample_sigma_global` Whether or not to update the `sigma^2` global error variance parameter based on `IG(a_global, b_global)`. Default: `TRUE`. #' @@ -154,14 +158,17 @@ #' tau_train <- tau_x[train_inds] #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, #' X_test = X_test, Z_test = Z_test, pi_test = pi_test) -#' # plot(rowMeans(bcf_model$mu_hat_test), mu_test, xlab = "predicted", ylab = "actual", main = "Prognostic function") +#' # plot(rowMeans(bcf_model$mu_hat_test), mu_test, xlab = "predicted", +#' # ylab = "actual", main = "Prognostic function") #' # abline(0,1,col="red",lty=3,lwd=3) -#' # plot(rowMeans(bcf_model$tau_hat_test), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") +#' # plot(rowMeans(bcf_model$tau_hat_test), tau_test, xlab = "predicted", +#' # ylab = "actual", main = "Treatment effect") #' # abline(0,1,col="red",lty=3,lwd=3) bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NULL, rfx_basis_train = NULL, X_test = NULL, Z_test = NULL, pi_test = NULL, group_ids_test = NULL, rfx_basis_test = NULL, num_gfr = 5, - num_burnin = 0, num_mcmc = 100, params = list()) { + num_burnin = 0, num_mcmc = 100, previous_model_json = NULL, + warmstart_sample_num = NULL, params = list()) { # Extract BCF parameters bcf_params <- preprocessBcfParams(params) cutpoint_grid_size <- bcf_params$cutpoint_grid_size @@ -201,9 +208,6 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU num_trees_mu <- bcf_params$num_trees_mu num_trees_tau <- bcf_params$num_trees_tau num_trees_variance <- bcf_params$num_trees_variance - num_gfr <- bcf_params$num_gfr - num_burnin <- bcf_params$num_burnin - num_mcmc <- bcf_params$num_mcmc sample_sigma_global <- bcf_params$sample_sigma_global sample_sigma_leaf_mu <- bcf_params$sample_sigma_leaf_mu sample_sigma_leaf_tau <- bcf_params$sample_sigma_leaf_tau @@ -216,8 +220,62 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU keep_burnin <- bcf_params$keep_burnin keep_gfr <- bcf_params$keep_gfr standardize <- bcf_params$standardize + keep_every <- bcf_params$keep_every + num_chains <- bcf_params$num_chains verbose <- bcf_params$verbose + # Override keep_gfr if there are no MCMC samples + if (num_mcmc == 0) keep_gfr <- T + + # Check if previous model JSON is provided and parse it if so + # TODO: check that warmstart_sample_num is <= the number of samples in this previous model + has_prev_model <- !is.null(previous_model_json) + if (has_prev_model) { + previous_bcf_model <- createBCFModelFromJsonString(previous_model_json) + previous_y_bar <- previous_bcf_model$model_params$outcome_mean + previous_y_scale <- previous_bcf_model$model_params$outcome_scale + previous_var_scale <- previous_bcf_model$model_params$variance_scale + previous_forest_samples_mu <- previous_bcf_model$forests_mu + previous_forest_samples_tau <- previous_bcf_model$forests_tau + if (previous_bcf_model$model_params$include_variance_forest) { + previous_forest_samples_variance <- previous_bcf_model$forests_variance + } else previous_forest_samples_variance <- NULL + if (previous_bcf_model$model_params$sample_sigma_global) { + previous_global_var_samples <- previous_bcf_model$sigma2_samples*( + previous_var_scale / (previous_y_scale*previous_y_scale) + ) + } else previous_global_var_samples <- NULL + if (previous_bcf_model$model_params$sample_sigma_leaf_mu) { + previous_leaf_var_mu_samples <- previous_bcf_model$sigma_leaf_mu_samples + } else previous_leaf_var_mu_samples <- NULL + if (previous_bcf_model$model_params$sample_sigma_leaf_tau) { + previous_leaf_var_tau_samples <- previous_bcf_model$sigma_leaf_tau_samples + } else previous_leaf_var_tau_samples <- NULL + if (previous_bcf_model$model_params$has_rfx) { + previous_rfx_samples <- previous_bcf_model$rfx_samples + } else previous_rfx_samples <- NULL + if (previous_bcf_model$model_params$adaptive_coding) { + previous_b_1_samples <- previous_bcf_model$b_1_samples + previous_b_0_samples <- previous_bcf_model$b_0_samples + } else { + previous_b_1_samples <- NULL + previous_b_0_samples <- NULL + } + } else { + previous_y_bar <- NULL + previous_y_scale <- NULL + previous_var_scale <- NULL + previous_global_var_samples <- NULL + previous_leaf_var_mu_samples <- NULL + previous_leaf_var_tau_samples <- NULL + previous_rfx_samples <- NULL + previous_forest_samples_mu <- NULL + previous_forest_samples_tau <- NULL + previous_forest_samples_variance <- NULL + previous_b_1_samples <- NULL + previous_b_0_samples <- NULL + } + # Determine whether conditional variance will be modeled if (num_trees_variance > 0) include_variance_forest = T else include_variance_forest = F @@ -625,18 +683,23 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU } # Container of variance parameter samples - num_samples <- num_gfr + num_burnin + num_mcmc - if (sample_sigma_global) global_var_samples <- rep(0, num_samples) - if (sample_sigma_leaf_mu) leaf_scale_mu_samples <- rep(0, num_samples) - if (sample_sigma_leaf_tau) leaf_scale_tau_samples <- rep(0, num_samples) + num_actual_mcmc_iter <- num_mcmc * keep_every + num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter + # Delete GFR samples from these containers after the fact if desired + # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc + num_retained_samples <- num_gfr + ifelse(keep_burnin, num_burnin, 0) + num_mcmc * num_chains + if (sample_sigma_global) global_var_samples <- rep(NA, num_retained_samples) + if (sample_sigma_leaf_mu) leaf_scale_mu_samples <- rep(NA, num_retained_samples) + if (sample_sigma_leaf_tau) leaf_scale_tau_samples <- rep(NA, num_retained_samples) + sample_counter <- 0 # Prepare adaptive coding structure if ((!is.numeric(b_0)) || (!is.numeric(b_1)) || (length(b_0) > 1) || (length(b_1) > 1)) { stop("b_0 and b_1 must be single numeric values") } if (adaptive_coding) { - b_0_samples <- rep(0, num_samples) - b_1_samples <- rep(0, num_samples) + b_0_samples <- rep(NA, num_retained_samples) + b_1_samples <- rep(NA, num_retained_samples) current_b_0 <- b_0 current_b_1 <- b_1 tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 @@ -665,27 +728,35 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU # Container of forest samples forest_samples_mu <- createForestContainer(num_trees_mu, 1, T) forest_samples_tau <- createForestContainer(num_trees_tau, 1, F) + active_forest_mu <- createForest(num_trees_mu, 1, T) + active_forest_tau <- createForest(num_trees_tau, 1, F) if (include_variance_forest) { forest_samples_variance <- createForestContainer(num_trees_variance, 1, TRUE, TRUE) + active_forest_variance <- createForest(num_trees_variance, 1, TRUE, TRUE) } # Initialize the leaves of each tree in the prognostic forest init_mu <- mean(resid_train) - forest_samples_mu$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mu, 0, init_mu) - + active_forest_mu$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mu, 0, init_mu) + active_forest_mu$adjust_residual(forest_dataset_train, outcome_train, forest_model_mu, F, F) + # Initialize the leaves of each tree in the treatment effect forest init_tau <- 0. - forest_samples_tau$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_tau, 1, init_tau) - + active_forest_tau$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_tau, 1, init_tau) + active_forest_tau$adjust_residual(forest_dataset_train, outcome_train, forest_model_tau, T, F) + # Initialize the leaves of each tree in the variance forest if (include_variance_forest) { - forest_samples_variance$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_variance, leaf_model_variance_forest, variance_forest_init) + active_forest_variance$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_variance, leaf_model_variance_forest, variance_forest_init) } # Run GFR (warm start) if specified if (num_gfr > 0){ - gfr_indices = 1:num_gfr for (i in 1:num_gfr) { + # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC + # keep_sample <- ifelse(keep_gfr, T, F) + keep_sample <- T + if (keep_sample) sample_counter <- sample_counter + 1 # Print progress if (verbose) { if ((i %% 10 == 0) || (i == num_gfr)) { @@ -695,33 +766,33 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU # Sample the prognostic forest forest_model_mu$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_mu, rng, feature_types, - 0, current_leaf_scale_mu, variable_weights_mu, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T + forest_dataset_train, outcome_train, forest_samples_mu, active_forest_mu, + rng, feature_types, 0, current_leaf_scale_mu, variable_weights_mu, a_forest, b_forest, + current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T ) # Sample variance parameters (if requested) if (sample_sigma_global) { - global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, a_global, b_global) - current_sigma2 <- global_var_samples[i] + current_sigma2 <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, a_global, b_global) } if (sample_sigma_leaf_mu) { - leaf_scale_mu_samples[i] <- sample_tau_one_iteration(forest_samples_mu, rng, a_leaf_mu, b_leaf_mu, i-1) - current_leaf_scale_mu <- as.matrix(leaf_scale_mu_samples[i]) + leaf_scale_mu_double <- sample_tau_one_iteration(active_forest_mu, rng, a_leaf_mu, b_leaf_mu) + current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) + if (keep_sample) leaf_scale_mu_samples[sample_counter] <- leaf_scale_mu_double } # Sample the treatment forest forest_model_tau$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_tau, rng, feature_types, - 1, current_leaf_scale_tau, variable_weights_tau, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T + forest_dataset_train, outcome_train, forest_samples_tau, active_forest_tau, + rng, feature_types, 1, current_leaf_scale_tau, variable_weights_tau, a_forest, b_forest, + current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T ) # Sample coding parameters (if requested) if (adaptive_coding) { # Estimate mu(X) and tau(X) and compute y - mu(X) - mu_x_raw_train <- forest_samples_mu$predict_raw_single_forest(forest_dataset_train, i-1) - tau_x_raw_train <- forest_samples_tau$predict_raw_single_forest(forest_dataset_train, i-1) + mu_x_raw_train <- active_forest_mu$predict_raw(forest_dataset_train) + tau_x_raw_train <- active_forest_tau$predict_raw(forest_dataset_train) partial_resid_mu_train <- resid_train - mu_x_raw_train if (has_rfx) { rfx_preds_train <- rfx_model$predict(rfx_dataset_train, rfx_tracker_train) @@ -741,147 +812,297 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU # Update basis for the leaf regression tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 forest_dataset_train$update_basis(tau_basis_train) - b_0_samples[i] <- current_b_0 - b_1_samples[i] <- current_b_1 + if (keep_sample) { + b_0_samples[sample_counter] <- current_b_0 + b_1_samples[sample_counter] <- current_b_1 + } if (has_test) { tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 forest_dataset_test$update_basis(tau_basis_test) } # Update leaf predictions and residual - forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, forest_samples_tau, i-1) + forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) } # Sample variance parameters (if requested) if (include_variance_forest) { forest_model_variance$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_variance, rng, feature_types, - leaf_model_variance_forest, current_leaf_scale_mu, variable_weights_variance, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T + forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance, + rng, feature_types, leaf_model_variance_forest, current_leaf_scale_mu, variable_weights_variance, + a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T ) } if (sample_sigma_global) { - global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, a_global, b_global) - current_sigma2 <- global_var_samples[i] + current_sigma2 <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, a_global, b_global) + if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 } if (sample_sigma_leaf_tau) { - leaf_scale_tau_samples[i] <- sample_tau_one_iteration(forest_samples_tau, rng, a_leaf_tau, b_leaf_tau, i-1) - current_leaf_scale_tau <- as.matrix(leaf_scale_tau_samples[i]) + leaf_scale_tau_double <- sample_tau_one_iteration(active_forest_tau, rng, a_leaf_tau, b_leaf_tau) + current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) + if (keep_sample) leaf_scale_tau_samples[sample_counter] <- leaf_scale_tau_double } # Sample random effects parameters (if requested) if (has_rfx) { - rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, current_sigma2, rng) + rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, keep_sample, current_sigma2, rng) } } } # Run MCMC if (num_burnin + num_mcmc > 0) { - if (num_burnin > 0) { - burnin_indices = (num_gfr+1):(num_gfr+num_burnin) - } - if (num_mcmc > 0) { - mcmc_indices = (num_gfr+num_burnin+1):(num_gfr+num_burnin+num_mcmc) - } - for (i in (num_gfr+1):num_samples) { - # Print progress - if (verbose) { - if (num_burnin > 0) { - if (((i - num_gfr) %% 100 == 0) || ((i - num_gfr) == num_burnin)) { - cat("Sampling", i - num_gfr, "out of", num_gfr, "BCF burn-in draws\n") + for (chain_num in 1:num_chains) { + if (num_gfr > 0) { + # Reset state of active_forest and forest_model based on a previous GFR sample + forest_ind <- num_gfr - chain_num + resetActiveForest(active_forest_mu, forest_samples_mu, forest_ind) + resetForestModel(forest_model_mu, active_forest_mu, forest_dataset_train, outcome_train, TRUE) + resetActiveForest(active_forest_tau, forest_samples_tau, forest_ind) + resetForestModel(forest_model_tau, active_forest_tau, forest_dataset_train, outcome_train, TRUE) + if (sample_sigma_leaf_mu) { + leaf_scale_mu_double <- leaf_scale_mu_samples[forest_ind + 1] + current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) + } + if (sample_sigma_leaf_tau) { + leaf_scale_tau_double <- leaf_scale_tau_samples[forest_ind + 1] + current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) + } + if (include_variance_forest) { + resetActiveForest(active_forest_variance, forest_samples_variance, forest_ind) + resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE) + } + if (has_rfx) { + resetRandomEffectsModel(rfx_model, rfx_samples, forest_ind, sigma_alpha_init) + resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) + } + if (adaptive_coding) { + current_b_1 <- b_1_samples[forest_ind + 1] + current_b_0 <- b_0_samples[forest_ind + 1] + tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 + forest_dataset_train$update_basis(tau_basis_train) + if (has_test) { + tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 + forest_dataset_test$update_basis(tau_basis_test) } + forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) } - if (num_mcmc > 0) { - if (((i - num_gfr - num_burnin) %% 100 == 0) || (i == num_samples)) { - cat("Sampling", i - num_burnin - num_gfr, "out of", num_mcmc, "BCF MCMC draws\n") + if (sample_sigma_global) current_sigma2 <- global_var_samples[forest_ind + 1] + } else if (has_prev_model) { + resetActiveForest(active_forest_mu, previous_forest_samples_mu, warmstart_sample_num - 1) + resetForestModel(forest_model_mu, active_forest_mu, forest_dataset_train, outcome_train, TRUE) + resetActiveForest(active_forest_tau, previous_forest_samples_tau, warmstart_sample_num - 1) + resetForestModel(forest_model_tau, active_forest_tau, forest_dataset_train, outcome_train, TRUE) + if (include_variance_forest) { + resetActiveForest(active_forest_variance, previous_forest_samples_variance, warmstart_sample_num - 1) + resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE) + } + if (sample_sigma_leaf_mu && (!is.null(previous_leaf_var_mu_samples))) { + leaf_scale_mu_double <- previous_leaf_var_mu_samples[warmstart_sample_num] + current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) + } + if (sample_sigma_leaf_tau && (!is.null(previous_leaf_var_tau_samples))) { + leaf_scale_tau_double <- previous_leaf_var_tau_samples[warmstart_sample_num] + current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) + } + if (adaptive_coding) { + if (!is.null(previous_b_1_samples)) { + current_b_1 <- previous_b_1_samples[warmstart_sample_num] + } + if (!is.null(previous_b_0_samples)) { + current_b_0 <- previous_b_0_samples[warmstart_sample_num] } + tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 + forest_dataset_train$update_basis(tau_basis_train) + if (has_test) { + tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 + forest_dataset_test$update_basis(tau_basis_test) + } + forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) + } + # TODO: also initialize from previous RFX samples + # if (has_rfx) { + # rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, + # sigma_xi_init, sigma_xi_shape, sigma_xi_scale) + # rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) + # } + if (sample_sigma_global) { + if (!is.null(previous_global_var_samples)) { + current_sigma2 <- previous_global_var_samples[warmstart_sample_num] + } + } + } else { + rootResetActiveForest(active_forest_mu) + active_forest_mu$set_root_leaves(init_mu / num_trees_mu) + resetForestModel(forest_model_mu, active_forest_mu, forest_dataset_train, outcome_train, TRUE) + rootResetActiveForest(active_forest_tau) + active_forest_tau$set_root_leaves(init_tau / num_trees_tau) + resetForestModel(forest_model_tau, active_forest_tau, forest_dataset_train, outcome_train, TRUE) + if (sample_sigma_leaf_mu) { + current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) + } + if (sample_sigma_leaf_tau) { + current_leaf_scale_tau <- as.matrix(sigma_leaf_tau) + } + if (include_variance_forest) { + rootResetActiveForest(active_forest_variance) + active_forest_variance$set_root_leaves(log(variance_forest_init) / num_trees_variance) + resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE) } - } - - # Sample the prognostic forest - forest_model_mu$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_mu, rng, feature_types, - 0, current_leaf_scale_mu, variable_weights_mu, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = T - ) - - # Sample variance parameters (if requested) - if (sample_sigma_global) { - global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, a_global, b_global) - current_sigma2 <- global_var_samples[i] - } - if (sample_sigma_leaf_mu) { - leaf_scale_mu_samples[i] <- sample_tau_one_iteration(forest_samples_mu, rng, a_leaf_mu, b_leaf_mu, i-1) - current_leaf_scale_mu <- as.matrix(leaf_scale_mu_samples[i]) - } - - # Sample the treatment forest - forest_model_tau$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_tau, rng, feature_types, - 1, current_leaf_scale_tau, variable_weights_tau, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = T - ) - - # Sample coding parameters (if requested) - if (adaptive_coding) { - # Estimate mu(X) and tau(X) and compute y - mu(X) - mu_x_raw_train <- forest_samples_mu$predict_raw_single_forest(forest_dataset_train, i-1) - tau_x_raw_train <- forest_samples_tau$predict_raw_single_forest(forest_dataset_train, i-1) - partial_resid_mu_train <- resid_train - mu_x_raw_train if (has_rfx) { - rfx_preds_train <- rfx_model$predict(rfx_dataset_train, rfx_tracker_train) - partial_resid_mu_train <- partial_resid_mu_train - rfx_preds_train + rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, + sigma_xi_init, sigma_xi_shape, sigma_xi_scale) + rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) + } + if (adaptive_coding) { + current_b_1 <- b_1 + current_b_0 <- b_0 + tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 + forest_dataset_train$update_basis(tau_basis_train) + if (has_test) { + tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 + forest_dataset_test$update_basis(tau_basis_test) + } + forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) + } + if (sample_sigma_global) current_sigma2 <- sigma2_init + } + for (i in (num_gfr+1):num_samples) { + is_mcmc <- i > (num_gfr + num_burnin) + if (is_mcmc) { + mcmc_counter <- i - (num_gfr + num_burnin) + if (mcmc_counter %% keep_every == 0) keep_sample <- T + else keep_sample <- F + } else { + if (keep_burnin) keep_sample <- T + else keep_sample <- F + } + if (keep_sample) sample_counter <- sample_counter + 1 + # Print progress + if (verbose) { + if (num_burnin > 0) { + if (((i - num_gfr) %% 100 == 0) || ((i - num_gfr) == num_burnin)) { + cat("Sampling", i - num_gfr, "out of", num_gfr, "BCF burn-in draws\n") + } + } + if (num_mcmc > 0) { + if (((i - num_gfr - num_burnin) %% 100 == 0) || (i == num_samples)) { + cat("Sampling", i - num_burnin - num_gfr, "out of", num_mcmc, "BCF MCMC draws\n") + } + } } - # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] - s_tt0 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==0)) - s_tt1 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==1)) - s_ty0 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==0)) - s_ty1 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==1)) + # Sample the prognostic forest + forest_model_mu$sample_one_iteration( + forest_dataset_train, outcome_train, forest_samples_mu, active_forest_mu, + rng, feature_types, 0, current_leaf_scale_mu, variable_weights_mu, a_forest, b_forest, + current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T + ) - # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z) - current_b_0 <- rnorm(1, (s_ty0/(s_tt0 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt0 + 2*current_sigma2))) - current_b_1 <- rnorm(1, (s_ty1/(s_tt1 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt1 + 2*current_sigma2))) + # Sample variance parameters (if requested) + if (sample_sigma_global) { + current_sigma2 <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, a_global, b_global) + } + if (sample_sigma_leaf_mu) { + leaf_scale_mu_double <- sample_tau_one_iteration(active_forest_mu, rng, a_leaf_mu, b_leaf_mu) + current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) + if (keep_sample) leaf_scale_mu_samples[sample_counter] <- leaf_scale_mu_double + } - # Update basis for the leaf regression - tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 - forest_dataset_train$update_basis(tau_basis_train) - b_0_samples[i] <- current_b_0 - b_1_samples[i] <- current_b_1 - if (has_test) { - tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 - forest_dataset_test$update_basis(tau_basis_test) + # Sample the treatment forest + forest_model_tau$sample_one_iteration( + forest_dataset_train, outcome_train, forest_samples_tau, active_forest_tau, + rng, feature_types, 1, current_leaf_scale_tau, variable_weights_tau, a_forest, b_forest, + current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T + ) + + # Sample coding parameters (if requested) + if (adaptive_coding) { + # Estimate mu(X) and tau(X) and compute y - mu(X) + mu_x_raw_train <- active_forest_mu$predict_raw(forest_dataset_train) + tau_x_raw_train <- active_forest_tau$predict_raw(forest_dataset_train) + partial_resid_mu_train <- resid_train - mu_x_raw_train + if (has_rfx) { + rfx_preds_train <- rfx_model$predict(rfx_dataset_train, rfx_tracker_train) + partial_resid_mu_train <- partial_resid_mu_train - rfx_preds_train + } + + # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] + s_tt0 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==0)) + s_tt1 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==1)) + s_ty0 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==0)) + s_ty1 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==1)) + + # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z) + current_b_0 <- rnorm(1, (s_ty0/(s_tt0 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt0 + 2*current_sigma2))) + current_b_1 <- rnorm(1, (s_ty1/(s_tt1 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt1 + 2*current_sigma2))) + + # Update basis for the leaf regression + tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 + forest_dataset_train$update_basis(tau_basis_train) + if (keep_sample) { + b_0_samples[sample_counter] <- current_b_0 + b_1_samples[sample_counter] <- current_b_1 + } + if (has_test) { + tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 + forest_dataset_test$update_basis(tau_basis_test) + } + + # Update leaf predictions and residual + forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) } - # Update leaf predictions and residual - forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, forest_samples_tau, i-1) + # Sample variance parameters (if requested) + if (include_variance_forest) { + forest_model_variance$sample_one_iteration( + forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance, + rng, feature_types, leaf_model_variance_forest, current_leaf_scale_mu, variable_weights_variance, + a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T + ) + } + if (sample_sigma_global) { + current_sigma2 <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, a_global, b_global) + if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 + } + if (sample_sigma_leaf_tau) { + leaf_scale_tau_double <- sample_tau_one_iteration(active_forest_tau, rng, a_leaf_tau, b_leaf_tau) + current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) + if (keep_sample) leaf_scale_tau_samples[sample_counter] <- leaf_scale_tau_double + } + + # Sample random effects parameters (if requested) + if (has_rfx) { + rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, keep_sample, current_sigma2, rng) + } } - - # Sample variance parameters (if requested) + } + } + + # Remove GFR samples if they are not to be retained + if ((!keep_gfr) && (num_gfr > 0)) { + for (i in 1:num_gfr) { + forest_samples_mu$delete_sample(i-1) + forest_samples_tau$delete_sample(i-1) if (include_variance_forest) { - forest_model_variance$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_variance, rng, feature_types, - leaf_model_variance_forest, current_leaf_scale_mu, variable_weights_variance, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = T - ) - } - if (sample_sigma_global) { - global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, a_global, b_global) - current_sigma2 <- global_var_samples[i] + forest_samples_variance$delete_sample(i-1) } - if (sample_sigma_leaf_tau) { - leaf_scale_tau_samples[i] <- sample_tau_one_iteration(forest_samples_tau, rng, a_leaf_tau, b_leaf_tau, i-1) - current_leaf_scale_tau <- as.matrix(leaf_scale_tau_samples[i]) - } - - # Sample random effects parameters (if requested) if (has_rfx) { - rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, current_sigma2, rng) + rfx_samples$delete_sample(i-1) } } + if (sample_sigma_global) { + global_var_samples <- global_var_samples[(num_gfr+1):length(global_var_samples)] + } + if (sample_sigma_leaf_mu) { + leaf_scale_mu_samples <- leaf_scale_mu_samples[(num_gfr+1):length(leaf_scale_mu_samples)] + } + if (sample_sigma_leaf_tau) { + leaf_scale_tau_samples <- leaf_scale_tau_samples[(num_gfr+1):length(leaf_scale_tau_samples)] + } + num_retained_samples <- num_retained_samples - num_gfr } - + # Forest predictions mu_hat_train <- forest_samples_mu$predict(forest_dataset_train)*y_std_train + y_bar_train if (adaptive_coding) { @@ -916,61 +1137,20 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU y_hat_test <- y_hat_test + rfx_preds_test } - # Compute retention indices - if (num_mcmc > 0) { - keep_indices = mcmc_indices - if (keep_gfr) keep_indices <- c(gfr_indices, keep_indices) - if (keep_burnin) keep_indices <- c(burnin_indices, keep_indices) - } else { - if ((num_gfr > 0) && (num_burnin > 0)) { - # Override keep_gfr = FALSE since there are no MCMC samples - # Don't retain both GFR and burnin samples - keep_indices = gfr_indices - } else if ((num_gfr <= 0) && (num_burnin > 0)) { - # Override keep_burnin = FALSE since there are no MCMC or GFR samples - keep_indices = burnin_indices - } else if ((num_gfr > 0) && (num_burnin <= 0)) { - # Override keep_gfr = FALSE since there are no MCMC samples - keep_indices = gfr_indices - } else { - stop("There are no samples to retain!") - } - } - - # Subset forest and RFX predictions - mu_hat_train <- mu_hat_train[,keep_indices] - tau_hat_train <- tau_hat_train[,keep_indices] - y_hat_train <- y_hat_train[,keep_indices] - if (has_rfx) { - rfx_preds_train <- rfx_preds_train[,keep_indices] - } - if (has_test) { - mu_hat_test <- mu_hat_test[,keep_indices] - tau_hat_test <- tau_hat_test[,keep_indices] - y_hat_test <- y_hat_test[,keep_indices] - if (has_rfx_test) { - rfx_preds_test <- rfx_preds_test[,keep_indices] - } - } - if (include_variance_forest) { - sigma_x_hat_train <- sigma_x_hat_train[,keep_indices] - if (has_test) sigma_x_hat_test <- sigma_x_hat_test[,keep_indices] - } - # Global error variance - if (sample_sigma_global) sigma2_samples <- global_var_samples[keep_indices]*(y_std_train^2) + if (sample_sigma_global) sigma2_samples <- global_var_samples*(y_std_train^2) # Leaf parameter variance for prognostic forest - if (sample_sigma_leaf_mu) sigma_leaf_mu_samples <- leaf_scale_mu_samples[keep_indices] + if (sample_sigma_leaf_mu) sigma_leaf_mu_samples <- leaf_scale_mu_samples # Leaf parameter variance for treatment effect forest - if (sample_sigma_leaf_tau) sigma_leaf_tau_samples <- leaf_scale_tau_samples[keep_indices] + if (sample_sigma_leaf_tau) sigma_leaf_tau_samples <- leaf_scale_tau_samples # Rescale variance forest prediction by global sigma2 (sampled or constant) if (include_variance_forest) { if (sample_sigma_global) { - sigma_x_hat_train <- sapply(1:length(keep_indices), function(i) sqrt(sigma_x_hat_train[,i]*sigma2_samples[i])) - if (has_test) sigma_x_hat_test <- sapply(1:length(keep_indices), function(i) sqrt(sigma_x_hat_test[,i]*sigma2_samples[i])) + sigma_x_hat_train <- sapply(1:num_retained_samples, function(i) sqrt(sigma_x_hat_train[,i]*sigma2_samples[i])) + if (has_test) sigma_x_hat_test <- sapply(1:num_retained_samples, function(i) sqrt(sigma_x_hat_test[,i]*sigma2_samples[i])) } else { sigma_x_hat_train <- sqrt(sigma_x_hat_train*sigma2_init)*y_std_train if (has_test) sigma_x_hat_test <- sqrt(sigma_x_hat_test*sigma2_init)*y_std_train @@ -1008,10 +1188,12 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU "propensity_covariate" = propensity_covariate, "binary_treatment" = binary_treatment, "adaptive_coding" = adaptive_coding, - "num_samples" = num_samples, + "num_samples" = num_retained_samples, "num_gfr" = num_gfr, "num_burnin" = num_burnin, "num_mcmc" = num_mcmc, + "keep_every" = keep_every, + "num_chains" = num_chains, "has_rfx" = has_rfx, "has_rfx_basis" = has_basis_rfx, "num_rfx_basis" = num_basis_rfx, @@ -1027,12 +1209,8 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU "mu_hat_train" = mu_hat_train, "tau_hat_train" = tau_hat_train, "y_hat_train" = y_hat_train, - "train_set_metadata" = X_train_metadata, - "keep_indices" = keep_indices + "train_set_metadata" = X_train_metadata ) - if (num_gfr > 0) result[["gfr_indices"]] = gfr_indices - if (num_burnin > 0) result[["burnin_indices"]] = burnin_indices - if (num_mcmc > 0) result[["mcmc_indices"]] = mcmc_indices if (has_test) result[["mu_hat_test"]] = mu_hat_test if (has_test) result[["tau_hat_test"]] = tau_hat_test if (has_test) result[["y_hat_test"]] = y_hat_test @@ -1116,9 +1294,11 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU #' tau_train <- tau_x[train_inds] #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train) #' preds <- predict(bcf_model, X_test, Z_test, pi_test) -#' # plot(rowMeans(preds$mu_hat), mu_test, xlab = "predicted", ylab = "actual", main = "Prognostic function") +#' # plot(rowMeans(preds$mu_hat), mu_test, xlab = "predicted", +#' # ylab = "actual", main = "Prognostic function") #' # abline(0,1,col="red",lty=3,lwd=3) -#' # plot(rowMeans(preds$tau_hat), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") +#' # plot(rowMeans(preds$tau_hat), tau_test, xlab = "predicted", +#' # ylab = "actual", main = "Treatment effect") #' # abline(0,1,col="red",lty=3,lwd=3) predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL, group_ids_test = NULL, rfx_basis_test = NULL){ # Preprocess covariates @@ -1193,6 +1373,7 @@ predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL, group_ids_test = NU prediction_dataset_tau <- createForestDataset(X_test_tau, Z_test) # Compute forest predictions + num_samples <- bcf$model_params$num_samples y_std <- bcf$model_params$outcome_scale y_bar <- bcf$model_params$outcome_mean initial_sigma2 <- bcf$model_params$initial_sigma2 @@ -1216,21 +1397,11 @@ predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL, group_ids_test = NU y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test) if (bcf$model_params$has_rfx) y_hat_test <- y_hat_test + rfx_predictions - # Restrict predictions to the "retained" samples (if applicable) - keep_indices = bcf$keep_indices - mu_hat_test <- mu_hat_test[,keep_indices] - tau_hat_test <- tau_hat_test[,keep_indices] - y_hat_test <- y_hat_test[,keep_indices] - if (bcf$model_params$has_rfx) rfx_predictions <- rfx_predictions[,keep_indices] - if (bcf$model_params$include_variance_forest) { - s_x_raw <- s_x_raw[,keep_indices] - } - # Scale variance forest predictions if (bcf$model_params$include_variance_forest) { if (bcf$model_params$sample_sigma_global) { sigma2_samples <- bcf$sigma2_global_samples - variance_forest_predictions <- sapply(1:length(keep_indices), function(i) sqrt(s_x_raw[,i]*sigma2_samples[i])) + variance_forest_predictions <- sapply(1:num_samples, function(i) sqrt(s_x_raw[,i]*sigma2_samples[i])) } else { variance_forest_predictions <- sqrt(s_x_raw*initial_sigma2)*y_std } @@ -1310,13 +1481,14 @@ predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL, group_ids_test = NU #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] +#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' pi_train = pi_train, group_ids_train = group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, -#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +#' params = bcf_params) #' rfx_samples <- getRandomEffectSamples(bcf_model) getRandomEffectSamples.bcf <- function(object, ...){ result = list() @@ -1396,13 +1568,14 @@ getRandomEffectSamples.bcf <- function(object, ...){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] +#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' pi_train = pi_train, group_ids_train = group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, -#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +#' params = bcf_params) #' # bcf_json <- convertBCFModelToJson(bcf_model) convertBCFModelToJson <- function(object){ jsonobj <- createCppJson() @@ -1452,8 +1625,9 @@ convertBCFModelToJson <- function(object){ jsonobj$add_scalar("num_burnin", object$model_params$num_burnin) jsonobj$add_scalar("num_mcmc", object$model_params$num_mcmc) jsonobj$add_scalar("num_samples", object$model_params$num_samples) + jsonobj$add_scalar("keep_every", object$model_params$keep_every) + jsonobj$add_scalar("num_chains", object$model_params$num_chains) jsonobj$add_scalar("num_covariates", object$model_params$num_covariates) - jsonobj$add_vector("keep_indices", object$keep_indices) if (object$model_params$sample_sigma_global) { jsonobj$add_vector("sigma2_samples", object$sigma2_samples, "parameters") } @@ -1536,13 +1710,14 @@ convertBCFModelToJson <- function(object){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] +#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' pi_train = pi_train, group_ids_train = group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, -#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +#' params = bcf_params) #' # saveBCFModelToJsonFile(bcf_model, "test.json") saveBCFModelToJsonFile <- function(object, filename){ # Convert to Json @@ -1609,13 +1784,14 @@ saveBCFModelToJsonFile <- function(object, filename){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] +#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' pi_train = pi_train, group_ids_train = group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, -#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +#' params = bcf_params) #' # saveBCFModelToJsonString(bcf_model) saveBCFModelToJsonString <- function(object){ # Convert to Json @@ -1684,13 +1860,14 @@ saveBCFModelToJsonString <- function(object){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] +#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' pi_train = pi_train, group_ids_train = group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, -#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +#' params = bcf_params) #' # bcf_json <- convertBCFModelToJson(bcf_model) #' # bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) createBCFModelFromJson <- function(json_object){ @@ -1722,8 +1899,7 @@ createBCFModelFromJson <- function(json_object){ train_set_metadata[["unordered_unique_levels"]] <- json_object$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]]) } output[["train_set_metadata"]] <- train_set_metadata - output[["keep_indices"]] <- json_object$get_vector("keep_indices") - + # Unpack model params model_params = list() model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") @@ -1830,13 +2006,14 @@ createBCFModelFromJson <- function(json_object){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] +#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' pi_train = pi_train, group_ids_train = group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, -#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +#' params = bcf_params) #' # saveBCFModelToJsonFile(bcf_model, "test.json") #' # bcf_model_roundtrip <- createBCFModelFromJsonFile("test.json") createBCFModelFromJsonFile <- function(json_filename){ @@ -1913,8 +2090,7 @@ createBCFModelFromJsonFile <- function(json_filename){ #' rfx_basis_train = rfx_basis_train, X_test = X_test, #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, #' rfx_basis_test = rfx_basis_test, -#' num_gfr = 100, num_burnin = 0, num_mcmc = 100, -#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +#' num_gfr = 100, num_burnin = 0, num_mcmc = 100) #' # bcf_json <- saveBCFModelToJsonString(bcf_model) #' # bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json) createBCFModelFromJsonString <- function(json_string){ @@ -1926,3 +2102,214 @@ createBCFModelFromJsonString <- function(json_string){ return(bcf_object) } + +#' Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object +#' which can be used for prediction, etc... +#' +#' @param json_string_list List of JSON strings which can be parsed to objects of type `CppJson` containing Json representation of a BART model +#' +#' @return Object of type `bartmodel` +#' @export +#' +#' @examples +#' n <- 100 +#' p <- 5 +#' x1 <- rnorm(n) +#' x2 <- rnorm(n) +#' x3 <- rnorm(n) +#' x4 <- rnorm(n) +#' x5 <- rnorm(n) +#' X <- cbind(x1,x2,x3,x4,x5) +#' p <- ncol(X) +#' g <- function(x) {ifelse(x[,5] < -0.44,2,ifelse(x[,5] < 0.44,-1,4))} +#' mu1 <- function(x) {1+g(x)+x[,1]*x[,3]} +#' mu2 <- function(x) {1+g(x)+6*abs(x[,3]-1)} +#' tau1 <- function(x) {rep(3,nrow(x))} +#' tau2 <- function(x) {1+2*x[,2]*(x[,4] > 0)} +#' mu_x <- mu1(X) +#' tau_x <- tau2(X) +#' pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 +#' Z <- rbinom(n,1,pi_x) +#' E_XZ <- mu_x + Z*tau_x +#' snr <- 3 +#' group_ids <- rep(c(1,2), n %/% 2) +#' rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) +#' rfx_basis <- cbind(1, runif(n, -1, 1)) +#' rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) +#' y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) +#' X <- as.data.frame(X) +#' X$x4 <- factor(X$x4, ordered = TRUE) +#' X$x5 <- factor(X$x5, ordered = TRUE) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' pi_test <- pi_x[test_inds] +#' pi_train <- pi_x[train_inds] +#' Z_test <- Z[test_inds] +#' Z_train <- Z[train_inds] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' mu_test <- mu_x[test_inds] +#' mu_train <- mu_x[train_inds] +#' tau_test <- tau_x[test_inds] +#' tau_train <- tau_x[train_inds] +#' group_ids_test <- group_ids[test_inds] +#' group_ids_train <- group_ids[train_inds] +#' rfx_basis_test <- rfx_basis[test_inds,] +#' rfx_basis_train <- rfx_basis[train_inds,] +#' rfx_term_test <- rfx_term[test_inds] +#' rfx_term_train <- rfx_term[train_inds] +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' pi_train = pi_train, group_ids_train = group_ids_train, +#' rfx_basis_train = rfx_basis_train, X_test = X_test, +#' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, +#' rfx_basis_test = rfx_basis_test, +#' num_gfr = 100, num_burnin = 0, num_mcmc = 100) +#' # bcf_json_string_list <- list(saveBCFModelToJsonString(bcf_model)) +#' # bcf_model_roundtrip <- createBCFModelFromCombinedJsonString(bcf_json_string_list) +createBCFModelFromCombinedJsonString <- function(json_string_list){ + # Initialize the BCF model + output <- list() + + # Convert JSON strings + json_object_list <- list() + for (i in 1:length(json_string_list)) { + json_string <- json_string_list[[i]] + json_object_list[[i]] <- createCppJsonString(json_string) + } + + # For scalar / preprocessing details which aren't sample-dependent, + # defer to the first json + json_object_default <- json_object_list[[1]] + + # Unpack the forests + output[["forests_mu"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0") + output[["forests_tau"]] <- loadForestContainerCombinedJson(json_object_list, "forest_1") + include_variance_forest <- json_object_default$get_boolean("include_variance_forest") + if (include_variance_forest) { + output[["forests_variance"]] <- loadForestContainerCombinedJson(json_object_list, "forest_2") + } + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar("num_numeric_vars") + train_set_metadata[["num_ordered_cat_vars"]] <- json_object_default$get_scalar("num_ordered_cat_vars") + train_set_metadata[["num_unordered_cat_vars"]] <- json_object_default$get_scalar("num_unordered_cat_vars") + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[["numeric_vars"]] <- json_object_default$get_string_vector("numeric_vars") + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[["ordered_cat_vars"]] <- json_object_default$get_string_vector("ordered_cat_vars") + train_set_metadata[["ordered_unique_levels"]] <- json_object_default$get_string_list("ordered_unique_levels", train_set_metadata[["ordered_cat_vars"]]) + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { + train_set_metadata[["unordered_cat_vars"]] <- json_object_default$get_string_vector("unordered_cat_vars") + train_set_metadata[["unordered_unique_levels"]] <- json_object_default$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]]) + } + output[["train_set_metadata"]] <- train_set_metadata + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale") + model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") + model_params[["standardize"]] <- json_object_default$get_boolean("standardize") + model_params[["initial_sigma2"]] <- json_object_default$get_scalar("initial_sigma2") + model_params[["sample_sigma_global"]] <- json_object_default$get_boolean("sample_sigma_global") + model_params[["sample_sigma_leaf_mu"]] <- json_object_default$get_boolean("sample_sigma_leaf_mu") + model_params[["sample_sigma_leaf_tau"]] <- json_object_default$get_boolean("sample_sigma_leaf_tau") + model_params[["include_variance_forest"]] <- include_variance_forest + model_params[["propensity_covariate"]] <- json_object_default$get_string("propensity_covariate") + model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean("has_rfx_basis") + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar("num_rfx_basis") + model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates") + model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") + model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") + model_params[["adaptive_coding"]] <- json_object_default$get_boolean("adaptive_coding") + + # Combine values that are sample-specific + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar("num_samples") + } else { + prev_json <- json_object_list[[i-1]] + model_params[["num_gfr"]] <- model_params[["num_gfr"]] + json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- model_params[["num_burnin"]] + json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- model_params[["num_samples"]] + json_object$get_scalar("num_samples") + } + } + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma_global"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_samples"]] <- json_object$get_vector("sigma2_samples", "parameters") + } else { + output[["sigma2_samples"]] <- c(output[["sigma2_samples"]], json_object$get_vector("sigma2_samples", "parameters")) + } + } + } + if (model_params[["sample_sigma_leaf_mu"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma_leaf_mu_samples"]] <- json_object$get_vector("sigma_leaf_mu_samples", "parameters") + } else { + output[["sigma_leaf_mu_samples"]] <- c(output[["sigma_leaf_mu_samples"]], json_object$get_vector("sigma_leaf_mu_samples", "parameters")) + } + } + } + if (model_params[["sample_sigma_leaf_tau"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma_leaf_tau_samples"]] <- json_object$get_vector("sigma_leaf_tau_samples", "parameters") + } else { + output[["sigma_leaf_tau_samples"]] <- c(output[["sigma_leaf_tau_samples"]], json_object$get_vector("sigma_leaf_tau_samples", "parameters")) + } + } + } + if (model_params[["sample_sigma_leaf_tau"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma_leaf_tau_samples"]] <- json_object$get_vector("sigma_leaf_tau_samples", "parameters") + } else { + output[["sigma_leaf_tau_samples"]] <- c(output[["sigma_leaf_tau_samples"]], json_object$get_vector("sigma_leaf_tau_samples", "parameters")) + } + } + } + if (model_params[["adaptive_coding"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["b_1_samples"]] <- json_object$get_vector("b_1_samples", "parameters") + output[["b_0_samples"]] <- json_object$get_vector("b_0_samples", "parameters") + } else { + output[["b_1_samples"]] <- c(output[["b_1_samples"]], json_object$get_vector("b_1_samples", "parameters")) + output[["b_0_samples"]] <- c(output[["b_0_samples"]], json_object$get_vector("b_0_samples", "parameters")) + } + } + } + + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[["rfx_unique_group_ids"]] <- json_object_default$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(json_object_list, 0) + } + + class(output) <- "bcf" + return(output) +} + diff --git a/R/calibration.R b/R/calibration.R index 998ec419..6df0575c 100644 --- a/R/calibration.R +++ b/R/calibration.R @@ -1,13 +1,13 @@ -#' Calibrate the scale parameter on an inverse gamma prior for the global error variance as in Chipman et al (2022) [1] +#' Calibrate the scale parameter on an inverse gamma prior for the global error variance as in Chipman et al (2022) #' -#' [1] Chipman, H., George, E., Hahn, R., McCulloch, R., Pratola, M. and Sparapani, R. (2022). Bayesian Additive Regression Trees, Computational Approaches. In Wiley StatsRef: Statistics Reference Online (eds N. Balakrishnan, T. Colton, B. Everitt, W. Piegorsch, F. Ruggeri and J.L. Teugels). https://doi.org/10.1002/9781118445112.stat08288 +#' Chipman, H., George, E., Hahn, R., McCulloch, R., Pratola, M. and Sparapani, R. (2022). Bayesian Additive Regression Trees, Computational Approaches. In Wiley StatsRef: Statistics Reference Online (eds N. Balakrishnan, T. Colton, B. Everitt, W. Piegorsch, F. Ruggeri and J.L. Teugels). https://doi.org/10.1002/9781118445112.stat08288 #' #' @param y Outcome to be modeled using BART, BCF or another nonparametric ensemble method. #' @param X Covariates to be used to partition trees in an ensemble or series of ensemble. -#' @param W [Optional] Basis used to define a "leaf regression" model for each decision tree. The "classic" BART model assumes a constant leaf parameter, which is equivalent to a "leaf regression" on a basis of all ones, though it is not necessary to pass a vector of ones, here or to the BART function. Default: `NULL`. +#' @param W (Optional) Basis used to define a "leaf regression" model for each decision tree. The "classic" BART model assumes a constant leaf parameter, which is equivalent to a "leaf regression" on a basis of all ones, though it is not necessary to pass a vector of ones, here or to the BART function. Default: `NULL`. #' @param nu The shape parameter for the global error variance's IG prior. The scale parameter in the Sparapani et al (2021) parameterization is defined as `nu*lambda` where `lambda` is the output of this function. Default: `3`. -#' @param quant [Optional] Quantile of the inverse gamma prior distribution represented by a linear-regression-based overestimate of `sigma^2`. Default: `0.9`. -#' @param standardize [Optional] Whether or not outcome should be standardized (`(y-mean(y))/sd(y)`) before calibration of `lambda`. Default: `TRUE`. +#' @param quant (Optional) Quantile of the inverse gamma prior distribution represented by a linear-regression-based overestimate of `sigma^2`. Default: `0.9`. +#' @param standardize (Optional) Whether or not outcome should be standardized (`(y-mean(y))/sd(y)`) before calibration of `lambda`. Default: `TRUE`. #' #' @return Value of `lambda` which determines the scale parameter of the global error variance prior (`sigma^2 ~ IG(nu,nu*lambda)`) #' @export diff --git a/R/cpp11.R b/R/cpp11.R index c36a3cd7..3cc45075 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -144,8 +144,8 @@ rfx_label_mapper_cpp <- function(rfx_tracker) { .Call(`_stochtree_rfx_label_mapper_cpp`, rfx_tracker) } -rfx_model_sample_random_effects_cpp <- function(rfx_model, rfx_dataset, residual, rfx_tracker, rfx_container, global_variance, rng) { - invisible(.Call(`_stochtree_rfx_model_sample_random_effects_cpp`, rfx_model, rfx_dataset, residual, rfx_tracker, rfx_container, global_variance, rng)) +rfx_model_sample_random_effects_cpp <- function(rfx_model, rfx_dataset, residual, rfx_tracker, rfx_container, keep_sample, global_variance, rng) { + invisible(.Call(`_stochtree_rfx_model_sample_random_effects_cpp`, rfx_model, rfx_dataset, residual, rfx_tracker, rfx_container, keep_sample, global_variance, rng)) } rfx_model_predict_cpp <- function(rfx_model, rfx_dataset, rfx_tracker) { @@ -168,6 +168,10 @@ rfx_container_num_groups_cpp <- function(rfx_container) { .Call(`_stochtree_rfx_container_num_groups_cpp`, rfx_container) } +rfx_container_delete_sample_cpp <- function(rfx_container, sample_num) { + invisible(.Call(`_stochtree_rfx_container_delete_sample_cpp`, rfx_container, sample_num)) +} + rfx_model_set_working_parameter_cpp <- function(rfx_model, working_param_init) { invisible(.Call(`_stochtree_rfx_model_set_working_parameter_cpp`, rfx_model, working_param_init)) } @@ -216,6 +220,22 @@ rfx_label_mapper_to_list_cpp <- function(label_mapper_ptr) { .Call(`_stochtree_rfx_label_mapper_to_list_cpp`, label_mapper_ptr) } +reset_rfx_model_cpp <- function(rfx_model, rfx_container, sample_num) { + invisible(.Call(`_stochtree_reset_rfx_model_cpp`, rfx_model, rfx_container, sample_num)) +} + +reset_rfx_tracker_cpp <- function(tracker, dataset, residual, rfx_model) { + invisible(.Call(`_stochtree_reset_rfx_tracker_cpp`, tracker, dataset, residual, rfx_model)) +} + +root_reset_rfx_tracker_cpp <- function(tracker, dataset, residual, rfx_model) { + invisible(.Call(`_stochtree_root_reset_rfx_tracker_cpp`, tracker, dataset, residual, rfx_model)) +} + +active_forest_cpp <- function(num_trees, output_dimension, is_leaf_constant, is_exponentiated) { + .Call(`_stochtree_active_forest_cpp`, num_trees, output_dimension, is_leaf_constant, is_exponentiated) +} + forest_container_cpp <- function(num_trees, output_dimension, is_leaf_constant, is_exponentiated) { .Call(`_stochtree_forest_container_cpp`, num_trees, output_dimension, is_leaf_constant, is_exponentiated) } @@ -280,6 +300,10 @@ is_leaf_constant_forest_container_cpp <- function(forest_samples) { .Call(`_stochtree_is_leaf_constant_forest_container_cpp`, forest_samples) } +is_exponentiated_forest_container_cpp <- function(forest_samples) { + .Call(`_stochtree_is_exponentiated_forest_container_cpp`, forest_samples) +} + all_roots_forest_container_cpp <- function(forest_samples, forest_num) { .Call(`_stochtree_all_roots_forest_container_cpp`, forest_samples, forest_num) } @@ -412,6 +436,10 @@ propagate_basis_update_forest_container_cpp <- function(data, residual, forest_s invisible(.Call(`_stochtree_propagate_basis_update_forest_container_cpp`, data, residual, forest_samples, tracker, forest_num)) } +remove_sample_forest_container_cpp <- function(forest_samples, forest_num) { + invisible(.Call(`_stochtree_remove_sample_forest_container_cpp`, forest_samples, forest_num)) +} + predict_forest_cpp <- function(forest_samples, dataset) { .Call(`_stochtree_predict_forest_cpp`, forest_samples, dataset) } @@ -428,6 +456,98 @@ predict_forest_raw_single_tree_cpp <- function(forest_samples, dataset, forest_n .Call(`_stochtree_predict_forest_raw_single_tree_cpp`, forest_samples, dataset, forest_num, tree_num) } +predict_active_forest_cpp <- function(active_forest, dataset) { + .Call(`_stochtree_predict_active_forest_cpp`, active_forest, dataset) +} + +predict_raw_active_forest_cpp <- function(active_forest, dataset) { + .Call(`_stochtree_predict_raw_active_forest_cpp`, active_forest, dataset) +} + +output_dimension_active_forest_cpp <- function(active_forest) { + .Call(`_stochtree_output_dimension_active_forest_cpp`, active_forest) +} + +average_max_depth_active_forest_cpp <- function(active_forest) { + .Call(`_stochtree_average_max_depth_active_forest_cpp`, active_forest) +} + +num_trees_active_forest_cpp <- function(active_forest) { + .Call(`_stochtree_num_trees_active_forest_cpp`, active_forest) +} + +ensemble_tree_max_depth_active_forest_cpp <- function(active_forest, tree_num) { + .Call(`_stochtree_ensemble_tree_max_depth_active_forest_cpp`, active_forest, tree_num) +} + +is_leaf_constant_active_forest_cpp <- function(active_forest) { + .Call(`_stochtree_is_leaf_constant_active_forest_cpp`, active_forest) +} + +is_exponentiated_active_forest_cpp <- function(active_forest) { + .Call(`_stochtree_is_exponentiated_active_forest_cpp`, active_forest) +} + +all_roots_active_forest_cpp <- function(active_forest) { + .Call(`_stochtree_all_roots_active_forest_cpp`, active_forest) +} + +set_leaf_value_active_forest_cpp <- function(active_forest, leaf_value) { + invisible(.Call(`_stochtree_set_leaf_value_active_forest_cpp`, active_forest, leaf_value)) +} + +set_leaf_vector_active_forest_cpp <- function(active_forest, leaf_vector) { + invisible(.Call(`_stochtree_set_leaf_vector_active_forest_cpp`, active_forest, leaf_vector)) +} + +add_numeric_split_tree_value_active_forest_cpp <- function(active_forest, tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value) { + invisible(.Call(`_stochtree_add_numeric_split_tree_value_active_forest_cpp`, active_forest, tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value)) +} + +add_numeric_split_tree_vector_active_forest_cpp <- function(active_forest, tree_num, leaf_num, feature_num, split_threshold, left_leaf_vector, right_leaf_vector) { + invisible(.Call(`_stochtree_add_numeric_split_tree_vector_active_forest_cpp`, active_forest, tree_num, leaf_num, feature_num, split_threshold, left_leaf_vector, right_leaf_vector)) +} + +get_tree_leaves_active_forest_cpp <- function(active_forest, tree_num) { + .Call(`_stochtree_get_tree_leaves_active_forest_cpp`, active_forest, tree_num) +} + +get_tree_split_counts_active_forest_cpp <- function(active_forest, tree_num, num_features) { + .Call(`_stochtree_get_tree_split_counts_active_forest_cpp`, active_forest, tree_num, num_features) +} + +get_overall_split_counts_active_forest_cpp <- function(active_forest, num_features) { + .Call(`_stochtree_get_overall_split_counts_active_forest_cpp`, active_forest, num_features) +} + +get_granular_split_count_array_active_forest_cpp <- function(active_forest, num_features) { + .Call(`_stochtree_get_granular_split_count_array_active_forest_cpp`, active_forest, num_features) +} + +initialize_forest_model_active_forest_cpp <- function(data, residual, active_forest, tracker, init_values, leaf_model_int) { + invisible(.Call(`_stochtree_initialize_forest_model_active_forest_cpp`, data, residual, active_forest, tracker, init_values, leaf_model_int)) +} + +adjust_residual_active_forest_cpp <- function(data, residual, active_forest, tracker, requires_basis, add) { + invisible(.Call(`_stochtree_adjust_residual_active_forest_cpp`, data, residual, active_forest, tracker, requires_basis, add)) +} + +propagate_basis_update_active_forest_cpp <- function(data, residual, active_forest, tracker) { + invisible(.Call(`_stochtree_propagate_basis_update_active_forest_cpp`, data, residual, active_forest, tracker)) +} + +reset_active_forest_cpp <- function(active_forest, forest_samples, forest_num) { + invisible(.Call(`_stochtree_reset_active_forest_cpp`, active_forest, forest_samples, forest_num)) +} + +reset_forest_model_cpp <- function(forest_tracker, forest, data, residual, is_mean_model) { + invisible(.Call(`_stochtree_reset_forest_model_cpp`, forest_tracker, forest, data, residual, is_mean_model)) +} + +root_reset_active_forest_cpp <- function(active_forest) { + invisible(.Call(`_stochtree_root_reset_active_forest_cpp`, active_forest)) +} + forest_container_get_max_leaf_index_cpp <- function(forest_container, forest_num) { .Call(`_stochtree_forest_container_get_max_leaf_index_cpp`, forest_container, forest_num) } @@ -436,20 +556,20 @@ compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums) .Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums) } -sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized) { - invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized)) +sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized) { + invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized)) } -sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized) { - invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized)) +sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized) { + invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized)) } sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) { .Call(`_stochtree_sample_sigma2_one_iteration_cpp`, residual, dataset, rng, a, b) } -sample_tau_one_iteration_cpp <- function(forest_samples, rng, a, b, sample_num) { - .Call(`_stochtree_sample_tau_one_iteration_cpp`, forest_samples, rng, a, b, sample_num) +sample_tau_one_iteration_cpp <- function(active_forest, rng, a, b) { + .Call(`_stochtree_sample_tau_one_iteration_cpp`, active_forest, rng, a, b) } rng_cpp <- function(random_seed) { diff --git a/R/forest.R b/R/forest.R index b799e48a..ff8eb52c 100644 --- a/R/forest.R +++ b/R/forest.R @@ -229,6 +229,20 @@ ForestSamples <- R6::R6Class( return(output_dimension_forest_container_cpp(self$forest_container_ptr)) }, + #' @description + #' Return constant leaf status of trees in a `ForestContainer` object + #' @return `T` if leaves are constant, `F` otherwise + is_constant_leaf = function() { + return(is_constant_leaf_forest_container_cpp(self$forest_container_ptr)) + }, + + #' @description + #' Return exponentiation status of trees in a `ForestContainer` object + #' @return `T` if leaf predictions must be exponentiated, `F` otherwise + is_exponentiated = function() { + return(is_exponentiated_forest_container_cpp(self$forest_container_ptr)) + }, + #' @description #' Add a new all-root ensemble to the container, with all of the leaves #' set to the value / vector provided @@ -242,7 +256,7 @@ ForestSamples <- R6::R6Class( }, #' @description - #' Add a numeric (i.e. X[,i] <= c) split to a given tree in the ensemble + #' Add a numeric (i.e. `X[,i] <= c`) split to a given tree in the ensemble #' @param forest_num Index of the forest which contains the tree to be split #' @param tree_num Index of the tree to be split #' @param leaf_num Leaf to be split @@ -520,6 +534,217 @@ ForestSamples <- R6::R6Class( #' @return Indices of leaf nodes leaves = function(forest_num, tree_num) { return(leaves_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num)) + }, + + #' @description + #' Modify the ``ForestSamples`` object by removing the forest sample indexed by `forest_num + #' @param forest_num Index of the forest to be removed + delete_sample = function(forest_num) { + return(remove_sample_forest_container_cpp(self$forest_container_ptr, forest_num)) + } + ) +) + +#' Class that stores a single ensemble of decision trees (often treated as the "active forest") +#' +#' @description +#' Wrapper around a C++ tree ensemble + +Forest <- R6::R6Class( + classname = "Forest", + cloneable = FALSE, + public = list( + + #' @field forest_ptr External pointer to a C++ TreeEnsemble class + forest_ptr = NULL, + + #' @description + #' Create a new Forest object. + #' @param num_trees Number of trees in the forest + #' @param output_dimension Dimensionality of the outcome model + #' @param is_leaf_constant Whether leaf is constant + #' @param is_exponentiated Whether forest predictions should be exponentiated before being returned + #' @return A new `Forest` object. + initialize = function(num_trees, output_dimension=1, is_leaf_constant=F, is_exponentiated=F) { + self$forest_ptr <- active_forest_cpp(num_trees, output_dimension, is_leaf_constant, is_exponentiated) + }, + + #' @description + #' Predict forest on every sample in `forest_dataset` + #' @param forest_dataset `ForestDataset` R class + #' @return vector of predictions with as many rows as in `forest_dataset` + predict = function(forest_dataset) { + stopifnot(!is.null(forest_dataset$data_ptr)) + stopifnot(!is.null(self$forest_ptr)) + return(predict_active_forest_cpp(self$forest_ptr, forest_dataset$data_ptr)) + }, + + #' @description + #' Predict "raw" leaf values (without being multiplied by basis) for every sample in `forest_dataset` + #' @param forest_dataset `ForestDataset` R class + #' @return Array of predictions for each observation in `forest_dataset` and + #' each sample in the `ForestSamples` class with each prediction having the + #' dimensionality of the forests' leaf model. In the case of a constant leaf model + #' or univariate leaf regression, this array is a vector (length is the number of + #' observations). In the case of a multivariate leaf regression, + #' this array is a matrix (number of observations by leaf model dimension, + #' number of samples). + predict_raw = function(forest_dataset) { + stopifnot(!is.null(forest_dataset$data_ptr)) + # Unpack dimensions + output_dim <- output_dimension_active_forest_cpp(self$forest_ptr) + n <- dataset_num_rows_cpp(forest_dataset$data_ptr) + + # Predict leaf values from forest + predictions <- predict_raw_active_forest_cpp(self$forest_ptr, forest_dataset$data_ptr) + if (output_dim > 1) { + dim(predictions) <- c(n, output_dim) + } + + return(predictions) + }, + + #' @description + #' Set a constant predicted value for every tree in the ensemble. + #' Stops program if any tree is more than a root node. + #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. + set_root_leaves = function(leaf_value) { + stopifnot(!is.null(self$forest_ptr)) + + # Set leaf values + if (length(leaf_value) == 1) { + stopifnot(output_dimension_active_forest_cpp(self$forest_ptr) == 1) + set_leaf_value_active_forest_cpp(self$forest_ptr, leaf_value) + } else if (length(leaf_value) > 1) { + stopifnot(output_dimension_active_forest_cpp(self$forest_ptr) == length(leaf_value)) + set_leaf_vector_active_forest_cpp(self$forest_ptr, leaf_value) + } else { + stop("leaf_value must be a numeric value or vector of length >= 1") + } + }, + + #' @description + #' Set a constant predicted value for every tree in the ensemble. + #' Stops program if any tree is more than a root node. + #' @param dataset `ForestDataset` Dataset class (covariates, basis, etc...) + #' @param outcome `Outcome` Outcome class (residual / partial residual) + #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling + #' @param leaf_model_int Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance). + #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. + prepare_for_sampler = function(dataset, outcome, forest_model, leaf_model_int, leaf_value) { + stopifnot(!is.null(dataset$data_ptr)) + stopifnot(!is.null(outcome$data_ptr)) + stopifnot(!is.null(forest_model$tracker_ptr)) + stopifnot(!is.null(self$forest_ptr)) + + # Initialize the model + initialize_forest_model_active_forest_cpp( + dataset$data_ptr, outcome$data_ptr, self$forest_ptr, + forest_model$tracker_ptr, leaf_value, leaf_model_int + ) + }, + + #' @description + #' Adjusts residual based on the predictions of a forest + #' + #' This is typically run just once at the beginning of a forest sampling algorithm. + #' After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual. + #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest + #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions + #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling + #' @param requires_basis Whether or not a forest requires a basis for prediction + #' @param add Whether forest predictions should be added to or subtracted from residuals + adjust_residual = function(dataset, outcome, forest_model, requires_basis, add) { + stopifnot(!is.null(dataset$data_ptr)) + stopifnot(!is.null(outcome$data_ptr)) + stopifnot(!is.null(forest_model$tracker_ptr)) + stopifnot(!is.null(self$forest_ptr)) + + adjust_residual_active_forest_cpp( + dataset$data_ptr, outcome$data_ptr, self$forest_ptr, + forest_model$tracker_ptr, requires_basis, add + ) + }, + + #' @description + #' Return number of trees in each ensemble of a `Forest` object + #' @return Tree count + num_trees = function() { + return(num_trees_active_forest_cpp(self$forest_ptr)) + }, + + #' @description + #' Return output dimension of trees in a `Forest` object + #' @return Leaf node parameter size + output_dimension = function() { + return(output_dimension_active_forest_cpp(self$forest_ptr)) + }, + + #' @description + #' Return constant leaf status of trees in a `Forest` object + #' @return `T` if leaves are constant, `F` otherwise + is_constant_leaf = function() { + return(is_constant_leaf_active_forest_cpp(self$forest_ptr)) + }, + + #' @description + #' Return exponentiation status of trees in a `Forest` object + #' @return `T` if leaf predictions must be exponentiated, `F` otherwise + is_exponentiated = function() { + return(is_exponentiated_active_forest_cpp(self$forest_ptr)) + }, + + #' @description + #' Add a numeric (i.e. `X[,i] <= c`) split to a given tree in the ensemble + #' @param tree_num Index of the tree to be split + #' @param leaf_num Leaf to be split + #' @param feature_num Feature that defines the new split + #' @param split_threshold Value that defines the cutoff of the new split + #' @param left_leaf_value Value (or vector of values) to assign to the newly created left node + #' @param right_leaf_value Value (or vector of values) to assign to the newly created right node + add_numeric_split_tree = function(tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value) { + if (length(left_leaf_value) > 1) { + add_numeric_split_tree_vector_active_forest_cpp(self$forest_ptr, tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value) + } else { + add_numeric_split_tree_value_active_forest_cpp(self$forest_ptr, tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value) + } + }, + + #' @description + #' Retrieve a vector of indices of leaf nodes for a given tree in a given forest + #' @param tree_num Index of the tree for which leaf indices will be retrieved + get_tree_leaves = function(tree_num) { + return(get_tree_leaves_active_forest_cpp(self$forest_ptr, tree_num)) + }, + + #' @description + #' Retrieve a vector of split counts for every training set variable in a given tree in the forest + #' @param tree_num Index of the tree for which split counts will be retrieved + #' @param num_features Total number of features in the training set + get_tree_split_counts = function(tree_num, num_features) { + return(get_tree_split_counts_active_forest_cpp(self$forest_ptr, tree_num, num_features)) + }, + + #' @description + #' Retrieve a vector of split counts for every training set variable in the forest + #' @param num_features Total number of features in the training set + get_forest_split_counts = function(num_features) { + return(get_forest_split_counts_active_forest_cpp(self$forest_ptr, num_features)) + }, + + #' @description + #' Maximum depth of a specific tree in the forest + #' @param tree_num Tree index within forest + #' @return Maximum leaf depth + tree_max_depth = function(tree_num) { + return(ensemble_tree_max_depth_active_forest_cpp(self$forest_ptr, tree_num)) + }, + + #' @description + #' Average the maximum depth of each tree in the forest + #' @return Average maximum depth + average_max_depth = function() { + return(ensemble_average_max_depth_active_forest_cpp(self$forest_ptr)) } ) ) @@ -538,3 +763,50 @@ createForestContainer <- function(num_trees, output_dimension=1, is_leaf_constan ForestSamples$new(num_trees, output_dimension, is_leaf_constant, is_exponentiated) ))) } + +#' Create a forest +#' +#' @param num_trees Number of trees in the forest +#' @param output_dimension Dimensionality of the outcome model +#' @param is_leaf_constant Whether leaf is constant +#' @param is_exponentiated Whether forest predictions should be exponentiated before being returned +#' +#' @return `Forest` object +#' @export +createForest <- function(num_trees, output_dimension=1, is_leaf_constant=F, is_exponentiated=F) { + return(invisible(( + Forest$new(num_trees, output_dimension, is_leaf_constant, is_exponentiated) + ))) +} + +#' Re-initialize an active forest from a specific forest in a `ForestContainer` +#' +#' @param active_forest Current active forest +#' @param forest_samples Container of forest samples from which to re-initialize active forest +#' @param forest_num Index of forest samples from which to initialize active forest +#' @export +resetActiveForest <- function(active_forest, forest_samples, forest_num) { + reset_active_forest_cpp(active_forest$forest_ptr, forest_samples$forest_container_ptr, forest_num) +} + +#' Re-initialize a forest model (tracking data structures) from a specific forest in a `ForestContainer` +#' +#' @param forest_model Forest model with tracking data structures +#' @param forest Forest from which to re-initialize forest model +#' @param dataset Training dataset object +#' @param residual Residual which will also be updated +#' @param is_mean_model Whether the model being updated is a conditional mean model +#' @export +resetForestModel <- function(forest_model, forest, dataset, residual, is_mean_model) { + reset_forest_model_cpp(forest_model$tracker_ptr, forest$forest_ptr, dataset$data_ptr, residual$data_ptr, is_mean_model) +} + +#' Reset an active forest to an ensemble of single-node (i.e. root) trees +#' +#' @param active_forest Current active forest +#' +#' @return `Forest` object +#' @export +rootResetActiveForest <- function(active_forest) { + root_reset_active_forest_cpp(active_forest$forest_ptr) +} diff --git a/R/model.R b/R/model.R index 2e2ffc3a..90ddbb3c 100644 --- a/R/model.R +++ b/R/model.R @@ -63,6 +63,7 @@ ForestModel <- R6::R6Class( #' @param forest_dataset Dataset used to sample the forest #' @param residual Outcome used to sample the forest #' @param forest_samples Container of forest samples + #' @param active_forest "Active" forest updated by the sampler in each iteration #' @param rng Wrapper around C++ random number generator #' @param feature_types Vector specifying the type of all p covariates in `forest_dataset` (0 = numeric, 1 = ordered categorical, 2 = unordered categorical) #' @param leaf_model_int Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression) @@ -71,26 +72,27 @@ ForestModel <- R6::R6Class( #' @param a_forest Shape parameter on variance forest model (if applicable) #' @param b_forest Scale parameter on variance forest model (if applicable) #' @param global_scale Global variance parameter - #' @param cutpoint_grid_size (Optional) Number of unique cutpoints to consider (default: 500, currently only used when `GFR = TRUE`) - #' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm - #' @param pre_initialized (Optional) Whether or not the leaves are pre-initialized outside of the sampling loop (before any samples are drawn). In multi-forest implementations like BCF, this is true, though in the single-forest supervised learning implementation, we can let C++ do the initialization. Default: F. - sample_one_iteration = function(forest_dataset, residual, forest_samples, rng, feature_types, + #' @param cutpoint_grid_size (Optional) Number of unique cutpoints to consider (default: `500`, currently only used when `GFR = TRUE`) + #' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `T`. + #' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `T`. + #' @param pre_initialized (Optional) Whether or not the leaves are pre-initialized outside of the sampling loop (before any samples are drawn). In multi-forest implementations like BCF, this is true, though in the single-forest supervised learning implementation, we can let C++ do the initialization. Default: `F`. + sample_one_iteration = function(forest_dataset, residual, forest_samples, active_forest, rng, feature_types, leaf_model_int, leaf_model_scale, variable_weights, a_forest, b_forest, global_scale, cutpoint_grid_size = 500, - gfr = T, pre_initialized = F) { + keep_forest = T, gfr = T, pre_initialized = F) { if (gfr) { sample_gfr_one_iteration_cpp( forest_dataset$data_ptr, residual$data_ptr, - forest_samples$forest_container_ptr, self$tracker_ptr, self$tree_prior_ptr, - rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale, - variable_weights, a_forest, b_forest, global_scale, leaf_model_int, pre_initialized + forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr, + self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale, + variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest, pre_initialized ) } else { sample_mcmc_one_iteration_cpp( forest_dataset$data_ptr, residual$data_ptr, - forest_samples$forest_container_ptr, self$tracker_ptr, self$tree_prior_ptr, - rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale, - variable_weights, a_forest, b_forest, global_scale, leaf_model_int, pre_initialized + forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr, + self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale, + variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest, pre_initialized ) } }, @@ -106,17 +108,16 @@ ForestModel <- R6::R6Class( #' changed and this should be reflected through to the residual before the next sampling loop is run. #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions - #' @param forest_samples `ForestSamples` object storing draws of tree ensembles - #' @param forest_num Index of forest used to update residuals (starting at 1, in R style) - propagate_basis_update = function(dataset, outcome, forest_samples, forest_num) { + #' @param active_forest "Active" forest updated by the sampler in each iteration + propagate_basis_update = function(dataset, outcome, active_forest) { stopifnot(!is.null(dataset$data_ptr)) stopifnot(!is.null(outcome$data_ptr)) stopifnot(!is.null(self$tracker_ptr)) - stopifnot(!is.null(forest_samples$forest_container_ptr)) + stopifnot(!is.null(active_forest$forest_ptr)) - propagate_basis_update_forest_container_cpp( - dataset$data_ptr, outcome$data_ptr, forest_samples$forest_container_ptr, - self$tracker_ptr, forest_num + propagate_basis_update_active_forest_cpp( + dataset$data_ptr, outcome$data_ptr, active_forest$forest_ptr, + self$tracker_ptr ) }, @@ -152,6 +153,7 @@ createRNG <- function(random_seed = -1){ #' @param alpha Root node split probability in tree prior #' @param beta Depth prior penalty in tree prior #' @param min_samples_leaf Minimum number of samples in a tree leaf +#' @param max_depth Maximum depth of any tree in the ensemble in the mean model. Setting to ``-1`` does not enforce any depth limits on trees. #' #' @return `ForestModel` object #' @export diff --git a/R/random_effects.R b/R/random_effects.R index f9d0eaf9..4604bd9b 100644 --- a/R/random_effects.R +++ b/R/random_effects.R @@ -146,6 +146,13 @@ RandomEffectSamples <- R6::R6Class( return(output) }, + #' @description + #' Modify the `RandomEffectsSamples` object by removing the parameter samples index by `sample_num`. + #' @param sample_num Index of the RFX sample to be removed + delete_sample = function(sample_num) { + rfx_container_delete_sample_cpp(self$rfx_container_ptr, sample_num) + }, + #' @description #' Convert the mapping of group IDs to random effect components indices from C++ to R native format #' @return List mapping group ID to random effect components. @@ -224,13 +231,14 @@ RandomEffectsModel <- R6::R6Class( #' @param residual Object of type `Outcome` #' @param rfx_tracker Object of type `RandomEffectsTracker` #' @param rfx_samples Object of type `RandomEffectSamples` + #' @param keep_sample Whether sample should be retained in `rfx_samples`. If `FALSE`, the state of `rfx_tracker` will be updated, but the parameter values will not be added to the sample container. Samples are commonly discarded due to burn-in or thinning. #' @param global_variance Scalar global variance parameter #' @param rng Object of type `CppRNG` #' @return None - sample_random_effect = function(rfx_dataset, residual, rfx_tracker, rfx_samples, global_variance, rng) { + sample_random_effect = function(rfx_dataset, residual, rfx_tracker, rfx_samples, keep_sample, global_variance, rng) { rfx_model_sample_random_effects_cpp(self$rfx_model_ptr, rfx_dataset$data_ptr, residual$data_ptr, rfx_tracker$rfx_tracker_ptr, - rfx_samples$rfx_container_ptr, global_variance, rng$rng_ptr) + rfx_samples$rfx_container_ptr, keep_sample, global_variance, rng$rng_ptr) }, #' @description @@ -357,3 +365,58 @@ createRandomEffectsModel <- function(num_components, num_groups) { RandomEffectsModel$new(num_components, num_groups) ))) } + +#' Reset a `RandomEffectsModel` object based on the parameters indexed by `sample_num` in a `RandomEffectsSamples` object +#' +#' @param rfx_model Object of type `RandomEffectsModel`. +#' @param rfx_samples Object of type `RandomEffectSamples`. +#' @param sample_num Index of sample stored in `rfx_samples` from which to reset the state of a random effects model. Zero-indexed, so resetting based on the first sample would require setting `sample_num = 0`. +#' @param sigma_alpha_init Initial value of the "working parameter" scale parameter. +#' @export +resetRandomEffectsModel <- function(rfx_model, rfx_samples, sample_num, sigma_alpha_init) { + reset_rfx_model_cpp(rfx_model$rfx_model_ptr, rfx_samples$rfx_container_ptr, sample_num) + rfx_model$set_working_parameter_cov(sigma_alpha_init) +} + +#' Reset a `RandomEffectsTracker` object based on the parameters indexed by `sample_num` in a `RandomEffectsSamples` object +#' +#' @param rfx_tracker Object of type `RandomEffectsTracker`. +#' @param rfx_model Object of type `RandomEffectsModel`. +#' @param rfx_dataset Object of type `RandomEffectsDataset`. +#' @param residual Object of type `Outcome`. +#' @param rfx_samples Object of type `RandomEffectSamples`. +#' @export +resetRandomEffectsTracker <- function(rfx_tracker, rfx_model, rfx_dataset, residual, rfx_samples) { + reset_rfx_tracker_cpp(rfx_tracker$rfx_tracker_ptr, rfx_dataset$data_ptr, residual$data_ptr, rfx_model$rfx_model_ptr) +} + +#' Reset a `RandomEffectsModel` object to its "default" state +#' +#' @param rfx_model Object of type `RandomEffectsModel`. +#' @param alpha_init Initial value of the "working parameter". +#' @param xi_init Initial value of the "group parameters". +#' @param sigma_alpha_init Initial value of the "working parameter" scale parameter. +#' @param sigma_xi_init Initial value of the "group parameters" scale parameter. +#' @param sigma_xi_shape Shape parameter for the inverse gamma variance model on the group parameters. +#' @param sigma_xi_scale Scale parameter for the inverse gamma variance model on the group parameters. +#' @export +rootResetRandomEffectsModel <- function(rfx_model, alpha_init, xi_init, sigma_alpha_init, + sigma_xi_init, sigma_xi_shape, sigma_xi_scale) { + rfx_model$set_working_parameter(alpha_init) + rfx_model$set_group_parameters(xi_init) + rfx_model$set_working_parameter_cov(sigma_alpha_init) + rfx_model$set_group_parameter_cov(sigma_xi_init) + rfx_model$set_variance_prior_shape(sigma_xi_shape) + rfx_model$set_variance_prior_scale(sigma_xi_scale) +} + +#' Reset a `RandomEffectsTracker` object to its "default" state +#' +#' @param rfx_tracker Object of type `RandomEffectsTracker`. +#' @param rfx_model Object of type `RandomEffectsModel`. +#' @param rfx_dataset Object of type `RandomEffectsDataset`. +#' @param residual Object of type `Outcome`. +#' @export +rootResetRandomEffectsTracker <- function(rfx_tracker, rfx_model, rfx_dataset, residual) { + root_reset_rfx_tracker_cpp(rfx_tracker$rfx_tracker_ptr, rfx_dataset$data_ptr, residual$data_ptr, rfx_model$rfx_model_ptr) +} diff --git a/R/stochtree-package.R b/R/stochtree-package.R index 08317e4f..7a912fd7 100644 --- a/R/stochtree-package.R +++ b/R/stochtree-package.R @@ -1,4 +1,5 @@ ## usethis namespace: start +#' @importFrom stats coef #' @importFrom stats lm #' @importFrom stats model.matrix #' @importFrom stats qgamma diff --git a/R/utils.R b/R/utils.R index 34edc32a..dd755c70 100644 --- a/R/utils.R +++ b/R/utils.R @@ -19,8 +19,8 @@ preprocessBartParams <- function(params) { variable_weights_mean = NULL, variable_weights_variance = NULL, num_trees_mean = 200, num_trees_variance = 0, sample_sigma_global = T, sample_sigma_leaf = F, - random_seed = -1, keep_burnin = F, keep_gfr = F, - standardize = T, verbose = F + random_seed = -1, keep_burnin = F, keep_gfr = F, keep_every = 1, + num_chains = 1, standardize = T, verbose = F ) # Override defaults @@ -50,16 +50,17 @@ preprocessBcfParams <- function(params) { beta_mu = 2.0, beta_tau = 3.0, beta_variance = 2.0, min_samples_leaf_mu = 5, min_samples_leaf_tau = 5, min_samples_leaf_variance = 5, max_depth_mu = 10, max_depth_tau = 5, max_depth_variance = 10, - a_global = 0, b_global = 0, a_leaf_mu = 3, a_leaf_tau = 3, b_leaf_mu = NULL, b_leaf_tau = NULL, - a_forest = NULL, b_forest = NULL, sigma2_init = NULL, variance_forest_init = NULL, - pct_var_sigma2_init = 1, pct_var_variance_forest_init = 1, + a_global = 0, b_global = 0, a_leaf_mu = 3, a_leaf_tau = 3, b_leaf_mu = NULL, + b_leaf_tau = NULL, a_forest = NULL, b_forest = NULL, sigma2_init = NULL, + variance_forest_init = NULL, pct_var_sigma2_init = 1, pct_var_variance_forest_init = 1, variable_weights = NULL, keep_vars_mu = NULL, drop_vars_mu = NULL, keep_vars_tau = NULL, drop_vars_tau = NULL, keep_vars_variance = NULL, - drop_vars_variance = NULL, num_trees_mu = 250, num_trees_tau = 50, num_trees_variance = 0, - num_gfr = 5, num_burnin = 0, num_mcmc = 100, sample_sigma_global = T, sample_sigma_leaf_mu = T, - sample_sigma_leaf_tau = F, propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5, b_1 = 0.5, + drop_vars_variance = NULL, num_trees_mu = 250, num_trees_tau = 50, + num_trees_variance = 0, num_gfr = 5, num_burnin = 0, num_mcmc = 100, + sample_sigma_global = T, sample_sigma_leaf_mu = T, sample_sigma_leaf_tau = F, + propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5, b_1 = 0.5, rfx_prior_var = NULL, random_seed = -1, keep_burnin = F, keep_gfr = F, - standardize = T, verbose = F + keep_every = 1, num_chains = 1, standardize = T, verbose = F ) # Override defaults @@ -79,7 +80,6 @@ preprocessBcfParams <- function(params) { #' types. Matrices will be passed through assuming all columns are numeric. #' #' @param input_data Covariates, provided as either a dataframe or a matrix -#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable #' #' @return List with preprocessed (unmodified) data and details on the number of each type #' of variable, unique categories associated with categorical variables, and the @@ -735,8 +735,10 @@ orderedCatInitializeAndPreprocess <- function(x_input) { #' @export #' #' @examples -#' x_levels <- c("1. Strongly disagree", "2. Disagree", "3. Neither agree nor disagree", "4. Agree", "5. Strongly agree") -#' x <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") +#' x_levels <- c("1. Strongly disagree", "2. Disagree", "3. Neither agree nor disagree", +#' "4. Agree", "5. Strongly agree") +#' x <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", +#' "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") #' x_processed <- orderedCatPreprocess(x, x_levels) orderedCatPreprocess <- function(x_input, unique_levels, var_name = NULL) { stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) diff --git a/R/variance.R b/R/variance.R index d225226e..b12bc89e 100644 --- a/R/variance.R +++ b/R/variance.R @@ -13,13 +13,12 @@ sample_sigma2_one_iteration <- function(residual, dataset, rng, a, b) { #' Sample one iteration of the leaf parameter variance model (only for univariate basis and constant leaf!) #' -#' @param forest_samples Container of forest samples +#' @param forest C++ forest #' @param rng C++ random number generator #' @param a Leaf variance shape parameter #' @param b Leaf variance scale parameter -#' @param sample_num Sample index #' #' @export -sample_tau_one_iteration <- function(forest_samples, rng, a, b, sample_num) { - return(sample_tau_one_iteration_cpp(forest_samples$forest_container_ptr, rng$rng_ptr, a, b, sample_num)) +sample_tau_one_iteration <- function(forest, rng, a, b) { + return(sample_tau_one_iteration_cpp(forest$forest_ptr, rng$rng_ptr, a, b)) } diff --git a/_pkgdown.yml b/_pkgdown.yml index dd01867c..195c0bb9 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -16,11 +16,6 @@ reference: contents: - bcf - predict.bcf - - saveBCFModelToJsonFile - - createBCFModelFromJsonFile - - createBCFModelFromJsonString - - convertBCFModelToJson - - createBCFModelFromJson - title: Low-level functionality @@ -39,6 +34,7 @@ reference: - loadVectorJson - loadScalarJson - convertBARTModelToJson + - convertBARTStateToJson - createBARTModelFromCombinedJson - createBARTModelFromCombinedJsonString - createBARTModelFromJson @@ -48,7 +44,13 @@ reference: - loadRandomEffectSamplesCombinedJsonString - saveBARTModelToJsonFile - saveBARTModelToJsonString + - saveBCFModelToJsonFile - saveBCFModelToJsonString + - createBCFModelFromJsonFile + - createBCFModelFromJsonString + - convertBCFModelToJson + - createBCFModelFromJson + - createBCFModelFromCombinedJsonString - subtitle: Data desc: > @@ -77,6 +79,8 @@ reference: desc: > Classes and functions for constructing and persisting forests contents: + - Forest + - createForest - ForestModel - createForestModel - ForestSamples @@ -89,6 +93,9 @@ reference: - computeMaxLeafIndex - computeForestLeafIndices - computeForestLeafVariances + - resetActiveForest + - resetForestModel + - rootResetActiveForest - subtitle: Random Effects desc: > @@ -105,6 +112,10 @@ reference: - getRandomEffectSamples.bcf - sample_sigma2_one_iteration - sample_tau_one_iteration + - resetRandomEffectsModel + - resetRandomEffectsTracker + - rootResetRandomEffectsModel + - rootResetRandomEffectsTracker - title: Package info desc: > diff --git a/cran-bootstrap.R b/cran-bootstrap.R index 6a3b1d29..295abc7c 100644 --- a/cran-bootstrap.R +++ b/cran-bootstrap.R @@ -24,7 +24,6 @@ pkg_core_files <- c( ".Rbuildignore", "DESCRIPTION", "LICENSE", - "LICENSE.md", list.files("man", recursive = TRUE, full.names = TRUE), "NAMESPACE", list.files("R", recursive = TRUE, full.names = TRUE), @@ -250,72 +249,17 @@ if (all(file.exists(eigen_files_to_vendor_src))) { } } -# Copy boost_math headers / implementations to an include/ subdirectory of src/ -boost_header_files_to_vendor_src <- c() -boost_header_files_to_vendor_dst <- c() -# Existing header files -boost_header_subfolder_src <- "deps/boost_math/include/boost" -boost_header_filenames_src <- list.files(boost_header_subfolder_src, pattern = "\\.(hpp)$", recursive = TRUE) -boost_header_files_to_vendor_src <- file.path(boost_header_subfolder_src, boost_header_filenames_src) -# Existing implementation files -boost_impl_subfolder_src <- "deps/boost_math/src" -boost_impl_filenames_src <- list.files(boost_impl_subfolder_src, pattern = "\\.(cpp)$", recursive = TRUE) -boost_impl_files_to_vendor_src <- file.path(boost_impl_subfolder_src, boost_impl_filenames_src) -# Destination files -boost_header_subfolder_dst <- "src/include/boost" -boost_header_files_to_vendor_dst <- file.path(cran_dir, boost_header_subfolder_dst, boost_header_filenames_src) -boost_impl_files_to_vendor_dst <- file.path(cran_dir, boost_header_subfolder_dst, boost_impl_filenames_src) - -if (all(file.exists(boost_header_files_to_vendor_src))) { - n_removed <- suppressWarnings(sum(file.remove(boost_header_files_to_vendor_dst))) - if (n_removed > 0) { - cat(sprintf("Removed %d previously vendored files from src/include/boost\n", n_removed)) - } - - cat( - sprintf( - "Vendoring files from deps/boost_math/include/boost/ to src/include/boost\n" - ) - ) - - # Recreate the directory structure - dst_dirs <- unique(dirname(boost_header_files_to_vendor_dst)) - for (dst_dir in dst_dirs) { - if (!dir.exists(dst_dir)) { - dir.create(dst_dir, recursive = TRUE) - } - } - - if (all(file.copy(boost_header_files_to_vendor_src, boost_header_files_to_vendor_dst))) { - cat("All deps/boost_math/include/boost header files successfully copied to src/include/boost\n") - } else { - stop("Failed to vendor all deps/boost_math/include/boost header files") - } -} - -if (all(file.exists(boost_impl_files_to_vendor_src))) { - n_removed <- suppressWarnings(sum(file.remove(boost_impl_files_to_vendor_dst))) - if (n_removed > 0) { - cat(sprintf("Removed %d previously vendored cpp files from src/include/boost\n", n_removed)) - } - - cat( - sprintf( - "Vendoring files from deps/boost_math/src/ to src/include/boost\n" - ) - ) - - # Recreate the directory structure - dst_dirs <- unique(dirname(boost_impl_files_to_vendor_dst)) - for (dst_dir in dst_dirs) { - if (!dir.exists(dst_dir)) { - dir.create(dst_dir, recursive = TRUE) - } - } - - if (all(file.copy(boost_impl_files_to_vendor_src, boost_impl_files_to_vendor_dst))) { - cat("All deps/boost_math/src header files successfully copied to src/include/boost\n") - } else { - stop("Failed to vendor all deps/boost_math/src header files") - } +# Clean up pragmas that suppress warnings in Eigen and JSON headers +# File 1: Eigen "DisableStupidWarnings" header +cran_eigen_suppress_warnings <- file.path(cran_dir, "src/include/Eigen/src/Core/util/DisableStupidWarnings.h") +eigen_suppress_warnings_lines <- readLines(cran_eigen_suppress_warnings) +for (i in 1:length(eigen_suppress_warnings_lines)) { + line <- eigen_suppress_warnings_lines[i] + eigen_suppress_warnings_lines[i] <- gsub("^.*#pragma clang diagnostic.*$", "", eigen_suppress_warnings_lines[i]) + eigen_suppress_warnings_lines[i] <- gsub("^.*#pragma diag_suppress.*$", "", eigen_suppress_warnings_lines[i]) + eigen_suppress_warnings_lines[i] <- gsub("^.*#pragma GCC diagnostic.*$", "", eigen_suppress_warnings_lines[i]) + eigen_suppress_warnings_lines[i] <- gsub("^.*#pragma region.*$", "", eigen_suppress_warnings_lines[i]) + eigen_suppress_warnings_lines[i] <- gsub("^.*#pragma endregion.*$", "", eigen_suppress_warnings_lines[i]) + eigen_suppress_warnings_lines[i] <- gsub("^.*#pragma warning.*$", "", eigen_suppress_warnings_lines[i]) } +writeLines(eigen_suppress_warnings_lines, cran_eigen_suppress_warnings) diff --git a/debug/README.md b/debug/README.md index 6e018fe6..8dc5a15c 100644 --- a/debug/README.md +++ b/debug/README.md @@ -9,6 +9,10 @@ The program takes several command line arguments (in order): 4. Number of grow-from-root (GFR) samples 5. Number of MCMC samples 6. Seed for random number generator (-1 means we defer to C++ `std::random_device`) +7. [Optional] name of data file to load for training, instead of simulating data (leave this blank as `""` if simulated data is desired) +8. [Optional] index of outcome column in data file (leave this blank as `0`) +9. [Optional] comma-delimited string of column indices of covariates (leave this blank as `""`) +10. [Optional] comma-delimited string of column indices of leaf regression bases (leave this blank as `""`) The DGPs are numbered as follows: @@ -23,3 +27,9 @@ The models are numbered as follows: 1. "Univariate basis" leaf regression model 2. "Multivariate basis" leaf regression model 3. Log linear heteroskedastic variance model + +For an example of how to run this progam for DGP 0, leaf model 1, no random effects, 10 GFR samples, 100 MCMC samples and a default seed (`-1`), run + +`./build/debugstochtree 0 1 0 10 100 -1 "" 0 "" ""` + +from the main `stochtree` project directory after building with `BUILD_DEBUG_TARGETS` set to `ON`. diff --git a/debug/api_debug.cpp b/debug/api_debug.cpp index fffefd9b..39a06ad0 100644 --- a/debug/api_debug.cpp +++ b/debug/api_debug.cpp @@ -532,6 +532,7 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia double outcome_scale; OutcomeOffsetScale(residual, outcome_offset, outcome_scale); + // Prepare random effects sampling (if desired) RandomEffectsDataset rfx_dataset; std::vector rfx_init(n, 0); RandomEffectsTracker rfx_tracker = RandomEffectsTracker(rfx_init); @@ -577,6 +578,9 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia } else { forest_exponentiated = false; } + // "Active" tree ensemble + TreeEnsemble active_forest = TreeEnsemble(num_trees, output_dimension, is_leaf_constant, forest_exponentiated); + // Stored forest samples ForestContainer forest_samples = ForestContainer(num_trees, output_dimension, is_leaf_constant, forest_exponentiated); // Initialize a leaf model @@ -605,8 +609,9 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia Eigen::MatrixXd leaf_scale_matrix(omega_cols, omega_cols); Eigen::MatrixXd leaf_scale_matrix_init(omega_cols, omega_cols); if (omega_cols > 0) { - leaf_scale_matrix_init << 1.0, 0.0, 0.0, 1.0; - leaf_scale_matrix = leaf_scale_matrix_init; + leaf_scale_matrix_init = Eigen::MatrixXd::Identity(omega_cols, omega_cols); + // leaf_scale_matrix_init << 1.0, 0.0, 0.0, 1.0; + leaf_scale_matrix = leaf_scale_matrix_init / num_trees; } // Set global variance @@ -636,27 +641,27 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia if (model_type == kConstantLeafGaussian) { init_val_glob = ComputeMeanOutcome(residual); init_val = init_val_glob / static_cast(num_trees); - forest_samples.InitializeRoot(init_val); - UpdateResidualEntireForest(tracker, dataset, residual, forest_samples.GetEnsemble(0), false, std::minus()); - tracker.UpdatePredictions(forest_samples.GetEnsemble(0), dataset); + active_forest.SetLeafValue(init_val); + UpdateResidualEntireForest(tracker, dataset, residual, &active_forest, false, std::minus()); + tracker.UpdatePredictions(&active_forest, dataset); } else if (model_type == kUnivariateRegressionLeafGaussian) { init_val_glob = ComputeMeanOutcome(residual); init_val = init_val_glob / static_cast(num_trees); - forest_samples.InitializeRoot(init_val); - UpdateResidualEntireForest(tracker, dataset, residual, forest_samples.GetEnsemble(0), true, std::minus()); - tracker.UpdatePredictions(forest_samples.GetEnsemble(0), dataset); + active_forest.SetLeafValue(init_val); + UpdateResidualEntireForest(tracker, dataset, residual, &active_forest, true, std::minus()); + tracker.UpdatePredictions(&active_forest, dataset); } else if (model_type == kMultivariateRegressionLeafGaussian) { init_val_glob = ComputeMeanOutcome(residual); init_val = init_val_glob / static_cast(num_trees); init_vec = std::vector(omega_cols, init_val); - forest_samples.InitializeRoot(init_vec); - UpdateResidualEntireForest(tracker, dataset, residual, forest_samples.GetEnsemble(0), true, std::minus()); - tracker.UpdatePredictions(forest_samples.GetEnsemble(0), dataset); + active_forest.SetLeafVector(init_vec); + UpdateResidualEntireForest(tracker, dataset, residual, &active_forest, true, std::minus()); + tracker.UpdatePredictions(&active_forest, dataset); } else if (model_type == kLogLinearVariance) { init_val_glob = ComputeVarianceOutcome(residual) * 0.4; init_val = std::log(init_val_glob) / static_cast(num_trees); - forest_samples.InitializeRoot(init_val); - tracker.UpdatePredictions(forest_samples.GetEnsemble(0), dataset); + active_forest.SetLeafValue(init_val); + tracker.UpdatePredictions(&active_forest, dataset); std::vector initial_preds(n, init_val_glob); dataset.AddVarianceWeights(initial_preds.data(), n); } @@ -678,13 +683,13 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia // Sample tree ensemble if (model_type == ModelType::kConstantLeafGaussian) { - GFRSampleOneIter(tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true); + GFRSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, true); } else if (model_type == ModelType::kUnivariateRegressionLeafGaussian) { - GFRSampleOneIter(tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true); + GFRSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, true); } else if (model_type == ModelType::kMultivariateRegressionLeafGaussian) { - GFRSampleOneIter(tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, omega_cols); + GFRSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, true, omega_cols); } else if (model_type == ModelType::kLogLinearVariance) { - GFRSampleOneIter(tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, false); + GFRSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, false); } if (rfx_included) { @@ -694,7 +699,7 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia } // Sample leaf node variance - leaf_variance_samples.push_back(leaf_var_model.SampleVarianceParameter(forest_samples.GetEnsemble(i), a_leaf, b_leaf, gen)); + leaf_variance_samples.push_back(leaf_var_model.SampleVarianceParameter(&active_forest, a_leaf, b_leaf, gen)); // Sample global variance global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), a_global, b_global, gen)); @@ -715,13 +720,13 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia // Sample tree ensemble if (model_type == ModelType::kConstantLeafGaussian) { - MCMCSampleOneIter(tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true); + MCMCSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, true); } else if (model_type == ModelType::kUnivariateRegressionLeafGaussian) { - MCMCSampleOneIter(tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true); + MCMCSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, true); } else if (model_type == ModelType::kMultivariateRegressionLeafGaussian) { - MCMCSampleOneIter(tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, omega_cols); + MCMCSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, true, omega_cols); } else if (model_type == ModelType::kLogLinearVariance) { - MCMCSampleOneIter(tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, false); + MCMCSampleOneIter(active_forest, tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, false); } if (rfx_included) { @@ -731,7 +736,7 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia } // Sample leaf node variance - leaf_variance_samples.push_back(leaf_var_model.SampleVarianceParameter(forest_samples.GetEnsemble(i), a_leaf, b_leaf, gen)); + leaf_variance_samples.push_back(leaf_var_model.SampleVarianceParameter(&active_forest, a_leaf, b_leaf, gen)); // Sample global variance global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), a_global, b_global, gen)); diff --git a/demo/debug/supervised_learning.py b/demo/debug/supervised_learning.py index f7e3b1ab..e6115957 100644 --- a/demo/debug/supervised_learning.py +++ b/demo/debug/supervised_learning.py @@ -59,14 +59,14 @@ def outcome_mean(X, W): bart_model.sample(X_train=X_train, y_train=y_train, basis_train=basis_train, X_test=X_test, basis_test=basis_test, num_gfr=10, num_mcmc=100) # Inspect the MCMC (BART) samples -forest_preds_y_mcmc = bart_model.y_hat_test[:,bart_model.num_gfr:] +forest_preds_y_mcmc = bart_model.y_hat_test y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True) y_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y_test,1), y_avg_mcmc), axis = 1), columns=["True outcome", "Average estimated outcome"]) sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome") plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3))) plt.show() -sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples - bart_model.num_gfr),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"]) +sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"]) sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma") plt.show() @@ -82,14 +82,14 @@ def outcome_mean(X, W): bart_model.sample(X_train=X_train_aug, y_train=y_train, X_test=X_test_aug, num_gfr=10, num_mcmc=100) # Inspect the MCMC (BART) samples -forest_preds_y_mcmc = bart_model.y_hat_test[:,bart_model.num_gfr:] +forest_preds_y_mcmc = bart_model.y_hat_test y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True) y_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y_test,1), y_avg_mcmc), axis = 1), columns=["True outcome", "Average estimated outcome"]) sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome") plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3))) plt.show() -sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples - bart_model.num_gfr),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"]) +sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"]) sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma") plt.show() @@ -103,14 +103,14 @@ def outcome_mean(X, W): bart_model.sample(X_train=X_train, y_train=y_train, X_test=X_test, num_gfr=10, num_mcmc=100) # Inspect the MCMC (BART) samples -forest_preds_y_mcmc = bart_model.y_hat_test[:,bart_model.num_gfr:] +forest_preds_y_mcmc = bart_model.y_hat_test y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True) y_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y_test,1), y_avg_mcmc), axis = 1), columns=["True outcome", "Average estimated outcome"]) sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome") plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3))) plt.show() -sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples - bart_model.num_gfr),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"]) +sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"]) sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma") plt.show() diff --git a/demo/notebooks/causal_inference.ipynb b/demo/notebooks/causal_inference.ipynb index 085243a0..c6d39642 100644 --- a/demo/notebooks/causal_inference.ipynb +++ b/demo/notebooks/causal_inference.ipynb @@ -16,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -69,7 +69,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -98,12 +98,12 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bcf_model = BCFModel()\n", - "bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100)" + "bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100, params={\"keep_every\": 5})" ] }, { @@ -161,7 +161,7 @@ "metadata": {}, "outputs": [], "source": [ - "sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples - bcf_model.num_gfr),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Sigma\"])\n", + "sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Sigma\"])\n", "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n", "plt.show()" ] @@ -172,7 +172,7 @@ "metadata": {}, "outputs": [], "source": [ - "b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples - bcf_model.num_gfr),axis=1), np.expand_dims(bcf_model.b0_samples,axis=1), np.expand_dims(bcf_model.b1_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Beta_0\", \"Beta_1\"])\n", + "b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.b0_samples,axis=1), np.expand_dims(bcf_model.b1_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Beta_0\", \"Beta_1\"])\n", "sns.scatterplot(data=b_df_mcmc, x=\"Sample\", y=\"Beta_0\")\n", "sns.scatterplot(data=b_df_mcmc, x=\"Sample\", y=\"Beta_1\")\n", "plt.show()" diff --git a/demo/notebooks/causal_inference_feature_subsets.ipynb b/demo/notebooks/causal_inference_feature_subsets.ipynb index b84c9c7f..f746baec 100644 --- a/demo/notebooks/causal_inference_feature_subsets.ipynb +++ b/demo/notebooks/causal_inference_feature_subsets.ipynb @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -42,7 +42,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -74,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -103,12 +103,12 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bcf_model = BCFModel()\n", - "bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100)" + "bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100, params={\"keep_every\": 5})" ] }, { @@ -166,7 +166,7 @@ "metadata": {}, "outputs": [], "source": [ - "sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples - bcf_model.num_gfr),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Sigma\"])\n", + "sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Sigma\"])\n", "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n", "plt.show()" ] @@ -177,7 +177,7 @@ "metadata": {}, "outputs": [], "source": [ - "b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples - bcf_model.num_gfr),axis=1), np.expand_dims(bcf_model.b0_samples,axis=1), np.expand_dims(bcf_model.b1_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Beta_0\", \"Beta_1\"])\n", + "b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.b0_samples,axis=1), np.expand_dims(bcf_model.b1_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Beta_0\", \"Beta_1\"])\n", "sns.scatterplot(data=b_df_mcmc, x=\"Sample\", y=\"Beta_0\")\n", "sns.scatterplot(data=b_df_mcmc, x=\"Sample\", y=\"Beta_1\")\n", "plt.show()" @@ -192,7 +192,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -256,7 +256,7 @@ "metadata": {}, "outputs": [], "source": [ - "sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model_subset.num_samples - bcf_model_subset.num_gfr),axis=1), \n", + "sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model_subset.num_samples),axis=1), \n", " np.expand_dims(bcf_model_subset.global_var_samples,axis=1)), axis = 1), \n", " columns=[\"Sample\", \"Sigma\"])\n", "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n", @@ -269,7 +269,7 @@ "metadata": {}, "outputs": [], "source": [ - "b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model_subset.num_samples - bcf_model_subset.num_gfr),axis=1), \n", + "b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model_subset.num_samples),axis=1), \n", " np.expand_dims(bcf_model_subset.b0_samples,axis=1), \n", " np.expand_dims(bcf_model_subset.b1_samples,axis=1)), axis = 1), \n", " columns=[\"Sample\", \"Beta_0\", \"Beta_1\"])\n", diff --git a/demo/notebooks/heteroskedastic_supervised_learning.ipynb b/demo/notebooks/heteroskedastic_supervised_learning.ipynb index 9eab1b68..9fe170ae 100644 --- a/demo/notebooks/heteroskedastic_supervised_learning.ipynb +++ b/demo/notebooks/heteroskedastic_supervised_learning.ipynb @@ -16,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -38,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -100,7 +100,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -134,7 +134,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/demo/notebooks/multivariate_treatment_causal_inference.ipynb b/demo/notebooks/multivariate_treatment_causal_inference.ipynb index 4d959109..6e175bd5 100644 --- a/demo/notebooks/multivariate_treatment_causal_inference.ipynb +++ b/demo/notebooks/multivariate_treatment_causal_inference.ipynb @@ -16,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -100,7 +100,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/demo/notebooks/prototype_interface.ipynb b/demo/notebooks/prototype_interface.ipynb index 0d7a36cd..881e8f87 100644 --- a/demo/notebooks/prototype_interface.ipynb +++ b/demo/notebooks/prototype_interface.ipynb @@ -51,9 +51,8 @@ "While the algorithm itself is conceptually simple, much of the core \n", "computation is carried out in low-level languages such as C or C++ \n", "because of the tree data structure. As a result, any changes to this \n", - "algorithm, such as supporting heteroskedasticity (@pratola2020heteroscedastic), \n", - "categorical outcomes (@murray2021log) or causal effect estimation (@hahn2020bayesian) \n", - "require modifying low-level code. \n", + "algorithm, such as supporting heteroskedasticity and categorical outcomes (Murray 2021) \n", + "or causal effect estimation (Hahn et al 2020) require modifying low-level code. \n", "\n", "The prototype interface exposes the core components of the \n", "loop above at the R level, thus making it possible to interchange \n", @@ -86,7 +85,7 @@ "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", - "from stochtree import Dataset, Residual, RNG, ForestSampler, ForestContainer, GlobalVarianceModel, LeafVarianceModel" + "from stochtree import Dataset, Residual, RNG, ForestSampler, ForestContainer, Forest, GlobalVarianceModel, LeafVarianceModel" ] }, { @@ -201,7 +200,8 @@ "metadata": {}, "outputs": [], "source": [ - "forest_container = ForestContainer(num_trees, W.shape[1], False)\n", + "forest_container = ForestContainer(num_trees, W.shape[1], False, False)\n", + "active_forest = Forest(num_trees, W.shape[1], False, False)\n", "forest_sampler = ForestSampler(dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf)\n", "cpp_rng = RNG(random_seed)\n", "global_var_model = GlobalVarianceModel()\n", @@ -242,9 +242,11 @@ "outputs": [], "source": [ "for i in range(num_warmstart):\n", - " forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, global_var_samples[i], 1, True, False)\n", + " forest_sampler.sample_one_iteration(forest_container, active_forest, dataset, residual, cpp_rng, \n", + " feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, \n", + " 0.0, 0.0, global_var_samples[i], 1, True, True, False)\n", " global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global)\n", - " leaf_scale_samples[i+1] = leaf_var_model.sample_one_iteration(forest_container, cpp_rng, a_leaf, b_leaf, i)\n", + " leaf_scale_samples[i+1] = leaf_var_model.sample_one_iteration(active_forest, cpp_rng, a_leaf, b_leaf)\n", " leaf_prior_scale[0,0] = leaf_scale_samples[i+1]" ] }, @@ -262,9 +264,11 @@ "outputs": [], "source": [ "for i in range(num_warmstart, num_samples):\n", - " forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, global_var_samples[i], 1, False, False)\n", + " forest_sampler.sample_one_iteration(forest_container, active_forest, dataset, residual, cpp_rng, \n", + " feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, \n", + " 0.0, 0.0, global_var_samples[i], 1, True, False, False)\n", " global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global)\n", - " leaf_scale_samples[i+1] = leaf_var_model.sample_one_iteration(forest_container, cpp_rng, a_leaf, b_leaf, i)\n", + " leaf_scale_samples[i+1] = leaf_var_model.sample_one_iteration(active_forest, cpp_rng, a_leaf, b_leaf)\n", " leaf_prior_scale[0,0] = leaf_scale_samples[i+1]" ] }, @@ -487,12 +491,14 @@ "outputs": [], "source": [ "# Prognostic forest sampling classes\n", - "forest_container_mu = ForestContainer(num_trees_mu, 1, True)\n", + "forest_container_mu = ForestContainer(num_trees_mu, 1, True, False)\n", + "active_forest_mu = Forest(num_trees_mu, 1, True, False)\n", "forest_sampler_mu = ForestSampler(dataset_mu, feature_types_mu, num_trees_mu, n, alpha_mu, beta_mu, min_samples_leaf_mu)\n", "leaf_var_model_mu = LeafVarianceModel()\n", "\n", "# Treatment forest sampling classes\n", - "forest_container_tau = ForestContainer(num_trees_tau, 1 if np.ndim(Z) == 1 else Z.shape[1], False)\n", + "forest_container_tau = ForestContainer(num_trees_tau, 1 if np.ndim(Z) == 1 else Z.shape[1], False, False)\n", + "active_forest_tau = Forest(num_trees_tau, 1 if np.ndim(Z) == 1 else Z.shape[1], False, False)\n", "forest_sampler_tau = ForestSampler(dataset_tau, feature_types_tau, num_trees_tau, n, alpha_tau, beta_tau, min_samples_leaf_tau)\n", "leaf_var_model_tau = LeafVarianceModel()\n", "\n", @@ -545,16 +551,20 @@ "source": [ "for i in range(num_warmstart):\n", " # Sample the prognostic forest\n", - " forest_sampler_mu.sample_one_iteration(forest_container_mu, dataset_mu, residual, cpp_rng, feature_types_mu, cutpoint_grid_size_mu, leaf_prior_scale_mu, var_weights_mu, global_var_samples[i], 0, True, False)\n", - " leaf_scale_samples_mu[i+1] = leaf_var_model_mu.sample_one_iteration(forest_container_mu, cpp_rng, a_leaf_mu, b_leaf_mu, i)\n", + " forest_sampler_mu.sample_one_iteration(forest_container_mu, active_forest_mu, dataset_mu, residual, cpp_rng, \n", + " feature_types_mu, cutpoint_grid_size_mu, leaf_prior_scale_mu, var_weights_mu, \n", + " 0.0, 0.0, global_var_samples[i], 0, True, True, False)\n", + " leaf_scale_samples_mu[i+1] = leaf_var_model_mu.sample_one_iteration(active_forest_mu, cpp_rng, a_leaf_mu, b_leaf_mu)\n", " leaf_prior_scale_mu[0,0] = leaf_scale_samples_mu[i+1]\n", - " mu_x = forest_container_mu.predict_raw_single_forest(dataset_mu, i)\n", + " mu_x = active_forest_mu.predict_raw(dataset_mu)\n", "\n", " # Sample the treatment effect forest\n", - " forest_sampler_tau.sample_one_iteration(forest_container_tau, dataset_tau, residual, cpp_rng, feature_types_tau, cutpoint_grid_size_tau, leaf_prior_scale_tau, var_weights_tau, global_var_samples[i], 1, True, False)\n", - " # leaf_scale_samples_tau[i+1] = leaf_var_model_tau.sample_one_iteration(forest_container_tau, cpp_rng, a_leaf_tau, b_leaf_tau, i)\n", + " forest_sampler_tau.sample_one_iteration(forest_container_tau, active_forest_tau, dataset_tau, residual, cpp_rng, \n", + " feature_types_tau, cutpoint_grid_size_tau, leaf_prior_scale_tau, var_weights_tau, \n", + " 0.0, 0.0, global_var_samples[i], 1, True, True, False)\n", + " # leaf_scale_samples_tau[i+1] = leaf_var_model_tau.sample_one_iteration(forest_container_tau, cpp_rng, a_leaf_tau, b_leaf_tau)\n", " # leaf_prior_scale_tau[0,0] = leaf_scale_samples_tau[i+1]\n", - " tau_x = np.squeeze(forest_container_tau.predict_raw_single_forest(dataset_tau, i))\n", + " tau_x = np.squeeze(active_forest_tau.predict_raw(dataset_tau))\n", " s_tt0 = np.sum(tau_x*tau_x*(Z==0))\n", " s_tt1 = np.sum(tau_x*tau_x*(Z==1))\n", " partial_resid_mu = resid - np.squeeze(mu_x)\n", @@ -564,6 +574,7 @@ " b_1_samples[i+1] = rng.normal(loc = (s_ty1/(s_tt1 + 2*global_var_samples[i])), scale = np.sqrt(global_var_samples[i]/(s_tt1 + 2*global_var_samples[i])), size = 1)\n", " tau_basis = (1-Z)*b_0_samples[i+1] + Z*b_1_samples[i+1]\n", " dataset_tau.update_basis(tau_basis)\n", + " forest_sampler_tau.propagate_basis_update(dataset_tau, residual, active_forest_tau)\n", " \n", " # Sample global variance\n", " global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global)" @@ -584,16 +595,20 @@ "source": [ "for i in range(num_warmstart, num_samples):\n", " # Sample the prognostic forest\n", - " forest_sampler_mu.sample_one_iteration(forest_container_mu, dataset_mu, residual, cpp_rng, feature_types_mu, cutpoint_grid_size_mu, leaf_prior_scale_mu, var_weights_mu, global_var_samples[i], 0, False, False)\n", - " leaf_scale_samples_mu[i+1] = leaf_var_model_mu.sample_one_iteration(forest_container_mu, cpp_rng, a_leaf_mu, b_leaf_mu, i)\n", + " forest_sampler_mu.sample_one_iteration(forest_container_mu, active_forest_mu, dataset_mu, residual, cpp_rng, \n", + " feature_types_mu, cutpoint_grid_size_mu, leaf_prior_scale_mu, var_weights_mu, \n", + " 0.0, 0.0, global_var_samples[i], 0, True, False, False)\n", + " leaf_scale_samples_mu[i+1] = leaf_var_model_mu.sample_one_iteration(active_forest_mu, cpp_rng, a_leaf_mu, b_leaf_mu)\n", " leaf_prior_scale_mu[0,0] = leaf_scale_samples_mu[i+1]\n", - " mu_x = forest_container_mu.predict_raw_single_forest(dataset_mu, i)\n", + " mu_x = active_forest_mu.predict_raw(dataset_mu)\n", "\n", " # Sample the treatment effect forest\n", - " forest_sampler_tau.sample_one_iteration(forest_container_tau, dataset_tau, residual, cpp_rng, feature_types_tau, cutpoint_grid_size_tau, leaf_prior_scale_tau, var_weights_tau, global_var_samples[i], 1, False, False)\n", + " forest_sampler_tau.sample_one_iteration(forest_container_tau, active_forest_tau, dataset_tau, residual, cpp_rng, \n", + " feature_types_tau, cutpoint_grid_size_tau, leaf_prior_scale_tau, var_weights_tau, \n", + " 0.0, 0.0, global_var_samples[i], 1, True, False, False)\n", " # leaf_scale_samples_tau[i+1] = leaf_var_model_tau.sample_one_iteration(forest_container_tau, cpp_rng, a_leaf_tau, b_leaf_tau, i)\n", " # leaf_prior_scale_tau[0,0] = leaf_scale_samples_tau[i+1]\n", - " tau_x = np.squeeze(forest_container_tau.predict_raw_single_forest(dataset_tau, i))\n", + " tau_x = np.squeeze(active_forest_tau.predict_raw(dataset_tau))\n", " s_tt0 = np.sum(tau_x*tau_x*(Z==0))\n", " s_tt1 = np.sum(tau_x*tau_x*(Z==1))\n", " partial_resid_mu = resid - np.squeeze(mu_x)\n", @@ -603,20 +618,12 @@ " b_1_samples[i+1] = rng.normal(loc = (s_ty1/(s_tt1 + 2*global_var_samples[i])), scale = np.sqrt(global_var_samples[i]/(s_tt1 + 2*global_var_samples[i])), size = 1)\n", " tau_basis = (1-Z)*b_0_samples[i+1] + Z*b_1_samples[i+1]\n", " dataset_tau.update_basis(tau_basis)\n", + " forest_sampler_tau.propagate_basis_update(dataset_tau, residual, active_forest_tau)\n", " \n", " # Sample global variance\n", " global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "forest_container_tau.predict_raw(dataset_tau)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -762,6 +769,22 @@ "sns.scatterplot(data=b_df_mcmc, x=\"Sample\", y=\"Beta_1\")\n", "plt.show()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# References" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Murray, Jared S. \"Log-linear Bayesian additive regression trees for multinomial logistic and count regression models.\" Journal of the American Statistical Association 116, no. 534 (2021): 756-769.\n", + "\n", + "Hahn, P. Richard, Jared S. Murray, and Carlos M. Carvalho. \"Bayesian regression tree models for causal inference: Regularization, confounding, and heterogeneous effects (with discussion).\" Bayesian Analysis 15, no. 3 (2020): 965-1056." + ] } ], "metadata": { diff --git a/demo/notebooks/supervised_learning.ipynb b/demo/notebooks/supervised_learning.ipynb index ca7d6312..4fe6465d 100644 --- a/demo/notebooks/supervised_learning.ipynb +++ b/demo/notebooks/supervised_learning.ipynb @@ -16,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -84,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -114,12 +114,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bart_model = BARTModel()\n", - "bart_model.sample(X_train=X_train, y_train=y_train, basis_train=basis_train, X_test=X_test, basis_test=basis_test, num_gfr=10, num_mcmc=100)" + "param_dict = {\"num_chains\": 3}\n", + "bart_model.sample(X_train=X_train, y_train=y_train, basis_train=basis_train, X_test=X_test, basis_test=basis_test, num_gfr=10, num_mcmc=100, params=param_dict)" ] }, { @@ -135,7 +136,7 @@ "metadata": {}, "outputs": [], "source": [ - "forest_preds_y_mcmc = bart_model.y_hat_test[:,bart_model.num_gfr:]\n", + "forest_preds_y_mcmc = bart_model.y_hat_test\n", "y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True)\n", "y_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y_test,1), y_avg_mcmc), axis = 1), columns=[\"True outcome\", \"Average estimated outcome\"])\n", "sns.scatterplot(data=y_df_mcmc, x=\"Average estimated outcome\", y=\"True outcome\")\n", @@ -149,7 +150,7 @@ "metadata": {}, "outputs": [], "source": [ - "sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples - bart_model.num_gfr),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Sigma\"])\n", + "sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Sigma\"])\n", "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n", "plt.show()" ] @@ -186,7 +187,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -209,7 +210,7 @@ "metadata": {}, "outputs": [], "source": [ - "forest_preds_y_mcmc = bart_model.y_hat_test[:,bart_model.num_gfr:]\n", + "forest_preds_y_mcmc = bart_model.y_hat_test\n", "y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True)\n", "y_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y_test,1), y_avg_mcmc), axis = 1), columns=[\"True outcome\", \"Average estimated outcome\"])\n", "sns.scatterplot(data=y_df_mcmc, x=\"Average estimated outcome\", y=\"True outcome\")\n", @@ -223,7 +224,7 @@ "metadata": {}, "outputs": [], "source": [ - "sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples - bart_model.num_gfr),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Sigma\"])\n", + "sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Sigma\"])\n", "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n", "plt.show()" ] @@ -260,7 +261,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -281,7 +282,7 @@ "metadata": {}, "outputs": [], "source": [ - "forest_preds_y_mcmc = bart_model.y_hat_test[:,bart_model.num_gfr:]\n", + "forest_preds_y_mcmc = bart_model.y_hat_test\n", "y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True)\n", "y_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y_test,1), y_avg_mcmc), axis = 1), columns=[\"True outcome\", \"Average estimated outcome\"])\n", "sns.scatterplot(data=y_df_mcmc, x=\"Average estimated outcome\", y=\"True outcome\")\n", @@ -295,7 +296,7 @@ "metadata": {}, "outputs": [], "source": [ - "sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples - bart_model.num_gfr),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Sigma\"])\n", + "sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Sigma\"])\n", "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n", "plt.show()" ] diff --git a/include/nlohmann/json.hpp b/include/nlohmann/json.hpp index 181d8e95..94e44eca 100644 --- a/include/nlohmann/json.hpp +++ b/include/nlohmann/json.hpp @@ -1056,9 +1056,11 @@ NLOHMANN_JSON_NAMESPACE_END JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,17) || \ JSON_HEDLEY_SUNPRO_VERSION_CHECK(8,0,0) || \ (JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) && defined(__C99_PRAGMA_OPERATOR)) - #define JSON_HEDLEY_PRAGMA(value) _Pragma(#value) + // #define JSON_HEDLEY_PRAGMA(value) _Pragma(#value) + #define JSON_HEDLEY_PRAGMA(value) #elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) - #define JSON_HEDLEY_PRAGMA(value) __pragma(value) + // #define JSON_HEDLEY_PRAGMA(value) __pragma(value) + #define JSON_HEDLEY_PRAGMA(value) #else #define JSON_HEDLEY_PRAGMA(value) #endif @@ -1070,22 +1072,32 @@ NLOHMANN_JSON_NAMESPACE_END #undef JSON_HEDLEY_DIAGNOSTIC_POP #endif #if defined(__clang__) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("clang diagnostic push") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("clang diagnostic pop") + // #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("clang diagnostic push") + // #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("clang diagnostic pop") + #define JSON_HEDLEY_DIAGNOSTIC_PUSH + #define JSON_HEDLEY_DIAGNOSTIC_POP #elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)") + // #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)") + // #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)") + #define JSON_HEDLEY_DIAGNOSTIC_PUSH + #define JSON_HEDLEY_DIAGNOSTIC_POP #elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("GCC diagnostic push") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("GCC diagnostic pop") + // #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("GCC diagnostic push") + // #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("GCC diagnostic pop") + #define JSON_HEDLEY_DIAGNOSTIC_PUSH + #define JSON_HEDLEY_DIAGNOSTIC_POP #elif \ JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) || \ JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH __pragma(warning(push)) - #define JSON_HEDLEY_DIAGNOSTIC_POP __pragma(warning(pop)) + // #define JSON_HEDLEY_DIAGNOSTIC_PUSH __pragma(warning(push)) + // #define JSON_HEDLEY_DIAGNOSTIC_POP __pragma(warning(pop)) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH + #define JSON_HEDLEY_DIAGNOSTIC_POP #elif JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("push") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("pop") + // #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("push") + // #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("pop") + #define JSON_HEDLEY_DIAGNOSTIC_PUSH + #define JSON_HEDLEY_DIAGNOSTIC_POP #elif \ JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ @@ -1093,11 +1105,15 @@ NLOHMANN_JSON_NAMESPACE_END JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("diag_push") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("diag_pop") + // #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("diag_push") + // #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("diag_pop") + #define JSON_HEDLEY_DIAGNOSTIC_PUSH + #define JSON_HEDLEY_DIAGNOSTIC_POP #elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)") + // #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)") + // #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)") + #define JSON_HEDLEY_DIAGNOSTIC_PUSH + #define JSON_HEDLEY_DIAGNOSTIC_POP #else #define JSON_HEDLEY_DIAGNOSTIC_PUSH #define JSON_HEDLEY_DIAGNOSTIC_POP @@ -1113,26 +1129,14 @@ NLOHMANN_JSON_NAMESPACE_END # if JSON_HEDLEY_HAS_WARNING("-Wc++17-extensions") # if JSON_HEDLEY_HAS_WARNING("-Wc++1z-extensions") # define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ - _Pragma("clang diagnostic ignored \"-Wc++17-extensions\"") \ - _Pragma("clang diagnostic ignored \"-Wc++1z-extensions\"") \ - xpr \ - JSON_HEDLEY_DIAGNOSTIC_POP + xpr # else # define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ - _Pragma("clang diagnostic ignored \"-Wc++17-extensions\"") \ - xpr \ - JSON_HEDLEY_DIAGNOSTIC_POP + xpr # endif # else # define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ - xpr \ - JSON_HEDLEY_DIAGNOSTIC_POP + xpr # endif # endif #endif @@ -1150,10 +1154,10 @@ NLOHMANN_JSON_NAMESPACE_END JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) # define JSON_HEDLEY_CONST_CAST(T, expr) (__extension__ ({ \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL \ + /* JSON_HEDLEY_DIAGNOSTIC_PUSH \*/ + /* JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL \*/ ((T) (expr)); \ - JSON_HEDLEY_DIAGNOSTIC_POP \ + /* JSON_HEDLEY_DIAGNOSTIC_POP \*/ })) #else # define JSON_HEDLEY_CONST_CAST(T, expr) ((T) (expr)) @@ -1183,15 +1187,9 @@ NLOHMANN_JSON_NAMESPACE_END #if defined(__cplusplus) # if JSON_HEDLEY_HAS_WARNING("-Wold-style-cast") # define JSON_HEDLEY_CPP_CAST(T, expr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wold-style-cast\"") \ - ((T) (expr)) \ - JSON_HEDLEY_DIAGNOSTIC_POP + ((T) (expr)) # elif JSON_HEDLEY_IAR_VERSION_CHECK(8,3,0) -# define JSON_HEDLEY_CPP_CAST(T, expr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("diag_suppress=Pe137") \ - JSON_HEDLEY_DIAGNOSTIC_POP +# define JSON_HEDLEY_CPP_CAST(T, expr) # else # define JSON_HEDLEY_CPP_CAST(T, expr) ((T) (expr)) # endif @@ -1203,21 +1201,29 @@ NLOHMANN_JSON_NAMESPACE_END #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #endif #if JSON_HEDLEY_HAS_WARNING("-Wdeprecated-declarations") - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("clang diagnostic ignored \"-Wdeprecated-declarations\"") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("clang diagnostic ignored \"-Wdeprecated-declarations\"") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warning(disable:1478 1786)") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warning(disable:1478 1786)") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:1478 1786)) + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:1478 1786)) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #elif JSON_HEDLEY_PGI_VERSION_CHECK(20,7,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1216,1444,1445") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1216,1444,1445") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1444") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1444") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:4996)) + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:4996)) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1444") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1444") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #elif \ JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ @@ -1230,15 +1236,20 @@ NLOHMANN_JSON_NAMESPACE_END JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1291,1718") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1291,1718") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && !defined(__cplusplus) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,E_DEPRECATED_ATT,E_DEPRECATED_ATT_MESS)") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,E_DEPRECATED_ATT,E_DEPRECATED_ATT_MESS)") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && defined(__cplusplus) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,symdeprecated,symdeprecated2)") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,symdeprecated,symdeprecated2)") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress=Pe1444,Pe1215") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress=Pe1444,Pe1215") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warn(disable:2241)") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warn(disable:2241)") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #else #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #endif @@ -1247,29 +1258,39 @@ NLOHMANN_JSON_NAMESPACE_END #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS #endif #if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("clang diagnostic ignored \"-Wunknown-pragmas\"") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("clang diagnostic ignored \"-Wunknown-pragmas\"") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS #elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("warning(disable:161)") #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("warning(disable:161)") #elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:161)) + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:161)) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS #elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 1675") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 1675") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS #elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("GCC diagnostic ignored \"-Wunknown-pragmas\"") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("GCC diagnostic ignored \"-Wunknown-pragmas\"") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS #elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:4068)) + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:4068)) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS #elif \ JSON_HEDLEY_TI_VERSION_CHECK(16,9,0) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS #elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS #elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress=Pe161") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress=Pe161") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS #elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 161") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 161") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS #else #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS #endif @@ -1278,30 +1299,41 @@ NLOHMANN_JSON_NAMESPACE_END #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES #endif #if JSON_HEDLEY_HAS_WARNING("-Wunknown-attributes") - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("clang diagnostic ignored \"-Wunknown-attributes\"") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("clang diagnostic ignored \"-Wunknown-attributes\"") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES #elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES #elif JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("warning(disable:1292)") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("warning(disable:1292)") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES #elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:1292)) + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:1292)) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES #elif JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:5030)) + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:5030)) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES #elif JSON_HEDLEY_PGI_VERSION_CHECK(20,7,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097,1098") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097,1098") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES #elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES #elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("error_messages(off,attrskipunsup)") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("error_messages(off,attrskipunsup)") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES #elif \ JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1173") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1173") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES #elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress=Pe1097") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress=Pe1097") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES #elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES #else #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES #endif @@ -1310,11 +1342,14 @@ NLOHMANN_JSON_NAMESPACE_END #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL #endif #if JSON_HEDLEY_HAS_WARNING("-Wcast-qual") - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("clang diagnostic ignored \"-Wcast-qual\"") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("clang diagnostic ignored \"-Wcast-qual\"") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL #elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("warning(disable:2203 2331)") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("warning(disable:2203 2331)") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL #elif JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("GCC diagnostic ignored \"-Wcast-qual\"") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("GCC diagnostic ignored \"-Wcast-qual\"") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL #else #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL #endif @@ -1323,13 +1358,17 @@ NLOHMANN_JSON_NAMESPACE_END #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION #endif #if JSON_HEDLEY_HAS_WARNING("-Wunused-function") - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("clang diagnostic ignored \"-Wunused-function\"") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("clang diagnostic ignored \"-Wunused-function\"") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION #elif JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("GCC diagnostic ignored \"-Wunused-function\"") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("GCC diagnostic ignored \"-Wunused-function\"") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION #elif JSON_HEDLEY_MSVC_VERSION_CHECK(1,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION __pragma(warning(disable:4505)) + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION __pragma(warning(disable:4505)) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION #elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("diag_suppress 3142") + // #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("diag_suppress 3142") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION #else #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION #endif @@ -1581,16 +1620,12 @@ NLOHMANN_JSON_NAMESPACE_END JSON_HEDLEY_DIAGNOSTIC_PUSH #if JSON_HEDLEY_HAS_WARNING("-Wpedantic") - #pragma clang diagnostic ignored "-Wpedantic" #endif #if JSON_HEDLEY_HAS_WARNING("-Wc++98-compat-pedantic") && defined(__cplusplus) - #pragma clang diagnostic ignored "-Wc++98-compat-pedantic" #endif #if JSON_HEDLEY_GCC_HAS_WARNING("-Wvariadic-macros",4,0,0) #if defined(__clang__) - #pragma clang diagnostic ignored "-Wvariadic-macros" #elif defined(JSON_HEDLEY_GCC_VERSION) - #pragma GCC diagnostic ignored "-Wvariadic-macros" #endif #endif #if defined(JSON_HEDLEY_NON_NULL) @@ -2262,15 +2297,9 @@ JSON_HEDLEY_DIAGNOSTIC_POP #if JSON_HEDLEY_HAS_ATTRIBUTE(diagnose_if) # if JSON_HEDLEY_HAS_WARNING("-Wgcc-compat") # define JSON_HEDLEY_REQUIRE(expr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wgcc-compat\"") \ - __attribute__((diagnose_if(!(expr), #expr, "error"))) \ - JSON_HEDLEY_DIAGNOSTIC_POP + __attribute__((diagnose_if(!(expr), #expr, "error"))) # define JSON_HEDLEY_REQUIRE_MSG(expr,msg) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wgcc-compat\"") \ - __attribute__((diagnose_if(!(expr), msg, "error"))) \ - JSON_HEDLEY_DIAGNOSTIC_POP + __attribute__((diagnose_if(!(expr), msg, "error"))) # else # define JSON_HEDLEY_REQUIRE(expr) __attribute__((diagnose_if(!(expr), #expr, "error"))) # define JSON_HEDLEY_REQUIRE_MSG(expr,msg) __attribute__((diagnose_if(!(expr), msg, "error"))) @@ -2294,10 +2323,7 @@ JSON_HEDLEY_DIAGNOSTIC_POP #endif #if JSON_HEDLEY_INTEL_VERSION_CHECK(19,0,0) # define JSON_HEDLEY_FLAGS_CAST(T, expr) (__extension__ ({ \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("warning(disable:188)") \ ((T) (expr)); \ - JSON_HEDLEY_DIAGNOSTIC_POP \ })) #else # define JSON_HEDLEY_FLAGS_CAST(T, expr) JSON_HEDLEY_STATIC_CAST(T, expr) @@ -2507,9 +2533,6 @@ JSON_HEDLEY_DIAGNOSTIC_POP // disable documentation warnings on clang #if defined(__clang__) - #pragma clang diagnostic push - #pragma clang diagnostic ignored "-Wdocumentation" - #pragma clang diagnostic ignored "-Wdocumentation-unknown-command" #endif // allow disabling exceptions @@ -5363,9 +5386,6 @@ namespace std { #if defined(__clang__) - // Fix: https://github.com/nlohmann/json/issues/1401 - #pragma clang diagnostic push - #pragma clang diagnostic ignored "-Wmismatched-tags" #endif template class tuple_size<::nlohmann::detail::iteration_proxy_value> // NOLINT(cert-dcl58-cpp) @@ -5380,7 +5400,6 @@ class tuple_element> ::nlohmann::detail::iteration_proxy_value> ())); }; #if defined(__clang__) - #pragma clang diagnostic pop #endif } // namespace std @@ -16801,8 +16820,6 @@ class binary_writer void write_compact_float(const number_float_t n, detail::input_format_t format) { #ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wfloat-equal" #endif if (static_cast(n) >= static_cast(std::numeric_limits::lowest()) && static_cast(n) <= static_cast((std::numeric_limits::max)()) && @@ -16821,7 +16838,6 @@ class binary_writer write_number(n); } #ifdef __GNUC__ -#pragma GCC diagnostic pop #endif } @@ -17983,8 +17999,6 @@ char* to_chars(char* first, const char* last, FloatType value) } #ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wfloat-equal" #endif if (value == 0) // +-0 { @@ -17995,7 +18009,6 @@ char* to_chars(char* first, const char* last, FloatType value) return first; } #ifdef __GNUC__ -#pragma GCC diagnostic pop #endif JSON_ASSERT(last - first >= std::numeric_limits::max_digits10); @@ -20071,8 +20084,6 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec // ordered_json uses a vector internally, so pointers could have // been invalidated; see https://github.com/nlohmann/json/issues/2962 #ifdef JSON_HEDLEY_MSVC_VERSION -#pragma warning(push ) -#pragma warning(disable : 4127) // ignore warning to replace if with if constexpr #endif if (detail::is_ordered_map::value) { @@ -20080,7 +20091,6 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec return j; } #ifdef JSON_HEDLEY_MSVC_VERSION -#pragma warning( pop ) #endif j.m_parent = this; @@ -22985,13 +22995,10 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec bool operator==(const_reference rhs) const noexcept { #ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wfloat-equal" #endif const_reference lhs = *this; JSON_IMPLEMENT_OPERATOR( ==, true, false, false) #ifdef __GNUC__ -#pragma GCC diagnostic pop #endif } @@ -23089,12 +23096,9 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec friend bool operator==(const_reference lhs, const_reference rhs) noexcept { #ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wfloat-equal" #endif JSON_IMPLEMENT_OPERATOR( ==, true, false, false) #ifdef __GNUC__ -#pragma GCC diagnostic pop #endif } @@ -24569,7 +24573,6 @@ inline void swap(nlohmann::NLOHMANN_BASIC_JSON_TPL& j1, nlohmann::NLOHMANN_BASIC // restore clang diagnostic settings #if defined(__clang__) - #pragma clang diagnostic pop #endif // clean up diff --git a/include/stochtree/container.h b/include/stochtree/container.h index 1a3dc1c7..13fa098c 100644 --- a/include/stochtree/container.h +++ b/include/stochtree/container.h @@ -26,6 +26,8 @@ class ForestContainer { ForestContainer(int num_samples, int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false); ~ForestContainer() {} + void DeleteSample(int sample_num); + void AddSample(TreeEnsemble& forest); void InitializeRoot(double leaf_value); void InitializeRoot(std::vector& leaf_vector); void AddSamples(int num_samples); @@ -64,6 +66,8 @@ class ForestContainer { inline int32_t OutputDimension(int ensemble_num) {return forests_[ensemble_num]->OutputDimension();} inline bool IsLeafConstant() {return is_leaf_constant_;} inline bool IsLeafConstant(int ensemble_num) {return forests_[ensemble_num]->IsLeafConstant();} + inline bool IsExponentiated() {return is_exponentiated_;} + inline bool IsExponentiated(int ensemble_num) {return forests_[ensemble_num]->IsExponentiated();} inline bool AllRoots(int ensemble_num) {return forests_[ensemble_num]->AllRoots();} inline void SetLeafValue(int ensemble_num, double leaf_value) {forests_[ensemble_num]->SetLeafValue(leaf_value);} inline void SetLeafVector(int ensemble_num, std::vector& leaf_vector) {forests_[ensemble_num]->SetLeafVector(leaf_vector);} diff --git a/include/stochtree/ensemble.h b/include/stochtree/ensemble.h index bbec292c..5f4330d3 100644 --- a/include/stochtree/ensemble.h +++ b/include/stochtree/ensemble.h @@ -49,10 +49,8 @@ class TreeEnsemble { trees_ = std::vector>(num_trees_); for (int i = 0; i < num_trees_; i++) { trees_[i].reset(new Tree()); - // trees_[i]->Init(output_dimension); } // Clone trees in the ensemble - // trees_ = std::vector>(num_trees_); for (int j = 0; j < num_trees_; j++) { Tree* tree = ensemble.GetTree(j); this->CloneFromExistingTree(j, tree); @@ -64,6 +62,12 @@ class TreeEnsemble { return trees_[i].get(); } + inline void ResetRoot() { + for (int i = 0; i < num_trees_; i++) { + ResetInitTree(i); + } + } + inline void ResetTree(int i) { trees_[i].reset(new Tree()); } @@ -77,6 +81,41 @@ class TreeEnsemble { return trees_[i]->CloneFromTree(tree); } + inline void ReconstituteFromForest(TreeEnsemble& ensemble) { + // Delete old tree pointers + trees_.clear(); + // Unpack ensemble configurations + num_trees_ = ensemble.num_trees_; + output_dimension_ = ensemble.output_dimension_; + is_leaf_constant_ = ensemble.is_leaf_constant_; + is_exponentiated_ = ensemble.is_exponentiated_; + // Initialize trees in the ensemble + trees_ = std::vector>(num_trees_); + for (int i = 0; i < num_trees_; i++) { + trees_[i].reset(new Tree()); + } + // Clone trees in the ensemble + for (int j = 0; j < num_trees_; j++) { + Tree* tree = ensemble.GetTree(j); + this->CloneFromExistingTree(j, tree); + } + } + + std::vector Predict(ForestDataset& dataset) { + data_size_t n = dataset.NumObservations(); + std::vector output(n); + PredictInplace(dataset, output, 0); + return output; + } + + std::vector PredictRaw(ForestDataset& dataset) { + data_size_t n = dataset.NumObservations(); + data_size_t total_output_size = n * output_dimension_; + std::vector output(total_output_size); + PredictRawInplace(dataset, output, 0); + return output; + } + inline void PredictInplace(ForestDataset& dataset, std::vector &output, data_size_t offset = 0) { PredictInplace(dataset, output, 0, trees_.size(), offset); } diff --git a/include/stochtree/partition_tracker.h b/include/stochtree/partition_tracker.h index 4b9b97ef..56b6c2e6 100644 --- a/include/stochtree/partition_tracker.h +++ b/include/stochtree/partition_tracker.h @@ -59,11 +59,14 @@ class ForestTracker { */ ForestTracker(Eigen::MatrixXd& covariates, std::vector& feature_types, int num_trees, int num_observations); ~ForestTracker() {} + void ReconstituteFromForest(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model); void AssignAllSamplesToRoot(); void AssignAllSamplesToRoot(int32_t tree_num); void AssignAllSamplesToConstantPrediction(double value); void AssignAllSamplesToConstantPrediction(int32_t tree_num, double value); void UpdatePredictions(TreeEnsemble* ensemble, ForestDataset& dataset); + void UpdateSampleTrackers(TreeEnsemble& forest, ForestDataset& dataset); + void UpdateSampleTrackersResidual(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model); void ResetRoot(Eigen::MatrixXd& covariates, std::vector& feature_types, int32_t tree_num); void AddSplit(Eigen::MatrixXd& covariates, TreeSplit& split, int32_t split_feature, int32_t tree_id, int32_t split_node_id, int32_t left_node_id, int32_t right_node_id, bool keep_sorted = false); void RemoveSplit(Eigen::MatrixXd& covariates, Tree* tree, int32_t tree_id, int32_t split_node_id, int32_t left_node_id, int32_t right_node_id, bool keep_sorted = false); @@ -109,9 +112,14 @@ class ForestTracker { int num_trees_; int num_observations_; int num_features_; + bool initialized_{false}; void UpdatePredictionsInternal(TreeEnsemble* ensemble, Eigen::MatrixXd& covariates, Eigen::MatrixXd& basis); void UpdatePredictionsInternal(TreeEnsemble* ensemble, Eigen::MatrixXd& covariates); + void UpdateSampleTrackersInternal(TreeEnsemble& forest, Eigen::MatrixXd& covariates, Eigen::MatrixXd& basis); + void UpdateSampleTrackersInternal(TreeEnsemble& forest, Eigen::MatrixXd& covariates); + void UpdateSampleTrackersResidualInternalBasis(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model); + void UpdateSampleTrackersResidualInternalNoBasis(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model); }; /*! \brief Class storing sample-prediction map for each tree in an ensemble */ @@ -229,6 +237,9 @@ class FeatureUnsortedPartition { public: FeatureUnsortedPartition(data_size_t n); + /*! \brief Reconstitute a tree partition tracker from root based on a tree */ + void ReconstituteFromTree(Tree& tree, ForestDataset& dataset); + /*! \brief Partition a node based on a new split rule */ void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int left_node_id, int right_node_id, int feature_split, TreeSplit& split); @@ -306,6 +317,9 @@ class UnsortedNodeSampleTracker { } } + /*! \brief Reconstruct the node sample tracker based on the splits in a forest */ + void ReconstituteFromForest(TreeEnsemble& forest, ForestDataset& dataset); + /*! \brief Partition a node based on a new split rule */ void PartitionTreeNode(Eigen::MatrixXd& covariates, int tree_id, int node_id, int left_node_id, int right_node_id, int feature_split, TreeSplit& split) { return feature_partitions_[tree_id]->PartitionNode(covariates, node_id, left_node_id, right_node_id, feature_split, split); diff --git a/include/stochtree/random_effects.h b/include/stochtree/random_effects.h index 623a1103..1f324970 100644 --- a/include/stochtree/random_effects.h +++ b/include/stochtree/random_effects.h @@ -28,6 +28,11 @@ namespace StochTree { +/*! \brief Forward declarations */ +class LabelMapper; +class MultivariateRegressionRandomEffectsModel; +class RandomEffectsContainer; + /*! \brief Wrapper around data structures for random effects sampling algorithms */ class RandomEffectsTracker { public: @@ -49,6 +54,14 @@ class RandomEffectsTracker { std::vector& NodeIndicesInternalIndex(int internal_category_id) {return category_sample_tracker_->NodeIndicesInternalIndex(internal_category_id);} double GetPrediction(data_size_t observation_num) {return rfx_predictions_.at(observation_num);} void SetPrediction(data_size_t observation_num, double pred) {rfx_predictions_.at(observation_num) = pred;} + /*! \brief Resets RFX tracker based on a specific sample. Assumes tracker already exists in main memory. */ + void ResetFromSample(MultivariateRegressionRandomEffectsModel& rfx_model, + RandomEffectsDataset& rfx_dataset, ColumnVector& residual); + /*! \brief Resets RFX tracker to initial default. Assumes tracker already exists in main memory. + * Assumes that the initial "clean slate" prediction of a random effects model is 0. + */ + void RootReset(MultivariateRegressionRandomEffectsModel& rfx_model, + RandomEffectsDataset& rfx_dataset, ColumnVector& residual); private: /*! \brief Mapper from observations to category indices */ @@ -102,6 +115,9 @@ class MultivariateRegressionRandomEffectsModel { working_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_); } ~MultivariateRegressionRandomEffectsModel() {} + + /*! \brief Reconstruction from serialized model parameter samples */ + void ResetFromSample(RandomEffectsContainer& rfx_container, int sample_num); /*! \brief Samplers */ void SampleRandomEffects(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen); @@ -260,6 +276,7 @@ class RandomEffectsContainer { } ~RandomEffectsContainer() {} void AddSample(MultivariateRegressionRandomEffectsModel& model); + void DeleteSample(int sample_num); void Predict(RandomEffectsDataset& dataset, LabelMapper& label_mapper, std::vector& output); int NumSamples() {return num_samples_;} int NumComponents() {return num_components_;} diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 42c851e4..d8db03b9 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -196,6 +196,31 @@ static inline void UpdateModelVarianceForest(ForestTracker& tracker, ForestDatas tracker.SyncPredictions(); } +static inline void UpdateResidualNoTrackerUpdate(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest, + bool requires_basis, std::function op) { + data_size_t n = dataset.GetCovariates().rows(); + double tree_pred = 0.; + double pred_value = 0.; + double new_resid = 0.; + int32_t leaf_pred; + for (data_size_t i = 0; i < n; i++) { + for (int j = 0; j < forest->NumTrees(); j++) { + Tree* tree = forest->GetTree(j); + leaf_pred = tracker.GetNodeId(i, j); + if (requires_basis) { + tree_pred += tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i); + } else { + tree_pred += tree->PredictFromNode(leaf_pred); + } + pred_value += tree_pred; + } + + // Run op (either plus or minus) on the residual and the new prediction + new_resid = op(residual.GetElement(i), pred_value); + residual.SetElement(i, new_resid); + } +} + static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest, bool requires_basis, std::function op) { data_size_t n = dataset.GetCovariates().rows(); @@ -384,43 +409,42 @@ static inline std::tuple EvaluateExist return std::tuple(split_log_ml, no_split_log_ml, left_n, right_n); } -template -static inline void ModelInitialization(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, - ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, - std::mt19937& gen, std::vector& variable_weights, double global_variance, - bool pre_initialized, bool backfitting, int prev_num_samples, bool var_trees = false) { - if ((prev_num_samples == 0) && (!pre_initialized)) { - // Add new forest to the container - forests.AddSamples(1); +// template +// static inline void ModelInitialization(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, +// ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, +// std::mt19937& gen, std::vector& variable_weights, double global_variance, +// bool pre_initialized, bool backfitting, int prev_num_samples, bool var_trees = false) { +// if ((prev_num_samples == 0) && (!pre_initialized)) { +// // Add new forest to the container +// forests.AddSamples(1); - // Set initial value for each leaf in the forest - double leaf_value; - if (var_trees) { - leaf_value = std::log(ComputeVarianceOutcome(residual)) / static_cast(forests.NumTrees()); - } else { - leaf_value = ComputeMeanOutcome(residual) / static_cast(forests.NumTrees()); - } - TreeEnsemble* ensemble = forests.GetEnsemble(0); - leaf_model.SetEnsembleRootPredictedValue(dataset, ensemble, leaf_value); - tracker.AssignAllSamplesToConstantPrediction(leaf_value); - } else if (prev_num_samples > 0) { - // Add new forest to the container - forests.AddSamples(1); - - // NOTE: only doing this for the simplicity of the partial residual step - // We could alternatively "reach back" to the tree predictions from a previous - // sample (whenever there is more than one sample). This is cleaner / quicker - // to implement during this refactor. - forests.CopyFromPreviousSample(prev_num_samples, prev_num_samples - 1); - } else { - forests.IncrementSampleCount(); - } -} +// // Set initial value for each leaf in the forest +// double leaf_value; +// if (var_trees) { +// leaf_value = std::log(ComputeVarianceOutcome(residual)) / static_cast(forests.NumTrees()); +// } else { +// leaf_value = ComputeMeanOutcome(residual) / static_cast(forests.NumTrees()); +// } +// TreeEnsemble* ensemble = forests.GetEnsemble(0); +// leaf_model.SetEnsembleRootPredictedValue(dataset, ensemble, leaf_value); +// tracker.AssignAllSamplesToConstantPrediction(leaf_value); +// } else if (prev_num_samples > 0) { +// // Add new forest to the container +// forests.AddSamples(1); + +// // NOTE: only doing this for the simplicity of the partial residual step +// // We could alternatively "reach back" to the tree predictions from a previous +// // sample (whenever there is more than one sample). This is cleaner / quicker +// // to implement during this refactor. +// forests.CopyFromPreviousSample(prev_num_samples, prev_num_samples - 1); +// } else { +// forests.IncrementSampleCount(); +// } +// } template -static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, - ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, - bool backfitting, Tree* tree, int tree_num) { +static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, + ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { if (backfitting) { UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus(), false); } else { @@ -430,9 +454,8 @@ static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, ForestC } template -static inline void AdjustStateAfterTreeSampling(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, - ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, - bool backfitting, Tree* tree, int tree_num) { +static inline void AdjustStateAfterTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, + ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { if (backfitting) { UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus(), true); } else { @@ -707,37 +730,37 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore } template -static inline void GFRSampleOneIter(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, +static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, double global_variance, std::vector& feature_types, int cutpoint_grid_size, - bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { - // Previous number of samples - int prev_num_samples = forests.NumSamples(); + bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + // // Previous number of samples + // int prev_num_samples = forests.NumSamples(); - // Handle any "initialization" of a model (trees, ForestTracker, etc...) if this is the first sample and - // the model was not pre-initialized - bool var_trees; - if (std::is_same_v) var_trees = true; - else var_trees = false; - ModelInitialization(tracker, forests, leaf_model, dataset, residual, tree_prior, gen, - variable_weights, global_variance, pre_initialized, backfitting, - prev_num_samples, var_trees); + // // Handle any "initialization" of a model (trees, ForestTracker, etc...) if this is the first sample and + // // the model was not pre-initialized + // bool var_trees; + // if (std::is_same_v) var_trees = true; + // else var_trees = false; + // ModelInitialization(tracker, forests, leaf_model, dataset, residual, tree_prior, gen, + // variable_weights, global_variance, pre_initialized, backfitting, + // prev_num_samples, var_trees); // Run the GFR algorithm for each tree - TreeEnsemble* ensemble = forests.GetEnsemble(prev_num_samples); + // TreeEnsemble* ensemble = forests.GetEnsemble(prev_num_samples); int num_trees = forests.NumTrees(); for (int i = 0; i < num_trees; i++) { // Adjust any model state needed to run a tree sampler // For models that involve Bayesian backfitting, this amounts to adding tree i's // predictions back to the residual (thus, training a model on the "partial residual") // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object - Tree* tree = ensemble->GetTree(i); - AdjustStateBeforeTreeSampling(tracker, forests, leaf_model, dataset, residual, tree_prior, backfitting, tree, i); + Tree* tree = active_forest.GetTree(i); + AdjustStateBeforeTreeSampling(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i); // Reset the tree and sample trackers - ensemble->ResetInitTree(i); + active_forest.ResetInitTree(i); tracker.ResetRoot(dataset.GetCovariates(), feature_types, i); - tree = ensemble->GetTree(i); + tree = active_forest.GetTree(i); // Sample tree i GFRSampleTreeOneIter( @@ -747,14 +770,18 @@ static inline void GFRSampleOneIter(ForestTracker& tracker, ForestContainer& for ); // Sample leaf parameters for tree i - tree = ensemble->GetTree(i); + tree = active_forest.GetTree(i); leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen); // Adjust any model state needed to run a tree sampler // For models that involve Bayesian backfitting, this amounts to subtracting tree i's // predictions back out of the residual (thus, using an updated "partial residual" in the following interation). // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object - AdjustStateAfterTreeSampling(tracker, forests, leaf_model, dataset, residual, tree_prior, backfitting, tree, i); + AdjustStateAfterTreeSampling(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i); + } + + if (keep_forest) { + forests.AddSample(active_forest); } } @@ -1000,49 +1027,53 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For } template -static inline void MCMCSampleOneIter(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, +static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - double global_variance, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { - // Previous number of samples - int prev_num_samples = forests.NumSamples(); + double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + // // Previous number of samples + // int prev_num_samples = forests.NumSamples(); - // Handle any "initialization" of a model (trees, ForestTracker, etc...) if this is the first sample and - // the model was not pre-initialized - bool var_trees; - if (std::is_same_v) var_trees = true; - else var_trees = false; - ModelInitialization(tracker, forests, leaf_model, dataset, residual, tree_prior, gen, - variable_weights, global_variance, pre_initialized, backfitting, - prev_num_samples, var_trees); + // // Handle any "initialization" of a model (trees, ForestTracker, etc...) if this is the first sample and + // // the model was not pre-initialized + // bool var_trees; + // if (std::is_same_v) var_trees = true; + // else var_trees = false; + // ModelInitialization(tracker, forests, leaf_model, dataset, residual, tree_prior, gen, + // variable_weights, global_variance, pre_initialized, backfitting, + // prev_num_samples, var_trees); // Run the MCMC algorithm for each tree - TreeEnsemble* ensemble = forests.GetEnsemble(prev_num_samples); - Tree* tree; + // TreeEnsemble* ensemble = forests.GetEnsemble(prev_num_samples); int num_trees = forests.NumTrees(); for (int i = 0; i < num_trees; i++) { // Adjust any model state needed to run a tree sampler // For models that involve Bayesian backfitting, this amounts to adding tree i's // predictions back to the residual (thus, training a model on the "partial residual") // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object - Tree* tree = ensemble->GetTree(i); - AdjustStateBeforeTreeSampling(tracker, forests, leaf_model, dataset, residual, tree_prior, backfitting, tree, i); + // Tree* tree = ensemble->GetTree(i); + Tree* tree = active_forest.GetTree(i); + AdjustStateBeforeTreeSampling(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i); // Sample tree i - tree = ensemble->GetTree(i); + tree = active_forest.GetTree(i); MCMCSampleTreeOneIter( tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, global_variance, leaf_suff_stat_args... ); // Sample leaf parameters for tree i - tree = ensemble->GetTree(i); + tree = active_forest.GetTree(i); leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen); // Adjust any model state needed to run a tree sampler // For models that involve Bayesian backfitting, this amounts to subtracting tree i's // predictions back out of the residual (thus, using an updated "partial residual" in the following interation). // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object - AdjustStateAfterTreeSampling(tracker, forests, leaf_model, dataset, residual, tree_prior, backfitting, tree, i); + AdjustStateAfterTreeSampling(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i); + } + + if (keep_forest) { + forests.AddSample(active_forest); } } diff --git a/man/Forest.Rd b/man/Forest.Rd new file mode 100644 index 00000000..6866779b --- /dev/null +++ b/man/Forest.Rd @@ -0,0 +1,364 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/forest.R +\name{Forest} +\alias{Forest} +\title{Class that stores a single ensemble of decision trees (often treated as the "active forest")} +\description{ +Wrapper around a C++ tree ensemble +} +\section{Public fields}{ +\if{html}{\out{
}} +\describe{ +\item{\code{forest_ptr}}{External pointer to a C++ TreeEnsemble class} +} +\if{html}{\out{
}} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-Forest-new}{\code{Forest$new()}} +\item \href{#method-Forest-predict}{\code{Forest$predict()}} +\item \href{#method-Forest-predict_raw}{\code{Forest$predict_raw()}} +\item \href{#method-Forest-set_root_leaves}{\code{Forest$set_root_leaves()}} +\item \href{#method-Forest-prepare_for_sampler}{\code{Forest$prepare_for_sampler()}} +\item \href{#method-Forest-adjust_residual}{\code{Forest$adjust_residual()}} +\item \href{#method-Forest-num_trees}{\code{Forest$num_trees()}} +\item \href{#method-Forest-output_dimension}{\code{Forest$output_dimension()}} +\item \href{#method-Forest-is_constant_leaf}{\code{Forest$is_constant_leaf()}} +\item \href{#method-Forest-is_exponentiated}{\code{Forest$is_exponentiated()}} +\item \href{#method-Forest-add_numeric_split_tree}{\code{Forest$add_numeric_split_tree()}} +\item \href{#method-Forest-get_tree_leaves}{\code{Forest$get_tree_leaves()}} +\item \href{#method-Forest-get_tree_split_counts}{\code{Forest$get_tree_split_counts()}} +\item \href{#method-Forest-get_forest_split_counts}{\code{Forest$get_forest_split_counts()}} +\item \href{#method-Forest-tree_max_depth}{\code{Forest$tree_max_depth()}} +\item \href{#method-Forest-average_max_depth}{\code{Forest$average_max_depth()}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-new}{}}} +\subsection{Method \code{new()}}{ +Create a new Forest object. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$new( + num_trees, + output_dimension = 1, + is_leaf_constant = F, + is_exponentiated = F +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{num_trees}}{Number of trees in the forest} + +\item{\code{output_dimension}}{Dimensionality of the outcome model} + +\item{\code{is_leaf_constant}}{Whether leaf is constant} + +\item{\code{is_exponentiated}}{Whether forest predictions should be exponentiated before being returned} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +A new \code{Forest} object. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-predict}{}}} +\subsection{Method \code{predict()}}{ +Predict forest on every sample in \code{forest_dataset} +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$predict(forest_dataset)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{forest_dataset}}{\code{ForestDataset} R class} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +vector of predictions with as many rows as in \code{forest_dataset} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-predict_raw}{}}} +\subsection{Method \code{predict_raw()}}{ +Predict "raw" leaf values (without being multiplied by basis) for every sample in \code{forest_dataset} +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$predict_raw(forest_dataset)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{forest_dataset}}{\code{ForestDataset} R class} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +Array of predictions for each observation in \code{forest_dataset} and +each sample in the \code{ForestSamples} class with each prediction having the +dimensionality of the forests' leaf model. In the case of a constant leaf model +or univariate leaf regression, this array is a vector (length is the number of +observations). In the case of a multivariate leaf regression, +this array is a matrix (number of observations by leaf model dimension, +number of samples). +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-set_root_leaves}{}}} +\subsection{Method \code{set_root_leaves()}}{ +Set a constant predicted value for every tree in the ensemble. +Stops program if any tree is more than a root node. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$set_root_leaves(leaf_value)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{leaf_value}}{Constant leaf value(s) to be fixed for each tree in the ensemble indexed by \code{forest_num}. Can be either a single number or a vector, depending on the forest's leaf dimension.} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-prepare_for_sampler}{}}} +\subsection{Method \code{prepare_for_sampler()}}{ +Set a constant predicted value for every tree in the ensemble. +Stops program if any tree is more than a root node. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$prepare_for_sampler( + dataset, + outcome, + forest_model, + leaf_model_int, + leaf_value +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{dataset}}{\code{ForestDataset} Dataset class (covariates, basis, etc...)} + +\item{\code{outcome}}{\code{Outcome} Outcome class (residual / partial residual)} + +\item{\code{forest_model}}{\code{ForestModel} object storing tracking structures used in training / sampling} + +\item{\code{leaf_model_int}}{Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance).} + +\item{\code{leaf_value}}{Constant leaf value(s) to be fixed for each tree in the ensemble indexed by \code{forest_num}. Can be either a single number or a vector, depending on the forest's leaf dimension.} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-adjust_residual}{}}} +\subsection{Method \code{adjust_residual()}}{ +Adjusts residual based on the predictions of a forest + +This is typically run just once at the beginning of a forest sampling algorithm. +After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$adjust_residual(dataset, outcome, forest_model, requires_basis, add)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{dataset}}{\code{ForestDataset} object storing the covariates and bases for a given forest} + +\item{\code{outcome}}{\code{Outcome} object storing the residuals to be updated based on forest predictions} + +\item{\code{forest_model}}{\code{ForestModel} object storing tracking structures used in training / sampling} + +\item{\code{requires_basis}}{Whether or not a forest requires a basis for prediction} + +\item{\code{add}}{Whether forest predictions should be added to or subtracted from residuals} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-num_trees}{}}} +\subsection{Method \code{num_trees()}}{ +Return number of trees in each ensemble of a \code{Forest} object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$num_trees()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Tree count +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-output_dimension}{}}} +\subsection{Method \code{output_dimension()}}{ +Return output dimension of trees in a \code{Forest} object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$output_dimension()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Leaf node parameter size +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-is_constant_leaf}{}}} +\subsection{Method \code{is_constant_leaf()}}{ +Return constant leaf status of trees in a \code{Forest} object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$is_constant_leaf()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +\code{T} if leaves are constant, \code{F} otherwise +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-is_exponentiated}{}}} +\subsection{Method \code{is_exponentiated()}}{ +Return exponentiation status of trees in a \code{Forest} object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$is_exponentiated()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +\code{T} if leaf predictions must be exponentiated, \code{F} otherwise +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-add_numeric_split_tree}{}}} +\subsection{Method \code{add_numeric_split_tree()}}{ +Add a numeric (i.e. \code{X[,i] <= c}) split to a given tree in the ensemble +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$add_numeric_split_tree( + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{tree_num}}{Index of the tree to be split} + +\item{\code{leaf_num}}{Leaf to be split} + +\item{\code{feature_num}}{Feature that defines the new split} + +\item{\code{split_threshold}}{Value that defines the cutoff of the new split} + +\item{\code{left_leaf_value}}{Value (or vector of values) to assign to the newly created left node} + +\item{\code{right_leaf_value}}{Value (or vector of values) to assign to the newly created right node} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-get_tree_leaves}{}}} +\subsection{Method \code{get_tree_leaves()}}{ +Retrieve a vector of indices of leaf nodes for a given tree in a given forest +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$get_tree_leaves(tree_num)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{tree_num}}{Index of the tree for which leaf indices will be retrieved} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-get_tree_split_counts}{}}} +\subsection{Method \code{get_tree_split_counts()}}{ +Retrieve a vector of split counts for every training set variable in a given tree in the forest +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$get_tree_split_counts(tree_num, num_features)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{tree_num}}{Index of the tree for which split counts will be retrieved} + +\item{\code{num_features}}{Total number of features in the training set} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-get_forest_split_counts}{}}} +\subsection{Method \code{get_forest_split_counts()}}{ +Retrieve a vector of split counts for every training set variable in the forest +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$get_forest_split_counts(num_features)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{num_features}}{Total number of features in the training set} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-tree_max_depth}{}}} +\subsection{Method \code{tree_max_depth()}}{ +Maximum depth of a specific tree in the forest +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$tree_max_depth(tree_num)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{tree_num}}{Tree index within forest} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +Maximum leaf depth +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-average_max_depth}{}}} +\subsection{Method \code{average_max_depth()}}{ +Average the maximum depth of each tree in the forest +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$average_max_depth()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Average maximum depth +} +} +} diff --git a/man/ForestModel.Rd b/man/ForestModel.Rd index 0ca53796..1d710ac1 100644 --- a/man/ForestModel.Rd +++ b/man/ForestModel.Rd @@ -79,6 +79,7 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR) forest_dataset, residual, forest_samples, + active_forest, rng, feature_types, leaf_model_int, @@ -88,6 +89,7 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR) b_forest, global_scale, cutpoint_grid_size = 500, + keep_forest = T, gfr = T, pre_initialized = F )}\if{html}{\out{}} @@ -102,6 +104,8 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR) \item{\code{forest_samples}}{Container of forest samples} +\item{\code{active_forest}}{"Active" forest updated by the sampler in each iteration} + \item{\code{rng}}{Wrapper around C++ random number generator} \item{\code{feature_types}}{Vector specifying the type of all p covariates in \code{forest_dataset} (0 = numeric, 1 = ordered categorical, 2 = unordered categorical)} @@ -118,11 +122,13 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR) \item{\code{global_scale}}{Global variance parameter} -\item{\code{cutpoint_grid_size}}{(Optional) Number of unique cutpoints to consider (default: 500, currently only used when \code{GFR = TRUE})} +\item{\code{cutpoint_grid_size}}{(Optional) Number of unique cutpoints to consider (default: \code{500}, currently only used when \code{GFR = TRUE})} + +\item{\code{keep_forest}}{(Optional) Whether the updated forest sample should be saved to \code{forest_samples}. Default: \code{T}.} -\item{\code{gfr}}{(Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm} +\item{\code{gfr}}{(Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: \code{T}.} -\item{\code{pre_initialized}}{(Optional) Whether or not the leaves are pre-initialized outside of the sampling loop (before any samples are drawn). In multi-forest implementations like BCF, this is true, though in the single-forest supervised learning implementation, we can let C++ do the initialization. Default: F.} +\item{\code{pre_initialized}}{(Optional) Whether or not the leaves are pre-initialized outside of the sampling loop (before any samples are drawn). In multi-forest implementations like BCF, this is true, though in the single-forest supervised learning implementation, we can let C++ do the initialization. Default: \code{F}.} } \if{html}{\out{}} } @@ -140,12 +146,7 @@ of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). Once a basis has been updated, the overall "function" represented by a tree model has changed and this should be reflected through to the residual before the next sampling loop is run. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{ForestModel$propagate_basis_update( - dataset, - outcome, - forest_samples, - forest_num -)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{ForestModel$propagate_basis_update(dataset, outcome, active_forest)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -155,9 +156,7 @@ changed and this should be reflected through to the residual before the next sam \item{\code{outcome}}{\code{Outcome} object storing the residuals to be updated based on forest predictions} -\item{\code{forest_samples}}{\code{ForestSamples} object storing draws of tree ensembles} - -\item{\code{forest_num}}{Index of forest used to update residuals (starting at 1, in R style)} +\item{\code{active_forest}}{"Active" forest updated by the sampler in each iteration} } \if{html}{\out{}} } diff --git a/man/ForestSamples.Rd b/man/ForestSamples.Rd index bead9332..6179c274 100644 --- a/man/ForestSamples.Rd +++ b/man/ForestSamples.Rd @@ -33,6 +33,8 @@ Wrapper around a C++ container of tree ensembles \item \href{#method-ForestSamples-num_samples}{\code{ForestSamples$num_samples()}} \item \href{#method-ForestSamples-num_trees}{\code{ForestSamples$num_trees()}} \item \href{#method-ForestSamples-output_dimension}{\code{ForestSamples$output_dimension()}} +\item \href{#method-ForestSamples-is_constant_leaf}{\code{ForestSamples$is_constant_leaf()}} +\item \href{#method-ForestSamples-is_exponentiated}{\code{ForestSamples$is_exponentiated()}} \item \href{#method-ForestSamples-add_forest_with_constant_leaves}{\code{ForestSamples$add_forest_with_constant_leaves()}} \item \href{#method-ForestSamples-add_numeric_split_tree}{\code{ForestSamples$add_numeric_split_tree()}} \item \href{#method-ForestSamples-get_tree_leaves}{\code{ForestSamples$get_tree_leaves()}} @@ -62,6 +64,7 @@ Wrapper around a C++ container of tree ensembles \item \href{#method-ForestSamples-num_split_nodes}{\code{ForestSamples$num_split_nodes()}} \item \href{#method-ForestSamples-nodes}{\code{ForestSamples$nodes()}} \item \href{#method-ForestSamples-leaves}{\code{ForestSamples$leaves()}} +\item \href{#method-ForestSamples-delete_sample}{\code{ForestSamples$delete_sample()}} } } \if{html}{\out{
}} @@ -443,6 +446,32 @@ Leaf node parameter size } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestSamples-is_constant_leaf}{}}} +\subsection{Method \code{is_constant_leaf()}}{ +Return constant leaf status of trees in a \code{ForestContainer} object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestSamples$is_constant_leaf()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +\code{T} if leaves are constant, \code{F} otherwise +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestSamples-is_exponentiated}{}}} +\subsection{Method \code{is_exponentiated()}}{ +Return exponentiation status of trees in a \code{ForestContainer} object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestSamples$is_exponentiated()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +\code{T} if leaf predictions must be exponentiated, \code{F} otherwise +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ForestSamples-add_forest_with_constant_leaves}{}}} \subsection{Method \code{add_forest_with_constant_leaves()}}{ @@ -464,7 +493,7 @@ set to the value / vector provided \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ForestSamples-add_numeric_split_tree}{}}} \subsection{Method \code{add_numeric_split_tree()}}{ -Add a numeric (i.e. X\link{,i} <= c) split to a given tree in the ensemble +Add a numeric (i.e. \code{X[,i] <= c}) split to a given tree in the ensemble \subsection{Usage}{ \if{html}{\out{
}}\preformatted{ForestSamples$add_numeric_split_tree( forest_num, @@ -1084,4 +1113,21 @@ Array of leaf indices in a given tree in a given forest in a \code{ForestSamples Indices of leaf nodes } } +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestSamples-delete_sample}{}}} +\subsection{Method \code{delete_sample()}}{ +Modify the \code{ForestSamples} object by removing the forest sample indexed by `forest_num +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestSamples$delete_sample(forest_num)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{forest_num}}{Index of the forest to be removed} +} +\if{html}{\out{
}} +} +} } diff --git a/man/RandomEffectSamples.Rd b/man/RandomEffectSamples.Rd index 90981546..3f0da6b9 100644 --- a/man/RandomEffectSamples.Rd +++ b/man/RandomEffectSamples.Rd @@ -33,6 +33,7 @@ needed for prediction / serialization \item \href{#method-RandomEffectSamples-append_from_json_string}{\code{RandomEffectSamples$append_from_json_string()}} \item \href{#method-RandomEffectSamples-predict}{\code{RandomEffectSamples$predict()}} \item \href{#method-RandomEffectSamples-extract_parameter_samples}{\code{RandomEffectSamples$extract_parameter_samples()}} +\item \href{#method-RandomEffectSamples-delete_sample}{\code{RandomEffectSamples$delete_sample()}} \item \href{#method-RandomEffectSamples-extract_label_mapping}{\code{RandomEffectSamples$extract_label_mapping()}} } } @@ -244,6 +245,23 @@ The sigma array has dimension (\code{num_components}, \code{num_samples}) and is } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectSamples-delete_sample}{}}} +\subsection{Method \code{delete_sample()}}{ +Modify the \code{RandomEffectsSamples} object by removing the parameter samples index by \code{sample_num}. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectSamples$delete_sample(sample_num)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{sample_num}}{Index of the RFX sample to be removed} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-RandomEffectSamples-extract_label_mapping}{}}} \subsection{Method \code{extract_label_mapping()}}{ diff --git a/man/RandomEffectsModel.Rd b/man/RandomEffectsModel.Rd index 41c8a03a..1c7aa6b1 100644 --- a/man/RandomEffectsModel.Rd +++ b/man/RandomEffectsModel.Rd @@ -65,6 +65,7 @@ Sample from random effects model. residual, rfx_tracker, rfx_samples, + keep_sample, global_variance, rng )}\if{html}{\out{
}} @@ -81,6 +82,8 @@ Sample from random effects model. \item{\code{rfx_samples}}{Object of type \code{RandomEffectSamples}} +\item{\code{keep_sample}}{Whether sample should be retained in \code{rfx_samples}. If \code{FALSE}, the state of \code{rfx_tracker} will be updated, but the parameter values will not be added to the sample container. Samples are commonly discarded due to burn-in or thinning.} + \item{\code{global_variance}}{Scalar global variance parameter} \item{\code{rng}}{Object of type \code{CppRNG}} diff --git a/man/bart.Rd b/man/bart.Rd index 7d3e07d1..d5d7bacd 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -17,6 +17,8 @@ bart( num_gfr = 5, num_burnin = 0, num_mcmc = 100, + previous_model_json = NULL, + warmstart_sample_num = NULL, params = list() ) } @@ -60,6 +62,10 @@ that were not in the training set.} \item{num_mcmc}{Number of "retained" iterations of the MCMC sampler. Default: 100.} +\item{previous_model_json}{(Optional) JSON string containing a previous BART model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Default: \code{NULL}.} + +\item{warmstart_sample_num}{(Optional) Sample number from \code{previous_model_json} that will be used to warmstart this BART sampler. One-indexed (so that the first sample is used for warm-start by setting \code{warmstart_sample_num = 1}). Default: \code{NULL}.} + \item{params}{The list of model parameters, each of which has a default value. \strong{1. Global Parameters} @@ -73,8 +79,10 @@ that were not in the training set.} \item \code{random_seed} Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to \code{std::random_device}. \item \code{sample_sigma_global} Whether or not to update the \code{sigma^2} global error variance parameter based on \code{IG(a_global, b_global)}. Default: \code{TRUE}. \item \code{keep_burnin} Whether or not "burnin" samples should be included in cached predictions. Default \code{FALSE}. Ignored if \code{num_mcmc = 0}. -\item \code{keep_gfr} Whether or not "grow-from-root" samples should be included in cached predictions. Default \code{TRUE}. Ignored if \code{num_mcmc = 0}. +\item \code{keep_gfr} Whether or not "grow-from-root" samples should be included in cached predictions. Default \code{FALSE}. Ignored if \code{num_mcmc = 0}. \item \code{standardize} Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: \code{TRUE}. +\item \code{keep_every} How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default \code{1}. Setting \code{keep_every <- k} for some \code{k > 1} will "thin" the MCMC samples by retaining every \code{k}-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. +\item \code{num_chains} How many independent MCMC chains should be sampled. If \code{num_mcmc = 0}, this is ignored. If \code{num_gfr = 0}, then each chain is run from root for \code{num_mcmc * keep_every + num_burnin} iterations, with \code{num_mcmc} samples retained. If \code{num_gfr > 0}, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that \code{num_gfr >= num_chains}. Default: \code{1}. \item \code{verbose} Whether or not to print progress during the sampling loops. Default: \code{FALSE}. } diff --git a/man/bcf.Rd b/man/bcf.Rd index 6e69348a..d5743a6d 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -19,6 +19,8 @@ bcf( num_gfr = 5, num_burnin = 0, num_mcmc = 100, + previous_model_json = NULL, + warmstart_sample_num = NULL, params = list() ) } @@ -61,6 +63,10 @@ that were not in the training set.} \item{num_mcmc}{Number of "retained" iterations of the MCMC sampler. Default: 100.} +\item{previous_model_json}{(Optional) JSON string containing a previous BCF model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Default: \code{NULL}.} + +\item{warmstart_sample_num}{(Optional) Sample number from \code{previous_model_json} that will be used to warmstart this BCF sampler. One-indexed (so that the first sample is used for warm-start by setting \code{warmstart_sample_num = 1}). Default: \code{NULL}.} + \item{params}{The list of model parameters, each of which has a default value. \strong{1. Global Parameters} @@ -79,6 +85,8 @@ that were not in the training set.} \item \code{keep_burnin} Whether or not "burnin" samples should be included in cached predictions. Default \code{FALSE}. Ignored if \code{num_mcmc = 0}. \item \code{keep_gfr} Whether or not "grow-from-root" samples should be included in cached predictions. Default \code{FALSE}. Ignored if \code{num_mcmc = 0}. \item \code{standardize} Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: \code{TRUE}. +\item \code{keep_every} How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default \code{1}. Setting \code{keep_every <- k} for some \code{k > 1} will "thin" the MCMC samples by retaining every \code{k}-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. +\item \code{num_chains} How many independent MCMC chains should be sampled. If \code{num_mcmc = 0}, this is ignored. If \code{num_gfr = 0}, then each chain is run from root for \code{num_mcmc * keep_every + num_burnin} iterations, with \code{num_mcmc} samples retained. If \code{num_gfr > 0}, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that \code{num_gfr >= num_chains}. Default: \code{1}. \item \code{verbose} Whether or not to print progress during the sampling loops. Default: \code{FALSE}. \item \code{sample_sigma_global} Whether or not to update the \code{sigma^2} global error variance parameter based on \code{IG(a_global, b_global)}. Default: \code{TRUE}. } @@ -205,8 +213,10 @@ tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test) -# plot(rowMeans(bcf_model$mu_hat_test), mu_test, xlab = "predicted", ylab = "actual", main = "Prognostic function") +# plot(rowMeans(bcf_model$mu_hat_test), mu_test, xlab = "predicted", +# ylab = "actual", main = "Prognostic function") # abline(0,1,col="red",lty=3,lwd=3) -# plot(rowMeans(bcf_model$tau_hat_test), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") +# plot(rowMeans(bcf_model$tau_hat_test), tau_test, xlab = "predicted", +# ylab = "actual", main = "Treatment effect") # abline(0,1,col="red",lty=3,lwd=3) } diff --git a/man/calibrate_inverse_gamma_error_variance.Rd b/man/calibrate_inverse_gamma_error_variance.Rd index 9d4b2713..4e7aa68e 100644 --- a/man/calibrate_inverse_gamma_error_variance.Rd +++ b/man/calibrate_inverse_gamma_error_variance.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/calibration.R \name{calibrate_inverse_gamma_error_variance} \alias{calibrate_inverse_gamma_error_variance} -\title{Calibrate the scale parameter on an inverse gamma prior for the global error variance as in Chipman et al (2022) \link{1}} +\title{Calibrate the scale parameter on an inverse gamma prior for the global error variance as in Chipman et al (2022)} \usage{ calibrate_inverse_gamma_error_variance( y, @@ -18,19 +18,19 @@ calibrate_inverse_gamma_error_variance( \item{X}{Covariates to be used to partition trees in an ensemble or series of ensemble.} -\item{W}{\link{Optional} Basis used to define a "leaf regression" model for each decision tree. The "classic" BART model assumes a constant leaf parameter, which is equivalent to a "leaf regression" on a basis of all ones, though it is not necessary to pass a vector of ones, here or to the BART function. Default: \code{NULL}.} +\item{W}{(Optional) Basis used to define a "leaf regression" model for each decision tree. The "classic" BART model assumes a constant leaf parameter, which is equivalent to a "leaf regression" on a basis of all ones, though it is not necessary to pass a vector of ones, here or to the BART function. Default: \code{NULL}.} \item{nu}{The shape parameter for the global error variance's IG prior. The scale parameter in the Sparapani et al (2021) parameterization is defined as \code{nu*lambda} where \code{lambda} is the output of this function. Default: \code{3}.} -\item{quant}{\link{Optional} Quantile of the inverse gamma prior distribution represented by a linear-regression-based overestimate of \code{sigma^2}. Default: \code{0.9}.} +\item{quant}{(Optional) Quantile of the inverse gamma prior distribution represented by a linear-regression-based overestimate of \code{sigma^2}. Default: \code{0.9}.} -\item{standardize}{\link{Optional} Whether or not outcome should be standardized (\code{(y-mean(y))/sd(y)}) before calibration of \code{lambda}. Default: \code{TRUE}.} +\item{standardize}{(Optional) Whether or not outcome should be standardized (\code{(y-mean(y))/sd(y)}) before calibration of \code{lambda}. Default: \code{TRUE}.} } \value{ Value of \code{lambda} which determines the scale parameter of the global error variance prior (\code{sigma^2 ~ IG(nu,nu*lambda)}) } \description{ -\link{1} Chipman, H., George, E., Hahn, R., McCulloch, R., Pratola, M. and Sparapani, R. (2022). Bayesian Additive Regression Trees, Computational Approaches. In Wiley StatsRef: Statistics Reference Online (eds N. Balakrishnan, T. Colton, B. Everitt, W. Piegorsch, F. Ruggeri and J.L. Teugels). https://doi.org/10.1002/9781118445112.stat08288 +Chipman, H., George, E., Hahn, R., McCulloch, R., Pratola, M. and Sparapani, R. (2022). Bayesian Additive Regression Trees, Computational Approaches. In Wiley StatsRef: Statistics Reference Online (eds N. Balakrishnan, T. Colton, B. Everitt, W. Piegorsch, F. Ruggeri and J.L. Teugels). https://doi.org/10.1002/9781118445112.stat08288 } \examples{ n <- 100 diff --git a/man/convertBARTStateToJson.Rd b/man/convertBARTStateToJson.Rd new file mode 100644 index 00000000..004d8014 --- /dev/null +++ b/man/convertBARTStateToJson.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{convertBARTStateToJson} +\alias{convertBARTStateToJson} +\title{Convert in-memory BART model objects (forests, random effects, vectors) to in-memory JSON. +This function is primarily a convenience function for serialization / deserialization in a parallel BART sampler.} +\usage{ +convertBARTStateToJson( + param_list, + mean_forest = NULL, + variance_forest = NULL, + rfx_samples = NULL, + global_variance_samples = NULL, + local_variance_samples = NULL +) +} +\arguments{ +\item{param_list}{List containing high-level model state parameters} + +\item{mean_forest}{Container of conditional mean forest samples (optional). Default: \code{NULL}.} + +\item{variance_forest}{Container of conditional variance forest samples (optional). Default: \code{NULL}.} + +\item{rfx_samples}{Container of random effect samples (optional). Default: \code{NULL}.} + +\item{global_variance_samples}{Vector of global error variance samples (optional). Default: \code{NULL}.} + +\item{local_variance_samples}{Vector of leaf scale samples (optional). Default: \code{NULL}.} +} +\value{ +Object of type \code{CppJson} +} +\description{ +Convert in-memory BART model objects (forests, random effects, vectors) to in-memory JSON. +This function is primarily a convenience function for serialization / deserialization in a parallel BART sampler. +} diff --git a/man/convertBCFModelToJson.Rd b/man/convertBCFModelToJson.Rd index baf8d037..29d1051c 100644 --- a/man/convertBCFModelToJson.Rd +++ b/man/convertBCFModelToJson.Rd @@ -66,12 +66,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] +bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, - sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) + params = bcf_params) # bcf_json <- convertBCFModelToJson(bcf_model) } diff --git a/man/createBCFModelFromCombinedJsonString.Rd b/man/createBCFModelFromCombinedJsonString.Rd new file mode 100644 index 00000000..a8a14194 --- /dev/null +++ b/man/createBCFModelFromCombinedJsonString.Rd @@ -0,0 +1,80 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bcf.R +\name{createBCFModelFromCombinedJsonString} +\alias{createBCFModelFromCombinedJsonString} +\title{Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object +which can be used for prediction, etc...} +\usage{ +createBCFModelFromCombinedJsonString(json_string_list) +} +\arguments{ +\item{json_string_list}{List of JSON strings which can be parsed to objects of type \code{CppJson} containing Json representation of a BART model} +} +\value{ +Object of type \code{bartmodel} +} +\description{ +Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object +which can be used for prediction, etc... +} +\examples{ +n <- 100 +p <- 5 +x1 <- rnorm(n) +x2 <- rnorm(n) +x3 <- rnorm(n) +x4 <- rnorm(n) +x5 <- rnorm(n) +X <- cbind(x1,x2,x3,x4,x5) +p <- ncol(X) +g <- function(x) {ifelse(x[,5] < -0.44,2,ifelse(x[,5] < 0.44,-1,4))} +mu1 <- function(x) {1+g(x)+x[,1]*x[,3]} +mu2 <- function(x) {1+g(x)+6*abs(x[,3]-1)} +tau1 <- function(x) {rep(3,nrow(x))} +tau2 <- function(x) {1+2*x[,2]*(x[,4] > 0)} +mu_x <- mu1(X) +tau_x <- tau2(X) +pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 +Z <- rbinom(n,1,pi_x) +E_XZ <- mu_x + Z*tau_x +snr <- 3 +group_ids <- rep(c(1,2), n \%/\% 2) +rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) +rfx_basis <- cbind(1, runif(n, -1, 1)) +rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) +y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) +X <- as.data.frame(X) +X$x4 <- factor(X$x4, ordered = TRUE) +X$x5 <- factor(X$x5, ordered = TRUE) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +group_ids_test <- group_ids[test_inds] +group_ids_train <- group_ids[train_inds] +rfx_basis_test <- rfx_basis[test_inds,] +rfx_basis_train <- rfx_basis[train_inds,] +rfx_term_test <- rfx_term[test_inds] +rfx_term_train <- rfx_term[train_inds] +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + pi_train = pi_train, group_ids_train = group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, + rfx_basis_test = rfx_basis_test, + num_gfr = 100, num_burnin = 0, num_mcmc = 100) +# bcf_json_string_list <- list(saveBCFModelToJsonString(bcf_model)) +# bcf_model_roundtrip <- createBCFModelFromCombinedJsonString(bcf_json_string_list) +} diff --git a/man/createBCFModelFromJson.Rd b/man/createBCFModelFromJson.Rd index 76c3ebcb..e605b9c9 100644 --- a/man/createBCFModelFromJson.Rd +++ b/man/createBCFModelFromJson.Rd @@ -68,13 +68,14 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] +bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, - sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) + params = bcf_params) # bcf_json <- convertBCFModelToJson(bcf_model) # bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) } diff --git a/man/createBCFModelFromJsonFile.Rd b/man/createBCFModelFromJsonFile.Rd index dd2354a1..ec0075bf 100644 --- a/man/createBCFModelFromJsonFile.Rd +++ b/man/createBCFModelFromJsonFile.Rd @@ -68,13 +68,14 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] +bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, - sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) + params = bcf_params) # saveBCFModelToJsonFile(bcf_model, "test.json") # bcf_model_roundtrip <- createBCFModelFromJsonFile("test.json") } diff --git a/man/createBCFModelFromJsonString.Rd b/man/createBCFModelFromJsonString.Rd index b25557ab..79b79773 100644 --- a/man/createBCFModelFromJsonString.Rd +++ b/man/createBCFModelFromJsonString.Rd @@ -73,8 +73,7 @@ bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, - num_gfr = 100, num_burnin = 0, num_mcmc = 100, - sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) + num_gfr = 100, num_burnin = 0, num_mcmc = 100) # bcf_json <- saveBCFModelToJsonString(bcf_model) # bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json) } diff --git a/man/createForest.Rd b/man/createForest.Rd new file mode 100644 index 00000000..541dc9f3 --- /dev/null +++ b/man/createForest.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/forest.R +\name{createForest} +\alias{createForest} +\title{Create a forest} +\usage{ +createForest( + num_trees, + output_dimension = 1, + is_leaf_constant = F, + is_exponentiated = F +) +} +\arguments{ +\item{num_trees}{Number of trees in the forest} + +\item{output_dimension}{Dimensionality of the outcome model} + +\item{is_leaf_constant}{Whether leaf is constant} + +\item{is_exponentiated}{Whether forest predictions should be exponentiated before being returned} +} +\value{ +\code{Forest} object +} +\description{ +Create a forest +} diff --git a/man/createForestModel.Rd b/man/createForestModel.Rd index dee2d2d6..0ca6fe16 100644 --- a/man/createForestModel.Rd +++ b/man/createForestModel.Rd @@ -29,6 +29,8 @@ createForestModel( \item{beta}{Depth prior penalty in tree prior} \item{min_samples_leaf}{Minimum number of samples in a tree leaf} + +\item{max_depth}{Maximum depth of any tree in the ensemble in the mean model. Setting to \code{-1} does not enforce any depth limits on trees.} } \value{ \code{ForestModel} object diff --git a/man/getRandomEffectSamples.bcf.Rd b/man/getRandomEffectSamples.bcf.Rd index 4fbdc6ae..4e38df8e 100644 --- a/man/getRandomEffectSamples.bcf.Rd +++ b/man/getRandomEffectSamples.bcf.Rd @@ -70,12 +70,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] +bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, - sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) + params = bcf_params) rfx_samples <- getRandomEffectSamples(bcf_model) } diff --git a/man/orderedCatPreprocess.Rd b/man/orderedCatPreprocess.Rd index 5e410cc2..4c5bf54c 100644 --- a/man/orderedCatPreprocess.Rd +++ b/man/orderedCatPreprocess.Rd @@ -28,7 +28,9 @@ ordered levels to integers if necessary, and storing the unique levels of a variable. } \examples{ -x_levels <- c("1. Strongly disagree", "2. Disagree", "3. Neither agree nor disagree", "4. Agree", "5. Strongly agree") -x <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") +x_levels <- c("1. Strongly disagree", "2. Disagree", "3. Neither agree nor disagree", + "4. Agree", "5. Strongly agree") +x <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", + "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") x_processed <- orderedCatPreprocess(x, x_levels) } diff --git a/man/predict.bcf.Rd b/man/predict.bcf.Rd index 1841c14d..311cce50 100644 --- a/man/predict.bcf.Rd +++ b/man/predict.bcf.Rd @@ -77,8 +77,10 @@ tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train) preds <- predict(bcf_model, X_test, Z_test, pi_test) -# plot(rowMeans(preds$mu_hat), mu_test, xlab = "predicted", ylab = "actual", main = "Prognostic function") +# plot(rowMeans(preds$mu_hat), mu_test, xlab = "predicted", +# ylab = "actual", main = "Prognostic function") # abline(0,1,col="red",lty=3,lwd=3) -# plot(rowMeans(preds$tau_hat), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") +# plot(rowMeans(preds$tau_hat), tau_test, xlab = "predicted", +# ylab = "actual", main = "Treatment effect") # abline(0,1,col="red",lty=3,lwd=3) } diff --git a/man/preprocessTrainData.Rd b/man/preprocessTrainData.Rd index abd02e15..e9a31edd 100644 --- a/man/preprocessTrainData.Rd +++ b/man/preprocessTrainData.Rd @@ -9,8 +9,6 @@ preprocessTrainData(input_data) } \arguments{ \item{input_data}{Covariates, provided as either a dataframe or a matrix} - -\item{variable_weights}{Numeric weights reflecting the relative probability of splitting on each variable} } \value{ List with preprocessed (unmodified) data and details on the number of each type diff --git a/man/resetActiveForest.Rd b/man/resetActiveForest.Rd new file mode 100644 index 00000000..58d418df --- /dev/null +++ b/man/resetActiveForest.Rd @@ -0,0 +1,18 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/forest.R +\name{resetActiveForest} +\alias{resetActiveForest} +\title{Re-initialize an active forest from a specific forest in a \code{ForestContainer}} +\usage{ +resetActiveForest(active_forest, forest_samples, forest_num) +} +\arguments{ +\item{active_forest}{Current active forest} + +\item{forest_samples}{Container of forest samples from which to re-initialize active forest} + +\item{forest_num}{Index of forest samples from which to initialize active forest} +} +\description{ +Re-initialize an active forest from a specific forest in a \code{ForestContainer} +} diff --git a/man/resetForestModel.Rd b/man/resetForestModel.Rd new file mode 100644 index 00000000..4c730b4a --- /dev/null +++ b/man/resetForestModel.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/forest.R +\name{resetForestModel} +\alias{resetForestModel} +\title{Re-initialize a forest model (tracking data structures) from a specific forest in a \code{ForestContainer}} +\usage{ +resetForestModel(forest_model, forest, dataset, residual, is_mean_model) +} +\arguments{ +\item{forest_model}{Forest model with tracking data structures} + +\item{forest}{Forest from which to re-initialize forest model} + +\item{dataset}{Training dataset object} + +\item{residual}{Residual which will also be updated} + +\item{is_mean_model}{Whether the model being updated is a conditional mean model} +} +\description{ +Re-initialize a forest model (tracking data structures) from a specific forest in a \code{ForestContainer} +} diff --git a/man/resetRandomEffectsModel.Rd b/man/resetRandomEffectsModel.Rd new file mode 100644 index 00000000..4b2c4568 --- /dev/null +++ b/man/resetRandomEffectsModel.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/random_effects.R +\name{resetRandomEffectsModel} +\alias{resetRandomEffectsModel} +\title{Reset a \code{RandomEffectsModel} object based on the parameters indexed by \code{sample_num} in a \code{RandomEffectsSamples} object} +\usage{ +resetRandomEffectsModel(rfx_model, rfx_samples, sample_num, sigma_alpha_init) +} +\arguments{ +\item{rfx_model}{Object of type \code{RandomEffectsModel}.} + +\item{rfx_samples}{Object of type \code{RandomEffectSamples}.} + +\item{sample_num}{Index of sample stored in \code{rfx_samples} from which to reset the state of a random effects model. Zero-indexed, so resetting based on the first sample would require setting \code{sample_num = 0}.} + +\item{sigma_alpha_init}{Initial value of the "working parameter" scale parameter.} +} +\description{ +Reset a \code{RandomEffectsModel} object based on the parameters indexed by \code{sample_num} in a \code{RandomEffectsSamples} object +} diff --git a/man/resetRandomEffectsTracker.Rd b/man/resetRandomEffectsTracker.Rd new file mode 100644 index 00000000..14db8d1a --- /dev/null +++ b/man/resetRandomEffectsTracker.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/random_effects.R +\name{resetRandomEffectsTracker} +\alias{resetRandomEffectsTracker} +\title{Reset a \code{RandomEffectsTracker} object based on the parameters indexed by \code{sample_num} in a \code{RandomEffectsSamples} object} +\usage{ +resetRandomEffectsTracker( + rfx_tracker, + rfx_model, + rfx_dataset, + residual, + rfx_samples +) +} +\arguments{ +\item{rfx_tracker}{Object of type \code{RandomEffectsTracker}.} + +\item{rfx_model}{Object of type \code{RandomEffectsModel}.} + +\item{rfx_dataset}{Object of type \code{RandomEffectsDataset}.} + +\item{residual}{Object of type \code{Outcome}.} + +\item{rfx_samples}{Object of type \code{RandomEffectSamples}.} +} +\description{ +Reset a \code{RandomEffectsTracker} object based on the parameters indexed by \code{sample_num} in a \code{RandomEffectsSamples} object +} diff --git a/man/rootResetActiveForest.Rd b/man/rootResetActiveForest.Rd new file mode 100644 index 00000000..1767c7b5 --- /dev/null +++ b/man/rootResetActiveForest.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/forest.R +\name{rootResetActiveForest} +\alias{rootResetActiveForest} +\title{Reset an active forest to an ensemble of single-node (i.e. root) trees} +\usage{ +rootResetActiveForest(active_forest) +} +\arguments{ +\item{active_forest}{Current active forest} +} +\value{ +\code{Forest} object +} +\description{ +Reset an active forest to an ensemble of single-node (i.e. root) trees +} diff --git a/man/rootResetRandomEffectsModel.Rd b/man/rootResetRandomEffectsModel.Rd new file mode 100644 index 00000000..409ef715 --- /dev/null +++ b/man/rootResetRandomEffectsModel.Rd @@ -0,0 +1,34 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/random_effects.R +\name{rootResetRandomEffectsModel} +\alias{rootResetRandomEffectsModel} +\title{Reset a \code{RandomEffectsModel} object to its "default" state} +\usage{ +rootResetRandomEffectsModel( + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale +) +} +\arguments{ +\item{rfx_model}{Object of type \code{RandomEffectsModel}.} + +\item{alpha_init}{Initial value of the "working parameter".} + +\item{xi_init}{Initial value of the "group parameters".} + +\item{sigma_alpha_init}{Initial value of the "working parameter" scale parameter.} + +\item{sigma_xi_init}{Initial value of the "group parameters" scale parameter.} + +\item{sigma_xi_shape}{Shape parameter for the inverse gamma variance model on the group parameters.} + +\item{sigma_xi_scale}{Scale parameter for the inverse gamma variance model on the group parameters.} +} +\description{ +Reset a \code{RandomEffectsModel} object to its "default" state +} diff --git a/man/rootResetRandomEffectsTracker.Rd b/man/rootResetRandomEffectsTracker.Rd new file mode 100644 index 00000000..3fbd1860 --- /dev/null +++ b/man/rootResetRandomEffectsTracker.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/random_effects.R +\name{rootResetRandomEffectsTracker} +\alias{rootResetRandomEffectsTracker} +\title{Reset a \code{RandomEffectsTracker} object to its "default" state} +\usage{ +rootResetRandomEffectsTracker(rfx_tracker, rfx_model, rfx_dataset, residual) +} +\arguments{ +\item{rfx_tracker}{Object of type \code{RandomEffectsTracker}.} + +\item{rfx_model}{Object of type \code{RandomEffectsModel}.} + +\item{rfx_dataset}{Object of type \code{RandomEffectsDataset}.} + +\item{residual}{Object of type \code{Outcome}.} +} +\description{ +Reset a \code{RandomEffectsTracker} object to its "default" state +} diff --git a/man/sample_tau_one_iteration.Rd b/man/sample_tau_one_iteration.Rd index 7fd48a8d..8e6201d5 100644 --- a/man/sample_tau_one_iteration.Rd +++ b/man/sample_tau_one_iteration.Rd @@ -4,18 +4,16 @@ \alias{sample_tau_one_iteration} \title{Sample one iteration of the leaf parameter variance model (only for univariate basis and constant leaf!)} \usage{ -sample_tau_one_iteration(forest_samples, rng, a, b, sample_num) +sample_tau_one_iteration(forest, rng, a, b) } \arguments{ -\item{forest_samples}{Container of forest samples} +\item{forest}{C++ forest} \item{rng}{C++ random number generator} \item{a}{Leaf variance shape parameter} \item{b}{Leaf variance scale parameter} - -\item{sample_num}{Sample index} } \description{ Sample one iteration of the leaf parameter variance model (only for univariate basis and constant leaf!) diff --git a/man/saveBCFModelToJsonFile.Rd b/man/saveBCFModelToJsonFile.Rd index 47fa634e..5c6bb6c0 100644 --- a/man/saveBCFModelToJsonFile.Rd +++ b/man/saveBCFModelToJsonFile.Rd @@ -65,12 +65,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] +bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, - sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) + params = bcf_params) # saveBCFModelToJsonFile(bcf_model, "test.json") } diff --git a/man/saveBCFModelToJsonString.Rd b/man/saveBCFModelToJsonString.Rd index 7dd31418..63c0d298 100644 --- a/man/saveBCFModelToJsonString.Rd +++ b/man/saveBCFModelToJsonString.Rd @@ -66,12 +66,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] +bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, - sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) + params = bcf_params) # saveBCFModelToJsonString(bcf_model) } diff --git a/src/Makevars b/src/Makevars index 8f704cd9..83bd4627 100644 --- a/src/Makevars +++ b/src/Makevars @@ -3,7 +3,8 @@ PKGROOT=.. STOCHTREE_CPPFLAGS = -DSTOCHTREE_R_BUILD -PKG_CPPFLAGS= -I$(PKGROOT)/include -I$(PKGROOT)/deps/eigen -I$(PKGROOT)/deps/fmt/include -I$(PKGROOT)/deps/fast_double_parser/include -I$(PKGROOT)/deps/boost_math/include $(STOCHTREE_CPPFLAGS) +# PKG_CPPFLAGS= -I$(PKGROOT)/include -I$(PKGROOT)/deps/eigen -I$(PKGROOT)/deps/fmt/include -I$(PKGROOT)/deps/fast_double_parser/include -I$(PKGROOT)/deps/boost_math/include $(STOCHTREE_CPPFLAGS) +PKG_CPPFLAGS= -I$(PKGROOT)/include -I$(PKGROOT)/deps/eigen -I$(PKGROOT)/deps/fmt/include -I$(PKGROOT)/deps/fast_double_parser/include $(STOCHTREE_CPPFLAGS) CXX_STD=CXX17 diff --git a/src/R_random_effects.cpp b/src/R_random_effects.cpp index ccb3aa98..f627b3c5 100644 --- a/src/R_random_effects.cpp +++ b/src/R_random_effects.cpp @@ -181,9 +181,9 @@ cpp11::external_pointer rfx_label_mapper_cpp(cpp11::exte [[cpp11::register]] void rfx_model_sample_random_effects_cpp(cpp11::external_pointer rfx_model, cpp11::external_pointer rfx_dataset, cpp11::external_pointer residual, cpp11::external_pointer rfx_tracker, - cpp11::external_pointer rfx_container, double global_variance, cpp11::external_pointer rng) { + cpp11::external_pointer rfx_container, bool keep_sample, double global_variance, cpp11::external_pointer rng) { rfx_model->SampleRandomEffects(*rfx_dataset, *residual, *rfx_tracker, global_variance, *rng); - rfx_container->AddSample(*rfx_model); + if (keep_sample) rfx_container->AddSample(*rfx_model); } [[cpp11::register]] @@ -220,6 +220,11 @@ int rfx_container_num_groups_cpp(cpp11::external_pointerNumGroups(); } +[[cpp11::register]] +void rfx_container_delete_sample_cpp(cpp11::external_pointer rfx_container, int sample_num) { + rfx_container->DeleteSample(sample_num); +} + [[cpp11::register]] void rfx_model_set_working_parameter_cpp(cpp11::external_pointer rfx_model, cpp11::doubles working_param_init) { Eigen::VectorXd working_param_eigen(working_param_init.size()); @@ -313,3 +318,29 @@ cpp11::list rfx_label_mapper_to_list_cpp(cpp11::external_pointer rfx_model, + cpp11::external_pointer rfx_container, + int sample_num) { + // Reet the RFX tracker + rfx_model->ResetFromSample(*rfx_container, sample_num); +} + +[[cpp11::register]] +void reset_rfx_tracker_cpp(cpp11::external_pointer tracker, + cpp11::external_pointer dataset, + cpp11::external_pointer residual, + cpp11::external_pointer rfx_model) { + // Reset the RFX tracker + tracker->ResetFromSample(*rfx_model, *dataset, *residual); +} + +[[cpp11::register]] +void root_reset_rfx_tracker_cpp(cpp11::external_pointer tracker, + cpp11::external_pointer dataset, + cpp11::external_pointer residual, + cpp11::external_pointer rfx_model) { + // Reset the RFX tracker + tracker->RootReset(*rfx_model, *dataset, *residual); +} diff --git a/src/container.cpp b/src/container.cpp index dd829c3b..db10e53b 100644 --- a/src/container.cpp +++ b/src/container.cpp @@ -32,6 +32,16 @@ void ForestContainer::CopyFromPreviousSample(int new_sample_id, int previous_sam forests_[new_sample_id].reset(new TreeEnsemble(*forests_[previous_sample_id])); } +void ForestContainer::DeleteSample(int sample_num) { + forests_.erase(forests_.begin() + sample_num); + num_samples_--; +} + +void ForestContainer::AddSample(TreeEnsemble& forest) { + forests_.push_back(std::make_unique(forest)); + num_samples_++; +} + void ForestContainer::InitializeRoot(double leaf_value) { CHECK(initialized_); CHECK_EQ(num_samples_, 0); diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 220ddf3b..9952659f 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -271,10 +271,10 @@ extern "C" SEXP _stochtree_rfx_label_mapper_cpp(SEXP rfx_tracker) { END_CPP11 } // R_random_effects.cpp -void rfx_model_sample_random_effects_cpp(cpp11::external_pointer rfx_model, cpp11::external_pointer rfx_dataset, cpp11::external_pointer residual, cpp11::external_pointer rfx_tracker, cpp11::external_pointer rfx_container, double global_variance, cpp11::external_pointer rng); -extern "C" SEXP _stochtree_rfx_model_sample_random_effects_cpp(SEXP rfx_model, SEXP rfx_dataset, SEXP residual, SEXP rfx_tracker, SEXP rfx_container, SEXP global_variance, SEXP rng) { +void rfx_model_sample_random_effects_cpp(cpp11::external_pointer rfx_model, cpp11::external_pointer rfx_dataset, cpp11::external_pointer residual, cpp11::external_pointer rfx_tracker, cpp11::external_pointer rfx_container, bool keep_sample, double global_variance, cpp11::external_pointer rng); +extern "C" SEXP _stochtree_rfx_model_sample_random_effects_cpp(SEXP rfx_model, SEXP rfx_dataset, SEXP residual, SEXP rfx_tracker, SEXP rfx_container, SEXP keep_sample, SEXP global_variance, SEXP rng) { BEGIN_CPP11 - rfx_model_sample_random_effects_cpp(cpp11::as_cpp>>(rfx_model), cpp11::as_cpp>>(rfx_dataset), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(rfx_tracker), cpp11::as_cpp>>(rfx_container), cpp11::as_cpp>(global_variance), cpp11::as_cpp>>(rng)); + rfx_model_sample_random_effects_cpp(cpp11::as_cpp>>(rfx_model), cpp11::as_cpp>>(rfx_dataset), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(rfx_tracker), cpp11::as_cpp>>(rfx_container), cpp11::as_cpp>(keep_sample), cpp11::as_cpp>(global_variance), cpp11::as_cpp>>(rng)); return R_NilValue; END_CPP11 } @@ -314,6 +314,14 @@ extern "C" SEXP _stochtree_rfx_container_num_groups_cpp(SEXP rfx_container) { END_CPP11 } // R_random_effects.cpp +void rfx_container_delete_sample_cpp(cpp11::external_pointer rfx_container, int sample_num); +extern "C" SEXP _stochtree_rfx_container_delete_sample_cpp(SEXP rfx_container, SEXP sample_num) { + BEGIN_CPP11 + rfx_container_delete_sample_cpp(cpp11::as_cpp>>(rfx_container), cpp11::as_cpp>(sample_num)); + return R_NilValue; + END_CPP11 +} +// R_random_effects.cpp void rfx_model_set_working_parameter_cpp(cpp11::external_pointer rfx_model, cpp11::doubles working_param_init); extern "C" SEXP _stochtree_rfx_model_set_working_parameter_cpp(SEXP rfx_model, SEXP working_param_init) { BEGIN_CPP11 @@ -403,6 +411,37 @@ extern "C" SEXP _stochtree_rfx_label_mapper_to_list_cpp(SEXP label_mapper_ptr) { return cpp11::as_sexp(rfx_label_mapper_to_list_cpp(cpp11::as_cpp>>(label_mapper_ptr))); END_CPP11 } +// R_random_effects.cpp +void reset_rfx_model_cpp(cpp11::external_pointer rfx_model, cpp11::external_pointer rfx_container, int sample_num); +extern "C" SEXP _stochtree_reset_rfx_model_cpp(SEXP rfx_model, SEXP rfx_container, SEXP sample_num) { + BEGIN_CPP11 + reset_rfx_model_cpp(cpp11::as_cpp>>(rfx_model), cpp11::as_cpp>>(rfx_container), cpp11::as_cpp>(sample_num)); + return R_NilValue; + END_CPP11 +} +// R_random_effects.cpp +void reset_rfx_tracker_cpp(cpp11::external_pointer tracker, cpp11::external_pointer dataset, cpp11::external_pointer residual, cpp11::external_pointer rfx_model); +extern "C" SEXP _stochtree_reset_rfx_tracker_cpp(SEXP tracker, SEXP dataset, SEXP residual, SEXP rfx_model) { + BEGIN_CPP11 + reset_rfx_tracker_cpp(cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(dataset), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(rfx_model)); + return R_NilValue; + END_CPP11 +} +// R_random_effects.cpp +void root_reset_rfx_tracker_cpp(cpp11::external_pointer tracker, cpp11::external_pointer dataset, cpp11::external_pointer residual, cpp11::external_pointer rfx_model); +extern "C" SEXP _stochtree_root_reset_rfx_tracker_cpp(SEXP tracker, SEXP dataset, SEXP residual, SEXP rfx_model) { + BEGIN_CPP11 + root_reset_rfx_tracker_cpp(cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(dataset), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(rfx_model)); + return R_NilValue; + END_CPP11 +} +// forest.cpp +cpp11::external_pointer active_forest_cpp(int num_trees, int output_dimension, bool is_leaf_constant, bool is_exponentiated); +extern "C" SEXP _stochtree_active_forest_cpp(SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP is_exponentiated) { + BEGIN_CPP11 + return cpp11::as_sexp(active_forest_cpp(cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(is_exponentiated))); + END_CPP11 +} // forest.cpp cpp11::external_pointer forest_container_cpp(int num_trees, int output_dimension, bool is_leaf_constant, bool is_exponentiated); extern "C" SEXP _stochtree_forest_container_cpp(SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP is_exponentiated) { @@ -520,6 +559,13 @@ extern "C" SEXP _stochtree_is_leaf_constant_forest_container_cpp(SEXP forest_sam END_CPP11 } // forest.cpp +int is_exponentiated_forest_container_cpp(cpp11::external_pointer forest_samples); +extern "C" SEXP _stochtree_is_exponentiated_forest_container_cpp(SEXP forest_samples) { + BEGIN_CPP11 + return cpp11::as_sexp(is_exponentiated_forest_container_cpp(cpp11::as_cpp>>(forest_samples))); + END_CPP11 +} +// forest.cpp bool all_roots_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num); extern "C" SEXP _stochtree_all_roots_forest_container_cpp(SEXP forest_samples, SEXP forest_num) { BEGIN_CPP11 @@ -761,6 +807,14 @@ extern "C" SEXP _stochtree_propagate_basis_update_forest_container_cpp(SEXP data END_CPP11 } // forest.cpp +void remove_sample_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num); +extern "C" SEXP _stochtree_remove_sample_forest_container_cpp(SEXP forest_samples, SEXP forest_num) { + BEGIN_CPP11 + remove_sample_forest_container_cpp(cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>(forest_num)); + return R_NilValue; + END_CPP11 +} +// forest.cpp cpp11::writable::doubles_matrix<> predict_forest_cpp(cpp11::external_pointer forest_samples, cpp11::external_pointer dataset); extern "C" SEXP _stochtree_predict_forest_cpp(SEXP forest_samples, SEXP dataset) { BEGIN_CPP11 @@ -788,6 +842,177 @@ extern "C" SEXP _stochtree_predict_forest_raw_single_tree_cpp(SEXP forest_sample return cpp11::as_sexp(predict_forest_raw_single_tree_cpp(cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(dataset), cpp11::as_cpp>(forest_num), cpp11::as_cpp>(tree_num))); END_CPP11 } +// forest.cpp +cpp11::writable::doubles predict_active_forest_cpp(cpp11::external_pointer active_forest, cpp11::external_pointer dataset); +extern "C" SEXP _stochtree_predict_active_forest_cpp(SEXP active_forest, SEXP dataset) { + BEGIN_CPP11 + return cpp11::as_sexp(predict_active_forest_cpp(cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(dataset))); + END_CPP11 +} +// forest.cpp +cpp11::writable::doubles predict_raw_active_forest_cpp(cpp11::external_pointer active_forest, cpp11::external_pointer dataset); +extern "C" SEXP _stochtree_predict_raw_active_forest_cpp(SEXP active_forest, SEXP dataset) { + BEGIN_CPP11 + return cpp11::as_sexp(predict_raw_active_forest_cpp(cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(dataset))); + END_CPP11 +} +// forest.cpp +int output_dimension_active_forest_cpp(cpp11::external_pointer active_forest); +extern "C" SEXP _stochtree_output_dimension_active_forest_cpp(SEXP active_forest) { + BEGIN_CPP11 + return cpp11::as_sexp(output_dimension_active_forest_cpp(cpp11::as_cpp>>(active_forest))); + END_CPP11 +} +// forest.cpp +double average_max_depth_active_forest_cpp(cpp11::external_pointer active_forest); +extern "C" SEXP _stochtree_average_max_depth_active_forest_cpp(SEXP active_forest) { + BEGIN_CPP11 + return cpp11::as_sexp(average_max_depth_active_forest_cpp(cpp11::as_cpp>>(active_forest))); + END_CPP11 +} +// forest.cpp +int num_trees_active_forest_cpp(cpp11::external_pointer active_forest); +extern "C" SEXP _stochtree_num_trees_active_forest_cpp(SEXP active_forest) { + BEGIN_CPP11 + return cpp11::as_sexp(num_trees_active_forest_cpp(cpp11::as_cpp>>(active_forest))); + END_CPP11 +} +// forest.cpp +int ensemble_tree_max_depth_active_forest_cpp(cpp11::external_pointer active_forest, int tree_num); +extern "C" SEXP _stochtree_ensemble_tree_max_depth_active_forest_cpp(SEXP active_forest, SEXP tree_num) { + BEGIN_CPP11 + return cpp11::as_sexp(ensemble_tree_max_depth_active_forest_cpp(cpp11::as_cpp>>(active_forest), cpp11::as_cpp>(tree_num))); + END_CPP11 +} +// forest.cpp +int is_leaf_constant_active_forest_cpp(cpp11::external_pointer active_forest); +extern "C" SEXP _stochtree_is_leaf_constant_active_forest_cpp(SEXP active_forest) { + BEGIN_CPP11 + return cpp11::as_sexp(is_leaf_constant_active_forest_cpp(cpp11::as_cpp>>(active_forest))); + END_CPP11 +} +// forest.cpp +int is_exponentiated_active_forest_cpp(cpp11::external_pointer active_forest); +extern "C" SEXP _stochtree_is_exponentiated_active_forest_cpp(SEXP active_forest) { + BEGIN_CPP11 + return cpp11::as_sexp(is_exponentiated_active_forest_cpp(cpp11::as_cpp>>(active_forest))); + END_CPP11 +} +// forest.cpp +bool all_roots_active_forest_cpp(cpp11::external_pointer active_forest); +extern "C" SEXP _stochtree_all_roots_active_forest_cpp(SEXP active_forest) { + BEGIN_CPP11 + return cpp11::as_sexp(all_roots_active_forest_cpp(cpp11::as_cpp>>(active_forest))); + END_CPP11 +} +// forest.cpp +void set_leaf_value_active_forest_cpp(cpp11::external_pointer active_forest, double leaf_value); +extern "C" SEXP _stochtree_set_leaf_value_active_forest_cpp(SEXP active_forest, SEXP leaf_value) { + BEGIN_CPP11 + set_leaf_value_active_forest_cpp(cpp11::as_cpp>>(active_forest), cpp11::as_cpp>(leaf_value)); + return R_NilValue; + END_CPP11 +} +// forest.cpp +void set_leaf_vector_active_forest_cpp(cpp11::external_pointer active_forest, cpp11::doubles leaf_vector); +extern "C" SEXP _stochtree_set_leaf_vector_active_forest_cpp(SEXP active_forest, SEXP leaf_vector) { + BEGIN_CPP11 + set_leaf_vector_active_forest_cpp(cpp11::as_cpp>>(active_forest), cpp11::as_cpp>(leaf_vector)); + return R_NilValue; + END_CPP11 +} +// forest.cpp +void add_numeric_split_tree_value_active_forest_cpp(cpp11::external_pointer active_forest, int tree_num, int leaf_num, int feature_num, double split_threshold, double left_leaf_value, double right_leaf_value); +extern "C" SEXP _stochtree_add_numeric_split_tree_value_active_forest_cpp(SEXP active_forest, SEXP tree_num, SEXP leaf_num, SEXP feature_num, SEXP split_threshold, SEXP left_leaf_value, SEXP right_leaf_value) { + BEGIN_CPP11 + add_numeric_split_tree_value_active_forest_cpp(cpp11::as_cpp>>(active_forest), cpp11::as_cpp>(tree_num), cpp11::as_cpp>(leaf_num), cpp11::as_cpp>(feature_num), cpp11::as_cpp>(split_threshold), cpp11::as_cpp>(left_leaf_value), cpp11::as_cpp>(right_leaf_value)); + return R_NilValue; + END_CPP11 +} +// forest.cpp +void add_numeric_split_tree_vector_active_forest_cpp(cpp11::external_pointer active_forest, int tree_num, int leaf_num, int feature_num, double split_threshold, cpp11::doubles left_leaf_vector, cpp11::doubles right_leaf_vector); +extern "C" SEXP _stochtree_add_numeric_split_tree_vector_active_forest_cpp(SEXP active_forest, SEXP tree_num, SEXP leaf_num, SEXP feature_num, SEXP split_threshold, SEXP left_leaf_vector, SEXP right_leaf_vector) { + BEGIN_CPP11 + add_numeric_split_tree_vector_active_forest_cpp(cpp11::as_cpp>>(active_forest), cpp11::as_cpp>(tree_num), cpp11::as_cpp>(leaf_num), cpp11::as_cpp>(feature_num), cpp11::as_cpp>(split_threshold), cpp11::as_cpp>(left_leaf_vector), cpp11::as_cpp>(right_leaf_vector)); + return R_NilValue; + END_CPP11 +} +// forest.cpp +cpp11::writable::integers get_tree_leaves_active_forest_cpp(cpp11::external_pointer active_forest, int tree_num); +extern "C" SEXP _stochtree_get_tree_leaves_active_forest_cpp(SEXP active_forest, SEXP tree_num) { + BEGIN_CPP11 + return cpp11::as_sexp(get_tree_leaves_active_forest_cpp(cpp11::as_cpp>>(active_forest), cpp11::as_cpp>(tree_num))); + END_CPP11 +} +// forest.cpp +cpp11::writable::integers get_tree_split_counts_active_forest_cpp(cpp11::external_pointer active_forest, int tree_num, int num_features); +extern "C" SEXP _stochtree_get_tree_split_counts_active_forest_cpp(SEXP active_forest, SEXP tree_num, SEXP num_features) { + BEGIN_CPP11 + return cpp11::as_sexp(get_tree_split_counts_active_forest_cpp(cpp11::as_cpp>>(active_forest), cpp11::as_cpp>(tree_num), cpp11::as_cpp>(num_features))); + END_CPP11 +} +// forest.cpp +cpp11::writable::integers get_overall_split_counts_active_forest_cpp(cpp11::external_pointer active_forest, int num_features); +extern "C" SEXP _stochtree_get_overall_split_counts_active_forest_cpp(SEXP active_forest, SEXP num_features) { + BEGIN_CPP11 + return cpp11::as_sexp(get_overall_split_counts_active_forest_cpp(cpp11::as_cpp>>(active_forest), cpp11::as_cpp>(num_features))); + END_CPP11 +} +// forest.cpp +cpp11::writable::integers get_granular_split_count_array_active_forest_cpp(cpp11::external_pointer active_forest, int num_features); +extern "C" SEXP _stochtree_get_granular_split_count_array_active_forest_cpp(SEXP active_forest, SEXP num_features) { + BEGIN_CPP11 + return cpp11::as_sexp(get_granular_split_count_array_active_forest_cpp(cpp11::as_cpp>>(active_forest), cpp11::as_cpp>(num_features))); + END_CPP11 +} +// forest.cpp +void initialize_forest_model_active_forest_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::doubles init_values, int leaf_model_int); +extern "C" SEXP _stochtree_initialize_forest_model_active_forest_cpp(SEXP data, SEXP residual, SEXP active_forest, SEXP tracker, SEXP init_values, SEXP leaf_model_int) { + BEGIN_CPP11 + initialize_forest_model_active_forest_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>(init_values), cpp11::as_cpp>(leaf_model_int)); + return R_NilValue; + END_CPP11 +} +// forest.cpp +void adjust_residual_active_forest_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, bool requires_basis, bool add); +extern "C" SEXP _stochtree_adjust_residual_active_forest_cpp(SEXP data, SEXP residual, SEXP active_forest, SEXP tracker, SEXP requires_basis, SEXP add) { + BEGIN_CPP11 + adjust_residual_active_forest_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>(requires_basis), cpp11::as_cpp>(add)); + return R_NilValue; + END_CPP11 +} +// forest.cpp +void propagate_basis_update_active_forest_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer active_forest, cpp11::external_pointer tracker); +extern "C" SEXP _stochtree_propagate_basis_update_active_forest_cpp(SEXP data, SEXP residual, SEXP active_forest, SEXP tracker) { + BEGIN_CPP11 + propagate_basis_update_active_forest_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker)); + return R_NilValue; + END_CPP11 +} +// forest.cpp +void reset_active_forest_cpp(cpp11::external_pointer active_forest, cpp11::external_pointer forest_samples, int forest_num); +extern "C" SEXP _stochtree_reset_active_forest_cpp(SEXP active_forest, SEXP forest_samples, SEXP forest_num) { + BEGIN_CPP11 + reset_active_forest_cpp(cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>(forest_num)); + return R_NilValue; + END_CPP11 +} +// forest.cpp +void reset_forest_model_cpp(cpp11::external_pointer forest_tracker, cpp11::external_pointer forest, cpp11::external_pointer data, cpp11::external_pointer residual, bool is_mean_model); +extern "C" SEXP _stochtree_reset_forest_model_cpp(SEXP forest_tracker, SEXP forest, SEXP data, SEXP residual, SEXP is_mean_model) { + BEGIN_CPP11 + reset_forest_model_cpp(cpp11::as_cpp>>(forest_tracker), cpp11::as_cpp>>(forest), cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>(is_mean_model)); + return R_NilValue; + END_CPP11 +} +// forest.cpp +void root_reset_active_forest_cpp(cpp11::external_pointer active_forest); +extern "C" SEXP _stochtree_root_reset_active_forest_cpp(SEXP active_forest) { + BEGIN_CPP11 + root_reset_active_forest_cpp(cpp11::as_cpp>>(active_forest)); + return R_NilValue; + END_CPP11 +} // kernel.cpp int forest_container_get_max_leaf_index_cpp(cpp11::external_pointer forest_container, int forest_num); extern "C" SEXP _stochtree_forest_container_get_max_leaf_index_cpp(SEXP forest_container, SEXP forest_num) { @@ -803,18 +1028,18 @@ extern "C" SEXP _stochtree_compute_leaf_indices_cpp(SEXP forest_container, SEXP END_CPP11 } // sampler.cpp -void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool pre_initialized); -extern "C" SEXP _stochtree_sample_gfr_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP tracker, SEXP split_prior, SEXP rng, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP pre_initialized) { +void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest, bool pre_initialized); +extern "C" SEXP _stochtree_sample_gfr_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest, SEXP pre_initialized) { BEGIN_CPP11 - sample_gfr_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(pre_initialized)); + sample_gfr_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest), cpp11::as_cpp>(pre_initialized)); return R_NilValue; END_CPP11 } // sampler.cpp -void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool pre_initialized); -extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP tracker, SEXP split_prior, SEXP rng, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP pre_initialized) { +void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest, bool pre_initialized); +extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest, SEXP pre_initialized) { BEGIN_CPP11 - sample_mcmc_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(pre_initialized)); + sample_mcmc_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest), cpp11::as_cpp>(pre_initialized)); return R_NilValue; END_CPP11 } @@ -826,10 +1051,10 @@ extern "C" SEXP _stochtree_sample_sigma2_one_iteration_cpp(SEXP residual, SEXP d END_CPP11 } // sampler.cpp -double sample_tau_one_iteration_cpp(cpp11::external_pointer forest_samples, cpp11::external_pointer rng, double a, double b, int sample_num); -extern "C" SEXP _stochtree_sample_tau_one_iteration_cpp(SEXP forest_samples, SEXP rng, SEXP a, SEXP b, SEXP sample_num) { +double sample_tau_one_iteration_cpp(cpp11::external_pointer active_forest, cpp11::external_pointer rng, double a, double b); +extern "C" SEXP _stochtree_sample_tau_one_iteration_cpp(SEXP active_forest, SEXP rng, SEXP a, SEXP b) { BEGIN_CPP11 - return cpp11::as_sexp(sample_tau_one_iteration_cpp(cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(rng), cpp11::as_cpp>(a), cpp11::as_cpp>(b), cpp11::as_cpp>(sample_num))); + return cpp11::as_sexp(sample_tau_one_iteration_cpp(cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(rng), cpp11::as_cpp>(a), cpp11::as_cpp>(b))); END_CPP11 } // sampler.cpp @@ -1094,14 +1319,20 @@ extern "C" SEXP _stochtree_json_load_string_cpp(SEXP json_ptr, SEXP json_string) extern "C" { static const R_CallMethodDef CallEntries[] = { + {"_stochtree_active_forest_cpp", (DL_FUNC) &_stochtree_active_forest_cpp, 4}, + {"_stochtree_add_numeric_split_tree_value_active_forest_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_value_active_forest_cpp, 7}, {"_stochtree_add_numeric_split_tree_value_forest_container_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_value_forest_container_cpp, 8}, + {"_stochtree_add_numeric_split_tree_vector_active_forest_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_vector_active_forest_cpp, 7}, {"_stochtree_add_numeric_split_tree_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_vector_forest_container_cpp, 8}, {"_stochtree_add_sample_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_forest_container_cpp, 1}, {"_stochtree_add_sample_value_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_value_forest_container_cpp, 2}, {"_stochtree_add_sample_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_vector_forest_container_cpp, 2}, {"_stochtree_add_to_column_vector_cpp", (DL_FUNC) &_stochtree_add_to_column_vector_cpp, 2}, + {"_stochtree_adjust_residual_active_forest_cpp", (DL_FUNC) &_stochtree_adjust_residual_active_forest_cpp, 6}, {"_stochtree_adjust_residual_forest_container_cpp", (DL_FUNC) &_stochtree_adjust_residual_forest_container_cpp, 7}, + {"_stochtree_all_roots_active_forest_cpp", (DL_FUNC) &_stochtree_all_roots_active_forest_cpp, 1}, {"_stochtree_all_roots_forest_container_cpp", (DL_FUNC) &_stochtree_all_roots_forest_container_cpp, 2}, + {"_stochtree_average_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_average_max_depth_active_forest_cpp, 1}, {"_stochtree_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_average_max_depth_forest_container_cpp, 1}, {"_stochtree_compute_leaf_indices_cpp", (DL_FUNC) &_stochtree_compute_leaf_indices_cpp, 3}, {"_stochtree_create_column_vector_cpp", (DL_FUNC) &_stochtree_create_column_vector_cpp, 1}, @@ -1113,6 +1344,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_dataset_num_covariates_cpp", (DL_FUNC) &_stochtree_dataset_num_covariates_cpp, 1}, {"_stochtree_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_dataset_num_rows_cpp, 1}, {"_stochtree_ensemble_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_average_max_depth_forest_container_cpp, 2}, + {"_stochtree_ensemble_tree_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_active_forest_cpp, 2}, {"_stochtree_ensemble_tree_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_forest_container_cpp, 3}, {"_stochtree_forest_container_append_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_cpp, 3}, {"_stochtree_forest_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_string_cpp, 3}, @@ -1126,15 +1358,23 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2}, {"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4}, {"_stochtree_get_forest_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_forest_split_counts_forest_container_cpp, 3}, + {"_stochtree_get_granular_split_count_array_active_forest_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_active_forest_cpp, 2}, {"_stochtree_get_granular_split_count_array_forest_container_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_forest_container_cpp, 2}, {"_stochtree_get_json_string_cpp", (DL_FUNC) &_stochtree_get_json_string_cpp, 1}, + {"_stochtree_get_overall_split_counts_active_forest_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_active_forest_cpp, 2}, {"_stochtree_get_overall_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_forest_container_cpp, 2}, {"_stochtree_get_residual_cpp", (DL_FUNC) &_stochtree_get_residual_cpp, 1}, + {"_stochtree_get_tree_leaves_active_forest_cpp", (DL_FUNC) &_stochtree_get_tree_leaves_active_forest_cpp, 2}, {"_stochtree_get_tree_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_get_tree_leaves_forest_container_cpp, 3}, + {"_stochtree_get_tree_split_counts_active_forest_cpp", (DL_FUNC) &_stochtree_get_tree_split_counts_active_forest_cpp, 3}, {"_stochtree_get_tree_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_tree_split_counts_forest_container_cpp, 4}, {"_stochtree_init_json_cpp", (DL_FUNC) &_stochtree_init_json_cpp, 0}, + {"_stochtree_initialize_forest_model_active_forest_cpp", (DL_FUNC) &_stochtree_initialize_forest_model_active_forest_cpp, 6}, {"_stochtree_initialize_forest_model_cpp", (DL_FUNC) &_stochtree_initialize_forest_model_cpp, 6}, {"_stochtree_is_categorical_split_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_categorical_split_node_forest_container_cpp, 4}, + {"_stochtree_is_exponentiated_active_forest_cpp", (DL_FUNC) &_stochtree_is_exponentiated_active_forest_cpp, 1}, + {"_stochtree_is_exponentiated_forest_container_cpp", (DL_FUNC) &_stochtree_is_exponentiated_forest_container_cpp, 1}, + {"_stochtree_is_leaf_constant_active_forest_cpp", (DL_FUNC) &_stochtree_is_leaf_constant_active_forest_cpp, 1}, {"_stochtree_is_leaf_constant_forest_container_cpp", (DL_FUNC) &_stochtree_is_leaf_constant_forest_container_cpp, 1}, {"_stochtree_is_leaf_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_leaf_node_forest_container_cpp, 4}, {"_stochtree_is_numeric_split_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_numeric_split_node_forest_container_cpp, 4}, @@ -1181,19 +1421,30 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_num_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_num_nodes_forest_container_cpp, 3}, {"_stochtree_num_samples_forest_container_cpp", (DL_FUNC) &_stochtree_num_samples_forest_container_cpp, 1}, {"_stochtree_num_split_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_num_split_nodes_forest_container_cpp, 3}, + {"_stochtree_num_trees_active_forest_cpp", (DL_FUNC) &_stochtree_num_trees_active_forest_cpp, 1}, {"_stochtree_num_trees_forest_container_cpp", (DL_FUNC) &_stochtree_num_trees_forest_container_cpp, 1}, + {"_stochtree_output_dimension_active_forest_cpp", (DL_FUNC) &_stochtree_output_dimension_active_forest_cpp, 1}, {"_stochtree_output_dimension_forest_container_cpp", (DL_FUNC) &_stochtree_output_dimension_forest_container_cpp, 1}, {"_stochtree_overwrite_column_vector_cpp", (DL_FUNC) &_stochtree_overwrite_column_vector_cpp, 2}, {"_stochtree_parent_node_forest_container_cpp", (DL_FUNC) &_stochtree_parent_node_forest_container_cpp, 4}, + {"_stochtree_predict_active_forest_cpp", (DL_FUNC) &_stochtree_predict_active_forest_cpp, 2}, {"_stochtree_predict_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_cpp, 2}, {"_stochtree_predict_forest_raw_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_cpp, 2}, {"_stochtree_predict_forest_raw_single_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_forest_cpp, 3}, {"_stochtree_predict_forest_raw_single_tree_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_tree_cpp, 4}, + {"_stochtree_predict_raw_active_forest_cpp", (DL_FUNC) &_stochtree_predict_raw_active_forest_cpp, 2}, + {"_stochtree_propagate_basis_update_active_forest_cpp", (DL_FUNC) &_stochtree_propagate_basis_update_active_forest_cpp, 4}, {"_stochtree_propagate_basis_update_forest_container_cpp", (DL_FUNC) &_stochtree_propagate_basis_update_forest_container_cpp, 5}, {"_stochtree_propagate_trees_column_vector_cpp", (DL_FUNC) &_stochtree_propagate_trees_column_vector_cpp, 2}, + {"_stochtree_remove_sample_forest_container_cpp", (DL_FUNC) &_stochtree_remove_sample_forest_container_cpp, 2}, + {"_stochtree_reset_active_forest_cpp", (DL_FUNC) &_stochtree_reset_active_forest_cpp, 3}, + {"_stochtree_reset_forest_model_cpp", (DL_FUNC) &_stochtree_reset_forest_model_cpp, 5}, + {"_stochtree_reset_rfx_model_cpp", (DL_FUNC) &_stochtree_reset_rfx_model_cpp, 3}, + {"_stochtree_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_reset_rfx_tracker_cpp, 4}, {"_stochtree_rfx_container_append_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_cpp, 3}, {"_stochtree_rfx_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_string_cpp, 3}, {"_stochtree_rfx_container_cpp", (DL_FUNC) &_stochtree_rfx_container_cpp, 2}, + {"_stochtree_rfx_container_delete_sample_cpp", (DL_FUNC) &_stochtree_rfx_container_delete_sample_cpp, 2}, {"_stochtree_rfx_container_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_from_json_cpp, 2}, {"_stochtree_rfx_container_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_from_json_string_cpp, 2}, {"_stochtree_rfx_container_get_alpha_cpp", (DL_FUNC) &_stochtree_rfx_container_get_alpha_cpp, 1}, @@ -1219,7 +1470,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rfx_label_mapper_to_list_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_to_list_cpp, 1}, {"_stochtree_rfx_model_cpp", (DL_FUNC) &_stochtree_rfx_model_cpp, 2}, {"_stochtree_rfx_model_predict_cpp", (DL_FUNC) &_stochtree_rfx_model_predict_cpp, 3}, - {"_stochtree_rfx_model_sample_random_effects_cpp", (DL_FUNC) &_stochtree_rfx_model_sample_random_effects_cpp, 7}, + {"_stochtree_rfx_model_sample_random_effects_cpp", (DL_FUNC) &_stochtree_rfx_model_sample_random_effects_cpp, 8}, {"_stochtree_rfx_model_set_group_parameter_covariance_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameter_covariance_cpp, 2}, {"_stochtree_rfx_model_set_group_parameters_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameters_cpp, 2}, {"_stochtree_rfx_model_set_variance_prior_scale_cpp", (DL_FUNC) &_stochtree_rfx_model_set_variance_prior_scale_cpp, 2}, @@ -1230,11 +1481,15 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rfx_tracker_get_unique_group_ids_cpp", (DL_FUNC) &_stochtree_rfx_tracker_get_unique_group_ids_cpp, 1}, {"_stochtree_right_child_node_forest_container_cpp", (DL_FUNC) &_stochtree_right_child_node_forest_container_cpp, 4}, {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, - {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 15}, - {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 15}, + {"_stochtree_root_reset_active_forest_cpp", (DL_FUNC) &_stochtree_root_reset_active_forest_cpp, 1}, + {"_stochtree_root_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_root_reset_rfx_tracker_cpp, 4}, + {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 17}, + {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 17}, {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 5}, - {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 5}, + {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 4}, + {"_stochtree_set_leaf_value_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_value_active_forest_cpp, 2}, {"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2}, + {"_stochtree_set_leaf_vector_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_active_forest_cpp, 2}, {"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2}, {"_stochtree_split_categories_forest_container_cpp", (DL_FUNC) &_stochtree_split_categories_forest_container_cpp, 4}, {"_stochtree_split_index_forest_container_cpp", (DL_FUNC) &_stochtree_split_index_forest_container_cpp, 4}, diff --git a/src/forest.cpp b/src/forest.cpp index ec839785..81881716 100644 --- a/src/forest.cpp +++ b/src/forest.cpp @@ -11,6 +11,15 @@ #include #include +[[cpp11::register]] +cpp11::external_pointer active_forest_cpp(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false) { + // Create smart pointer to newly allocated object + std::unique_ptr forest_ptr_ = std::make_unique(num_trees, output_dimension, is_leaf_constant, is_exponentiated); + + // Release management of the pointer to R session + return cpp11::external_pointer(forest_ptr_.release()); +} + [[cpp11::register]] cpp11::external_pointer forest_container_cpp(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false) { // Create smart pointer to newly allocated object @@ -133,6 +142,11 @@ int is_leaf_constant_forest_container_cpp(cpp11::external_pointerIsLeafConstant(); } +[[cpp11::register]] +int is_exponentiated_forest_container_cpp(cpp11::external_pointer forest_samples) { + return forest_samples->IsExponentiated(); +} + [[cpp11::register]] bool all_roots_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num) { return forest_samples->AllRoots(forest_num); @@ -481,7 +495,6 @@ void initialize_forest_model_cpp(cpp11::external_pointerUpdatePredictions(forest_samples->GetEnsemble(0), *data); int n = data->NumObservations(); std::vector initial_preds(n, init_val); - // for (int i = 0; i < n; i++) initial_preds[i] = 1/initial_preds[i]; data->AddVarianceWeights(initial_preds.data(), n); } } @@ -511,6 +524,12 @@ void propagate_basis_update_forest_container_cpp(cpp11::external_pointerGetEnsemble(forest_num)); } +[[cpp11::register]] +void remove_sample_forest_container_cpp(cpp11::external_pointer forest_samples, + int forest_num) { + forest_samples->DeleteSample(forest_num); +} + [[cpp11::register]] cpp11::writable::doubles_matrix<> predict_forest_cpp(cpp11::external_pointer forest_samples, cpp11::external_pointer dataset) { // Predict from the sampled forests @@ -534,7 +553,7 @@ cpp11::writable::doubles predict_forest_raw_cpp(cpp11::external_pointer output_raw = forest_samples->PredictRaw(*dataset); - // Convert result to a matrix + // Unpack / re-arrange results int n = dataset->GetCovariates().rows(); int num_samples = forest_samples->NumSamples(); int output_dimension = forest_samples->OutputDimension(); @@ -542,7 +561,7 @@ cpp11::writable::doubles predict_forest_raw_cpp(cpp11::external_pointer predict_forest_raw_single_tree_cpp(cpp11::exte output(i, j) = output_raw[i*output_dimension + j]; } } + return output; +} + +[[cpp11::register]] +cpp11::writable::doubles predict_active_forest_cpp(cpp11::external_pointer active_forest, cpp11::external_pointer dataset) { + int n = dataset->GetCovariates().rows(); + std::vector output(n); + active_forest->PredictInplace(*dataset, output, 0); + return output; +} + +[[cpp11::register]] +cpp11::writable::doubles predict_raw_active_forest_cpp(cpp11::external_pointer active_forest, cpp11::external_pointer dataset) { + int n = dataset->GetCovariates().rows(); + int output_dimension = active_forest->OutputDimension(); + std::vector output_raw(n*output_dimension); + active_forest->PredictRawInplace(*dataset, output_raw, 0); + + cpp11::writable::doubles output(n*output_dimension); + for (size_t i = 0; i < n; i++) { + for (int j = 0; j < output_dimension; j++) { + // Convert from row-major to column-major + output.at(j*n + i) = output_raw[i*output_dimension + j]; + } + } return output; } + +[[cpp11::register]] +int output_dimension_active_forest_cpp(cpp11::external_pointer active_forest) { + return active_forest->OutputDimension(); +} + +[[cpp11::register]] +double average_max_depth_active_forest_cpp(cpp11::external_pointer active_forest) { + return active_forest->AverageMaxDepth(); +} + +[[cpp11::register]] +int num_trees_active_forest_cpp(cpp11::external_pointer active_forest) { + return active_forest->NumTrees(); +} + +[[cpp11::register]] +int ensemble_tree_max_depth_active_forest_cpp(cpp11::external_pointer active_forest, int tree_num) { + return active_forest->TreeMaxDepth(tree_num); +} + +[[cpp11::register]] +int is_leaf_constant_active_forest_cpp(cpp11::external_pointer active_forest) { + return active_forest->IsLeafConstant(); +} + +[[cpp11::register]] +int is_exponentiated_active_forest_cpp(cpp11::external_pointer active_forest) { + return active_forest->IsExponentiated(); +} + +[[cpp11::register]] +bool all_roots_active_forest_cpp(cpp11::external_pointer active_forest) { + return active_forest->AllRoots(); +} + +[[cpp11::register]] +void set_leaf_value_active_forest_cpp(cpp11::external_pointer active_forest, double leaf_value) { + active_forest->SetLeafValue(leaf_value); +} + +[[cpp11::register]] +void set_leaf_vector_active_forest_cpp(cpp11::external_pointer active_forest, cpp11::doubles leaf_vector) { + std::vector leaf_vector_cast(leaf_vector.begin(), leaf_vector.end()); + active_forest->SetLeafVector(leaf_vector_cast); +} + +[[cpp11::register]] +void add_numeric_split_tree_value_active_forest_cpp(cpp11::external_pointer active_forest, int tree_num, int leaf_num, int feature_num, double split_threshold, double left_leaf_value, double right_leaf_value) { + if (active_forest->OutputDimension() != 1) { + cpp11::stop("leaf_vector must match forest leaf dimension"); + } + StochTree::Tree* tree = active_forest->GetTree(tree_num); + if (!tree->IsLeaf(leaf_num)) { + cpp11::stop("leaf_num is not a leaf"); + } + tree->ExpandNode(leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value); +} + +[[cpp11::register]] +void add_numeric_split_tree_vector_active_forest_cpp(cpp11::external_pointer active_forest, int tree_num, int leaf_num, int feature_num, double split_threshold, cpp11::doubles left_leaf_vector, cpp11::doubles right_leaf_vector) { + if (active_forest->OutputDimension() != left_leaf_vector.size()) { + cpp11::stop("left_leaf_vector must match forest leaf dimension"); + } + if (active_forest->OutputDimension() != right_leaf_vector.size()) { + cpp11::stop("right_leaf_vector must match forest leaf dimension"); + } + std::vector left_leaf_vector_cast(left_leaf_vector.begin(), left_leaf_vector.end()); + std::vector right_leaf_vector_cast(right_leaf_vector.begin(), right_leaf_vector.end()); + StochTree::Tree* tree = active_forest->GetTree(tree_num); + if (!tree->IsLeaf(leaf_num)) { + cpp11::stop("leaf_num is not a leaf"); + } + tree->ExpandNode(leaf_num, feature_num, split_threshold, left_leaf_vector_cast, right_leaf_vector_cast); +} + +[[cpp11::register]] +cpp11::writable::integers get_tree_leaves_active_forest_cpp(cpp11::external_pointer active_forest, int tree_num) { + StochTree::Tree* tree = active_forest->GetTree(tree_num); + std::vector leaves_raw = tree->GetLeaves(); + cpp11::writable::integers leaves(leaves_raw.begin(), leaves_raw.end()); + return leaves; +} + +[[cpp11::register]] +cpp11::writable::integers get_tree_split_counts_active_forest_cpp(cpp11::external_pointer active_forest, int tree_num, int num_features) { + cpp11::writable::integers output(num_features); + for (int i = 0; i < output.size(); i++) output.at(i) = 0; + StochTree::Tree* tree = active_forest->GetTree(tree_num); + std::vector split_nodes = tree->GetInternalNodes(); + for (int i = 0; i < split_nodes.size(); i++) { + auto split_feature = split_nodes.at(i); + output.at(split_feature)++; + } + return output; +} + +[[cpp11::register]] +cpp11::writable::integers get_overall_split_counts_active_forest_cpp(cpp11::external_pointer active_forest, int num_features) { + cpp11::writable::integers output(num_features); + for (int i = 0; i < output.size(); i++) output.at(i) = 0; + int num_trees = active_forest->NumTrees(); + for (int i = 0; i < num_trees; i++) { + StochTree::Tree* tree = active_forest->GetTree(i); + std::vector split_nodes = tree->GetInternalNodes(); + for (int j = 0; j < split_nodes.size(); j++) { + auto split_feature = split_nodes.at(j); + output.at(split_feature)++; + } + } + return output; +} + +[[cpp11::register]] +cpp11::writable::integers get_granular_split_count_array_active_forest_cpp(cpp11::external_pointer active_forest, int num_features) { + int num_trees = active_forest->NumTrees(); + cpp11::writable::integers output(num_features*num_trees); + for (int elem = 0; elem < output.size(); elem++) output.at(elem) = 0; + for (int i = 0; i < num_trees; i++) { + StochTree::Tree* tree = active_forest->GetTree(i); + std::vector split_nodes = tree->GetInternalNodes(); + for (int j = 0; j < split_nodes.size(); j++) { + auto split_feature = split_nodes.at(j); + output.at(split_feature*num_trees + i)++; + } + } + return output; +} + +[[cpp11::register]] +void initialize_forest_model_active_forest_cpp(cpp11::external_pointer data, + cpp11::external_pointer residual, + cpp11::external_pointer active_forest, + cpp11::external_pointer tracker, + cpp11::doubles init_values, int leaf_model_int){ + // Convert leaf model type to enum + StochTree::ModelType model_type; + if (leaf_model_int == 0) model_type = StochTree::ModelType::kConstantLeafGaussian; + else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; + else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; + else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance; + else StochTree::Log::Fatal("Invalid model type"); + + // Unpack initial value + int num_trees = active_forest->NumTrees(); + double init_val; + std::vector init_value_vector; + if ((model_type == StochTree::ModelType::kConstantLeafGaussian) || + (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) || + (model_type == StochTree::ModelType::kLogLinearVariance)) { + init_val = init_values.at(0); + } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { + int leaf_dim = init_values.size(); + init_value_vector.resize(leaf_dim); + for (int i = 0; i < leaf_dim; i++) { + init_value_vector[i] = init_values[i] / static_cast(num_trees); + } + } + + // Initialize the models accordingly + double leaf_init_val; + if (model_type == StochTree::ModelType::kConstantLeafGaussian) { + leaf_init_val = init_val / static_cast(num_trees); + active_forest->SetLeafValue(leaf_init_val); + UpdateResidualEntireForest(*tracker, *data, *residual, active_forest.get(), false, std::minus()); + tracker->UpdatePredictions(active_forest.get(), *data); + } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { + leaf_init_val = init_val / static_cast(num_trees); + active_forest->SetLeafValue(leaf_init_val); + UpdateResidualEntireForest(*tracker, *data, *residual, active_forest.get(), true, std::minus()); + tracker->UpdatePredictions(active_forest.get(), *data); + } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { + active_forest->SetLeafVector(init_value_vector); + UpdateResidualEntireForest(*tracker, *data, *residual, active_forest.get(), true, std::minus()); + tracker->UpdatePredictions(active_forest.get(), *data); + } else if (model_type == StochTree::ModelType::kLogLinearVariance) { + leaf_init_val = std::log(init_val) / static_cast(num_trees); + active_forest->SetLeafValue(leaf_init_val); + tracker->UpdatePredictions(active_forest.get(), *data); + int n = data->NumObservations(); + std::vector initial_preds(n, init_val); + data->AddVarianceWeights(initial_preds.data(), n); + } +} + +[[cpp11::register]] +void adjust_residual_active_forest_cpp(cpp11::external_pointer data, + cpp11::external_pointer residual, + cpp11::external_pointer active_forest, + cpp11::external_pointer tracker, + bool requires_basis, bool add) { + // Determine whether or not we are adding forest predictions to the residuals + std::function op; + if (add) op = std::plus(); + else op = std::minus(); + + // Perform the update (addition / subtraction) operation + StochTree::UpdateResidualEntireForest(*tracker, *data, *residual, active_forest.get(), requires_basis, op); +} + +[[cpp11::register]] +void propagate_basis_update_active_forest_cpp(cpp11::external_pointer data, + cpp11::external_pointer residual, + cpp11::external_pointer active_forest, + cpp11::external_pointer tracker) { + // Perform the update (addition / subtraction) operation + StochTree::UpdateResidualNewBasis(*tracker, *data, *residual, active_forest.get()); +} + +[[cpp11::register]] +void reset_active_forest_cpp(cpp11::external_pointer active_forest, + cpp11::external_pointer forest_samples, + int forest_num) { + // Extract raw pointer to the forest held at index forest_num + StochTree::TreeEnsemble* forest = forest_samples->GetEnsemble(forest_num); + + // Reset active forest using the forest held at index forest_num + active_forest->ReconstituteFromForest(*forest); +} + +[[cpp11::register]] +void reset_forest_model_cpp(cpp11::external_pointer forest_tracker, + cpp11::external_pointer forest, + cpp11::external_pointer data, + cpp11::external_pointer residual, + bool is_mean_model) { + // Reset forest tracker using the forest held at index forest_num + forest_tracker->ReconstituteFromForest(*forest, *data, *residual, is_mean_model); +} + +[[cpp11::register]] +void root_reset_active_forest_cpp(cpp11::external_pointer active_forest) { + // Reset active forest to root + active_forest->ResetRoot(); +} diff --git a/src/partition_tracker.cpp b/src/partition_tracker.cpp index 38b4ce68..cda214be 100644 --- a/src/partition_tracker.cpp +++ b/src/partition_tracker.cpp @@ -26,6 +26,23 @@ ForestTracker::ForestTracker(Eigen::MatrixXd& covariates, std::vectorReconstituteFromForest(forest, dataset); + } void ForestTracker::ResetRoot(Eigen::MatrixXd& covariates, std::vector& feature_types, int32_t tree_num) { @@ -84,6 +101,142 @@ void ForestTracker::AssignAllSamplesToConstantPrediction(int32_t tree_num, doubl sample_pred_mapper_->AssignAllSamplesToConstantPrediction(tree_num, value); } +void ForestTracker::UpdateSampleTrackersInternal(TreeEnsemble& forest, Eigen::MatrixXd& covariates, Eigen::MatrixXd& basis) { + int output_dim = basis.cols(); + double forest_pred, tree_pred; + + for (data_size_t i = 0; i < num_observations_; i++) { + forest_pred = 0.0; + for (int j = 0; j < num_trees_; j++) { + tree_pred = 0.0; + Tree* tree = forest.GetTree(j); + std::int32_t nidx = EvaluateTree(*tree, covariates, i); + sample_node_mapper_->SetNodeId(i, j, nidx); + for (int32_t k = 0; k < output_dim; k++) { + tree_pred += tree->LeafValue(nidx, k) * basis(i, k); + } + sample_pred_mapper_->SetPred(i, j, tree_pred); + forest_pred += tree_pred; + } + sum_predictions_[i] = forest_pred; + } +} + +void ForestTracker::UpdateSampleTrackersInternal(TreeEnsemble& forest, Eigen::MatrixXd& covariates) { + double forest_pred, tree_pred; + + for (data_size_t i = 0; i < num_observations_; i++) { + forest_pred = 0.0; + for (int j = 0; j < num_trees_; j++) { + Tree* tree = forest.GetTree(j); + std::int32_t nidx = EvaluateTree(*tree, covariates, i); + sample_node_mapper_->SetNodeId(i, j, nidx); + tree_pred = tree->LeafValue(nidx, 0); + sample_pred_mapper_->SetPred(i, j, tree_pred); + forest_pred += tree_pred; + } + sum_predictions_[i] = forest_pred; + } +} + +void ForestTracker::UpdateSampleTrackers(TreeEnsemble& forest, ForestDataset& dataset) { + if (!forest.IsLeafConstant()) { + CHECK(dataset.HasBasis()); + UpdateSampleTrackersInternal(forest, dataset.GetCovariates(), dataset.GetBasis()); + } else { + UpdateSampleTrackersInternal(forest, dataset.GetCovariates()); + } +} + +void ForestTracker::UpdateSampleTrackersResidualInternalBasis(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model) { + double new_forest_pred, new_tree_pred, prev_tree_pred, new_resid, new_weight; + Eigen::MatrixXd& covariates = dataset.GetCovariates(); + Eigen::MatrixXd& basis = dataset.GetBasis(); + int output_dim = basis.cols(); + if (!is_mean_model) { + CHECK(dataset.HasVarWeights()); + } + + for (data_size_t i = 0; i < num_observations_; i++) { + new_forest_pred = 0.0; + for (int j = 0; j < num_trees_; j++) { + // Query the previously cached prediction for tree j, observation i + prev_tree_pred = sample_pred_mapper_->GetPred(i, j); + + // Compute the new prediction for tree j, observation i + new_tree_pred = 0.0; + Tree* tree = forest.GetTree(j); + std::int32_t nidx = EvaluateTree(*tree, covariates, i); + for (int32_t k = 0; k < output_dim; k++) { + new_tree_pred += tree->LeafValue(nidx, k) * basis(i, k); + } + + if (is_mean_model) { + // Adjust the residual by adding the previous prediction and subtracting the new prediction + new_resid = residual.GetElement(i) - new_tree_pred + prev_tree_pred; + residual.SetElement(i, new_resid); + } else { + // Adjust the variance weights by subtracting the previous prediction and adding the new prediction (in log scale) and then exponentiating + new_weight = std::log(dataset.VarWeightValue(i)) + new_tree_pred - prev_tree_pred; + dataset.SetVarWeightValue(i, new_weight, true); + } + + // Update the sample node mapper and sample prediction mapper + sample_node_mapper_->SetNodeId(i, j, nidx); + sample_pred_mapper_->SetPred(i, j, new_tree_pred); + new_forest_pred += new_tree_pred; + } + // Update the overall cached forest prediction + sum_predictions_[i] = new_forest_pred; + } +} + +void ForestTracker::UpdateSampleTrackersResidualInternalNoBasis(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model) { + double new_forest_pred, new_tree_pred, prev_tree_pred, new_resid, new_weight; + Eigen::MatrixXd& covariates = dataset.GetCovariates(); + if (!is_mean_model) { + CHECK(dataset.HasVarWeights()); + } + + for (data_size_t i = 0; i < num_observations_; i++) { + new_forest_pred = 0.0; + for (int j = 0; j < num_trees_; j++) { + // Query the previously cached prediction for tree j, observation i + prev_tree_pred = sample_pred_mapper_->GetPred(i, j); + + // Compute the new prediction for tree j, observation i + Tree* tree = forest.GetTree(j); + std::int32_t nidx = EvaluateTree(*tree, covariates, i); + new_tree_pred = tree->LeafValue(nidx, 0); + + if (is_mean_model) { + // Adjust the residual by adding the previous prediction and subtracting the new prediction + new_resid = residual.GetElement(i) - new_tree_pred + prev_tree_pred; + residual.SetElement(i, new_resid); + } else { + new_weight = std::log(dataset.VarWeightValue(i)) + new_tree_pred - prev_tree_pred; + dataset.SetVarWeightValue(i, new_weight, true); + } + + // Update the sample node mapper and sample prediction mapper + sample_node_mapper_->SetNodeId(i, j, nidx); + sample_pred_mapper_->SetPred(i, j, new_tree_pred); + new_forest_pred += new_tree_pred; + } + // Update the overall cached forest prediction + sum_predictions_[i] = new_forest_pred; + } +} + +void ForestTracker::UpdateSampleTrackersResidual(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model) { + if (!forest.IsLeafConstant()) { + CHECK(dataset.HasBasis()); + UpdateSampleTrackersResidualInternalBasis(forest, dataset, residual, is_mean_model); + } else { + UpdateSampleTrackersResidualInternalNoBasis(forest, dataset, residual, is_mean_model); + } +} + void ForestTracker::UpdatePredictionsInternal(TreeEnsemble* ensemble, Eigen::MatrixXd& covariates, Eigen::MatrixXd& basis) { int output_dim = basis.cols(); double forest_pred, tree_pred; @@ -172,6 +325,15 @@ void ForestTracker::SyncPredictions() { } } +void UnsortedNodeSampleTracker::ReconstituteFromForest(TreeEnsemble& forest, ForestDataset& dataset) { + int n = dataset.NumObservations(); + for (int i = 0; i < num_trees_; i++) { + Tree* tree = forest.GetTree(i); + feature_partitions_[i].reset(new FeatureUnsortedPartition(n)); + feature_partitions_[i]->ReconstituteFromTree(*tree, dataset); + } +} + FeatureUnsortedPartition::FeatureUnsortedPartition(data_size_t n) { indices_.resize(n); std::iota(indices_.begin(), indices_.end(), 0); @@ -184,6 +346,92 @@ FeatureUnsortedPartition::FeatureUnsortedPartition(data_size_t n) { num_deleted_nodes_ = 0; } +void FeatureUnsortedPartition::ReconstituteFromTree(Tree& tree, ForestDataset& dataset) { + // Make sure this data structure is a root + CHECK_EQ(num_nodes_, 1); + CHECK_EQ(num_deleted_nodes_, 0); + data_size_t n = dataset.NumObservations(); + CHECK_EQ(indices_.size(), n); + + // Extract covariates + Eigen::MatrixXd& covariates = dataset.GetCovariates(); + + // Set node counters + num_nodes_ = tree.NumNodes(); + num_deleted_nodes_ = tree.NumDeletedNodes(); + + // Resize tracking vectors + node_begin_.resize(num_nodes_); + node_length_.resize(num_nodes_); + parent_nodes_.resize(num_nodes_); + left_nodes_.resize(num_nodes_); + right_nodes_.resize(num_nodes_); + + // Unpack tree's splits into this data structure + bool is_deleted; + TreeNodeType node_type; + data_size_t node_start_idx; + data_size_t num_node_elements; + data_size_t num_true, num_false; + TreeSplit split_rule; + int split_index; + for (int i = 0; i < num_nodes_; i++) { + is_deleted = tree.IsDeleted(i); + if (is_deleted) { + deleted_nodes_.push_back(i); + } else { + // Node beginning and length in indices_ + if (i == 0) { + node_start_idx = 0; + num_node_elements = n; + } else { + node_start_idx = node_begin_[i]; + num_node_elements = node_length_[i]; + } + // Tree node info + parent_nodes_[i] = tree.Parent(i); + node_type = tree.NodeType(i); + left_nodes_[i] = tree.LeftChild(i); + right_nodes_[i] = tree.RightChild(i); + // Only update indices_, node_begin_ and node_length_ if a split is to be added + if (node_type == TreeNodeType::kNumericalSplitNode) { + // Extract split rule + split_rule = TreeSplit(tree.Threshold(i)); + split_index = tree.SplitIndex(i); + } else if (node_type == TreeNodeType::kCategoricalSplitNode) { + std::vector categories = tree.CategoryList(i); + split_rule = TreeSplit(categories); + split_index = tree.SplitIndex(i); + } else { + continue; + } + // Partition the node indices + auto node_begin = (indices_.begin() + node_begin_[i]); + auto node_end = (indices_.begin() + node_begin_[i] + node_length_[i]); + auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return split_rule.SplitTrue(covariates(row, split_index)); }); + + // Determine the number of true and false elements + node_begin = (indices_.begin() + node_begin_[i]); + num_true = std::distance(node_begin, right_node_begin); + num_false = num_node_elements - num_true; + + // Add left node tracking information + node_begin_[left_nodes_[i]] = node_start_idx; + node_length_[left_nodes_[i]] = num_true; + parent_nodes_[left_nodes_[i]] = i; + left_nodes_[left_nodes_[i]] = StochTree::Tree::kInvalidNodeId; + left_nodes_[right_nodes_[i]] = StochTree::Tree::kInvalidNodeId; + + // Add right node tracking information + node_begin_[right_nodes_[i]] = node_start_idx + num_true; + node_length_[right_nodes_[i]] = num_false; + parent_nodes_[right_nodes_[i]] = i; + right_nodes_[left_nodes_[i]] = StochTree::Tree::kInvalidNodeId; + right_nodes_[right_nodes_[i]] = StochTree::Tree::kInvalidNodeId; + } + } +} + data_size_t FeatureUnsortedPartition::NodeBegin(int node_id) { return node_begin_[node_id]; } diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 45ff4127..87279397 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -174,6 +174,10 @@ class ForestContainerCpp { return forest->SumLeafSquared(); } + void DeleteSample(int forest_num) { + forest_samples_->DeleteSample(forest_num); + } + py::array_t Predict(ForestDatasetCpp& dataset) { // Predict from the forest container data_size_t n = dataset.NumRows(); @@ -623,6 +627,320 @@ class ForestContainerCpp { bool is_exponentiated_; }; +class ForestCpp { + public: + ForestCpp(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false) { + // Initialize pointer to C++ TreeEnsemble class + forest_ = std::make_unique(num_trees, output_dimension, is_leaf_constant, is_exponentiated); + num_trees_ = num_trees; + output_dimension_ = output_dimension; + is_leaf_constant_ = is_leaf_constant; + is_exponentiated_ = is_exponentiated; + } + ~ForestCpp() {} + + int OutputDimension() { + return forest_->OutputDimension(); + } + + int NumLeavesForest() { + return forest_->NumLeaves(); + } + + double SumLeafSquared(int forest_num) { + return forest_->SumLeafSquared(); + } + + void ResetRoot() { + // Reset active forest using the forest held at index forest_num + forest_->ResetRoot(); + } + + void Reset(ForestContainerCpp& forest_container, int forest_num) { + // Extract raw pointer to the forest held at index forest_num + StochTree::TreeEnsemble* forest = forest_container.GetForest(forest_num); + + // Reset active forest using the forest held at index forest_num + forest_->ReconstituteFromForest(*forest); + } + + py::array_t Predict(ForestDatasetCpp& dataset) { + // Predict from the forest container + data_size_t n = dataset.NumRows(); + StochTree::ForestDataset* data_ptr = dataset.GetDataset(); + std::vector output_raw = forest_->Predict(*data_ptr); + + // Convert result to a matrix + auto result = py::array_t(py::detail::any_container({n})); + auto accessor = result.mutable_unchecked<1>(); + for (size_t i = 0; i < n; i++) { + accessor(i) = output_raw[i]; + } + + return result; + } + + py::array_t PredictRaw(ForestDatasetCpp& dataset) { + // Predict from the forest container + data_size_t n = dataset.NumRows(); + int output_dim = this->OutputDimension(); + StochTree::ForestDataset* data_ptr = dataset.GetDataset(); + std::vector output_raw = forest_->PredictRaw(*data_ptr); + + // Convert result to 2 dimensional array (n x output_dim) + auto result = py::array_t(py::detail::any_container({n, output_dim})); + auto accessor = result.mutable_unchecked<2>(); + for (size_t i = 0; i < n; i++) { + for (int j = 0; j < output_dim; j++) { + accessor(i,j) = output_raw[i*output_dim + j]; + } + } + + return result; + } + + void SetRootValue(double leaf_value) { + forest_->SetLeafValue(leaf_value); + } + + void SetRootVector(py::array_t& leaf_vector, int leaf_size) { + std::vector leaf_vector_converted(leaf_size); + for (int i = 0; i < leaf_size; i++) { + leaf_vector_converted[i] = leaf_vector.at(i); + } + forest_->SetLeafVector(leaf_vector_converted); + } + + void AdjustResidual(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, bool add); + + StochTree::TreeEnsemble* GetEnsemble() { + return forest_.get(); + } + + void AddNumericSplitValue(int tree_num, int leaf_num, int feature_num, double split_threshold, + double left_leaf_value, double right_leaf_value) { + if (forest_->OutputDimension() != 1) { + StochTree::Log::Fatal("left_leaf_value must match forest leaf dimension"); + } + if (forest_->OutputDimension() != 1) { + StochTree::Log::Fatal("right_leaf_value must match forest leaf dimension"); + } + StochTree::TreeEnsemble* ensemble = forest_.get(); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + if (!tree->IsLeaf(leaf_num)) { + StochTree::Log::Fatal("leaf_num is not a leaf"); + } + tree->ExpandNode(leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value); + } + + void AddNumericSplitVector(int tree_num, int leaf_num, int feature_num, double split_threshold, + py::array_t left_leaf_vector, py::array_t right_leaf_vector) { + if (forest_->OutputDimension() != left_leaf_vector.size()) { + StochTree::Log::Fatal("left_leaf_vector must match forest leaf dimension"); + } + if (forest_->OutputDimension() != right_leaf_vector.size()) { + StochTree::Log::Fatal("right_leaf_vector must match forest leaf dimension"); + } + std::vector left_leaf_vector_cast(left_leaf_vector.size()); + std::vector right_leaf_vector_cast(right_leaf_vector.size()); + for (int i = 0; i < left_leaf_vector.size(); i++) left_leaf_vector_cast.at(i) = left_leaf_vector.at(i); + for (int i = 0; i < right_leaf_vector.size(); i++) right_leaf_vector_cast.at(i) = right_leaf_vector.at(i); + StochTree::TreeEnsemble* ensemble = forest_.get(); + StochTree::Tree* tree = ensemble->GetTree(tree_num); + if (!tree->IsLeaf(leaf_num)) { + StochTree::Log::Fatal("leaf_num is not a leaf"); + } + tree->ExpandNode(leaf_num, feature_num, split_threshold, left_leaf_vector_cast, right_leaf_vector_cast); + } + + py::array_t GetTreeLeaves(int tree_num) { + StochTree::Tree* tree = forest_->GetTree(tree_num); + std::vector leaves_raw = tree->GetLeaves(); + int num_leaves = leaves_raw.size(); + auto result = py::array_t(py::detail::any_container({num_leaves})); + auto accessor = result.mutable_unchecked<1>(); + for (size_t i = 0; i < num_leaves; i++) { + accessor(i) = leaves_raw.at(i); + } + return result; + } + + py::array_t GetTreeSplitCounts(int tree_num, int num_features) { + auto result = py::array_t(py::detail::any_container({num_features})); + auto accessor = result.mutable_unchecked<1>(); + for (size_t i = 0; i < num_features; i++) { + accessor(i) = 0; + } + StochTree::Tree* tree = forest_->GetTree(tree_num); + std::vector split_nodes = tree->GetInternalNodes(); + for (int i = 0; i < split_nodes.size(); i++) { + auto node_id = split_nodes.at(i); + auto split_feature = tree->SplitIndex(node_id); + accessor(split_feature)++; + } + return result; + } + + py::array_t GetOverallSplitCounts(int num_features) { + auto result = py::array_t(py::detail::any_container({num_features})); + auto accessor = result.mutable_unchecked<1>(); + for (size_t i = 0; i < num_features; i++) { + accessor(i) = 0; + } + int num_trees = forest_->NumTrees(); + for (int i = 0; i < num_trees; i++) { + StochTree::Tree* tree = forest_->GetTree(i); + std::vector split_nodes = tree->GetInternalNodes(); + for (int j = 0; j < split_nodes.size(); j++) { + auto node_id = split_nodes.at(j); + auto split_feature = tree->SplitIndex(node_id); + accessor(split_feature)++; + } + } + return result; + } + + py::array_t GetGranularSplitCounts(int num_features) { + int num_trees = forest_->NumTrees(); + auto result = py::array_t(py::detail::any_container({num_trees,num_features})); + auto accessor = result.mutable_unchecked<2>(); + for (int i = 0; i < num_trees; i++) { + for (int j = 0; j < num_features; j++) { + accessor(i,j) = 0; + } + } + for (int i = 0; i < num_trees; i++) { + StochTree::Tree* tree = forest_->GetTree(i); + std::vector split_nodes = tree->GetInternalNodes(); + for (int j = 0; j < split_nodes.size(); j++) { + auto node_id = split_nodes.at(i); + auto split_feature = tree->SplitIndex(node_id); + accessor(i,split_feature)++; + } + } + return result; + } + + bool IsLeafNode(int tree_id, int node_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + return tree->IsLeaf(node_id); + } + + bool IsNumericSplitNode(int tree_id, int node_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + return tree->IsNumericSplitNode(node_id); + } + + bool IsCategoricalSplitNode(int tree_id, int node_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + return tree->IsCategoricalSplitNode(node_id); + } + + int ParentNode(int tree_id, int node_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + return tree->Parent(node_id); + } + + int LeftChildNode(int tree_id, int node_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + return tree->LeftChild(node_id); + } + + int RightChildNode(int tree_id, int node_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + return tree->RightChild(node_id); + } + + int SplitIndex(int tree_id, int node_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + return tree->SplitIndex(node_id); + } + + int NodeDepth(int tree_id, int node_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + return tree->GetDepth(node_id); + } + + double SplitThreshold(int tree_id, int node_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + return tree->Threshold(node_id); + } + + py::array_t SplitCategories(int tree_id, int node_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + std::vector raw_categories = tree->CategoryList(node_id); + int num_categories = raw_categories.size(); + auto result = py::array_t(py::detail::any_container({num_categories})); + auto accessor = result.mutable_unchecked<1>(); + for (int i = 0; i < num_categories; i++) { + accessor(i) = raw_categories.at(i); + } + return result; + } + + py::array_t NodeLeafValues(int tree_id, int node_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + int num_outputs = tree->OutputDimension(); + auto result = py::array_t(py::detail::any_container({num_outputs})); + auto accessor = result.mutable_unchecked<1>(); + for (int i = 0; i < num_outputs; i++) { + accessor(i) = tree->LeafValue(node_id, i); + } + return result; + } + + int NumNodes(int tree_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + return tree->NumValidNodes(); + } + + int NumLeaves(int tree_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + return tree->NumLeaves(); + } + + int NumLeafParents(int tree_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + return tree->NumLeafParents(); + } + + int NumSplitNodes(int tree_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + return tree->NumSplitNodes(); + } + + py::array_t Nodes(int tree_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + std::vector nodes = tree->GetNodes(); + int num_nodes = nodes.size(); + auto result = py::array_t(py::detail::any_container({num_nodes})); + auto accessor = result.mutable_unchecked<1>(); + for (int i = 0; i < num_nodes; i++) { + accessor(i) = nodes.at(i); + } + return result; + } + + py::array_t Leaves(int tree_id) { + StochTree::Tree* tree = forest_->GetTree(tree_id); + std::vector leaves = tree->GetLeaves(); + int num_leaves = leaves.size(); + auto result = py::array_t(py::detail::any_container({num_leaves})); + auto accessor = result.mutable_unchecked<1>(); + for (int i = 0; i < num_leaves; i++) { + accessor(i) = leaves.at(i); + } + return result; + } + + private: + std::unique_ptr forest_; + int num_trees_; + int output_dimension_; + bool is_leaf_constant_; + bool is_exponentiated_; +}; + class ForestSamplerCpp { public: ForestSamplerCpp(ForestDatasetCpp& dataset, py::array_t feature_types, int num_trees, data_size_t num_obs, double alpha, double beta, int min_samples_leaf, int max_depth) { @@ -641,10 +959,20 @@ class ForestSamplerCpp { StochTree::ForestTracker* GetTracker() {return tracker_.get();} - void SampleOneIteration(ForestContainerCpp& forest_samples, ForestDatasetCpp& dataset, ResidualCpp& residual, RngCpp& rng, + void ReconstituteTrackerFromForest(ForestCpp& forest, ForestDatasetCpp& dataset, ResidualCpp& residual, bool is_mean_model) { + // Extract raw pointer to the forest and dataset + StochTree::TreeEnsemble* forest_ptr = forest.GetEnsemble(); + StochTree::ForestDataset* data_ptr = dataset.GetDataset(); + StochTree::ColumnVector* residual_ptr = residual.GetData(); + + // Reset forest tracker using the forest held at index forest_num + tracker_->ReconstituteFromForest(*forest_ptr, *data_ptr, *residual_ptr, is_mean_model); + } + + void SampleOneIteration(ForestContainerCpp& forest_samples, ForestCpp& forest, ForestDatasetCpp& dataset, ResidualCpp& residual, RngCpp& rng, py::array_t feature_types, int cutpoint_grid_size, py::array_t leaf_model_scale_input, py::array_t variable_weights, double a_forest, double b_forest, double global_variance, - int leaf_model_int, bool gfr = true, bool pre_initialized = false) { + int leaf_model_int, bool keep_forest = true, bool gfr = true, bool pre_initialized = false) { // Unpack feature types std::vector feature_types_(feature_types.size()); for (int i = 0; i < feature_types.size(); i++) { @@ -686,34 +1014,35 @@ class ForestSamplerCpp { // Run one iteration of the sampler StochTree::ForestContainer* forest_sample_ptr = forest_samples.GetContainer(); + StochTree::TreeEnsemble* active_forest_ptr = forest.GetEnsemble(); StochTree::ForestDataset* forest_data_ptr = dataset.GetDataset(); StochTree::ColumnVector* residual_data_ptr = residual.GetData(); int num_basis = forest_data_ptr->NumBasis(); std::mt19937* rng_ptr = rng.GetRng(); if (gfr) { if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - StochTree::GFRSampleOneIter(*(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized, true); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized, true); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized, true, num_basis); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::GFRSampleOneIter(*(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized, false); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false); } } else { if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - StochTree::MCMCSampleOneIter(*(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, pre_initialized, true); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, pre_initialized, true); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, pre_initialized, true, num_basis); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, keep_forest, pre_initialized, true, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::MCMCSampleOneIter(*(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, pre_initialized, false); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, keep_forest, pre_initialized, false); } } } - void InitializeForestModel(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestContainerCpp& forest_samples, + void InitializeForestModel(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestCpp& forest, int leaf_model_int, py::array_t initial_values) { // Convert leaf model type to enum StochTree::ModelType model_type; @@ -724,10 +1053,10 @@ class ForestSamplerCpp { else StochTree::Log::Fatal("Invalid model type"); // Unpack initial value - StochTree::ForestContainer* forest_sample_ptr = forest_samples.GetContainer(); + StochTree::TreeEnsemble* forest_ptr = forest.GetEnsemble(); StochTree::ForestDataset* forest_data_ptr = dataset.GetDataset(); StochTree::ColumnVector* residual_data_ptr = residual.GetData(); - int num_trees = forest_sample_ptr->NumTrees(); + int num_trees = forest_ptr->NumTrees(); double init_val; std::vector init_value_vector; if ((model_type == StochTree::ModelType::kConstantLeafGaussian) || @@ -743,30 +1072,34 @@ class ForestSamplerCpp { } // Initialize the models accordingly + double leaf_init_val; if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - forest_samples.InitializeRootValue(init_val / static_cast(num_trees)); - StochTree::UpdateResidualEntireForest(*tracker_, *forest_data_ptr, *residual_data_ptr, forest_sample_ptr->GetEnsemble(0), false, std::minus()); - tracker_->UpdatePredictions(forest_sample_ptr->GetEnsemble(0), *forest_data_ptr); + leaf_init_val = init_val / static_cast(num_trees); + forest_ptr->SetLeafValue(leaf_init_val); + StochTree::UpdateResidualEntireForest(*tracker_, *forest_data_ptr, *residual_data_ptr, forest_ptr, false, std::minus()); + tracker_->UpdatePredictions(forest_ptr, *forest_data_ptr); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - forest_samples.InitializeRootValue(init_val / static_cast(num_trees)); - StochTree::UpdateResidualEntireForest(*tracker_, *forest_data_ptr, *residual_data_ptr, forest_sample_ptr->GetEnsemble(0), true, std::minus()); - tracker_->UpdatePredictions(forest_sample_ptr->GetEnsemble(0), *forest_data_ptr); + leaf_init_val = init_val / static_cast(num_trees); + forest_ptr->SetLeafValue(leaf_init_val); + StochTree::UpdateResidualEntireForest(*tracker_, *forest_data_ptr, *residual_data_ptr, forest_ptr, true, std::minus()); + tracker_->UpdatePredictions(forest_ptr, *forest_data_ptr); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - forest_samples.InitializeRootVector(init_value_vector); - StochTree::UpdateResidualEntireForest(*tracker_, *forest_data_ptr, *residual_data_ptr, forest_sample_ptr->GetEnsemble(0), true, std::minus()); - tracker_->UpdatePredictions(forest_sample_ptr->GetEnsemble(0), *forest_data_ptr); + forest_ptr->SetLeafVector(init_value_vector); + StochTree::UpdateResidualEntireForest(*tracker_, *forest_data_ptr, *residual_data_ptr, forest_ptr, true, std::minus()); + tracker_->UpdatePredictions(forest_ptr, *forest_data_ptr); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - forest_samples.InitializeRootValue(std::log(init_val) / static_cast(num_trees)); - tracker_->UpdatePredictions(forest_sample_ptr->GetEnsemble(0), *forest_data_ptr); + leaf_init_val = std::log(init_val) / static_cast(num_trees); + forest_ptr->SetLeafValue(leaf_init_val); + tracker_->UpdatePredictions(forest_ptr, *forest_data_ptr); int n = forest_data_ptr->NumObservations(); std::vector initial_preds(n, init_val); forest_data_ptr->AddVarianceWeights(initial_preds.data(), n); } } - void PropagateBasisUpdate(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestContainerCpp& forest_samples, int forest_num) { + void PropagateBasisUpdate(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestCpp& forest) { // Perform the update operation - StochTree::UpdateResidualNewBasis(*tracker_, *(dataset.GetDataset()), *(residual.GetData()), forest_samples.GetForest(forest_num)); + StochTree::UpdateResidualNewBasis(*tracker_, *(dataset.GetDataset()), *(residual.GetData()), forest.GetEnsemble()); } void PropagateResidualUpdate(ResidualCpp& residual) { @@ -805,10 +1138,10 @@ class LeafVarianceModelCpp { } ~LeafVarianceModelCpp() {} - double SampleOneIteration(ForestContainerCpp& forest_samples, RngCpp& rng, double a, double b, int sample_num) { - StochTree::ForestContainer* forest_sample_ptr = forest_samples.GetContainer(); + double SampleOneIteration(ForestCpp& forest, RngCpp& rng, double a, double b) { + StochTree::TreeEnsemble* forest_ptr = forest.GetEnsemble(); std::mt19937* rng_ptr = rng.GetRng(); - return var_model_.SampleVarianceParameter(forest_sample_ptr->GetEnsemble(sample_num), a, b, *rng_ptr); + return var_model_.SampleVarianceParameter(forest_ptr, a, b, *rng_ptr); } private: @@ -825,6 +1158,16 @@ void ForestContainerCpp::AdjustResidual(ForestDatasetCpp& dataset, ResidualCpp& StochTree::UpdateResidualEntireForest(*(sampler.GetTracker()), *(dataset.GetDataset()), *(residual.GetData()), forest_samples_->GetEnsemble(forest_num), requires_basis, op); } +void ForestCpp::AdjustResidual(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, bool add) { + // Determine whether or not we are adding forest_num to the residuals + std::function op; + if (add) op = std::plus(); + else op = std::minus(); + + // Perform the update (addition / subtraction) operation + StochTree::UpdateResidualEntireForest(*(sampler.GetTracker()), *(dataset.GetDataset()), *(residual.GetData()), forest_.get(), requires_basis, op); +} + class JsonCpp { public: JsonCpp() { @@ -1156,6 +1499,7 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def(py::init()) .def("OutputDimension", &ForestContainerCpp::OutputDimension) .def("NumSamples", &ForestContainerCpp::NumSamples) + .def("DeleteSample", &ForestContainerCpp::DeleteSample) .def("Predict", &ForestContainerCpp::Predict) .def("PredictRaw", &ForestContainerCpp::PredictRaw) .def("PredictRawSingleForest", &ForestContainerCpp::PredictRawSingleForest) @@ -1196,8 +1540,48 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("Nodes", &ForestContainerCpp::Nodes) .def("Leaves", &ForestContainerCpp::Leaves); + py::class_(m, "ForestCpp") + .def(py::init()) + .def("OutputDimension", &ForestCpp::OutputDimension) + .def("NumLeavesForest", &ForestCpp::NumLeavesForest) + .def("SumLeafSquared", &ForestCpp::SumLeafSquared) + .def("ResetRoot", &ForestCpp::ResetRoot) + .def("Reset", &ForestCpp::Reset) + .def("Predict", &ForestCpp::Predict) + .def("PredictRaw", &ForestCpp::PredictRaw) + .def("SetRootValue", &ForestCpp::SetRootValue) + .def("SetRootVector", &ForestCpp::SetRootVector) + .def("AdjustResidual", &ForestCpp::AdjustResidual) + .def("AddNumericSplitValue", &ForestCpp::AddNumericSplitValue) + .def("AddNumericSplitVector", &ForestCpp::AddNumericSplitVector) + .def("GetEnsemble", &ForestCpp::GetEnsemble) + .def("GetTreeLeaves", &ForestCpp::GetTreeLeaves) + .def("GetTreeSplitCounts", &ForestCpp::GetTreeSplitCounts) + .def("GetOverallSplitCounts", &ForestCpp::GetOverallSplitCounts) + .def("GetGranularSplitCounts", &ForestCpp::GetGranularSplitCounts) + .def("NumLeavesForest", &ForestCpp::NumLeavesForest) + .def("SumLeafSquared", &ForestCpp::SumLeafSquared) + .def("IsLeafNode", &ForestCpp::IsLeafNode) + .def("IsNumericSplitNode", &ForestCpp::IsNumericSplitNode) + .def("IsCategoricalSplitNode", &ForestCpp::IsCategoricalSplitNode) + .def("ParentNode", &ForestCpp::ParentNode) + .def("LeftChildNode", &ForestCpp::LeftChildNode) + .def("RightChildNode", &ForestCpp::RightChildNode) + .def("SplitIndex", &ForestCpp::SplitIndex) + .def("NodeDepth", &ForestCpp::NodeDepth) + .def("SplitThreshold", &ForestCpp::SplitThreshold) + .def("SplitCategories", &ForestCpp::SplitCategories) + .def("NodeLeafValues", &ForestCpp::NodeLeafValues) + .def("NumNodes", &ForestCpp::NumNodes) + .def("NumLeaves", &ForestCpp::NumLeaves) + .def("NumLeafParents", &ForestCpp::NumLeafParents) + .def("NumSplitNodes", &ForestCpp::NumSplitNodes) + .def("Nodes", &ForestCpp::Nodes) + .def("Leaves", &ForestCpp::Leaves); + py::class_(m, "ForestSamplerCpp") .def(py::init, int, data_size_t, double, double, int, int>()) + .def("ReconstituteTrackerFromForest", &ForestSamplerCpp::ReconstituteTrackerFromForest) .def("SampleOneIteration", &ForestSamplerCpp::SampleOneIteration) .def("InitializeForestModel", &ForestSamplerCpp::InitializeForestModel) .def("PropagateBasisUpdate", &ForestSamplerCpp::PropagateBasisUpdate) diff --git a/src/random_effects.cpp b/src/random_effects.cpp index efb141cf..40c828ba 100644 --- a/src/random_effects.cpp +++ b/src/random_effects.cpp @@ -39,6 +39,60 @@ void LabelMapper::from_json(const nlohmann::json& rfx_label_mapper_json) { } } +void RandomEffectsTracker::ResetFromSample(MultivariateRegressionRandomEffectsModel& rfx_model, + RandomEffectsDataset& rfx_dataset, ColumnVector& residual) { + Eigen::MatrixXd X = rfx_dataset.GetBasis(); + std::vector group_labels = rfx_dataset.GetGroupLabels(); + CHECK_EQ(X.rows(), group_labels.size()); + int n = X.rows(); + double prev_pred; + double new_pred; + double new_resid; + Eigen::MatrixXd alpha_diag = rfx_model.GetWorkingParameter().asDiagonal().toDenseMatrix(); + Eigen::MatrixXd xi = rfx_model.GetGroupParameters(); + std::int32_t group_ind; + for (int i = 0; i < n; i++) { + group_ind = CategoryNumber(group_labels[i]); + prev_pred = GetPrediction(i); + new_pred = X(i, Eigen::all) * alpha_diag * xi(Eigen::all, group_ind); + new_resid = residual.GetElement(i) - new_pred + prev_pred; + residual.SetElement(i, new_resid); + SetPrediction(i, new_pred); + } +} + +void RandomEffectsTracker::RootReset(MultivariateRegressionRandomEffectsModel& rfx_model, + RandomEffectsDataset& rfx_dataset, ColumnVector& residual) { + int n = rfx_dataset.NumObservations(); + CHECK_EQ(n, num_observations_); + double prev_pred; + double new_pred; + double new_resid; + for (int i = 0; i < n; i++) { + prev_pred = GetPrediction(i); + new_pred = 0.; + new_resid = residual.GetElement(i) - new_pred + prev_pred; + residual.SetElement(i, new_resid); + SetPrediction(i, new_pred); + } +} + +void MultivariateRegressionRandomEffectsModel::ResetFromSample(RandomEffectsContainer& rfx_container, int sample_num) { + // Extract parameter vectors + std::vector& alpha = rfx_container.GetAlpha(); + std::vector& xi = rfx_container.GetXi(); + std::vector& sigma = rfx_container.GetSigma(); + + // Unpack parameters + for (int i = 0; i < num_components_; i++) { + working_parameter_(i) = alpha.at(sample_num*num_components_ + i); + group_parameter_covariance_(i, i) = sigma.at(sample_num*num_components_ + i); + for (int j = 0; j < num_groups_; j++) { + group_parameters_(i,j) = xi.at(sample_num*num_groups_*num_components_ + j*num_components_ + i); + } + } + } + void MultivariateRegressionRandomEffectsModel::SampleRandomEffects(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, std::mt19937& gen) { // Update partial residual to add back in the random effects @@ -266,6 +320,47 @@ nlohmann::json RandomEffectsContainer::to_json() { return result_obj; } +void RandomEffectsContainer::DeleteSample(int sample_num){ + // Decrement number of samples + num_samples_--; + + // Remove sample_num from alpha + // ---------------------------- + // This code works because the data are stored in a "column-major" format, + // with components comprising rows and and samples comprising columns, so that + // element `sample_num*num_components_ + i` will contain the "i"-th component of the + // sample indexed by sample_num. Erasing the `sample_num*num_components_ + 0` + // element of the vector will move the element that was previously in position + // `sample_num*num_components_ + 1` into the position `sample_num*num_components_ + 0` + // and thus we can repeat `alpha_.erase(alpha_.begin() + sample_num*num_components_);` + // exactly `num_components_` times to erase each component pertaining to this sample. + for (int i = 0; i < num_components_; i++) { + alpha_.erase(alpha_.begin() + sample_num*num_components_); + } + + // Remove sample_num from xi and beta + // ---------------------------------- + // This code works as above, with the added nuance of the three-dimensional (Fortran-aligned) array, + // in which sample number is the third dimension, group number is the second dimension, and component + // number is the third dimension. The nested loop assembles all `num_groups_*num_components_` offsets, + // expressed as `j*num_components_ + i`. In order to remove each of the elements stored in these offsets + // from `sample_num*num_groups_*num_components_`, we simply need to erase the + // `sample_num*num_groups_*num_components_` element, exactly `num_groups_*num_components_` times. + for (int i = 0; i < num_components_; i++) { + for (int j = 0; j < num_groups_; j++) { + xi_.erase(xi_.begin() + sample_num*num_groups_*num_components_); + beta_.erase(beta_.begin() + sample_num*num_groups_*num_components_); + } + } + + // Remove sample_num from sigma + // ---------------------------- + // This code works as with alpha + for (int i = 0; i < num_components_; i++) { + sigma_xi_.erase(sigma_xi_.begin() + sample_num*num_components_); + } +} + void RandomEffectsContainer::from_json(const nlohmann::json& rfx_container_json) { int beta_size = rfx_container_json.at("beta_size"); int alpha_size = rfx_container_json.at("alpha_size"); diff --git a/src/sampler.cpp b/src/sampler.cpp index 1792f1b2..f6e0f3c6 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -15,6 +15,7 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, + cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, @@ -23,6 +24,7 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer(*tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized, true); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized, true); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized, true, num_basis); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::GFRSampleOneIter(*tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized, false); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false); } } @@ -82,6 +84,7 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, + cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, @@ -90,6 +93,7 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer(*tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized, true); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized, true); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized, true, num_basis); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, keep_forest, pre_initialized, true, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::MCMCSampleOneIter(*tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized, false); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, keep_forest, pre_initialized, false); } } @@ -161,13 +165,13 @@ double sample_sigma2_one_iteration_cpp(cpp11::external_pointer forest_samples, +double sample_tau_one_iteration_cpp(cpp11::external_pointer active_forest, cpp11::external_pointer rng, - double a, double b, int sample_num + double a, double b ) { // Run one iteration of the sampler StochTree::LeafNodeHomoskedasticVarianceModel var_model = StochTree::LeafNodeHomoskedasticVarianceModel(); - return var_model.SampleVarianceParameter(forest_samples->GetEnsemble(sample_num), a, b, *rng); + return var_model.SampleVarianceParameter(active_forest.get(), a, b, *rng); } [[cpp11::register]] diff --git a/src/tree.cpp b/src/tree.cpp index 6eb43910..fa6fd8f8 100644 --- a/src/tree.cpp +++ b/src/tree.cpp @@ -599,13 +599,15 @@ void JsonToTreeNodeVectors(const json& tree_json, Tree* tree) { tree->category_list_begin_.clear(); tree->category_list_end_.clear(); + bool is_univariate = tree->OutputDimension() == 1; int num_nodes = tree->NumNodes(); for (int i = 0; i < num_nodes; i++) { tree->parent_.push_back(tree_json.at("parent").at(i)); tree->cleft_.push_back(tree_json.at("left").at(i)); tree->cright_.push_back(tree_json.at("right").at(i)); tree->split_index_.push_back(tree_json.at("split_index").at(i)); - tree->leaf_value_.push_back(tree_json.at("leaf_value").at(i)); + if (is_univariate) tree->leaf_value_.push_back(tree_json.at("leaf_value").at(i)); + else tree->leaf_value_.push_back(0.); tree->threshold_.push_back(tree_json.at("threshold").at(i)); tree->node_deleted_.push_back(tree_json.at("node_deleted").at(i)); // Handle type conversions for node_type, leaf_vector_begin/end, and category_list_begin/end diff --git a/stochtree/__init__.py b/stochtree/__init__.py index f4aa122b..94e93250 100644 --- a/stochtree/__init__.py +++ b/stochtree/__init__.py @@ -2,12 +2,12 @@ from .bcf import BCFModel from .calibration import calibrate_global_error_variance from .data import Dataset, Residual -from .forest import ForestContainer +from .forest import ForestContainer, Forest from .preprocessing import CovariateTransformer from .sampler import RNG, ForestSampler, GlobalVarianceModel, LeafVarianceModel from .serialization import JSONSerializer from .utils import NotSampledError -__all__ = ['BARTModel', 'BCFModel', 'Dataset', 'Residual', 'ForestContainer', +__all__ = ['BARTModel', 'BCFModel', 'Dataset', 'Residual', 'ForestContainer', 'Forest', 'CovariateTransformer', 'RNG', 'ForestSampler', 'GlobalVarianceModel', 'LeafVarianceModel', 'JSONSerializer', 'NotSampledError', 'calibrate_global_error_variance'] \ No newline at end of file diff --git a/stochtree/bart.py b/stochtree/bart.py index 1c8afba2..6d78dae8 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1,11 +1,13 @@ """ Bayesian Additive Regression Trees (BART) module """ +from numbers import Number, Integral +from math import log import numpy as np import pandas as pd from typing import Optional, Dict, Any from .data import Dataset, Residual -from .forest import ForestContainer +from .forest import ForestContainer, Forest from .preprocessing import CovariateTransformer, _preprocess_bart_params from .sampler import ForestSampler, RNG, GlobalVarianceModel, LeafVarianceModel from .serialization import JSONSerializer @@ -80,6 +82,7 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N * ``random_seed`` (``int``): Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to ``std::random_device``. * ``keep_burnin`` (``bool``): Whether or not "burnin" samples should be included in predictions. Defaults to ``False``. Ignored if ``num_mcmc == 0``. * ``keep_gfr`` (``bool``): Whether or not "warm-start" / grow-from-root samples should be included in predictions. Defaults to ``False``. Ignored if ``num_mcmc == 0``. + * ``keep_every`` (``int``): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to ``1``. Setting ``keep_every = k`` for some ``k > 1`` will "thin" the MCMC samples by retaining every ``k``-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. Returns ------- @@ -119,7 +122,18 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N keep_burnin = bart_params['keep_burnin'] keep_gfr = bart_params['keep_gfr'] self.standardize = bart_params['standardize'] - + num_chains = bart_params['num_chains'] + keep_every = bart_params['keep_every'] + + # Check that num_chains >= 1 + if not isinstance(num_chains, Integral) or num_chains < 1: + raise ValueError("num_chains must be an integer greater than 0") + + # Check if there are enough GFR samples to seed num_chains samplers + if num_gfr > 0: + if num_chains > num_gfr: + raise ValueError("num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains") + # Determine which models (conditional mean, conditional variance, or both) we will fit self.include_mean_forest = True if num_trees_mean > 0 else False self.include_variance_forest = True if num_trees_variance > 0 else False @@ -255,13 +269,23 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N self.num_gfr = num_gfr self.num_burnin = num_burnin self.num_mcmc = num_mcmc - self.num_samples = num_gfr + num_burnin + num_mcmc + num_actual_mcmc_iter = num_mcmc * keep_every * num_chains + num_temp_samples = num_gfr + num_burnin + num_mcmc * keep_every + num_retained_samples = num_mcmc * num_chains + # Delete GFR samples from these containers after the fact if desired + # if keep_gfr: + # num_retained_samples += num_gfr + num_retained_samples += num_gfr + if keep_burnin: + num_retained_samples += num_burnin * num_chains + self.num_samples = num_retained_samples self.sample_sigma_global = sample_sigma_global self.sample_sigma_leaf = sample_sigma_leaf if sample_sigma_global: - self.global_var_samples = np.zeros(self.num_samples) + self.global_var_samples = np.empty(self.num_samples, dtype = np.float64) if sample_sigma_leaf: - self.leaf_scale_samples = np.zeros(self.num_samples) + self.leaf_scale_samples = np.empty(self.num_samples, dtype = np.float64) + sample_counter = -1 # Forest Dataset (covariates and optional basis) forest_dataset_train = Dataset() @@ -303,8 +327,10 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N # Container of forest samples if self.include_mean_forest: self.forest_container_mean = ForestContainer(num_trees_mean, 1, True, False) if not self.has_basis else ForestContainer(num_trees_mean, self.num_basis, False, False) + active_forest_mean = Forest(num_trees_mean, 1, True, False) if not self.has_basis else Forest(num_trees_mean, self.num_basis, False, False) if self.include_variance_forest: self.forest_container_variance = ForestContainer(num_trees_variance, 1, True, True) + active_forest_variance = Forest(num_trees_variance, 1, True, True) # Variance samplers if self.sample_sigma_global: @@ -318,123 +344,156 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N init_val_mean = np.repeat(0., basis_train.shape[1]) else: init_val_mean = np.array([0.]) - forest_sampler_mean.prepare_for_sampler(forest_dataset_train, residual_train, self.forest_container_mean, leaf_model_mean_forest, init_val_mean) + forest_sampler_mean.prepare_for_sampler(forest_dataset_train, residual_train, active_forest_mean, leaf_model_mean_forest, init_val_mean) # Initialize the leaves of each tree in the variance forest if self.include_variance_forest: init_val_variance = np.array([variance_forest_leaf_init]) - forest_sampler_variance.prepare_for_sampler(forest_dataset_train, residual_train, self.forest_container_variance, leaf_model_variance_forest, init_val_variance) + forest_sampler_variance.prepare_for_sampler(forest_dataset_train, residual_train, active_forest_variance, leaf_model_variance_forest, init_val_variance) # Run GFR (warm start) if specified if self.num_gfr > 0: - gfr_indices = np.arange(self.num_gfr) for i in range(self.num_gfr): + # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC + # keep_sample = keep_gfr + keep_sample = True + if keep_sample: + sample_counter += 1 # Sample the mean forest if self.include_mean_forest: forest_sampler_mean.sample_one_iteration( - self.forest_container_mean, forest_dataset_train, residual_train, cpp_rng, feature_types, - cutpoint_grid_size, current_leaf_scale, variable_weights_mean, a_forest, b_forest, - current_sigma2, leaf_model_mean_forest, True, True + self.forest_container_mean, active_forest_mean, forest_dataset_train, residual_train, + cpp_rng, feature_types, cutpoint_grid_size, current_leaf_scale, variable_weights_mean, a_forest, b_forest, + current_sigma2, leaf_model_mean_forest, keep_sample, True, True ) # Sample the variance forest if self.include_variance_forest: forest_sampler_variance.sample_one_iteration( - self.forest_container_variance, forest_dataset_train, residual_train, cpp_rng, feature_types, - cutpoint_grid_size, current_leaf_scale, variable_weights_variance, a_forest, b_forest, - current_sigma2, leaf_model_variance_forest, True, True + self.forest_container_variance, active_forest_variance, forest_dataset_train, residual_train, + cpp_rng, feature_types, cutpoint_grid_size, current_leaf_scale, variable_weights_variance, a_forest, b_forest, + current_sigma2, leaf_model_variance_forest, keep_sample, True, True ) # Sample variance parameters (if requested) if self.sample_sigma_global: current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global) - self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std/self.variance_scale + if keep_sample: + self.global_var_samples[sample_counter] = current_sigma2 if self.sample_sigma_leaf: - self.leaf_scale_samples[i] = leaf_var_model.sample_one_iteration(self.forest_container_mean, cpp_rng, a_leaf, b_leaf, i) - current_leaf_scale[0,0] = self.leaf_scale_samples[i] + current_leaf_scale[0,0] = leaf_var_model.sample_one_iteration(active_forest_mean, cpp_rng, a_leaf, b_leaf) + if keep_sample: + self.leaf_scale_samples[sample_counter] = current_leaf_scale[0,0] # Run MCMC if self.num_burnin + self.num_mcmc > 0: - if self.num_burnin > 0: - burnin_indices = np.arange(self.num_gfr, self.num_gfr + self.num_burnin) - if self.num_mcmc > 0: - mcmc_indices = np.arange(self.num_gfr + self.num_burnin, self.num_gfr + self.num_burnin + self.num_mcmc) - for i in range(self.num_gfr, self.num_samples): - # Sample the mean forest - if self.include_mean_forest: - forest_sampler_mean.sample_one_iteration( - self.forest_container_mean, forest_dataset_train, residual_train, cpp_rng, feature_types, - cutpoint_grid_size, current_leaf_scale, variable_weights_mean, a_forest, b_forest, - current_sigma2, leaf_model_mean_forest, False, True - ) - - # Sample the variance forest - if self.include_variance_forest: - forest_sampler_variance.sample_one_iteration( - self.forest_container_variance, forest_dataset_train, residual_train, cpp_rng, feature_types, - cutpoint_grid_size, current_leaf_scale, variable_weights_variance, a_forest, b_forest, - current_sigma2, leaf_model_variance_forest, False, True - ) - - # Sample variance parameters (if requested) - if self.sample_sigma_global: - current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global) - self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std/self.variance_scale - if self.sample_sigma_leaf: - self.leaf_scale_samples[i] = leaf_var_model.sample_one_iteration(self.forest_container_mean, cpp_rng, a_leaf, b_leaf, i) - current_leaf_scale[0,0] = self.leaf_scale_samples[i] + for chain_num in range(num_chains): + if num_gfr > 0: + forest_ind = num_gfr - chain_num - 1 + if self.include_mean_forest: + active_forest_mean.reset(self.forest_container_mean, forest_ind) + forest_sampler_mean.reconstitute_from_forest(active_forest_mean, forest_dataset_train, residual_train, True) + if self.include_variance_forest: + active_forest_variance.reset(self.forest_container_variance, forest_ind) + forest_sampler_variance.reconstitute_from_forest(active_forest_variance, forest_dataset_train, residual_train, False) + if sample_sigma_global: + current_sigma2 = self.global_var_samples[forest_ind] + else: + if self.include_mean_forest: + active_forest_mean.reset_root() + if init_val_mean.shape[0] == 1: + active_forest_mean.set_root_leaves(init_val_mean[0] / num_trees_mean) + else: + active_forest_mean.set_root_leaves(init_val_mean / num_trees_mean) + forest_sampler_mean.reconstitute_from_forest(active_forest_mean, forest_dataset_train, residual_train, True) + if self.include_variance_forest: + active_forest_variance.reset_root() + active_forest_variance.set_root_leaves(log(variance_forest_leaf_init) / num_trees_mean) + forest_sampler_variance.reconstitute_from_forest(active_forest_variance, forest_dataset_train, residual_train, False) + + for i in range(self.num_gfr, num_temp_samples): + is_mcmc = i + 1 > num_gfr + num_burnin + if is_mcmc: + mcmc_counter = i - num_gfr - num_burnin + 1 + if (mcmc_counter % keep_every == 0): + keep_sample = True + else: + keep_sample = False + else: + if keep_burnin: + keep_sample = True + else: + keep_sample = False + if keep_sample: + sample_counter += 1 + # Sample the mean forest + if self.include_mean_forest: + forest_sampler_mean.sample_one_iteration( + self.forest_container_mean, active_forest_mean, forest_dataset_train, residual_train, + cpp_rng, feature_types, cutpoint_grid_size, current_leaf_scale, variable_weights_mean, a_forest, b_forest, + current_sigma2, leaf_model_mean_forest, keep_sample, False, True + ) + + # Sample the variance forest + if self.include_variance_forest: + forest_sampler_variance.sample_one_iteration( + self.forest_container_variance, active_forest_variance, forest_dataset_train, residual_train, + cpp_rng, feature_types, cutpoint_grid_size, current_leaf_scale, variable_weights_variance, a_forest, b_forest, + current_sigma2, leaf_model_variance_forest, keep_sample, False, True + ) + + # Sample variance parameters (if requested) + if self.sample_sigma_global: + current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global) + if keep_sample: + self.global_var_samples[sample_counter] = current_sigma2 + if self.sample_sigma_leaf: + current_leaf_scale[0,0] = leaf_var_model.sample_one_iteration(active_forest_mean, cpp_rng, a_leaf, b_leaf) + if keep_sample: + self.leaf_scale_samples[sample_counter] = current_leaf_scale[0,0] # Mark the model as sampled self.sampled = True - # Prediction indices to be stored - if self.num_mcmc > 0: - self.keep_indices = mcmc_indices - if keep_gfr: - self.keep_indices = np.concatenate((gfr_indices, self.keep_indices)) - else: - # Don't retain both GFR and burnin samples - if keep_burnin: - self.keep_indices = np.concatenate((burnin_indices, self.keep_indices)) - else: - if self.num_gfr > 0 and self.num_burnin > 0: - # Override keep_gfr = False since there are no MCMC samples - # Don't retain both GFR and burnin samples - self.keep_indices = gfr_indices - elif self.num_gfr <= 0 and self.num_burnin > 0: - self.keep_indices = burnin_indices - elif self.num_gfr > 0 and self.num_burnin <= 0: - self.keep_indices = gfr_indices - else: - raise RuntimeError("There are no samples to retain!") - + # Remove GFR samples if they are not to be retained + if not keep_gfr and num_gfr > 0: + for i in range(num_gfr): + if self.include_mean_forest: + self.forest_container_mean.delete_sample(i) + if self.include_variance_forest: + self.forest_container_variance.delete_sample(i) + if self.sample_sigma_global: + self.global_var_samples = self.global_var_samples[num_gfr:] + if self.sample_sigma_leaf: + self.leaf_scale_samples = self.leaf_scale_samples[num_gfr:] + # Store predictions if self.sample_sigma_global: - self.global_var_samples = self.global_var_samples[self.keep_indices] + self.global_var_samples = self.global_var_samples*self.y_std*self.y_std/self.variance_scale if self.sample_sigma_leaf: - self.leaf_scale_samples = self.leaf_scale_samples[self.keep_indices] + self.leaf_scale_samples = self.leaf_scale_samples if self.include_mean_forest: - yhat_train_raw = self.forest_container_mean.forest_container_cpp.Predict(forest_dataset_train.dataset_cpp)[:,self.keep_indices] + yhat_train_raw = self.forest_container_mean.forest_container_cpp.Predict(forest_dataset_train.dataset_cpp) self.y_hat_train = yhat_train_raw*self.y_std/np.sqrt(self.variance_scale) + self.y_bar if self.has_test: - yhat_test_raw = self.forest_container_mean.forest_container_cpp.Predict(forest_dataset_test.dataset_cpp)[:,self.keep_indices] + yhat_test_raw = self.forest_container_mean.forest_container_cpp.Predict(forest_dataset_test.dataset_cpp) self.y_hat_test = yhat_test_raw*self.y_std/np.sqrt(self.variance_scale) + self.y_bar if self.include_variance_forest: - sigma_x_train_raw = self.forest_container_variance.forest_container_cpp.Predict(forest_dataset_train.dataset_cpp)[:,self.keep_indices] + sigma_x_train_raw = self.forest_container_variance.forest_container_cpp.Predict(forest_dataset_train.dataset_cpp) if self.sample_sigma_global: self.sigma_x_train = sigma_x_train_raw - for i in range(self.keep_indices.shape[0]): + for i in range(num_retained_samples): self.sigma_x_train[:,i] = np.sqrt(sigma_x_train_raw[:,i]*self.global_var_samples[i]) else: self.sigma_x_train = np.sqrt(sigma_x_train_raw*self.sigma2_init)*self.y_std/np.sqrt(self.variance_scale) if self.has_test: - sigma_x_test_raw = self.forest_container_variance.forest_container_cpp.Predict(forest_dataset_test.dataset_cpp)[:,self.keep_indices] + sigma_x_test_raw = self.forest_container_variance.forest_container_cpp.Predict(forest_dataset_test.dataset_cpp) if self.sample_sigma_global: self.sigma_x_test = sigma_x_test_raw - for i in range(self.keep_indices.shape[0]): + for i in range(num_retained_samples): self.sigma_x_test[:,i] = np.sqrt(sigma_x_test_raw[:,i]*self.global_var_samples[i]) else: self.sigma_x_test = np.sqrt(sigma_x_test_raw*self.sigma2_init)*self.y_std/np.sqrt(self.variance_scale) @@ -478,13 +537,13 @@ def predict(self, covariates: np.array, basis: np.array = None) -> np.array: if basis is not None: pred_dataset.add_basis(basis) if self.include_mean_forest: - mean_pred_raw = self.forest_container_mean.forest_container_cpp.Predict(pred_dataset.dataset_cpp)[:,self.keep_indices] + mean_pred_raw = self.forest_container_mean.forest_container_cpp.Predict(pred_dataset.dataset_cpp) mean_pred = mean_pred_raw*self.y_std/np.sqrt(self.variance_scale) + self.y_bar if self.include_variance_forest: - variance_pred_raw = self.forest_container_variance.forest_container_cpp.Predict(pred_dataset.dataset_cpp)[:,self.keep_indices] + variance_pred_raw = self.forest_container_variance.forest_container_cpp.Predict(pred_dataset.dataset_cpp) if self.sample_sigma_global: variance_pred = variance_pred_raw - for i in range(self.keep_indices.shape[0]): + for i in range(self.num_samples): variance_pred[:,i] = np.sqrt(variance_pred_raw[:,i]*self.global_var_samples[i]) else: variance_pred = np.sqrt(variance_pred_raw*self.sigma2_init)*self.y_std/np.sqrt(self.variance_scale) @@ -541,7 +600,7 @@ def predict_mean(self, covariates: np.array, basis: np.array = None) -> np.array pred_dataset.add_covariates(covariates) if basis is not None: pred_dataset.add_basis(basis) - mean_pred_raw = self.forest_container_mean.forest_container_cpp.Predict(pred_dataset.dataset_cpp)[:,self.keep_indices] + mean_pred_raw = self.forest_container_mean.forest_container_cpp.Predict(pred_dataset.dataset_cpp) mean_pred = mean_pred_raw*self.y_std/np.sqrt(self.variance_scale) + self.y_bar return mean_pred @@ -591,10 +650,10 @@ def predict_variance(self, covariates: np.array, basis: np.array = None) -> np.a pred_dataset.add_covariates(covariates) # if basis is not None: # pred_dataset.add_basis(basis) - variance_pred_raw = self.forest_container_variance.forest_container_cpp.Predict(pred_dataset.dataset_cpp)[:,self.keep_indices] + variance_pred_raw = self.forest_container_variance.forest_container_cpp.Predict(pred_dataset.dataset_cpp) if self.sample_sigma_global: variance_pred = variance_pred_raw - for i in range(self.keep_indices.shape[0]): + for i in range(self.num_samples): variance_pred[:,i] = np.sqrt(variance_pred_raw[:,i]*self.global_var_samples[i]) else: variance_pred = np.sqrt(variance_pred_raw*self.sigma2_init)*self.y_std/np.sqrt(self.variance_scale) @@ -643,7 +702,6 @@ def to_json(self) -> str: bart_json.add_scalar("num_samples", self.num_samples) bart_json.add_scalar("num_basis", self.num_basis) bart_json.add_boolean("requires_basis", self.has_basis) - bart_json.add_numeric_vector("keep_indices", self.keep_indices) # Add parameter samples if self.sample_sigma_global: @@ -696,7 +754,6 @@ def from_json(self, json_string: str) -> None: self.num_samples = bart_json.get_scalar("num_samples") self.num_basis = bart_json.get_scalar("num_basis") self.has_basis = bart_json.get_boolean("requires_basis") - self.keep_indices = bart_json.get_numeric_vector("keep_indices").astype(int) # Unpack parameter samples if self.sample_sigma_global: diff --git a/stochtree/bcf.py b/stochtree/bcf.py index ab53f2e3..052dbaed 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -7,7 +7,7 @@ from typing import Optional, Union, Dict, Any from .bart import BARTModel from .data import Dataset, Residual -from .forest import ForestContainer +from .forest import ForestContainer, Forest from .preprocessing import CovariateTransformer, _preprocess_bcf_params from .sampler import ForestSampler, RNG, GlobalVarianceModel, LeafVarianceModel from .utils import NotSampledError @@ -105,6 +105,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr * ``random_seed`` (``int``): Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to ``std::random_device``. * ``keep_burnin`` (``bool``): Whether or not "burnin" samples should be included in predictions. Defaults to ``False``. Ignored if ``num_mcmc == 0``. * ``keep_gfr`` (``bool``): Whether or not "warm-start" / grow-from-root samples should be included in predictions. Defaults to ``False``. Ignored if ``num_mcmc == 0``. + * ``keep_every`` (``int``): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to ``1``. Setting ``keep_every = k`` for some ``k > 1`` will "thin" the MCMC samples by retaining every ``k``-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. Returns ------- @@ -150,6 +151,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr keep_burnin = bcf_params['keep_burnin'] keep_gfr = bcf_params['keep_gfr'] self.standardize = bcf_params['standardize'] + keep_every = bcf_params['keep_every'] # Variable weight preprocessing (and initialization if necessary) if variable_weights is None: @@ -558,19 +560,26 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr self.propensity_covariate = propensity_covariate # Container of variance parameter samples - self.num_gfr = num_gfr - self.num_burnin = num_burnin - self.num_mcmc = num_mcmc - self.num_samples = num_gfr + num_burnin + num_mcmc + num_actual_mcmc_iter = num_mcmc * keep_every + num_temp_samples = num_gfr + num_burnin + num_actual_mcmc_iter + num_retained_samples = num_mcmc + # Delete GFR samples from these containers after the fact if desired + # if keep_gfr: + # num_retained_samples += num_gfr + num_retained_samples += num_gfr + if keep_burnin: + num_retained_samples += num_burnin + self.num_samples = num_retained_samples self.sample_sigma_global = sample_sigma_global self.sample_sigma_leaf_mu = sample_sigma_leaf_mu self.sample_sigma_leaf_tau = sample_sigma_leaf_tau if sample_sigma_global: - self.global_var_samples = np.zeros(self.num_samples) + self.global_var_samples = np.empty(self.num_samples, dtype=np.float64) if sample_sigma_leaf_mu: - self.leaf_scale_mu_samples = np.zeros(self.num_samples) + self.leaf_scale_mu_samples = np.empty(self.num_samples, dtype=np.float64) if sample_sigma_leaf_tau: - self.leaf_scale_tau_samples = np.zeros(self.num_samples) + self.leaf_scale_tau_samples = np.empty(self.num_samples, dtype=np.float64) + sample_counter = -1 # Prepare adaptive coding structure if self.adaptive_coding: @@ -578,8 +587,8 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr raise ValueError("b_0 and b_1 must be single numeric values") if not (isinstance(b_0, (int, float)) or isinstance(b_1, (int, float))): raise ValueError("b_0 and b_1 must be numeric values") - self.b0_samples = np.zeros(self.num_samples) - self.b1_samples = np.zeros(self.num_samples) + self.b0_samples = np.empty(self.num_samples, dtype=np.float64) + self.b1_samples = np.empty(self.num_samples, dtype=np.float64) current_b_0 = b_0 current_b_1 = b_1 tau_basis_train = (1-Z_train)*current_b_0 + Z_train*current_b_1 @@ -619,6 +628,8 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr # Container of forest samples self.forest_container_mu = ForestContainer(num_trees_mu, 1, True, False) self.forest_container_tau = ForestContainer(num_trees_tau, Z_train.shape[1], False, False) + active_forest_mu = Forest(num_trees_mu, 1, True, False) + active_forest_tau = Forest(num_trees_tau, Z_train.shape[1], False, False) # Variance samplers if self.sample_sigma_global: @@ -630,53 +641,59 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr # Initialize the leaves of each tree in the prognostic forest init_mu = np.array([np.squeeze(np.mean(resid_train))]) - forest_sampler_mu.prepare_for_sampler(forest_dataset_train, residual_train, self.forest_container_mu, 0, init_mu) + forest_sampler_mu.prepare_for_sampler(forest_dataset_train, residual_train, active_forest_mu, 0, init_mu) # Initialize the leaves of each tree in the treatment forest if self.multivariate_treatment: init_tau = np.zeros(Z_train.shape[1]) else: init_tau = np.array([0.]) - forest_sampler_tau.prepare_for_sampler(forest_dataset_train, residual_train, self.forest_container_tau, treatment_leaf_model, init_tau) + forest_sampler_tau.prepare_for_sampler(forest_dataset_train, residual_train, active_forest_tau, treatment_leaf_model, init_tau) # Run GFR (warm start) if specified - if self.num_gfr > 0: - gfr_indices = np.arange(self.num_gfr) - for i in range(self.num_gfr): + if num_gfr > 0: + for i in range(num_gfr): + # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC + # keep_sample = keep_gfr + keep_sample = True + if keep_sample: + sample_counter += 1 # Sample the prognostic forest forest_sampler_mu.sample_one_iteration( - self.forest_container_mu, forest_dataset_train, residual_train, cpp_rng, feature_types, + self.forest_container_mu, active_forest_mu, forest_dataset_train, residual_train, cpp_rng, feature_types, cutpoint_grid_size, current_leaf_scale_mu, variable_weights_mu, a_forest, b_forest, - current_sigma2, 0, True, True + current_sigma2, 0, keep_sample, True, True ) # Sample variance parameters (if requested) if self.sample_sigma_global: current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global) - self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std if self.sample_sigma_leaf_mu: - self.leaf_scale_mu_samples[i] = leaf_var_model_mu.sample_one_iteration(self.forest_container_mu, cpp_rng, a_leaf_mu, b_leaf_mu, i) - current_leaf_scale_mu[0,0] = self.leaf_scale_mu_samples[i] + current_leaf_scale_mu[0,0] = leaf_var_model_mu.sample_one_iteration(active_forest_mu, cpp_rng, a_leaf_mu, b_leaf_mu) + if keep_sample: + self.leaf_scale_mu_samples[sample_counter] = current_leaf_scale_mu[0,0] # Sample the treatment forest forest_sampler_tau.sample_one_iteration( - self.forest_container_tau, forest_dataset_train, residual_train, cpp_rng, feature_types, + self.forest_container_tau, active_forest_tau, forest_dataset_train, residual_train, cpp_rng, feature_types, cutpoint_grid_size, current_leaf_scale_tau, variable_weights_tau, a_forest, b_forest, - current_sigma2, treatment_leaf_model, True, True + current_sigma2, treatment_leaf_model, keep_sample, True, True ) # Sample variance parameters (if requested) if self.sample_sigma_global: current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global) - self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std + if keep_sample: + self.global_var_samples[sample_counter] = current_sigma2 if self.sample_sigma_leaf_tau: - self.leaf_scale_tau_samples[i] = leaf_var_model_tau.sample_one_iteration(self.forest_container_tau, cpp_rng, a_leaf_tau, b_leaf_tau, i) - current_leaf_scale_tau[0,0] = self.leaf_scale_tau_samples[i] + current_leaf_scale_tau[0,0] = leaf_var_model_tau.sample_one_iteration(active_forest_tau, cpp_rng, a_leaf_tau, b_leaf_tau) + if keep_sample: + self.leaf_scale_tau_samples[sample_counter] = current_leaf_scale_tau[0,0] # Sample coding parameters (if requested) if self.adaptive_coding: - mu_x = self.forest_container_mu.predict_raw_single_forest(forest_dataset_train, i) - tau_x = np.squeeze(self.forest_container_tau.predict_raw_single_forest(forest_dataset_train, i)) + mu_x = active_forest_mu.predict_raw(forest_dataset_train) + tau_x = np.squeeze(active_forest_tau.predict_raw(forest_dataset_train)) s_tt0 = np.sum(tau_x*tau_x*(np.squeeze(Z_train)==0)) s_tt1 = np.sum(tau_x*tau_x*(np.squeeze(Z_train)==1)) partial_resid_mu = np.squeeze(resid_train - mu_x) @@ -691,53 +708,66 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr if self.has_test: tau_basis_test = (1-np.squeeze(Z_test))*current_b_0 + np.squeeze(Z_test)*current_b_1 forest_dataset_test.update_basis(tau_basis_test) - self.b0_samples[i] = current_b_0 - self.b1_samples[i] = current_b_1 + if keep_sample: + self.b0_samples[sample_counter] = current_b_0 + self.b1_samples[sample_counter] = current_b_1 # Update residual to reflect adjusted basis - forest_sampler_tau.propagate_basis_update(forest_dataset_train, residual_train, self.forest_container_tau, i) + forest_sampler_tau.propagate_basis_update(forest_dataset_train, residual_train, active_forest_tau) # Run MCMC - if self.num_burnin + self.num_mcmc > 0: - if self.num_burnin > 0: - burnin_indices = np.arange(self.num_gfr, self.num_gfr + self.num_burnin) - if self.num_mcmc > 0: - mcmc_indices = np.arange(self.num_gfr + self.num_burnin, self.num_gfr + self.num_burnin + self.num_mcmc) - for i in range(self.num_gfr, self.num_samples): + if num_burnin + num_mcmc > 0: + for i in range(num_gfr, num_temp_samples): + is_mcmc = i + 1 > num_gfr + num_burnin + if is_mcmc: + mcmc_counter = i - num_gfr - num_burnin + 1 + if (mcmc_counter % keep_every == 0): + keep_sample = True + else: + keep_sample = False + else: + if keep_burnin: + keep_sample = True + else: + keep_sample = False + if keep_sample: + sample_counter += 1 # Sample the prognostic forest forest_sampler_mu.sample_one_iteration( - self.forest_container_mu, forest_dataset_train, residual_train, cpp_rng, feature_types, + self.forest_container_mu, active_forest_mu, forest_dataset_train, residual_train, cpp_rng, feature_types, cutpoint_grid_size, current_leaf_scale_mu, variable_weights_mu, a_forest, b_forest, - current_sigma2, 0, False, True + current_sigma2, 0, keep_sample, False, True ) # Sample variance parameters (if requested) if self.sample_sigma_global: current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global) - self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std if self.sample_sigma_leaf_mu: - self.leaf_scale_mu_samples[i] = leaf_var_model_mu.sample_one_iteration(self.forest_container_mu, cpp_rng, a_leaf_mu, b_leaf_mu, i) - current_leaf_scale_mu[0,0] = self.leaf_scale_mu_samples[i] + current_leaf_scale_mu[0,0] = leaf_var_model_mu.sample_one_iteration(active_forest_mu, cpp_rng, a_leaf_mu, b_leaf_mu) + if keep_sample: + self.leaf_scale_mu_samples[sample_counter] = current_leaf_scale_mu[0,0] # Sample the treatment forest forest_sampler_tau.sample_one_iteration( - self.forest_container_tau, forest_dataset_train, residual_train, cpp_rng, feature_types, + self.forest_container_tau, active_forest_tau, forest_dataset_train, residual_train, cpp_rng, feature_types, cutpoint_grid_size, current_leaf_scale_tau, variable_weights_tau, a_forest, b_forest, - current_sigma2, treatment_leaf_model, False, True + current_sigma2, treatment_leaf_model, keep_sample, False, True ) # Sample variance parameters (if requested) if self.sample_sigma_global: current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global) - self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std + if keep_sample: + self.global_var_samples[sample_counter] = current_sigma2 if self.sample_sigma_leaf_tau: - self.leaf_scale_tau_samples[i] = leaf_var_model_tau.sample_one_iteration(self.forest_container_tau, cpp_rng, a_leaf_tau, b_leaf_tau, i) - current_leaf_scale_tau[0,0] = self.leaf_scale_tau_samples[i] + current_leaf_scale_tau[0,0] = leaf_var_model_tau.sample_one_iteration(active_forest_tau, cpp_rng, a_leaf_tau, b_leaf_tau) + if keep_sample: + self.leaf_scale_tau_samples[sample_counter] = current_leaf_scale_tau[0,0] # Sample coding parameters (if requested) if self.adaptive_coding: - mu_x = self.forest_container_mu.predict_raw_single_forest(forest_dataset_train, i) - tau_x = np.squeeze(self.forest_container_tau.predict_raw_single_forest(forest_dataset_train, i)) + mu_x = active_forest_mu.predict_raw(forest_dataset_train) + tau_x = np.squeeze(active_forest_tau.predict_raw(forest_dataset_train)) s_tt0 = np.sum(tau_x*tau_x*(np.squeeze(Z_train)==0)) s_tt1 = np.sum(tau_x*tau_x*(np.squeeze(Z_train)==1)) partial_resid_mu = np.squeeze(resid_train - mu_x) @@ -752,43 +782,38 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr if self.has_test: tau_basis_test = (1-np.squeeze(Z_test))*current_b_0 + np.squeeze(Z_test)*current_b_1 forest_dataset_test.update_basis(tau_basis_test) - self.b0_samples[i] = current_b_0 - self.b1_samples[i] = current_b_1 + if keep_sample: + self.b0_samples[sample_counter] = current_b_0 + self.b1_samples[sample_counter] = current_b_1 # Update residual to reflect adjusted basis - forest_sampler_tau.propagate_basis_update(forest_dataset_train, residual_train, self.forest_container_tau, i) + forest_sampler_tau.propagate_basis_update(forest_dataset_train, residual_train, active_forest_tau) # Mark the model as sampled self.sampled = True - # Prediction indices to be stored - if self.num_mcmc > 0: - self.keep_indices = mcmc_indices - if keep_gfr: - self.keep_indices = np.concatenate((gfr_indices, self.keep_indices)) - else: - # Don't retain both GFR and burnin samples - if keep_burnin: - self.keep_indices = np.concatenate((burnin_indices, self.keep_indices)) - else: - if self.num_gfr > 0 and self.num_burnin > 0: - # Override keep_gfr = False since there are no MCMC samples - # Don't retain both GFR and burnin samples - self.keep_indices = gfr_indices - elif self.num_gfr <= 0 and self.num_burnin > 0: - self.keep_indices = burnin_indices - elif self.num_gfr > 0 and self.num_burnin <= 0: - self.keep_indices = gfr_indices - else: - raise RuntimeError("There are no samples to retain!") - + # Remove GFR samples if they are not to be retained + if not keep_gfr and num_gfr > 0: + for i in range(num_gfr): + self.forest_container_mu.delete_sample(i) + self.forest_container_tau.delete_sample(i) + if self.adaptive_coding: + self.b1_samples = self.b1_samples[num_gfr:] + self.b0_samples = self.b0_samples[num_gfr:] + if self.sample_sigma_global: + self.global_var_samples = self.global_var_samples[num_gfr:] + if self.sample_sigma_leaf_mu: + self.leaf_scale_mu_samples = self.leaf_scale_mu_samples[num_gfr:] + if self.sample_sigma_leaf_tau: + self.leaf_scale_tau_samples = self.leaf_scale_tau_samples[num_gfr:] + # Store predictions mu_raw = self.forest_container_mu.forest_container_cpp.Predict(forest_dataset_train.dataset_cpp) - self.mu_hat_train = mu_raw[:,self.keep_indices]*self.y_std + self.y_bar + self.mu_hat_train = mu_raw*self.y_std + self.y_bar tau_raw_train = self.forest_container_tau.forest_container_cpp.PredictRaw(forest_dataset_train.dataset_cpp) - self.tau_hat_train = tau_raw_train[:,self.keep_indices,:] + self.tau_hat_train = tau_raw_train if self.adaptive_coding: - adaptive_coding_weights = np.expand_dims(self.b1_samples[self.keep_indices] - self.b0_samples[self.keep_indices], axis=(0,2)) + adaptive_coding_weights = np.expand_dims(self.b1_samples - self.b0_samples, axis=(0,2)) self.tau_hat_train = self.tau_hat_train*adaptive_coding_weights self.tau_hat_train = np.squeeze(self.tau_hat_train*self.y_std) if self.multivariate_treatment: @@ -798,11 +823,11 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr self.y_hat_train = self.mu_hat_train + treatment_term_train if self.has_test: mu_raw_test = self.forest_container_mu.forest_container_cpp.Predict(forest_dataset_test.dataset_cpp) - self.mu_hat_test = mu_raw_test[:,self.keep_indices]*self.y_std + self.y_bar + self.mu_hat_test = mu_raw_test*self.y_std + self.y_bar tau_raw_test = self.forest_container_tau.forest_container_cpp.PredictRaw(forest_dataset_test.dataset_cpp) - self.tau_hat_test = tau_raw_test[:,self.keep_indices,:] + self.tau_hat_test = tau_raw_test if self.adaptive_coding: - adaptive_coding_weights_test = np.expand_dims(self.b1_samples[self.keep_indices] - self.b0_samples[self.keep_indices], axis=(0,2)) + adaptive_coding_weights_test = np.expand_dims(self.b1_samples - self.b0_samples, axis=(0,2)) self.tau_hat_test = self.tau_hat_test*adaptive_coding_weights_test self.tau_hat_test = np.squeeze(self.tau_hat_test*self.y_std) if self.multivariate_treatment: @@ -812,17 +837,17 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr self.y_hat_test = self.mu_hat_test + treatment_term_test if self.sample_sigma_global: - self.global_var_samples = self.global_var_samples[self.keep_indices] + self.global_var_samples = self.global_var_samples*self.y_std*self.y_std if self.sample_sigma_leaf_mu: - self.leaf_scale_mu_samples = self.leaf_scale_mu_samples[self.keep_indices] + self.leaf_scale_mu_samples = self.leaf_scale_mu_samples if self.sample_sigma_leaf_tau: - self.leaf_scale_tau_samples = self.leaf_scale_tau_samples[self.keep_indices] + self.leaf_scale_tau_samples = self.leaf_scale_tau_samples if self.adaptive_coding: - self.b0_samples = self.b0_samples[self.keep_indices] - self.b1_samples = self.b1_samples[self.keep_indices] + self.b0_samples = self.b0_samples + self.b1_samples = self.b1_samples def predict_tau(self, X: np.array, Z: np.array, propensity: np.array = None) -> np.array: """Predict CATE function for every provided observation. @@ -886,7 +911,7 @@ def predict_tau(self, X: np.array, Z: np.array, propensity: np.array = None) -> # Estimate treatment effect tau_raw = self.forest_container_tau.forest_container_cpp.PredictRaw(forest_dataset_tau.dataset_cpp) - tau_raw = tau_raw[:,self.keep_indices,:] + tau_raw = tau_raw if self.adaptive_coding: adaptive_coding_weights = np.expand_dims(self.b1_samples - self.b0_samples, axis=(0,2)) tau_raw = tau_raw*adaptive_coding_weights @@ -972,9 +997,9 @@ def predict(self, X: np.array, Z: np.array, propensity: np.array = None) -> np.a # Compute predicted outcome and decomposed outcome model terms mu_raw = self.forest_container_mu.forest_container_cpp.Predict(forest_dataset_tau.dataset_cpp) - mu_x = mu_raw[:,self.keep_indices]*self.y_std + self.y_bar + mu_x = mu_raw*self.y_std + self.y_bar tau_raw = self.forest_container_tau.forest_container_cpp.PredictRaw(forest_dataset_tau.dataset_cpp) - tau_raw = tau_raw[:,self.keep_indices,:] + tau_raw = tau_raw if self.adaptive_coding: adaptive_coding_weights = np.expand_dims(self.b1_samples - self.b0_samples, axis=(0,2)) tau_raw = tau_raw*adaptive_coding_weights diff --git a/stochtree/forest.py b/stochtree/forest.py index bf434884..033a6b41 100644 --- a/stochtree/forest.py +++ b/stochtree/forest.py @@ -4,7 +4,7 @@ import numpy as np from .data import Dataset, Residual # from .serialization import JSONSerializer -from stochtree_cpp import ForestContainerCpp +from stochtree_cpp import ForestContainerCpp, ForestCpp from typing import Union class ForestContainer: @@ -64,6 +64,8 @@ def add_sample(self, leaf_value: Union[float, np.array]) -> None: """ Add a new all-root ensemble to the container, with all of the leaves set to the value / vector provided + Parameters + ---------- leaf_value : :obj:`float` or :obj:`np.array` Value (or vector of values) to initialize root nodes in tree """ @@ -78,6 +80,8 @@ def add_numeric_split(self, forest_num: int, tree_num: int, leaf_num: int, featu """ Add a numeric (i.e. X[,i] <= c) split to a given tree in the ensemble + Parameters + ---------- forest_num : :obj:`int` Index of the forest which contains the tree to be split tree_num : :obj:`int` @@ -104,6 +108,8 @@ def get_tree_leaves(self, forest_num: int, tree_num: int) -> np.array: """ Retrieve a vector of indices of leaf nodes for a given tree in a given forest + Parameters + ---------- forest_num : :obj:`int` Index of the forest which contains tree `tree_num` tree_num : :obj:`float` or :obj:`np.array` @@ -115,6 +121,8 @@ def get_tree_split_counts(self, forest_num: int, tree_num: int, num_features: in """ Retrieve a vector of split counts for every training set variable in a given tree in a given forest + Parameters + ---------- forest_num : :obj:`int` Index of the forest which contains tree `tree_num` tree_num : :obj:`int` @@ -128,10 +136,17 @@ def get_forest_split_counts(self, forest_num: int, num_features: int) -> np.arra """ Retrieve a vector of split counts for every training set variable in a given forest + Parameters + ---------- forest_num : :obj:`int` Index of the forest which contains tree `tree_num` num_features : :obj:`int` Total number of features in the training set + + Returns + ------- + :obj:`np.array` + One-dimensional numpy array, containing the number of splits a variable receives, summed across each tree of a given forest in a ``ForestContainer`` """ return self.forest_container_cpp.GetForestSplitCounts(forest_num, num_features) @@ -139,8 +154,15 @@ def get_overall_split_counts(self, num_features: int) -> np.array: """ Retrieve a vector of split counts for every training set variable in a given forest, aggregated across ensembles and trees + Parameters + ---------- num_features : :obj:`int` Total number of features in the training set + + Returns + ------- + :obj:`np.array` + One-dimensional numpy array, containing the number of splits a variable receives, summed across each tree of every forest in a ``ForestContainer`` """ return self.forest_container_cpp.GetOverallSplitCounts(num_features) @@ -148,8 +170,15 @@ def get_granular_split_counts(self, num_features: int) -> np.array: """ Retrieve a vector of split counts for every training set variable in a given forest, reported separately for each ensemble and tree + Parameters + ---------- num_features : :obj:`int` Total number of features in the training set + + Returns + ------- + :obj:`np.array` + Three-dimensional numpy array, containing the number of splits a variable receives in each tree of each forest in a ``ForestContainer`` """ return self.forest_container_cpp.GetGranularSplitCounts(num_features) @@ -157,8 +186,15 @@ def num_forest_leaves(self, forest_num: int) -> int: """ Return the total number of leaves for a given forest in the ``ForestContainer`` + Parameters + ---------- forest_num : :obj:`int` Index of the forest to be queried + + Returns + ------- + :obj:`int` + Number of leaves in a given forest in a ``ForestContainer`` """ return self.forest_container_cpp.NumLeavesForest(forest_num) @@ -166,11 +202,18 @@ def sum_leaves_squared(self, forest_num: int) -> float: """ Return the total sum of squared leaf values for a given forest in the ``ForestContainer`` + Parameters + ---------- forest_num : :obj:`int` Index of the forest to be queried + + Returns + ------- + :obj:`float` + Sum of squared leaf values in a given forest in a ``ForestContainer`` """ return self.forest_container_cpp.SumLeafSquared(forest_num) - + def is_leaf_node(self, forest_num: int, tree_num: int, node_id: int) -> bool: """ Whether or not a given node of a given tree in a given forest in the ``ForestContainer`` is a leaf @@ -392,4 +435,344 @@ def leaves(self, forest_num: int, tree_num: int) -> np.array: Index of the tree to be queried """ return self.forest_container_cpp.Leaves(forest_num, tree_num) - \ No newline at end of file + + def delete_sample(self, forest_num: int) -> None: + """ + Modify the ``ForestContainer`` by removing the forest sample indexed by ``forest_num``. + + forest_num : :obj:`int` + Index of the forest to be removed from the ``ForestContainer`` + """ + return self.forest_container_cpp.DeleteSample(forest_num) + +class Forest: + def __init__(self, num_trees: int, output_dimension: int, leaf_constant: bool, is_exponentiated: bool) -> None: + # Initialize a ForestCpp object + self.forest_cpp = ForestCpp(num_trees, output_dimension, leaf_constant, is_exponentiated) + + def reset_root(self) -> None: + """ + Reset forest to a forest with all single node (i.e. "root") trees + """ + self.forest_cpp.ResetRoot() + + def reset(self, forest_container: ForestContainer, forest_num: int) -> None: + """ + Reset forest to the forest indexed by ``forest_num`` in ``forest_container`` + + Parameters + ---------- + forest_container : :obj:`ForestContainer` + Stochtree object storing tree ensembles + forest_num : :obj:`int` + Index of the ensemble used to reset the ``Forest`` + """ + self.forest_cpp.Reset(forest_container.forest_container_cpp, forest_num) + + def predict(self, dataset: Dataset) -> np.array: + # Predict samples from Dataset + return self.forest_cpp.Predict(dataset.dataset_cpp) + + def predict_raw(self, dataset: Dataset) -> np.array: + # Predict raw leaf values for a specific forest (indexed by forest_num) from Dataset + result = self.forest_cpp.PredictRaw(dataset.dataset_cpp) + if result.ndim == 3: + if result.shape[1] == 1: + result = result.reshape(result.shape[0], result.shape[2]) + return result + + def set_root_leaves(self, leaf_value: Union[float, np.array]) -> None: + # Predict raw leaf values for a specific forest (indexed by forest_num) from Dataset + if not isinstance(leaf_value, np.ndarray) and not isinstance(leaf_value, float): + raise ValueError("leaf_value must be either a float or np.array") + if isinstance(leaf_value, np.ndarray): + leaf_value = np.squeeze(leaf_value) + if len(leaf_value.shape) != 1: + raise ValueError("leaf_value must be either a one-dimensional array") + self.forest_cpp.SetRootVector(leaf_value, leaf_value.shape[0]) + else: + self.forest_cpp.SetRootValue(leaf_value) + + def add_numeric_split(self, tree_num: int, leaf_num: int, feature_num: int, split_threshold: float, + left_leaf_value: Union[float, np.array], right_leaf_value: Union[float, np.array]) -> None: + """ + Add a numeric (i.e. X[,i] <= c) split to a given tree in the forest + + Parameters + ---------- + tree_num : :obj:`int` + Index of the tree to be split + leaf_num : :obj:`int` + Leaf to be split + feature_num : :obj:`int` + Feature that defines the new split + split_threshold : :obj:`float` + Value that defines the cutoff of the new split + left_leaf_value : :obj:`float` or :obj:`np.array` + Value (or array of values) to assign to the newly created left node + right_leaf_value : :obj:`float` or :obj:`np.array` + Value (or array of values) to assign to the newly created right node + """ + if isinstance(left_leaf_value, np.ndarray): + left_leaf_value = np.squeeze(left_leaf_value) + right_leaf_value = np.squeeze(right_leaf_value) + self.forest_cpp.AddNumericSplitVector(tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value) + else: + self.forest_cpp.AddNumericSplitValue(tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value) + + def get_tree_leaves(self, tree_num: int) -> np.array: + """ + Retrieve a vector of indices of leaf nodes for a given tree in the forest + + Parameters + ---------- + tree_num : :obj:`float` or :obj:`np.array` + Index of the tree for which leaf indices will be retrieved + """ + return self.forest_cpp.GetTreeLeaves(tree_num) + + def get_tree_split_counts(self, tree_num: int, num_features: int) -> np.array: + """ + Retrieve a vector of split counts for every training set variable in a given tree in the forest + + Parameters + ---------- + tree_num : :obj:`int` + Index of the tree for which split counts will be retrieved + num_features : :obj:`int` + Total number of features in the training set + """ + return self.forest_cpp.GetTreeSplitCounts(tree_num, num_features) + + def get_overall_split_counts(self, num_features: int) -> np.array: + """ + Retrieve a vector of split counts for every training set variable in the forest + + Parameters + ---------- + num_features : :obj:`int` + Total number of features in the training set + """ + return self.forest_cpp.GetOverallSplitCounts(num_features) + + def get_granular_split_counts(self, num_features: int) -> np.array: + """ + Retrieve a vector of split counts for every training set variable in the forest, reported separately for each tree + + Parameters + ---------- + num_features : :obj:`int` + Total number of features in the training set + """ + return self.forest_cpp.GetGranularSplitCounts(num_features) + + def num_forest_leaves(self) -> int: + """ + Return the total number of leaves in a forest + + Returns + ------- + :obj:`int` + Number of leaves in a forest + """ + return self.forest_cpp.NumLeavesForest() + + def sum_leaves_squared(self) -> float: + """ + Return the total sum of squared leaf values in a forest + + Returns + ------- + :obj:`float` + Sum of squared leaf values in a forest + """ + return self.forest_cpp.SumLeafSquared() + + def is_leaf_node(self, tree_num: int, node_id: int) -> bool: + """ + Whether or not a given node of a given tree of a forest is a leaf + + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + return self.forest_cpp.IsLeafNode(tree_num, node_id) + + def is_numeric_split_node(self, tree_num: int, node_id: int) -> bool: + """ + Whether or not a given node of a given tree of a forest is a numeric split node + + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + return self.forest_cpp.IsNumericSplitNode(tree_num, node_id) + + def is_categorical_split_node(self, tree_num: int, node_id: int) -> bool: + """ + Whether or not a given node of a given tree of a forest is a categorical split node + + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + return self.forest_cpp.IsCategoricalSplitNode(tree_num, node_id) + + def parent_node(self, tree_num: int, node_id: int) -> int: + """ + Parent node of given node of a given tree of a forest + + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + return self.forest_cpp.ParentNode(tree_num, node_id) + + def left_child_node(self, tree_num: int, node_id: int) -> int: + """ + Left child node of given node of a given tree of a forest + + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + return self.forest_cpp.LeftChildNode(tree_num, node_id) + + def right_child_node(self, tree_num: int, node_id: int) -> int: + """ + Right child node of given node of a given tree of a forest + + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + return self.forest_cpp.RightChildNode(tree_num, node_id) + + def node_depth(self, tree_num: int, node_id: int) -> int: + """ + Depth of given node of a given tree of a forest + Returns ``-1`` if the node is a leaf. + + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + return self.forest_cpp.NodeDepth(tree_num, node_id) + + def node_split_index(self, tree_num: int, node_id: int) -> int: + """ + Split index of given node of a given tree of a forest. + Returns ``-1`` if the node is a leaf. + + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + if self.is_leaf_node(tree_num, node_id): + return -1 + else: + return self.forest_cpp.SplitIndex(tree_num, node_id) + + def node_split_threshold(self, tree_num: int, node_id: int) -> float: + """ + Threshold that defines a numeric split for a given node of a given tree of a forest. + Returns ``np.Inf`` if the node is a leaf or a categorical split node. + + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + if self.is_leaf_node(tree_num, node_id) or self.is_categorical_split_node(tree_num, node_id): + return np.Inf + else: + return self.forest_cpp.SplitThreshold(tree_num, node_id) + + def node_split_categories(self, tree_num: int, node_id: int) -> np.array: + """ + Array of category indices that define a categorical split for a given node of a given tree of a forest. + Returns ``np.array([np.Inf])`` if the node is a leaf or a numeric split node. + + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + if self.is_leaf_node(tree_num, node_id) or self.is_numeric_split_node(tree_num, node_id): + return np.array([np.Inf]) + else: + return self.forest_cpp.SplitCategories(tree_num, node_id) + + def node_leaf_values(self, tree_num: int, node_id: int) -> np.array: + """ + Leaf node value(s) for a given node of a given tree of a forest. + Values are stale if the node is a split node. + + tree_num : :obj:`int` + Index of the tree to be queried + node_id : :obj:`int` + Index of the node to be queried + """ + return self.forest_cpp.NodeLeafValues(tree_num, node_id) + + def num_nodes(self, tree_num: int) -> int: + """ + Number of nodes in a given tree of a forest + + tree_num : :obj:`int` + Index of the tree to be queried + """ + return self.forest_cpp.NumNodes(tree_num) + + def num_leaves(self, tree_num: int) -> int: + """ + Number of leaves in a given tree of a forest + + tree_num : :obj:`int` + Index of the tree to be queried + """ + return self.forest_cpp.NumLeaves(tree_num) + + def num_leaf_parents(self, tree_num: int) -> int: + """ + Number of leaf parents in a given tree of a forest + + tree_num : :obj:`int` + Index of the tree to be queried + """ + return self.forest_cpp.NumLeafParents(tree_num) + + def num_split_nodes(self, tree_num: int) -> int: + """ + Number of split_nodes in a given tree of a forest + + tree_num : :obj:`int` + Index of the tree to be queried + """ + return self.forest_cpp.NumSplitNodes(tree_num) + + def nodes(self, tree_num: int) -> np.array: + """ + Array of node indices in a given tree of a forest + + tree_num : :obj:`int` + Index of the tree to be queried + """ + return self.forest_cpp.Nodes(tree_num) + + def leaves(self, tree_num: int) -> np.array: + """ + Array of leaf indices in a given tree of a forest + + tree_num : :obj:`int` + Index of the tree to be queried + """ + return self.forest_cpp.Leaves(tree_num) diff --git a/stochtree/preprocessing.py b/stochtree/preprocessing.py index 9dea3845..d0f3f314 100644 --- a/stochtree/preprocessing.py +++ b/stochtree/preprocessing.py @@ -42,7 +42,9 @@ def _preprocess_bart_params(params: Optional[Dict[str, Any]] = None) -> Dict[str 'random_seed' : -1, 'keep_burnin' : False, 'keep_gfr' : False, - 'standardize': True + 'standardize': True, + 'num_chains' : 1, + 'keep_every' : 1 } if params: @@ -92,7 +94,9 @@ def _preprocess_bcf_params(params: Optional[Dict[str, Any]] = None) -> Dict[str, 'random_seed': -1, 'keep_burnin': False, 'keep_gfr': False, - 'standardize': True + 'standardize': True, + 'num_chains' : 1, + 'keep_every' : 1 } if params: diff --git a/stochtree/sampler.py b/stochtree/sampler.py index 9410e4e9..b384f258 100644 --- a/stochtree/sampler.py +++ b/stochtree/sampler.py @@ -3,7 +3,7 @@ """ import numpy as np from .data import Dataset, Residual -from .forest import ForestContainer +from .forest import ForestContainer, Forest from stochtree_cpp import RngCpp, ForestSamplerCpp, GlobalVarianceModelCpp, LeafVarianceModelCpp from typing import Union @@ -18,44 +18,111 @@ def __init__(self, dataset: Dataset, feature_types: np.array, num_trees: int, nu # Initialize a ForestDatasetCpp object self.forest_sampler_cpp = ForestSamplerCpp(dataset.dataset_cpp, feature_types, num_trees, num_obs, alpha, beta, min_samples_leaf, max_depth) - def sample_one_iteration(self, forest_container: ForestContainer, dataset: Dataset, residual: Residual, rng: RNG, - feature_types: np.array, cutpoint_grid_size: int, leaf_model_scale_input: np.array, - variable_weights: np.array, a_forest: float, b_forest: float, global_variance: float, - leaf_model_int: int, gfr: bool, pre_initialized: bool): + def reconstitute_from_forest(self, forest: Forest, dataset: Dataset, residual: Residual, is_mean_model: bool) -> None: + """ + Re-initialize a forest sampler tracking data structures from a specific forest in a ``ForestContainer`` + + Parameters + ---------- + dataset : :obj:`Dataset` + Stochtree dataset object storing covariates / bases / weights + residual : :obj:`Residual` + Stochtree object storing continuously updated partial / full residual + forest : :obj:`Forest` + Stochtree object storing tree ensemble + is_mean_model : :obj:`bool` + Indicator of whether the model being updated a conditional mean model (``True``) or a conditional variance model (``False``) + """ + self.forest_sampler_cpp.ReconstituteTrackerFromForest(forest.forest_cpp, dataset.dataset_cpp, residual.residual_cpp, is_mean_model) + + def sample_one_iteration(self, forest_container: ForestContainer, forest: Forest, dataset: Dataset, + residual: Residual, rng: RNG, feature_types: np.array, cutpoint_grid_size: int, + leaf_model_scale_input: np.array, variable_weights: np.array, a_forest: float, b_forest: float, + global_variance: float, leaf_model_int: int, keep_forest: bool, gfr: bool, pre_initialized: bool) -> None: """ Sample one iteration of a forest using the specified model and tree sampling algorithm + + Parameters + ---------- + forest_container : :obj:`ForestContainer` + Stochtree object storing tree ensembles + forest : :obj:`Forest` + Stochtree object storing the "active" forest being sampled + dataset : :obj:`Dataset` + Stochtree dataset object storing covariates / bases / weights + residual : :obj:`Residual` + Stochtree object storing continuously updated partial / full residual + rng : :obj:`RNG` + Stochtree object storing C++ random number generator to be used sampling algorithm + feature_types : :obj:`np.array` + Array of integer-coded feature types (0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + cutpoint_grid_size : :obj:`int` + Maximum size of a grid of available cutpoints (which thins the number of possible splits, particularly useful in the grow-from-root algorithm) + leaf_model_scale_input : :obj:`np.array` + Numpy array containing leaf model scale parameter (if the leaf model is univariate, this is essentially a scalar which is used as such in the C++ source, but stored as a numpy array) + variable_weights : :obj:`np.array` + Numpy array containing sampling probabilities for each feature + a_forest : :obj:`float` + Scale parameter for the inverse gamma outcome model for heteroskedasticity forest + b_forest : :obj:`float` + Scale parameter for the inverse gamma outcome model for heteroskedasticity forest + global_variance : :obj:`float` + Current value of the global error variance parameter + leaf_model_int : :obj:`int` + Integer encoding the leaf model type (0 = constant Gaussian leaf mean model, 1 = univariate Gaussian leaf regression mean model, 2 = multivariate Gaussian leaf regression mean model, 3 = univariate Inverse Gamma constant leaf variance model) + keep_forest : :obj:`bool` + Whether or not the resulting forest should be retained in ``forest_container`` or discarded (due to burnin or thinning for example) + gfr : :obj:`bool` + Whether or not the "grow-from-root" (GFR) sampler is run (if this is ``True`` and ``leaf_model_int=0`` this is equivalent to XBART, if this is ``FALSE`` and ``leaf_model_int=0`` this is equivalent to the original BART) + pre_initialized : :obj:`bool` + Whether or not the forest being sampled has already been initialized """ - self.forest_sampler_cpp.SampleOneIteration(forest_container.forest_container_cpp, dataset.dataset_cpp, residual.residual_cpp, rng.rng_cpp, + self.forest_sampler_cpp.SampleOneIteration(forest_container.forest_container_cpp, forest.forest_cpp, dataset.dataset_cpp, residual.residual_cpp, rng.rng_cpp, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, - a_forest, b_forest, global_variance, leaf_model_int, gfr, pre_initialized) + a_forest, b_forest, global_variance, leaf_model_int, keep_forest, gfr, pre_initialized) - def prepare_for_sampler(self, dataset: Dataset, residual: Residual, forests: ForestContainer, leaf_model: int, initial_values: np.array): + def prepare_for_sampler(self, dataset: Dataset, residual: Residual, forest: Forest, leaf_model: int, initial_values: np.array) -> None: """ Initialize forest and tracking data structures with constant root values before running a sampler + Parameters + ---------- dataset : :obj:`Dataset` Stochtree dataset object storing covariates / bases / weights residual : :obj:`Residual` Stochtree object storing continuously updated partial / full residual - forests : :obj:`ForestContainer` - Stochtree object storing tree ensembles + forest : :obj:`Forest` + Stochtree object storing the "active" forest being sampled leaf_model : :obj:`int` Integer encoding the leaf model type initial_values : :obj:`np.array` Constant root node value(s) at which to initialize forest prediction (internally, it is divided by the number of trees and typically it is 0 for mean models and 1 for variance models). """ - self.forest_sampler_cpp.InitializeForestModel(dataset.dataset_cpp, residual.residual_cpp, forests.forest_container_cpp, leaf_model, initial_values) + self.forest_sampler_cpp.InitializeForestModel(dataset.dataset_cpp, residual.residual_cpp, forest.forest_cpp, leaf_model, initial_values) - def adjust_residual(self, dataset: Dataset, residual: Residual, forest_container: ForestContainer, requires_basis: bool, forest_num: int, add: bool) -> None: + def adjust_residual(self, dataset: Dataset, residual: Residual, forest: Forest, requires_basis: bool, add: bool) -> None: """ Method that "adjusts" the residual used for training tree ensembles by either adding or subtracting the prediction of each tree to the existing residual. This is typically run just once at the beginning of a forest sampling algorithm --- after trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual. + + Parameters + ---------- + dataset : :obj:`Dataset` + Stochtree dataset object storing covariates / bases / weights + residual : :obj:`Residual` + Stochtree object storing continuously updated partial / full residual + forest : :obj:`Forest` + Stochtree object storing the "active" forest being sampled + requires_basis : :obj:`bool` + Whether or not the forest requires a basis dot product when predicting + add : :obj:`bool` + Whether the predictions of each tree are added (if ``add=True``) or subtracted (``add=False``) from the outcome to form the new residual """ - forest_container.forest_container_cpp.AdjustResidual(dataset.dataset_cpp, residual.residual_cpp, self.forest_sampler_cpp, requires_basis, forest_num, add) + forest.forest_cpp.AdjustResidual(dataset.dataset_cpp, residual.residual_cpp, self.forest_sampler_cpp, requires_basis, add) - def propagate_basis_update(self, dataset: Dataset, residual: Residual, forest_container: ForestContainer, forest_num: int) -> None: + def propagate_basis_update(self, dataset: Dataset, residual: Residual, forest: Forest) -> None: """ Propagates basis update through to the (full/partial) residual by iteratively (a) adding back in the previous prediction of each tree, (b) recomputing predictions for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual. @@ -63,8 +130,17 @@ def propagate_basis_update(self, dataset: Dataset, residual: Residual, forest_co This is useful in cases where a basis (for e.g. leaf regression) is updated outside of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). Once a basis has been updated, the overall "function" represented by a tree model has changed and this should be reflected through to the residual before the next sampling loop is run. + + Parameters + ---------- + dataset : :obj:`Dataset` + Stochtree dataset object storing covariates / bases / weights + residual : :obj:`Residual` + Stochtree object storing continuously updated partial / full residual + forest : :obj:`Forest` + Stochtree object storing the "active" forest being sampled """ - self.forest_sampler_cpp.PropagateBasisUpdate(dataset.dataset_cpp, residual.residual_cpp, forest_container.forest_container_cpp, forest_num) + self.forest_sampler_cpp.PropagateBasisUpdate(dataset.dataset_cpp, residual.residual_cpp, forest.forest_cpp) def propagate_residual_update(self, residual: Residual) -> None: self.forest_sampler_cpp.PropagateResidualUpdate(residual.residual_cpp) @@ -87,8 +163,8 @@ def __init__(self) -> None: # Initialize a LeafVarianceModelCpp object self.variance_model_cpp = LeafVarianceModelCpp() - def sample_one_iteration(self, forest_container: ForestContainer, rng: RNG, a: float, b: float, sample_num: int) -> float: + def sample_one_iteration(self, forest: Forest, rng: RNG, a: float, b: float) -> float: """ Sample one iteration of a forest leaf model's variance parameter (assuming a location-scale leaf model, most commonly ``N(0, tau)``) """ - return self.variance_model_cpp.SampleOneIteration(forest_container.forest_container_cpp, rng.rng_cpp, a, b, sample_num) + return self.variance_model_cpp.SampleOneIteration(forest.forest_cpp, rng.rng_cpp, a, b) diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R new file mode 100644 index 00000000..d7a5f36f --- /dev/null +++ b/test/R/testthat/test-bart.R @@ -0,0 +1,131 @@ +test_that("MCMC BART", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + param_list <- list(num_chains = 1, keep_every = 1) + expect_no_error( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 0, num_burnin = 10, num_mcmc = 10, + params = param_list) + ) + + # 3 chains, no thinning + param_list <- list(num_chains = 3, keep_every = 1) + expect_no_error( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 0, num_burnin = 10, num_mcmc = 10, + params = param_list) + ) + + # 1 chain, thinning + param_list <- list(num_chains = 1, keep_every = 5) + expect_no_error( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 0, num_burnin = 10, num_mcmc = 10, + params = param_list) + ) + + # 3 chains, thinning + param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 0, num_burnin = 10, num_mcmc = 10, + params = param_list) + ) +}) + +test_that("GFR BART", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + param_list <- list(num_chains = 1, keep_every = 1) + expect_no_error( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 10, num_burnin = 10, num_mcmc = 10, + params = param_list) + ) + + # 3 chains, no thinning + param_list <- list(num_chains = 3, keep_every = 1) + expect_no_error( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 10, num_burnin = 10, num_mcmc = 10, + params = param_list) + ) + + # 1 chain, thinning + param_list <- list(num_chains = 1, keep_every = 5) + expect_no_error( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 10, num_burnin = 10, num_mcmc = 10, + params = param_list) + ) + + # 3 chains, thinning + param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 10, num_burnin = 10, num_mcmc = 10, + params = param_list) + ) + + # Check for error when more chains than GFR forests + param_list <- list(num_chains = 11, keep_every = 1) + expect_error( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 10, num_burnin = 10, num_mcmc = 10, + params = param_list) + ) + + # Check for error when more chains than GFR forests + param_list <- list(num_chains = 11, keep_every = 5) + expect_error( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 10, num_burnin = 10, num_mcmc = 10, + params = param_list) + ) +}) diff --git a/test/R/testthat/test-residual.R b/test/R/testthat/test-residual.R index e65b620e..08833585 100644 --- a/test/R/testthat/test-residual.R +++ b/test/R/testthat/test-residual.R @@ -38,16 +38,17 @@ test_that("Residual updates correctly propagated after forest sampling step", { # Create forest sampler and forest container forest_model = createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) forest_samples = createForestContainer(num_trees, 1, F) + active_forest = createForest(num_trees, 1, F) # Initialize the leaves of each tree in the prognostic forest - forest_samples$set_root_leaves(0, mean(resid) / num_trees) - forest_samples$adjust_residual(forest_dataset, residual, forest_model, F, 0, F) - + active_forest$prepare_for_sampler(forest_dataset, residual, forest_model, 0, mean(resid)) + active_forest$adjust_residual(forest_dataset, residual, forest_model, F, F) + # Run the forest sampling algorithm for a single iteration forest_model$sample_one_iteration( - forest_dataset, residual, forest_samples, cpp_rng, feature_types, - 0, current_leaf_scale, variable_weights, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T + forest_dataset, residual, forest_samples, active_forest, + cpp_rng, feature_types, 0, current_leaf_scale, variable_weights, a_forest, b_forest, + current_sigma2, cutpoint_grid_size, keep_forest = T, gfr = T, pre_initialized = T ) # Get the current residual after running the sampler @@ -62,7 +63,7 @@ test_that("Residual updates correctly propagated after forest sampling step", { forest_dataset$update_basis(W_update) # Update residual to reflect adjusted basis - forest_model$propagate_basis_update(forest_dataset, residual, forest_samples, 0) + forest_model$propagate_basis_update(forest_dataset, residual, active_forest) # Get updated prediction from the tree ensemble updated_yhat = as.numeric(forest_samples$predict(forest_dataset)) diff --git a/test/python/test_json.py b/test/python/test_json.py index 262e4ffa..ee65e80c 100644 --- a/test/python/test_json.py +++ b/test/python/test_json.py @@ -1,6 +1,6 @@ import numpy as np from stochtree import ( - BARTModel, JSONSerializer, ForestContainer, Dataset, Residual, + BARTModel, JSONSerializer, ForestContainer, Forest, Dataset, Residual, RNG, ForestSampler, ForestContainer, GlobalVarianceModel ) @@ -66,7 +66,7 @@ def outcome_mean(X): # Predict from the deserialized forest container forest_dataset = Dataset() forest_dataset.add_covariates(X) - forest_preds_json_reload = forest_container.predict(forest_dataset)[:,bart_model.keep_indices] + forest_preds_json_reload = forest_container.predict(forest_dataset) forest_preds_json_reload = forest_preds_json_reload*bart_model.y_std + bart_model.y_bar # Check the predictions np.testing.assert_almost_equal(forest_preds_y_mcmc_cached, forest_preds_json_reload) @@ -133,6 +133,7 @@ def outcome_mean(X, W): # Forest samplers and temporary tracking data structures forest_container = ForestContainer(num_trees, W.shape[1], False, False) + active_forest = Forest(num_trees, W.shape[1], False, False) forest_sampler = ForestSampler(dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf) cpp_rng = RNG(random_seed) global_var_model = GlobalVarianceModel() @@ -145,12 +146,12 @@ def outcome_mean(X, W): # Run "grow-from-root" sampler for i in range(num_warmstart): - forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, 1., 1., global_var_samples[i], 1, True, False) + forest_sampler.sample_one_iteration(forest_container, active_forest, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, 1., 1., global_var_samples[i], 1, True, True, False) global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global) # Run MCMC sampler for i in range(num_warmstart, num_samples): - forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, 1., 1., global_var_samples[i], 1, False, False) + forest_sampler.sample_one_iteration(forest_container, active_forest, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, 1., 1., global_var_samples[i], 1, True, False, False) global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global) # Extract predictions from the sampler diff --git a/test/python/test_residual.py b/test/python/test_residual.py index d51e3379..c039a9cb 100644 --- a/test/python/test_residual.py +++ b/test/python/test_residual.py @@ -1,5 +1,5 @@ import numpy as np -from stochtree import ForestContainer, Dataset, Residual, ForestSampler, RNG +from stochtree import ForestContainer, Forest, Dataset, Residual, ForestSampler, RNG class TestResidual: def test_basis_update(self): @@ -50,17 +50,18 @@ def test_basis_update(self): # Create forest sampler and forest container forest_sampler = ForestSampler(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf) forest_container = ForestContainer(num_trees, 1, False, False) + active_forest = Forest(num_trees, 1, False, False) # Initialize the leaves of each tree in the prognostic forest init_root = np.squeeze(np.mean(resid)) / num_trees - forest_container.set_root_leaves(0, init_root) - forest_sampler.adjust_residual(forest_dataset, residual, forest_container, False, 0, True) + active_forest.set_root_leaves(init_root) + forest_sampler.adjust_residual(forest_dataset, residual, active_forest, False, True) # Run the forest sampling algorithm for a single iteration forest_sampler.sample_one_iteration( - forest_container, forest_dataset, residual, cpp_rng, feature_types, + forest_container, active_forest, forest_dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, current_leaf_scale, variable_weights, a_forest, b_forest, - current_sigma2, 1, True, True + current_sigma2, 1, True, True, True ) # Get the current residual after running the sampler @@ -75,7 +76,7 @@ def test_basis_update(self): forest_dataset.update_basis(W_update) # Update residual to reflect adjusted basis - forest_sampler.propagate_basis_update(forest_dataset, residual, forest_container, 0) + forest_sampler.propagate_basis_update(forest_dataset, residual, active_forest) # Get updated prediction from the tree ensemble updated_yhat = forest_container.predict(forest_dataset) diff --git a/tools/debug/forest_reset_debug.R b/tools/debug/forest_reset_debug.R new file mode 100644 index 00000000..b92dc139 --- /dev/null +++ b/tools/debug/forest_reset_debug.R @@ -0,0 +1,145 @@ +# Load libraries +library(stochtree) + +# Generate some data +seed <- 1234 +set.seed(seed) +n <- 1000 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +E_y <- 10*X[,1] +eps <- rnorm(n,0,1) +y <- E_y + eps +y_std <- (y-mean(y))/sd(y) + +# Prepare to run sampler +num_mcmc <- 100 +num_gfr <- 10 +num_burnin <- 0 +num_chains <- 4 +keep_every <- 5 +num_trees <- 100 +alpha <- 0.95 +beta <- 0.95 +min_samples_leaf <- 5 +max_depth <- 10 +cutpoint_grid_size <- 100 +variable_weights = rep(1/ncol(X), ncol(X)) +output_dimension <- 1 +is_leaf_constant <- T +leaf_model <- 0 +current_sigma2 <- 1. +current_leaf_scale <- as.matrix(1/num_trees) +a_forest <- 1. +b_forest <- 1. +a_global <- 0. +b_global <- 0. +a_leaf <- 3. +b_leaf <- 1./num_trees +forest_dataset <- createForestDataset(X) +outcome <- createOutcome(y_std) +rng <- createRNG(seed) +feature_types <- as.integer(rep(0,p)) +forest_model <- createForestModel(forest_dataset, feature_types, num_trees, nrow(X), alpha, beta, min_samples_leaf, max_depth) +forest_samples <- createForestContainer(num_trees, output_dimension, is_leaf_constant, FALSE) +active_forest <- createForest(num_trees, output_dimension, is_leaf_constant, FALSE) + +# Container of parameter samples +sample_sigma_global <- T +sample_sigma_leaf <- F +keep_burnin <- F +keep_gfr <- T +num_actual_mcmc_iter <- num_mcmc * keep_every * num_chains +num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter +num_mcmc_samples <- num_burnin + num_actual_mcmc_iter +num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin * num_chains, 0) + num_mcmc * num_chains +if (sample_sigma_global) global_var_samples <- rep(NA, num_retained_samples) +if (sample_sigma_leaf) leaf_scale_samples <- rep(NA, num_retained_samples) +sample_counter <- 0 + +# Initialize the forest model (ensemble of root-only trees) +init_root_value <- 0. +active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, leaf_model, init_root_value) +active_forest$adjust_residual(forest_dataset, outcome, forest_model, FALSE, FALSE) + +# Run GFR (warm start) if specified +if (num_gfr > 0){ + gfr_indices = 1:num_gfr + for (i in 1:num_gfr) { + keep_sample <- ifelse(keep_gfr, T, F) + if (keep_sample) sample_counter <- sample_counter + 1 + + forest_model$sample_one_iteration( + forest_dataset, outcome, forest_samples, active_forest, + rng, feature_types, leaf_model, current_leaf_scale, variable_weights, + a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T + ) + + if (sample_sigma_global) { + current_sigma2 <- sample_sigma2_one_iteration(outcome, forest_dataset, rng, a_global, b_global) + if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 + } + + if (sample_sigma_leaf) { + leaf_scale_double <- sample_tau_one_iteration(active_forest, rng, a_leaf, b_leaf) + current_leaf_scale <- as.matrix(leaf_scale_double) + if (keep_sample) leaf_scale_samples[sample_counter] <- leaf_scale_double + } + } +} + +# Run MCMC +if (num_burnin + num_mcmc > 0) { + for (chain_num in 1:num_chains) { + # Reset state of active_forest and forest_model based on a previous GFR sample + forest_ind <- num_gfr - chain_num + resetActiveForest(active_forest, forest_samples, forest_ind) + resetForestModel(forest_model, forest_dataset, forest_samples, forest_ind) + + # Run the MCMC sampler starting from the current active forest + for (i in 1:num_mcmc_samples) { + is_mcmc <- i > num_burnin + if (is_mcmc) { + mcmc_counter <- i - num_burnin + if (mcmc_counter %% keep_every == 0) keep_sample <- T + else keep_sample <- F + } else { + if (keep_burnin) keep_sample <- T + else keep_sample <- F + } + if (keep_sample) sample_counter <- sample_counter + 1 + + forest_model$sample_one_iteration( + forest_dataset, outcome, forest_samples, active_forest, + rng, feature_types, leaf_model, current_leaf_scale, variable_weights, + a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T + ) + + if (sample_sigma_global) { + current_sigma2 <- sample_sigma2_one_iteration(outcome, forest_dataset, rng, a_global, b_global) + if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 + } + if (sample_sigma_leaf) { + leaf_scale_double <- sample_tau_one_iteration(active_forest, rng, a_leaf, b_leaf) + current_leaf_scale <- as.matrix(leaf_scale_double) + if (keep_sample) leaf_scale_samples[sample_counter] <- leaf_scale_double + } + } + } +} + +# Obtain predictions for all of the warmstarted MCMC samples +yhat <- forest_samples$predict(forest_dataset)*sd(y) + mean(y) + +# Plot results for each chain +plot_chain <- function(y, yhat, i) { + yhat_mean_chain <- rowMeans(yhat[,(num_gfr + (i-1)*num_mcmc):(num_gfr + (i)*num_mcmc)]) + plot(yhat_mean_chain, y, main = paste0("Chain ", i, " MCMC Samples"), xlab = "yhat") + abline(0,1,col="red",lty=3,lwd=3) +} +yhat_mean_gfr <- rowMeans(yhat[,1:num_gfr]) +plot(yhat_mean_gfr, y, main = "GFR Samples", xlab = "yhat") +abline(0,1,col="red",lty=3,lwd=3) +for (i in 1:num_chains) { + plot_chain(y, yhat, i) +} diff --git a/tools/debug/parallel_warmstart.R b/tools/debug/parallel_warmstart.R new file mode 100644 index 00000000..18b73574 --- /dev/null +++ b/tools/debug/parallel_warmstart.R @@ -0,0 +1,115 @@ +# Load libraries +library(stochtree) +library(foreach) +library(doParallel) + +# Sampler settings +num_chains <- 6 +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 20 +num_trees <- 100 + +# Generate the data +n <- 500 +p_x <- 20 +snr <- 2 +X <- matrix(runif(n*p_x), ncol = p_x) +f_XW <- sin(4*pi*X[,1]) + cos(4*pi*X[,2]) + sin(4*pi*X[,3]) +cos(4*pi*X[,4]) +noise_sd <- sd(f_XW) / snr +y <- f_XW + rnorm(n, 0, 1)*noise_sd + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- as.data.frame(X[test_inds,]) +X_train <- as.data.frame(X[train_inds,]) +y_test <- y[test_inds] +y_train <- y[train_inds] + +# Run the GFR algorithm +xbart_params <- list(sample_sigma_global = T, + num_trees_mean = num_trees, alpha_mean = 0.99, + beta_mean = 1, max_depth_mean = -1, + min_samples_leaf_mean = 1, sample_sigma_leaf = F, + sigma_leaf_init = 1/num_trees) +xbart_model <- stochtree::bart( + X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = num_gfr, num_burnin = 0, num_mcmc = 0, params = xbart_params +) +plot(rowMeans(xbart_model$y_hat_test), y_test); abline(0,1) +cat(sqrt(mean((rowMeans(xbart_model$y_hat_test) - y_test)^2)), "\n") +cat(mean((apply(xbart_model$y_hat_test, 1, quantile, probs=0.05) <= y_test) & (apply(xbart_model$y_hat_test, 1, quantile, probs=0.95) >= y_test)), "\n") +xbart_model_string <- stochtree::saveBARTModelToJsonString(xbart_model) + +# Parallel setup +ncores <- parallel::detectCores() +cl <- makeCluster(ncores) +registerDoParallel(cl) + +# Run the parallel BART MCMC samplers +bart_model_outputs <- foreach (i = 1:num_chains) %dopar% { + random_seed <- i + bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, + num_trees_mean = num_trees, random_seed = random_seed, + alpha_mean = 0.999, beta_mean = 1) + bart_model <- stochtree::bart( + X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 0, num_burnin = num_burnin, num_mcmc = num_mcmc, params = bart_params, + previous_model_json = xbart_model_string, warmstart_sample_num = num_gfr - i + 1, + ) + bart_model_string <- stochtree::saveBARTModelToJsonString(bart_model) + y_hat_test <- bart_model$y_hat_test + list(model=bart_model_string, yhat=y_hat_test) +} + +# Close the cluster connection +stopCluster(cl) + +# Combine the forests +bart_model_strings <- list() +bart_model_yhats <- matrix(NA, nrow = length(y_test), ncol = num_chains) +for (i in 1:length(bart_model_outputs)) { + bart_model_strings[[i]] <- bart_model_outputs[[i]]$model + bart_model_yhats[,i] <- rowMeans(bart_model_outputs[[i]]$yhat) +} +combined_bart <- createBARTModelFromCombinedJsonString(bart_model_strings) + +# Inspect the results +yhat_combined <- predict(combined_bart, X_test)$y_hat +par(mfrow = c(1,2)) +for (i in 1:num_chains) { + offset <- (i-1)*num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot(rowMeans(yhat_combined[,inds_start:inds_end]), bart_model_yhats[,i], + xlab = "deserialized", ylab = "original", + main = paste0("Chain ", i, "\nPredictions")) + abline(0,1,col="red",lty=3,lwd=3) +} +for (i in 1:num_chains) { + offset <- (i-1)*num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot(rowMeans(yhat_combined[,inds_start:inds_end]), y_test, + xlab = "predicted", ylab = "actual", + main = paste0("Chain ", i, "\nPredictions")) + abline(0,1,col="red",lty=3,lwd=3) + cat(sqrt(mean((rowMeans(yhat_combined[,inds_start:inds_end]) - y_test)^2)), "\n") + cat(mean((apply(yhat_combined[,inds_start:inds_end], 1, quantile, probs=0.05) <= y_test) & (apply(yhat_combined[,inds_start:inds_end], 1, quantile, probs=0.95) >= y_test)), "\n") +} +par(mfrow = c(1,1)) + +# Compare to a single chain of MCMC samples initialized at root +bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, + num_trees_mean = num_trees, alpha_mean = 0.95, beta_mean = 2) +bart_model <- stochtree::bart( + X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, params = bart_params +) +plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual"); abline(0,1) +cat(sqrt(mean((rowMeans(bart_model$y_hat_test) - y_test)^2)), "\n") +cat(mean((apply(bart_model$y_hat_test, 1, quantile, probs=0.05) <= y_test) & (apply(bart_model$y_hat_test, 1, quantile, probs=0.95) >= y_test)), "\n") diff --git a/tools/debug/parallel_warmstart_bcf.R b/tools/debug/parallel_warmstart_bcf.R new file mode 100644 index 00000000..9d002b32 --- /dev/null +++ b/tools/debug/parallel_warmstart_bcf.R @@ -0,0 +1,134 @@ +# Load libraries +library(stochtree) +library(foreach) +library(doParallel) + +# Sampler settings +num_chains <- 6 +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 20 +num_trees_mu <- 100 +num_trees_tau <- 20 + +# Generate the data +n <- 500 +x1 <- rnorm(n) +x2 <- rnorm(n) +x3 <- rnorm(n) +x4 <- rnorm(n,x2,1) +X <- cbind(x1,x2,x3,x4) +p <- ncol(X) +mu <- function(x) {-1*(x[,1]>(x[,2])) + 1*(x[,1]<(x[,2])) - 0.1} +tau <- function(x) {1/(1 + exp(-x[,3])) + x[,2]/10} +mu_x <- mu(X) +tau_x <- tau(X) +pi_x <- pnorm(mu_x) +Z <- rbinom(n,1,pi_x) +E_XZ <- mu_x + Z*tau_x +sigma <- diff(range(mu_x + tau_x*pi))/8 +y <- E_XZ + sigma*rnorm(n) +X <- as.data.frame(X) + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] + +# Run the GFR algorithm +xbcf_params <- list(num_trees_mu = num_trees_mu, num_trees_tau = num_trees_tau, + alpha_mu = 0.95, beta_mu = 1, max_depth_mu = -1, + alpha_tau = 0.8, beta_tau = 2, max_depth_tau = 10) +xbcf_model <- stochtree::bcf( + X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, + X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, + num_burnin = 0, num_mcmc = 0, params = xbcf_params +) +plot(rowMeans(xbcf_model$y_hat_test), y_test); abline(0,1) +cat(sqrt(mean((rowMeans(xbcf_model$y_hat_test) - y_test)^2)), "\n") +cat(mean((apply(xbcf_model$y_hat_test, 1, quantile, probs=0.05) <= y_test) & (apply(xbcf_model$y_hat_test, 1, quantile, probs=0.95) >= y_test)), "\n") +xbcf_model_string <- stochtree::saveBCFModelToJsonString(xbcf_model) + +# Parallel setup +ncores <- parallel::detectCores() +cl <- makeCluster(ncores) +registerDoParallel(cl) + +# Run the parallel BART MCMC samplers +bcf_model_outputs <- foreach (i = 1:num_chains) %dopar% { + random_seed <- i + bcf_params <- list(num_trees_mu = num_trees_mu, num_trees_tau = num_trees_tau, + random_seed = random_seed) + bcf_model <- stochtree::bcf( + X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, + X_test = X_test, Z_test = Z_test, pi_test = pi_test, + num_gfr = 0, num_burnin = num_burnin, num_mcmc = num_mcmc, params = bcf_params, + previous_model_json = xbcf_model_string, warmstart_sample_num = num_gfr - i + 1, + ) + bcf_model_string <- stochtree::saveBCFModelToJsonString(bcf_model) + y_hat_test <- bcf_model$y_hat_test + list(model=bcf_model_string, yhat=y_hat_test) +} + +# Close the cluster connection +stopCluster(cl) + +# Combine the forests +bcf_model_strings <- list() +bcf_model_yhats <- matrix(NA, nrow = length(y_test), ncol = num_chains) +for (i in 1:length(bcf_model_outputs)) { + bcf_model_strings[[i]] <- bcf_model_outputs[[i]]$model + bcf_model_yhats[,i] <- rowMeans(bcf_model_outputs[[i]]$yhat) +} +combined_bcf <- createBCFModelFromCombinedJsonString(bcf_model_strings) + +# Inspect the results +yhat_combined <- predict(combined_bcf, X_test)$y_hat +par(mfrow = c(1,2)) +for (i in 1:num_chains) { + offset <- (i-1)*num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot(rowMeans(yhat_combined[,inds_start:inds_end]), bcf_model_yhats[,i], + xlab = "deserialized", ylab = "original", + main = paste0("Chain ", i, "\nPredictions")) + abline(0,1,col="red",lty=3,lwd=3) +} +for (i in 1:num_chains) { + offset <- (i-1)*num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot(rowMeans(yhat_combined[,inds_start:inds_end]), y_test, + xlab = "predicted", ylab = "actual", + main = paste0("Chain ", i, "\nPredictions")) + abline(0,1,col="red",lty=3,lwd=3) + cat(sqrt(mean((rowMeans(yhat_combined[,inds_start:inds_end]) - y_test)^2)), "\n") + cat(mean((apply(yhat_combined[,inds_start:inds_end], 1, quantile, probs=0.05) <= y_test) & (apply(yhat_combined[,inds_start:inds_end], 1, quantile, probs=0.95) >= y_test)), "\n") +} +par(mfrow = c(1,1)) + +# Compare to a single chain of MCMC samples initialized at root +bcf_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, + num_trees_mean = num_trees, alpha_mean = 0.95, beta_mean = 2) +bcf_model <- stochtree::bcf( + X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, + X_test = X_test, Z_test = Z_test, pi_test = pi_test, + num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, params = bcf_params +) +plot(rowMeans(bcf_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual"); abline(0,1) +cat(sqrt(mean((rowMeans(bcf_model$y_hat_test) - y_test)^2)), "\n") +cat(mean((apply(bcf_model$y_hat_test, 1, quantile, probs=0.05) <= y_test) & (apply(bcf_model$y_hat_test, 1, quantile, probs=0.95) >= y_test)), "\n") diff --git a/tools/debug/rfx_debug.R b/tools/debug/rfx_debug.R deleted file mode 100644 index f4b734e2..00000000 --- a/tools/debug/rfx_debug.R +++ /dev/null @@ -1,156 +0,0 @@ -library(stochtree) - -# Generate the data -n <- 500 -p_X <- 10 -p_W <- 1 -X <- matrix(runif(n*p_X), ncol = p_X) -W <- matrix(runif(n*p_W), ncol = p_W) -group_ids <- rep(c(1,2), n %/% 2) -rfx_coefs <- c(-5, 5) -rfx_basis <- rep(1, n) -f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-3*W[,1]) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-1*W[,1]) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (1*W[,1]) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (3*W[,1]) -) -rfx_term <- rfx_coefs[group_ids] * rfx_basis -y <- f_XW + rfx_term + rnorm(n, 0, 1) - -# Standardize outcome -y_bar <- mean(y) -y_std <- sd(y) -resid <- (y-y_bar)/y_std - -alpha <- 0.9 -beta <- 1.25 -min_samples_leaf <- 1 -num_trees <- 100 -cutpoint_grid_size = 100 -global_variance_init = 1. -tau_init = 0.5 -leaf_prior_scale = matrix(c(tau_init), ncol = 1) -nu <- 4 -lambda <- 0.5 -a_leaf <- 2. -b_leaf <- 0.5 -leaf_regression <- T -feature_types <- as.integer(rep(0, p_X)) # 0 = numeric -var_weights <- rep(1/p_X, p_X) - -alpha_init <- c(1) -xi_init <- matrix(c(1,1),1,2) -sigma_alpha_init <- matrix(c(1),1,1) -sigma_xi_init <- matrix(c(1),1,1) -sigma_xi_shape <- 1 -sigma_xi_scale <- 1 - -# Data -if (leaf_regression) { - forest_dataset <- createForestDataset(X, W) - outcome_model_type <- 1 -} else { - forest_dataset <- createForestDataset(X) - outcome_model_type <- 0 -} -outcome <- createOutcome(resid) - -# Random number generator (std::mt19937) -rng <- createRNG() - -# Sampling data structures -forest_model <- createForestModel(forest_dataset, feature_types, - num_trees, n, alpha, beta, min_samples_leaf) - -# Container of forest samples -if (leaf_regression) { - forest_samples <- createForestContainer(num_trees, 1, F) -} else { - forest_samples <- createForestContainer(num_trees, 1, T) -} - -# Random effects dataset -rfx_basis <- as.matrix(rfx_basis) -group_ids <- as.integer(group_ids) -rfx_dataset <- createRandomEffectsDataset(group_ids, rfx_basis) - -# Random effects details -num_groups <- length(unique(group_ids)) -num_components <- ncol(rfx_basis) - -# Random effects tracker -rfx_tracker <- createRandomEffectsTracker(group_ids) - -# Random effects model -rfx_model <- createRandomEffectsModel(num_components, num_groups) -rfx_model$set_working_parameter(alpha_init) -rfx_model$set_group_parameters(xi_init) -rfx_model$set_working_parameter_cov(sigma_alpha_init) -rfx_model$set_group_parameter_cov(sigma_xi_init) -rfx_model$set_variance_prior_shape(sigma_xi_shape) -rfx_model$set_variance_prior_scale(sigma_xi_scale) - -# Random effect samples -rfx_samples <- createRandomEffectSamples(num_components, num_groups, rfx_tracker) - -num_warmstart <- 10 -num_mcmc <- 100 -num_samples <- num_warmstart + num_mcmc -global_var_samples <- c(global_variance_init, rep(0, num_samples)) -leaf_scale_samples <- c(tau_init, rep(0, num_samples)) - -for (i in 1:num_warmstart) { - # Sample forest - forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - global_var_samples[i], cutpoint_grid_size, gfr = T - ) - - # Sample global variance parameter - global_var_samples[i+1] <- sample_sigma2_one_iteration( - outcome, forest_dataset, rng, nu, lambda - ) - - # Sample leaf node variance parameter and update `leaf_prior_scale` - leaf_scale_samples[i+1] <- sample_tau_one_iteration( - forest_samples, rng, a_leaf, b_leaf, i-1 - ) - leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] - - # Sample random effects model - rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, global_var_samples[i+1], rng) -} - -for (i in (num_warmstart+1):num_samples) { - # Sample forest - forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - global_var_samples[i], cutpoint_grid_size, gfr = F - ) - - # Sample global variance parameter - global_var_samples[i+1] <- sample_sigma2_one_iteration( - outcome, forest_dataset, rng, nu, lambda - ) - - # Sample leaf node variance parameter and update `leaf_prior_scale` - leaf_scale_samples[i+1] <- sample_tau_one_iteration( - forest_samples, rng, a_leaf, b_leaf, i-1 - ) - leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] - - # Sample random effects model - rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, global_var_samples[i+1], rng) -} - -# Forest predictions -preds <- forest_samples$predict(forest_dataset)*y_std + y_bar - -# Random effects predictions -rfx_preds <- rfx_samples$predict(group_ids, rfx_basis)*y_std - -# Global error variance -sigma_samples <- sqrt(global_var_samples)*y_std diff --git a/tools/debug/supervised_learning_task_analysis.R b/tools/debug/supervised_learning_task_analysis.R deleted file mode 100644 index b71d2a85..00000000 --- a/tools/debug/supervised_learning_task_analysis.R +++ /dev/null @@ -1,439 +0,0 @@ -################################################################################ -## Run stochastic tree ensemble models on data for a supervised learning task -## and inspect their performance in terms of: -## (a) Run time (s) -## (b) Train set RMSE -## (c) Test set RMSE -################################################################################ - -# Setup -library(stochtree) -library(BART) -library(dbarts) -source("tools/debug/dgps.R") - -# Generate dataset -generate_data <- function(dgp_name, n, p_x, p_w = NULL, snr = NULL, test_set_pct = 0.2) { - # Dispatch the right DGP simulation function - if (dgp_name == "partitioned_linear_model") { - data_list <- dgp_prediction_partitioned_lm(n, p_x, p_w, snr) - } else if (dgp_name == "step_function") { - data_list <- dgp_prediction_step_function(n, p_x, snr) - } else { - stop(paste0("Invalid dgp_name: ", dgp_name)) - } - - # Unpack the data - has_basis <- data_list$has_basis - y <- data_list$y - X <- data_list$X - if (has_basis) { - W <- data_list$W - } else { - W <- NULL - } - snr <- data_list$snr - noise_sd <- data_list$noise_sd - - # Run test / train split - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - - # Split data into test and train sets - X_test <- X[test_inds,] - X_train <- X[train_inds,] - if (has_basis) { - W_test <- W[test_inds,] - W_train <- W[train_inds,] - } else { - W_test <- NULL - W_train <- NULL - } - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Standardize outcome separately for test and train - y_bar_train <- mean(y_train) - y_std_train <- sd(y_train) - resid_train <- (y_train-y_bar_train)/y_std_train - y_bar_test <- mean(y_test) - y_std_test <- sd(y_test) - resid_test <- (y_test-y_bar_test)/y_std_test - - return(list( - resid_train = resid_train, resid_test = resid_test, - y_train = y_train, y_test = y_test, - X_train = X_train, X_test = X_test, - W_train = W_train, W_test = W_test, - y_bar_train = y_bar_train, y_bar_test = y_bar_test, - y_std_train = y_std_train, y_std_test = y_std_test, - snr = snr, noise_sd = noise_sd, n = n, - n_train = n_train, n_test = n_test - )) -} - -# Performance analysis functions for stochtree -stochtree_analysis <- function(resid_train, resid_test, y_train, y_test, - X_train, X_test, y_bar_train, y_bar_test, - y_std_train, y_std_test, n, n_train, n_test, - num_gfr, num_burnin, num_mcmc_retained, - W_train = NULL, W_test = NULL, random_seed = NULL) { - # Model parameters - ntree <- 200 - if (is.null(W_train)) { - leaf_regression <- F - } else { - leaf_regression <- T - } - p_x <- ncol(X_train) - tau_init <- var(y_train) / ntree - # tau_init <- 0.1 - param_list <- list( - alpha = 0.95, beta = 2, min_samples_leaf = 1, num_trees = ntree, - cutpoint_grid_size = 100, global_variance_init = 1.0, tau_init = tau_init, - leaf_prior_scale = matrix(c(tau_init), ncol = 1), nu = 16, lambda = 0.25, - a_leaf = 3., b_leaf = 0.5 * tau_init, leaf_regression = leaf_regression, - feature_types = as.integer(rep(0, p_x)), var_weights = rep(1/p_x, p_x) - ) - - # Package the data - data_list <- list( - resid_train = resid_train, resid_test = resid_test, - y_train = y_train, y_test = y_test, X_train = X_train, X_test = X_test, - y_bar_train = y_bar_train, y_bar_test = y_bar_test, y_std_train = y_std_train, - y_std_test = y_std_test, W_train = W_train, W_test = W_test - ) - - return(dispatch_stochtree_run(num_gfr, num_burnin, num_mcmc_retained, param_list, data_list, random_seed)) -} - -dispatch_stochtree_run <- function(num_gfr, num_burnin, num_mcmc_retained, param_list, data_list, random_seed = NULL) { - # Start timer - start_time <- proc.time() - - # Data - if (param_list$leaf_regression) { - forest_dataset_train <- createForestDataset(data_list$X_train, data_list$W_train) - forest_dataset_test <- createForestDataset(data_list$X_test, data_list$W_test) - outcome_model_type <- 1 - } else { - forest_dataset_train <- createForestDataset(data_list$X_train) - forest_dataset_test <- createForestDataset(data_list$X_test) - outcome_model_type <- 0 - } - outcome_train <- createOutcome(data_list$resid_train) - - # Random number generator (std::mt19937) - if (is.null(random_seed)) {random_seed = sample(1:10000,1,F)} - rng <- createRNG(random_seed) - - # Sampling data structures - forest_model <- createForestModel(forest_dataset_train, param_list$feature_types, param_list$num_trees, nrow(data_list$X_train), param_list$alpha, param_list$beta, param_list$min_samples_leaf) - - # Container of forest samples - if (param_list$leaf_regression) { - forest_samples <- createForestContainer(param_list$num_trees, 1, F) - } else { - forest_samples <- createForestContainer(param_list$num_trees, 1, T) - } - - # Container of variance parameter samples - num_samples <- num_gfr + num_burnin + num_mcmc_retained - global_var_samples <- c(param_list$global_variance_init, rep(0, num_samples)) - leaf_scale_samples <- c(param_list$tau_init, rep(0, num_samples)) - - # Run GFR (warm start) if specified - if (num_gfr > 0){ - for (i in 1:num_gfr) { - forest_model$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples, rng, param_list$feature_types, - outcome_model_type, param_list$leaf_prior_scale, param_list$var_weights, - global_var_samples[i], param_list$cutpoint_grid_size, gfr = T - ) - global_var_samples[i+1] <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, param_list$nu, param_list$lambda) - leaf_scale_samples[i+1] <- sample_tau_one_iteration(forest_samples, rng, param_list$a_leaf, param_list$b_leaf, i-1) - param_list$leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] - } - } - - # Run MCMC - for (i in (num_gfr+1):num_samples) { - forest_model$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples, rng, param_list$feature_types, - outcome_model_type, param_list$leaf_prior_scale, param_list$var_weights, - global_var_samples[i], param_list$cutpoint_grid_size, gfr = F - ) - global_var_samples[i+1] <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, param_list$nu, param_list$lambda) - } - - # Forest predictions - train_preds <- forest_samples$predict(forest_dataset_train)*data_list$y_std_train + data_list$y_bar_train - test_preds <- forest_samples$predict(forest_dataset_test)*data_list$y_std_test + data_list$y_bar_test - - # End timer and measure run time - end_time <- proc.time() - runtime <- end_time[3] - start_time[3] - - # Global error variance - sigma_samples <- sqrt(global_var_samples)*data_list$y_std_train - - # RMSEs (post-burnin) - train_rmse <- sqrt(mean((rowMeans(train_preds[,(num_gfr+num_burnin+1):num_samples]) - data_list$y_train)^2)) - test_rmse <- sqrt(mean((rowMeans(test_preds[,(num_gfr+num_burnin+1):num_samples]) - data_list$y_test)^2)) - - return(c(runtime,train_rmse,test_rmse)) -} - -# Performance analysis functions for stochtree -wrapped_bart_stochtree_analysis <- function(resid_train, resid_test, y_train, y_test, - X_train, X_test, y_bar_train, y_bar_test, - y_std_train, y_std_test, n, n_train, n_test, - num_gfr, num_burnin, num_mcmc_retained, - W_train = NULL, W_test = NULL, random_seed = NULL) { - # Start timer - start_time <- proc.time() - - # Run BART - bart_model <- stochtree::bart( - X_train = X_train, W_train = W_train, y_train = y_train, - X_test = X_test, W_test = W_test, num_trees = 200, num_gfr = num_gfr, - num_burnin = num_burnin, num_mcmc = num_mcmc_retained, sample_sigma = T, - sample_tau = F, random_seed = 1234 - # random_seed = 1234, nu = 16 - ) - - # End timer and measure run time - end_time <- proc.time() - runtime <- end_time[3] - start_time[3] - - # RMSEs - num_samples <- num_gfr + num_burnin + num_mcmc_retained - ypred_mean_train <- rowMeans(bart_model$y_hat_train[,(num_gfr+num_burnin+1):num_samples]) - ypred_mean_test <- rowMeans(bart_model$y_hat_test[,(num_gfr+num_burnin+1):num_samples]) - train_rmse <- sqrt(mean((ypred_mean_train - y_train)^2)) - test_rmse <- sqrt(mean((ypred_mean_test - y_test)^2)) - - return(c(runtime,train_rmse,test_rmse)) -} - -# Performance analysis functions for wbart -wbart_analysis <- function(resid_train, resid_test, y_train, y_test, X_train, X_test, - y_bar_train, y_bar_test, y_std_train, y_std_test, - n, n_train, n_test, num_burnin, num_mcmc_retained, - W_train = NULL, W_test = NULL, random_seed = NULL) { - # Start timer - start_time <- proc.time() - - # Run wbart from the BART (add W to X if W is present, since wbart doesn't support leaf regression) - ntree <- 200 - alpha <- 0.95 - beta <- 2.0 - if (!is.null(W_train)) {X_train <- cbind(X_train, W_train)} - if (!is.null(W_test)) {X_test <- cbind(X_test, W_test)} - bartFit = wbart(X_train,y_train,X_test,power=beta,base=alpha,ntree=ntree,nskip=num_burnin,ndpost=num_mcmc_retained) - - # End timer and measure run time - end_time <- proc.time() - runtime <- end_time[3] - start_time[3] - - # RMSEs - train_rmse <- sqrt(mean((bartFit$yhat.train.mean - y_train)^2)) - test_rmse <- sqrt(mean((bartFit$yhat.test.mean - y_test)^2)) - - return(c(runtime,train_rmse,test_rmse)) -} - -# Run the code -# DGP 1 - Run 1 -dgp_name <- "partitioned_linear_model" -plm_data <- generate_data(dgp_name, n = 1000, p_x = 10, p_w = 1, snr = NULL, test_set_pct = 0.2) -warmstart_stochtree_results <- stochtree_analysis( - plm_data$resid_train, plm_data$resid_test, plm_data$y_train, plm_data$y_test, - plm_data$X_train, plm_data$X_test, plm_data$y_bar_train, plm_data$y_bar_test, - plm_data$y_std_train, plm_data$y_std_test, plm_data$n, plm_data$n_train, plm_data$n_test, - num_gfr = 10, num_burnin = 0, num_mcmc_retained = 100, W_train = plm_data$W_train, - W_test = plm_data$W_test, random_seed = NULL -) -gc() -mcmc_stochtree_results <- stochtree_analysis( - plm_data$resid_train, plm_data$resid_test, plm_data$y_train, plm_data$y_test, - plm_data$X_train, plm_data$X_test, plm_data$y_bar_train, plm_data$y_bar_test, - plm_data$y_std_train, plm_data$y_std_test, plm_data$n, plm_data$n_train, plm_data$n_test, - num_gfr = 0, num_burnin = 2000, num_mcmc_retained = 2000, W_train = plm_data$W_train, - W_test = plm_data$W_test, random_seed = NULL -) -gc() -wrapped_bart_warmstart_stochtree_results <- wrapped_bart_stochtree_analysis( - plm_data$resid_train, plm_data$resid_test, plm_data$y_train, plm_data$y_test, - plm_data$X_train, plm_data$X_test, plm_data$y_bar_train, plm_data$y_bar_test, - plm_data$y_std_train, plm_data$y_std_test, plm_data$n, plm_data$n_train, plm_data$n_test, - num_gfr = 10, num_burnin = 0, num_mcmc_retained = 100, W_train = plm_data$W_train, - W_test = plm_data$W_test, random_seed = NULL -) -gc() -wrapped_bart_mcmc_stochtree_results <- wrapped_bart_stochtree_analysis( - plm_data$resid_train, plm_data$resid_test, plm_data$y_train, plm_data$y_test, - plm_data$X_train, plm_data$X_test, plm_data$y_bar_train, plm_data$y_bar_test, - plm_data$y_std_train, plm_data$y_std_test, plm_data$n, plm_data$n_train, plm_data$n_test, - num_gfr = 0, num_burnin = 2000, num_mcmc_retained = 2000, W_train = plm_data$W_train, - W_test = plm_data$W_test, random_seed = NULL -) -gc() -mcmc_wbart_results <- wbart_analysis( - plm_data$resid_train, plm_data$resid_test, plm_data$y_train, plm_data$y_test, - plm_data$X_train, plm_data$X_test, plm_data$y_bar_train, plm_data$y_bar_test, - plm_data$y_std_train, plm_data$y_std_test, plm_data$n, plm_data$n_train, - plm_data$n_test, num_burnin = 2000, num_mcmc_retained = 2000, - W_train = plm_data$W_train, W_test = plm_data$W_test, random_seed = NULL -) -results_dgp1a <- rbind(warmstart_stochtree_results, mcmc_stochtree_results, wrapped_bart_warmstart_stochtree_results, wrapped_bart_mcmc_stochtree_results, mcmc_wbart_results) -results_dgp1a <- cbind(results_dgp1a, plm_data$snr, dgp_name, - c("stochtree_warm_start", "stochtree_mcmc", "bart_stochtree_warm_start", "bart_stochtree_mcmc", "wbart_mcmc")) -cat("DGP 1 out of 2 - Run 1 out of 2\n") - -# DGP 1 - Run 2 -plm_data <- generate_data(dgp_name, n = 1000, p_x = 10, p_w = 1, snr = NULL, test_set_pct = 0.2) -warmstart_stochtree_results <- stochtree_analysis( - plm_data$resid_train, plm_data$resid_test, plm_data$y_train, plm_data$y_test, - plm_data$X_train, plm_data$X_test, plm_data$y_bar_train, plm_data$y_bar_test, - plm_data$y_std_train, plm_data$y_std_test, plm_data$n, plm_data$n_train, plm_data$n_test, - num_gfr = 10, num_burnin = 0, num_mcmc_retained = 100, W_train = plm_data$W_train, - W_test = plm_data$W_test, random_seed = NULL -) -gc() -mcmc_stochtree_results <- stochtree_analysis( - plm_data$resid_train, plm_data$resid_test, plm_data$y_train, plm_data$y_test, - plm_data$X_train, plm_data$X_test, plm_data$y_bar_train, plm_data$y_bar_test, - plm_data$y_std_train, plm_data$y_std_test, plm_data$n, plm_data$n_train, plm_data$n_test, - num_gfr = 0, num_burnin = 2000, num_mcmc_retained = 2000, W_train = plm_data$W_train, - W_test = plm_data$W_test, random_seed = NULL -) -gc() -wrapped_bart_warmstart_stochtree_results <- wrapped_bart_stochtree_analysis( - plm_data$resid_train, plm_data$resid_test, plm_data$y_train, plm_data$y_test, - plm_data$X_train, plm_data$X_test, plm_data$y_bar_train, plm_data$y_bar_test, - plm_data$y_std_train, plm_data$y_std_test, plm_data$n, plm_data$n_train, plm_data$n_test, - num_gfr = 10, num_burnin = 0, num_mcmc_retained = 100, W_train = plm_data$W_train, - W_test = plm_data$W_test, random_seed = NULL -) -gc() -wrapped_bart_mcmc_stochtree_results <- wrapped_bart_stochtree_analysis( - plm_data$resid_train, plm_data$resid_test, plm_data$y_train, plm_data$y_test, - plm_data$X_train, plm_data$X_test, plm_data$y_bar_train, plm_data$y_bar_test, - plm_data$y_std_train, plm_data$y_std_test, plm_data$n, plm_data$n_train, plm_data$n_test, - num_gfr = 0, num_burnin = 2000, num_mcmc_retained = 2000, W_train = plm_data$W_train, - W_test = plm_data$W_test, random_seed = NULL -) -gc() -mcmc_wbart_results <- wbart_analysis( - plm_data$resid_train, plm_data$resid_test, plm_data$y_train, plm_data$y_test, - plm_data$X_train, plm_data$X_test, plm_data$y_bar_train, plm_data$y_bar_test, - plm_data$y_std_train, plm_data$y_std_test, plm_data$n, plm_data$n_train, - plm_data$n_test, num_burnin = 2000, num_mcmc_retained = 2000, - W_train = plm_data$W_train, W_test = plm_data$W_test, random_seed = NULL -) -results_dgp1b <- rbind(warmstart_stochtree_results, mcmc_stochtree_results, wrapped_bart_warmstart_stochtree_results, wrapped_bart_mcmc_stochtree_results, mcmc_wbart_results) -results_dgp1b <- cbind(results_dgp1b, plm_data$snr, dgp_name, - c("stochtree_warm_start", "stochtree_mcmc", "bart_stochtree_warm_start", "bart_stochtree_mcmc", "wbart_mcmc")) -cat("DGP 1 out of 2 - Run 2 out of 2\n") - -# DGP 2 - Run 1 -dgp_name <- "step_function" -stpfn_data <- generate_data(dgp_name, n = 1000, p_x = 10, p_w = NULL, snr = NULL, test_set_pct = 0.2) -warmstart_stochtree_results <- stochtree_analysis( - stpfn_data$resid_train, stpfn_data$resid_test, stpfn_data$y_train, stpfn_data$y_test, - stpfn_data$X_train, stpfn_data$X_test, stpfn_data$y_bar_train, stpfn_data$y_bar_test, - stpfn_data$y_std_train, stpfn_data$y_std_test, stpfn_data$n, stpfn_data$n_train, stpfn_data$n_test, - num_gfr = 10, num_burnin = 0, num_mcmc_retained = 100, W_train = stpfn_data$W_train, - W_test = stpfn_data$W_test, random_seed = NULL -) -gc() -mcmc_stochtree_results <- stochtree_analysis( - stpfn_data$resid_train, stpfn_data$resid_test, stpfn_data$y_train, stpfn_data$y_test, - stpfn_data$X_train, stpfn_data$X_test, stpfn_data$y_bar_train, stpfn_data$y_bar_test, - stpfn_data$y_std_train, stpfn_data$y_std_test, stpfn_data$n, stpfn_data$n_train, stpfn_data$n_test, - num_gfr = 0, num_burnin = 2000, num_mcmc_retained = 2000, W_train = stpfn_data$W_train, - W_test = stpfn_data$W_test, random_seed = NULL -) -gc() -wrapped_bart_warmstart_stochtree_results <- wrapped_bart_stochtree_analysis( - stpfn_data$resid_train, stpfn_data$resid_test, stpfn_data$y_train, stpfn_data$y_test, - stpfn_data$X_train, stpfn_data$X_test, stpfn_data$y_bar_train, stpfn_data$y_bar_test, - stpfn_data$y_std_train, stpfn_data$y_std_test, stpfn_data$n, stpfn_data$n_train, stpfn_data$n_test, - num_gfr = 10, num_burnin = 0, num_mcmc_retained = 100, W_train = stpfn_data$W_train, - W_test = stpfn_data$W_test, random_seed = NULL -) -gc() -wrapped_bart_mcmc_stochtree_results <- wrapped_bart_stochtree_analysis( - stpfn_data$resid_train, stpfn_data$resid_test, stpfn_data$y_train, stpfn_data$y_test, - stpfn_data$X_train, stpfn_data$X_test, stpfn_data$y_bar_train, stpfn_data$y_bar_test, - stpfn_data$y_std_train, stpfn_data$y_std_test, stpfn_data$n, stpfn_data$n_train, stpfn_data$n_test, - num_gfr = 0, num_burnin = 2000, num_mcmc_retained = 2000, W_train = stpfn_data$W_train, - W_test = stpfn_data$W_test, random_seed = NULL -) -gc() -mcmc_wbart_results <- wbart_analysis( - stpfn_data$resid_train, stpfn_data$resid_test, stpfn_data$y_train, stpfn_data$y_test, - stpfn_data$X_train, stpfn_data$X_test, stpfn_data$y_bar_train, stpfn_data$y_bar_test, - stpfn_data$y_std_train, stpfn_data$y_std_test, stpfn_data$n, stpfn_data$n_train, - stpfn_data$n_test, num_burnin = 2000, num_mcmc_retained = 2000, - W_train = stpfn_data$W_train, W_test = stpfn_data$W_test, random_seed = NULL -) -results_dgp2a <- rbind(warmstart_stochtree_results, mcmc_stochtree_results, wrapped_bart_warmstart_stochtree_results, wrapped_bart_mcmc_stochtree_results, mcmc_wbart_results) -results_dgp2a <- cbind(results_dgp2a, stpfn_data$snr, dgp_name, - c("stochtree_warm_start", "stochtree_mcmc", "bart_stochtree_warm_start", "bart_stochtree_mcmc", "wbart_mcmc")) -cat("DGP 2 out of 2 - Run 1 out of 2\n") - -# DGP 2 - Run 2 -stpfn_data <- generate_data(dgp_name, n = 1000, p_x = 10, p_w = NULL, snr = NULL, test_set_pct = 0.2) -warmstart_stochtree_results <- stochtree_analysis( - stpfn_data$resid_train, stpfn_data$resid_test, stpfn_data$y_train, stpfn_data$y_test, - stpfn_data$X_train, stpfn_data$X_test, stpfn_data$y_bar_train, stpfn_data$y_bar_test, - stpfn_data$y_std_train, stpfn_data$y_std_test, stpfn_data$n, stpfn_data$n_train, stpfn_data$n_test, - num_gfr = 10, num_burnin = 0, num_mcmc_retained = 100, W_train = stpfn_data$W_train, - W_test = stpfn_data$W_test, random_seed = NULL -) -gc() -mcmc_stochtree_results <- stochtree_analysis( - stpfn_data$resid_train, stpfn_data$resid_test, stpfn_data$y_train, stpfn_data$y_test, - stpfn_data$X_train, stpfn_data$X_test, stpfn_data$y_bar_train, stpfn_data$y_bar_test, - stpfn_data$y_std_train, stpfn_data$y_std_test, stpfn_data$n, stpfn_data$n_train, stpfn_data$n_test, - num_gfr = 0, num_burnin = 2000, num_mcmc_retained = 2000, W_train = stpfn_data$W_train, - W_test = stpfn_data$W_test, random_seed = NULL -) -gc() -wrapped_bart_warmstart_stochtree_results <- wrapped_bart_stochtree_analysis( - stpfn_data$resid_train, stpfn_data$resid_test, stpfn_data$y_train, stpfn_data$y_test, - stpfn_data$X_train, stpfn_data$X_test, stpfn_data$y_bar_train, stpfn_data$y_bar_test, - stpfn_data$y_std_train, stpfn_data$y_std_test, stpfn_data$n, stpfn_data$n_train, stpfn_data$n_test, - num_gfr = 10, num_burnin = 0, num_mcmc_retained = 100, W_train = stpfn_data$W_train, - W_test = stpfn_data$W_test, random_seed = NULL -) -gc() -wrapped_bart_mcmc_stochtree_results <- wrapped_bart_stochtree_analysis( - stpfn_data$resid_train, stpfn_data$resid_test, stpfn_data$y_train, stpfn_data$y_test, - stpfn_data$X_train, stpfn_data$X_test, stpfn_data$y_bar_train, stpfn_data$y_bar_test, - stpfn_data$y_std_train, stpfn_data$y_std_test, stpfn_data$n, stpfn_data$n_train, stpfn_data$n_test, - num_gfr = 0, num_burnin = 2000, num_mcmc_retained = 2000, W_train = stpfn_data$W_train, - W_test = stpfn_data$W_test, random_seed = NULL -) -gc() -mcmc_wbart_results <- wbart_analysis( - stpfn_data$resid_train, stpfn_data$resid_test, stpfn_data$y_train, stpfn_data$y_test, - stpfn_data$X_train, stpfn_data$X_test, stpfn_data$y_bar_train, stpfn_data$y_bar_test, - stpfn_data$y_std_train, stpfn_data$y_std_test, stpfn_data$n, stpfn_data$n_train, - stpfn_data$n_test, num_burnin = 2000, num_mcmc_retained = 2000, - W_train = stpfn_data$W_train, W_test = stpfn_data$W_test, random_seed = NULL -) -results_dgp2b <- rbind(warmstart_stochtree_results, mcmc_stochtree_results, wrapped_bart_warmstart_stochtree_results, wrapped_bart_mcmc_stochtree_results, mcmc_wbart_results) -results_dgp2b <- cbind(results_dgp2b, stpfn_data$snr, dgp_name, - c("stochtree_warm_start", "stochtree_mcmc", "bart_stochtree_warm_start", "bart_stochtree_mcmc", "wbart_mcmc")) -cat("DGP 2 out of 2 - Run 2 out of 2\n") - -results_df <- data.frame(rbind(results_dgp1a, results_dgp1b, results_dgp2a, results_dgp2b)) -colnames(results_df) <- c("runtime", "train_rmse", "test_rmse", "snr", "dgp", "model_type") -rownames(results_df) <- 1:nrow(results_df) -results_df <- results_df[,c("dgp", "model_type", "snr", "runtime", "train_rmse", "test_rmse")] -results_df diff --git a/vignettes/CausalInference.Rmd b/vignettes/CausalInference.Rmd index 74291196..fb461708 100644 --- a/vignettes/CausalInference.Rmd +++ b/vignettes/CausalInference.Rmd @@ -112,7 +112,8 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 1000 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F) +bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, + keep_every = 5) bcf_model_warmstart <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, @@ -159,7 +160,8 @@ num_gfr <- 0 num_burnin <- 1000 num_mcmc <- 1000 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F) +bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, + keep_every = 5) bcf_model_root <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, @@ -275,7 +277,8 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F) +bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, + keep_every = 5) bcf_model_warmstart <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, @@ -322,7 +325,8 @@ num_gfr <- 0 num_burnin <- 100 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F) +bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, + keep_every = 5) bcf_model_root <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, @@ -438,7 +442,8 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F) +bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, + keep_every = 5) bcf_model_warmstart <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, @@ -485,7 +490,8 @@ num_gfr <- 0 num_burnin <- 100 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F) +bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, + keep_every = 5) bcf_model_root <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, @@ -599,7 +605,8 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F) +bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, + keep_every = 5) bcf_model_warmstart <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, @@ -646,7 +653,8 @@ num_gfr <- 0 num_burnin <- 100 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F) +bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, + keep_every = 5) bcf_model_root <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, @@ -753,7 +761,8 @@ num_gfr <- 100 num_burnin <- 0 num_mcmc <- 500 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F) +bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, + keep_every = 5) bcf_model_warmstart <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, @@ -867,7 +876,8 @@ num_gfr <- 0 num_burnin <- 1000 num_mcmc <- 1000 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F) +bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, + keep_every = 5) bcf_model_mcmc <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, @@ -926,7 +936,8 @@ num_gfr <- 0 num_burnin <- 1000 num_mcmc <- 1000 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, keep_vars_tau = c("x1","x2")) +bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, + keep_vars_tau = c("x1","x2"), keep_every = 5) bcf_model_mcmc <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, @@ -985,7 +996,8 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 1000 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F) +bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, + keep_every = 5) bcf_model_warmstart <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, @@ -1044,7 +1056,8 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 1000 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, keep_vars_tau = c("x1","x2")) +bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, + keep_vars_tau = c("x1","x2"), keep_every = 5) bcf_model_warmstart <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, @@ -1169,7 +1182,8 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 1000 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F) +bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, + keep_every = 5) bcf_model_warmstart <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, @@ -1216,7 +1230,8 @@ num_gfr <- 0 num_burnin <- 1000 num_mcmc <- 1000 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F) +bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, + keep_every = 5) bcf_model_root <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, diff --git a/vignettes/CustomSamplingRoutine.Rmd b/vignettes/CustomSamplingRoutine.Rmd index f78659ee..5d6acae9 100644 --- a/vignettes/CustomSamplingRoutine.Rmd +++ b/vignettes/CustomSamplingRoutine.Rmd @@ -2,7 +2,7 @@ title: "Custom Sampling Routines in StochTree" output: rmarkdown::html_vignette vignette: > - %\VignetteIndexEntry{Prototype-Interface} + %\VignetteIndexEntry{Custom-Sampling-Routine} %\VignetteEncoding{UTF-8} %\VignetteEngine{knitr::rmarkdown} bibliography: vignettes.bib @@ -135,11 +135,15 @@ forest_model <- createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) -# Container of forest samples +# "Active forest" (which gets updated by the sample) and +# container of forest samples (which is written to when +# a sample is not discarded due to burn-in / thinning) if (leaf_regression) { forest_samples <- createForestContainer(num_trees, 1, F) + active_forest <- createForest(num_trees, 1, F) } else { forest_samples <- createForestContainer(num_trees, 1, T) + active_forest <- createForest(num_trees, 1, T) } ``` @@ -159,9 +163,9 @@ Run the grow-from-root sampler to "warm-start" BART for (i in 1:num_warmstart) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, rng, feature_types, + forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, gfr = T + 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = T ) # Sample global variance parameter @@ -171,7 +175,7 @@ for (i in 1:num_warmstart) { # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sample_tau_one_iteration( - forest_samples, rng, a_leaf, b_leaf, i-1 + active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] } @@ -184,9 +188,9 @@ scale parameters) with an MCMC sampler for (i in (num_warmstart+1):num_samples) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, rng, feature_types, + forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, gfr = F + 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = F ) # Sample global variance parameter @@ -196,7 +200,7 @@ for (i in (num_warmstart+1):num_samples) { # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sample_tau_one_iteration( - forest_samples, rng, a_leaf, b_leaf, i-1 + active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] } @@ -320,11 +324,15 @@ forest_model <- createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) -# Container of forest samples +# "Active forest" (which gets updated by the sample) and +# container of forest samples (which is written to when +# a sample is not discarded due to burn-in / thinning) if (leaf_regression) { forest_samples <- createForestContainer(num_trees, 1, F) + active_forest <- createForest(num_trees, 1, F) } else { forest_samples <- createForestContainer(num_trees, 1, T) + active_forest <- createForest(num_trees, 1, T) } # Random effects dataset @@ -368,9 +376,9 @@ Run the grow-from-root sampler to "warm-start" BART for (i in 1:num_warmstart) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, rng, feature_types, + forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, gfr = T + 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = T ) # Sample global variance parameter @@ -380,12 +388,13 @@ for (i in 1:num_warmstart) { # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sample_tau_one_iteration( - forest_samples, rng, a_leaf, b_leaf, i-1 + active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] # Sample random effects model - rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, global_var_samples[i+1], rng) + rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, + TRUE, global_var_samples[i+1], rng) } ``` @@ -396,9 +405,9 @@ scale parameters) with an MCMC sampler for (i in (num_warmstart+1):num_samples) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, rng, feature_types, + forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, gfr = F + 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = F ) # Sample global variance parameter @@ -408,12 +417,13 @@ for (i in (num_warmstart+1):num_samples) { # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sample_tau_one_iteration( - forest_samples, rng, a_leaf, b_leaf, i-1 + active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] # Sample random effects model - rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, global_var_samples[i+1], rng) + rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, + TRUE, global_var_samples[i+1], rng) } ``` @@ -549,11 +559,15 @@ forest_model <- createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) -# Container of forest samples +# "Active forest" (which gets updated by the sample) and +# container of forest samples (which is written to when +# a sample is not discarded due to burn-in / thinning) if (leaf_regression) { forest_samples <- createForestContainer(num_trees, 1, F) + active_forest <- createForest(num_trees, 1, F) } else { forest_samples <- createForestContainer(num_trees, 1, T) + active_forest <- createForest(num_trees, 1, T) } # Random effects dataset @@ -597,9 +611,9 @@ Run the grow-from-root sampler to "warm-start" BART for (i in 1:num_warmstart) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, rng, feature_types, + forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, gfr = T + 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = T ) # Sample global variance parameter @@ -609,12 +623,13 @@ for (i in 1:num_warmstart) { # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sample_tau_one_iteration( - forest_samples, rng, a_leaf, b_leaf, i-1 + active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] # Sample random effects model - rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, global_var_samples[i+1], rng) + rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, + TRUE, global_var_samples[i+1], rng) } ``` @@ -625,9 +640,9 @@ scale parameters) with an MCMC sampler for (i in (num_warmstart+1):num_samples) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, rng, feature_types, + forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, gfr = F + 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = F ) # Sample global variance parameter @@ -637,12 +652,13 @@ for (i in (num_warmstart+1):num_samples) { # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sample_tau_one_iteration( - forest_samples, rng, a_leaf, b_leaf, i-1 + active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] # Sample random effects model - rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, global_var_samples[i+1], rng) + rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, + TRUE, global_var_samples[i+1], rng) } ``` @@ -777,11 +793,15 @@ forest_model <- createForestModel(forest_dataset, feature_types, num_trees, n, alpha_bart, beta_bart, min_samples_leaf, max_depth) -# Container of forest samples +# "Active forest" (which gets updated by the sample) and +# container of forest samples (which is written to when +# a sample is not discarded due to burn-in / thinning) if (leaf_regression) { forest_samples <- createForestContainer(num_trees, 1, F) + active_forest <- createForest(num_trees, 1, F) } else { forest_samples <- createForestContainer(num_trees, 1, T) + active_forest <- createForest(num_trees, 1, T) } ``` @@ -822,9 +842,9 @@ for (i in 1:num_warmstart) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, rng, feature_types, + forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, sigma2, cutpoint_grid_size, gfr = T + 1, 1, sigma2, cutpoint_grid_size, keep_forest = T, gfr = T ) # Sample global variance parameter @@ -860,9 +880,9 @@ for (i in (num_warmstart+1):num_samples) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, rng, feature_types, + forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, gfr = F + 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = F ) # Sample global variance parameter @@ -1126,19 +1146,17 @@ forest_model_tau <- createForestModel( # Container of forest samples forest_samples_mu <- createForestContainer(num_trees_mu, 1, T) +active_forest_mu <- createForest(num_trees_mu, 1, T) forest_samples_tau <- createForestContainer(num_trees_tau, 1, F) +active_forest_tau <- createForest(num_trees_tau, 1, F) # Initialize the leaves of each tree in the prognostic forest -forest_samples_mu$set_root_leaves(0, mean(resid) / num_trees_mu) -forest_samples_mu$adjust_residual( - forest_dataset_mu, outcome, forest_model_mu, F, 0, F -) +active_forest_mu$prepare_for_sampler(forest_dataset_mu, outcome, forest_model_mu, 0, mean(resid)) +active_forest_mu$adjust_residual(forest_dataset_mu, outcome, forest_model_mu, F, F) # Initialize the leaves of each tree in the treatment effect forest -forest_samples_tau$set_root_leaves(0, 0.) -forest_samples_tau$adjust_residual( - forest_dataset_tau, outcome, forest_model_tau, T, 0, F -) +active_forest_tau$prepare_for_sampler(forest_dataset_tau, outcome, forest_model_tau, 1, 0.) +active_forest_tau$adjust_residual(forest_dataset_tau, outcome, forest_model_tau, T, F) ``` Run the grow-from-root sampler to "warm-start" BART, also updating the adaptive coding parameter $b_0$ and $b_1$ @@ -1148,9 +1166,10 @@ if (num_gfr > 0){ for (i in 1:num_gfr) { # Sample the prognostic forest forest_model_mu$sample_one_iteration( - forest_dataset_mu, outcome, forest_samples_mu, rng, + forest_dataset_mu, outcome, forest_samples_mu, active_forest_mu, rng, feature_types_mu, 0, current_leaf_scale_mu, variable_weights_mu, - 1, 1, current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T + 1, 1, current_sigma2, cutpoint_grid_size, keep_forest = T, gfr = T, + pre_initialized = T ) # Sample variance parameters (if requested) @@ -1161,14 +1180,15 @@ if (num_gfr > 0){ # Sample the treatment forest forest_model_tau$sample_one_iteration( - forest_dataset_tau, outcome, forest_samples_tau, rng, + forest_dataset_tau, outcome, forest_samples_tau, active_forest_tau, rng, feature_types_tau, 1, current_leaf_scale_tau, variable_weights_tau, - 1, 1, current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T + 1, 1, current_sigma2, cutpoint_grid_size, keep_forest = T, gfr = T, + pre_initialized = T ) # Sample adaptive coding parameters - mu_x_raw <- forest_samples_mu$predict_raw_single_forest(forest_dataset_mu, i-1) - tau_x_raw <- forest_samples_tau$predict_raw_single_forest(forest_dataset_tau, i-1) + mu_x_raw <- active_forest_mu$predict_raw(forest_dataset_mu) + tau_x_raw <- active_forest_tau$predict_raw(forest_dataset_tau) s_tt0 <- sum(tau_x_raw*tau_x_raw*(Z==0)) s_tt1 <- sum(tau_x_raw*tau_x_raw*(Z==1)) partial_resid_mu <- resid - mu_x_raw @@ -1180,6 +1200,7 @@ if (num_gfr > 0){ sqrt(current_sigma2/(s_tt1 + 2*current_sigma2))) tau_basis <- (1-Z)*current_b_0 + Z*current_b_1 forest_dataset_tau$update_basis(tau_basis) + forest_model_tau$propagate_basis_update(forest_dataset_tau, outcome, active_forest_tau) b_0_samples[i] <- current_b_0 b_1_samples[i] <- current_b_1 @@ -1197,9 +1218,9 @@ if (num_burnin + num_mcmc > 0) { for (i in (num_gfr+1):num_samples) { # Sample the prognostic forest forest_model_mu$sample_one_iteration( - forest_dataset_mu, outcome, forest_samples_mu, rng, feature_types_mu, + forest_dataset_mu, outcome, forest_samples_mu, active_forest_mu, rng, feature_types_mu, 0, current_leaf_scale_mu, variable_weights_mu, 1, 1, current_sigma2, - cutpoint_grid_size, gfr = F, pre_initialized = T + cutpoint_grid_size, keep_forest = T, gfr = F, pre_initialized = T ) # Sample global variance parameter @@ -1208,14 +1229,14 @@ if (num_burnin + num_mcmc > 0) { # Sample the treatment forest forest_model_tau$sample_one_iteration( - forest_dataset_tau, outcome, forest_samples_tau, rng, feature_types_tau, + forest_dataset_tau, outcome, forest_samples_tau, active_forest_tau, rng, feature_types_tau, 1, current_leaf_scale_tau, variable_weights_tau, 1, 1, current_sigma2, - cutpoint_grid_size, gfr = F, pre_initialized = T + cutpoint_grid_size, keep_forest = T, gfr = F, pre_initialized = T ) # Sample coding parameters - mu_x_raw <- forest_samples_mu$predict_raw_single_forest(forest_dataset_mu, i-1) - tau_x_raw <- forest_samples_tau$predict_raw_single_forest(forest_dataset_tau, i-1) + mu_x_raw <- active_forest_mu$predict_raw(forest_dataset_mu) + tau_x_raw <- active_forest_tau$predict_raw(forest_dataset_tau) s_tt0 <- sum(tau_x_raw*tau_x_raw*(Z==0)) s_tt1 <- sum(tau_x_raw*tau_x_raw*(Z==1)) partial_resid_mu <- resid - mu_x_raw @@ -1227,6 +1248,7 @@ if (num_burnin + num_mcmc > 0) { sqrt(current_sigma2/(s_tt1 + 2*current_sigma2))) tau_basis <- (1-Z)*current_b_0 + Z*current_b_1 forest_dataset_tau$update_basis(tau_basis) + forest_model_tau$propagate_basis_update(forest_dataset_tau, outcome, active_forest_tau) b_0_samples[i] <- current_b_0 b_1_samples[i] <- current_b_1 @@ -1268,7 +1290,7 @@ mean((rowMeans(tau_hat[,1:num_gfr]) - tau_x)^2) Inspect the warm start BART results ```{r bcf_warm_start_plot} -plot(sigma_samples[(num_gfr+1):num_samples], ylab="sigma^2") +plot(sigma2_samples[(num_gfr+1):num_samples], ylab="sigma^2") plot(rowMeans(mu_hat[,(num_gfr+1):num_samples]), mu_x, pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "prognostic term") abline(0,1,col="red",lty=2,lwd=2.5) diff --git a/vignettes/Heteroskedasticity.Rmd b/vignettes/Heteroskedasticity.Rmd index 8195edca..d29008b0 100644 --- a/vignettes/Heteroskedasticity.Rmd +++ b/vignettes/Heteroskedasticity.Rmd @@ -2,7 +2,7 @@ title: "Bayesian Supervised Learning with Heteroskedasticity in StochTree" output: rmarkdown::html_vignette vignette: > - %\VignetteIndexEntry{Bayesian-Supervised-Learning} + %\VignetteIndexEntry{Heteroskedasticity} %\VignetteEncoding{UTF-8} %\VignetteEngine{knitr::rmarkdown} bibliography: vignettes.bib diff --git a/vignettes/MultiChain.Rmd b/vignettes/MultiChain.Rmd index bc5adc2e..c9d69b50 100644 --- a/vignettes/MultiChain.Rmd +++ b/vignettes/MultiChain.Rmd @@ -2,7 +2,7 @@ title: "Running Multiple Chains (Sequentially or in Parallel) in StochTree" output: rmarkdown::html_vignette vignette: > - %\VignetteIndexEntry{Prototype-Interface} + %\VignetteIndexEntry{Multiple-Chains} %\VignetteEncoding{UTF-8} %\VignetteEngine{knitr::rmarkdown} bibliography: vignettes.bib @@ -246,4 +246,129 @@ for (i in 1:num_chains) { par(mfrow = c(1,1)) ``` +## Warmstarting Multiple Chains in Parallel + +In the above example, we ran multiple parallel chains with each MCMC sampler +starting from a "root" forest. Consider instead the "warmstart" approach +of @he2023stochastic, where forests are sampled using the fast "grow-from-root" (GFR) +algorithm and then several MCMC chains are run using different GFR forests. + +We use the same high-level parameters as in the parallel demo. + +```{r} +num_chains <- 4 +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +num_trees <- 100 +``` + +First, we sample this model using the grow-from-root algorithm in the main R session +for several iterations (we will use these forests to see independent parallel chains in a moment). + +```{r} +xbart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, + num_trees_mean = num_trees) +xbart_model <- stochtree::bart( + X_train = X_train, W_train = W_train, y_train = y_train, X_test = X_test, W_test = W_test, + num_gfr = num_gfr, num_burnin = 0, num_mcmc = 0, params = xbart_params +) +xbart_model_string <- stochtree::saveBARTModelToJsonString(xbart_model) +``` + + +In order to run this sampler in parallel, a parallel backend must be registered in your R environment. +The code below will register a parallel backend with access to as many cores are available on your machine. +Note that we do not **evaluate** the code snippet below in order to interact nicely with CRAN / Github Actions environments. + +```{r, eval=FALSE} +ncores <- parallel::detectCores() +cl <- makeCluster(ncores) +registerDoParallel(cl) +``` + +Note that the `bartmodel` object contains external pointers to forests created by +the `stochtree` shared object, and when `stochtree::bart()` is run in parallel +on independent subprocesses, these pointers are not generally accessible in the +session that kicked off the parallel run. + +To overcome this, you can return a JSON representation of a `bartmodel` in memory +and combine them into a single in-memory `bartmodel` object. + +The first step of this process is to run the sampler in parallel, +storing the resulting BART JSON strings in a list. + +```{r} +bart_model_outputs <- foreach (i = 1:num_chains) %dopar% { + random_seed <- i + bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, + num_trees_mean = num_trees, random_seed = random_seed) + bart_model <- stochtree::bart( + X_train = X_train, W_train = W_train, y_train = y_train, X_test = X_test, W_test = W_test, + num_gfr = 0, num_burnin = num_burnin, num_mcmc = num_mcmc, params = bart_params, + previous_model_json = xbart_model_string, warmstart_sample_num = num_gfr - i + 1, + ) + bart_model_string <- stochtree::saveBARTModelToJsonString(bart_model) + y_hat_test <- bart_model$y_hat_test + list(model=bart_model_string, yhat=y_hat_test) +} +``` + +Close the parallel cluster (not evaluated here, as explained above). + +```{r, eval=FALSE} +stopCluster(cl) +``` + +Now, if we want to combine the forests from each of these BART models into a +single forest, we can do so as follows + +```{r} +bart_model_strings <- list() +bart_model_yhats <- matrix(NA, nrow = length(y_test), ncol = num_chains) +for (i in 1:length(bart_model_outputs)) { + bart_model_strings[[i]] <- bart_model_outputs[[i]]$model + bart_model_yhats[,i] <- rowMeans(bart_model_outputs[[i]]$yhat) +} +combined_bart <- createBARTModelFromCombinedJsonString(bart_model_strings) +``` + +We can predict from this combined forest as follows + +```{r} +yhat_combined <- predict(combined_bart, X_test, W_test)$y_hat +``` + +Compare average predictions from each chain to the original predictions. + +```{r} +par(mfrow = c(1,2)) +for (i in 1:num_chains) { + offset <- (i-1)*num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot(rowMeans(yhat_combined[,inds_start:inds_end]), bart_model_yhats[,i], + xlab = "deserialized", ylab = "original", + main = paste0("Chain ", i, "\nPredictions")) + abline(0,1,col="red",lty=3,lwd=3) +} +par(mfrow = c(1,1)) +``` + +And to the true $y$ values. + +```{r} +par(mfrow = c(1,2)) +for (i in 1:num_chains) { + offset <- (i-1)*num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot(rowMeans(yhat_combined[,inds_start:inds_end]), y_test, + xlab = "predicted", ylab = "actual", + main = paste0("Chain ", i, "\nPredictions")) + abline(0,1,col="red",lty=3,lwd=3) +} +par(mfrow = c(1,1)) +``` + # References diff --git a/vignettes/PriorCalibration.Rmd b/vignettes/PriorCalibration.Rmd index 604ddab6..5e30202b 100644 --- a/vignettes/PriorCalibration.Rmd +++ b/vignettes/PriorCalibration.Rmd @@ -2,7 +2,7 @@ title: "Prior Calibration Approaches for Parametric Components of Stochastic Tree Ensembles" output: rmarkdown::html_vignette vignette: > - %\VignetteIndexEntry{Ensemble-Kernel} + %\VignetteIndexEntry{Prior-Calibration} %\VignetteEncoding{UTF-8} %\VignetteEngine{knitr::rmarkdown} bibliography: vignettes.bib diff --git a/vignettes/TreeInspection.Rmd b/vignettes/TreeInspection.Rmd index c06a3332..fd7b41d6 100644 --- a/vignettes/TreeInspection.Rmd +++ b/vignettes/TreeInspection.Rmd @@ -2,7 +2,7 @@ title: "Deeper Dive on Sampled Forests in StochTree" output: rmarkdown::html_vignette vignette: > - %\VignetteIndexEntry{Bayesian-Supervised-Learning} + %\VignetteIndexEntry{Tree-Inspection} %\VignetteEncoding{UTF-8} %\VignetteEngine{knitr::rmarkdown} bibliography: vignettes.bib