From 5e92a2a26b779f8c7d238261a9a4e77f861d73dc Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 3 Aug 2024 00:45:55 -0500 Subject: [PATCH 01/41] Added functions to convert BCF model to JSON string --- NAMESPACE | 3 + R/bcf.R | 290 ++++++++++++++++++++++++++++ R/cpp11.R | 16 +- R/serialization.R | 32 ++- _pkgdown.yml | 2 + man/CppJson.Rd | 35 ++++ man/createBCFModelFromJson.Rd | 66 +++++++ man/createBCFModelFromJsonString.Rd | 80 ++++++++ man/createCppJsonString.Rd | 17 ++ man/saveBCFModelToJsonString.Rd | 77 ++++++++ src/cpp11.cpp | 33 +++- src/serialization.cpp | 14 +- 12 files changed, 649 insertions(+), 16 deletions(-) create mode 100644 man/createBCFModelFromJsonString.Rd create mode 100644 man/createCppJsonString.Rd create mode 100644 man/saveBCFModelToJsonString.Rd diff --git a/NAMESPACE b/NAMESPACE index ab87b7b9..03e83f44 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -11,8 +11,10 @@ export(computeForestLeafIndices) export(convertBCFModelToJson) export(createBCFModelFromJson) export(createBCFModelFromJsonFile) +export(createBCFModelFromJsonString) export(createCppJson) export(createCppJsonFile) +export(createCppJsonString) export(createForestContainer) export(createForestCovariates) export(createForestCovariatesFromMetadata) @@ -43,6 +45,7 @@ export(preprocessTrainMatrix) export(sample_sigma2_one_iteration) export(sample_tau_one_iteration) export(saveBCFModelToJsonFile) +export(saveBCFModelToJsonString) importFrom(R6,R6Class) importFrom(stats,lm) importFrom(stats,model.matrix) diff --git a/R/bcf.R b/R/bcf.R index 8e2287e2..1d554429 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -1322,6 +1322,79 @@ saveBCFModelToJsonFile <- function(object, filename){ jsonobj$save_file(filename) } +#' Convert the persistent aspects of a BCF model to (in-memory) JSON string +#' +#' @param object Object of type `bcf` containing draws of a Bayesian causal forest model and associated sampling outputs. +#' @return JSON string +#' @export +#' +#' @examples +#' n <- 500 +#' x1 <- rnorm(n) +#' x2 <- rnorm(n) +#' x3 <- rnorm(n) +#' x4 <- as.numeric(rbinom(n,1,0.5)) +#' x5 <- as.numeric(sample(1:3,n,replace=TRUE)) +#' X <- cbind(x1,x2,x3,x4,x5) +#' p <- ncol(X) +#' g <- function(x) {ifelse(x[,5]==1,2,ifelse(x[,5]==2,-1,4))} +#' mu1 <- function(x) {1+g(x)+x[,1]*x[,3]} +#' mu2 <- function(x) {1+g(x)+6*abs(x[,3]-1)} +#' tau1 <- function(x) {rep(3,nrow(x))} +#' tau2 <- function(x) {1+2*x[,2]*x[,4]} +#' mu_x <- mu1(X) +#' tau_x <- tau2(X) +#' pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 +#' Z <- rbinom(n,1,pi_x) +#' E_XZ <- mu_x + Z*tau_x +#' snr <- 3 +#' group_ids <- rep(c(1,2), n %/% 2) +#' rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) +#' rfx_basis <- cbind(1, runif(n, -1, 1)) +#' rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) +#' y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) +#' X <- as.data.frame(X) +#' X$x4 <- factor(X$x4, ordered = TRUE) +#' X$x5 <- factor(X$x5, ordered = TRUE) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' pi_test <- pi_x[test_inds] +#' pi_train <- pi_x[train_inds] +#' Z_test <- Z[test_inds] +#' Z_train <- Z[train_inds] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' mu_test <- mu_x[test_inds] +#' mu_train <- mu_x[train_inds] +#' tau_test <- tau_x[test_inds] +#' tau_train <- tau_x[train_inds] +#' group_ids_test <- group_ids[test_inds] +#' group_ids_train <- group_ids[train_inds] +#' rfx_basis_test <- rfx_basis[test_inds,] +#' rfx_basis_train <- rfx_basis[train_inds,] +#' rfx_term_test <- rfx_term[test_inds] +#' rfx_term_train <- rfx_term[train_inds] +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' pi_train = pi_train, group_ids_train = group_ids_train, +#' rfx_basis_train = rfx_basis_train, X_test = X_test, +#' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, +#' rfx_basis_test = rfx_basis_test, +#' num_gfr = 100, num_burnin = 0, num_mcmc = 100, +#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +#' # saveBCFModelToJsonString(bcf_model) +saveBCFModelToJsonString <- function(object){ + # Convert to Json + jsonobj <- convertBCFModelToJson(object) + + # Dump to string + return(jsonobj$return_json_string()) +} + #' Convert an (in-memory) JSON representation of a BCF model to a BCF model object #' which can be used for prediction, etc... #' @@ -1538,3 +1611,220 @@ createBCFModelFromJsonFile <- function(json_filename){ return(bcf_object) } + +#' Convert a JSON string containing sample information on a trained BCF model +#' to a BCF model object which can be used for prediction, etc... +#' +#' @param json_string JSON string dump +#' +#' @return Object of type `bcf` +#' @export +#' +#' @examples +#' n <- 500 +#' x1 <- rnorm(n) +#' x2 <- rnorm(n) +#' x3 <- rnorm(n) +#' x4 <- as.numeric(rbinom(n,1,0.5)) +#' x5 <- as.numeric(sample(1:3,n,replace=TRUE)) +#' X <- cbind(x1,x2,x3,x4,x5) +#' p <- ncol(X) +#' g <- function(x) {ifelse(x[,5]==1,2,ifelse(x[,5]==2,-1,4))} +#' mu1 <- function(x) {1+g(x)+x[,1]*x[,3]} +#' mu2 <- function(x) {1+g(x)+6*abs(x[,3]-1)} +#' tau1 <- function(x) {rep(3,nrow(x))} +#' tau2 <- function(x) {1+2*x[,2]*x[,4]} +#' mu_x <- mu1(X) +#' tau_x <- tau2(X) +#' pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 +#' Z <- rbinom(n,1,pi_x) +#' E_XZ <- mu_x + Z*tau_x +#' snr <- 3 +#' group_ids <- rep(c(1,2), n %/% 2) +#' rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) +#' rfx_basis <- cbind(1, runif(n, -1, 1)) +#' rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) +#' y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) +#' X <- as.data.frame(X) +#' X$x4 <- factor(X$x4, ordered = TRUE) +#' X$x5 <- factor(X$x5, ordered = TRUE) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' pi_test <- pi_x[test_inds] +#' pi_train <- pi_x[train_inds] +#' Z_test <- Z[test_inds] +#' Z_train <- Z[train_inds] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' mu_test <- mu_x[test_inds] +#' mu_train <- mu_x[train_inds] +#' tau_test <- tau_x[test_inds] +#' tau_train <- tau_x[train_inds] +#' group_ids_test <- group_ids[test_inds] +#' group_ids_train <- group_ids[train_inds] +#' rfx_basis_test <- rfx_basis[test_inds,] +#' rfx_basis_train <- rfx_basis[train_inds,] +#' rfx_term_test <- rfx_term[test_inds] +#' rfx_term_train <- rfx_term[train_inds] +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' pi_train = pi_train, group_ids_train = group_ids_train, +#' rfx_basis_train = rfx_basis_train, X_test = X_test, +#' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, +#' rfx_basis_test = rfx_basis_test, +#' num_gfr = 100, num_burnin = 0, num_mcmc = 100, +#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +#' # bcf_json <- saveBCFModelToJsonString(bcf_model) +#' # bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json) +createBCFModelFromJsonString <- function(json_string){ + # Load a `CppJson` object from string + bcf_json <- createCppJsonString(json_string) + + # Create and return the BCF object + bcf_object <- createBCFModelFromJson(bcf_json) + + return(bcf_object) +} + +#' Convert an (in-memory) JSON representation of a BCF model to a BCF model object +#' which can be used for prediction, etc... +#' +#' @param json_object Object of type `CppJson` containing Json representation of a BCF model +#' +#' @return Object of type `bcf` +#' @export +#' +#' @examples +#' n <- 500 +#' x1 <- rnorm(n) +#' x2 <- rnorm(n) +#' x3 <- rnorm(n) +#' x4 <- as.numeric(rbinom(n,1,0.5)) +#' x5 <- as.numeric(sample(1:3,n,replace=TRUE)) +#' X <- cbind(x1,x2,x3,x4,x5) +#' p <- ncol(X) +#' g <- function(x) {ifelse(x[,5]==1,2,ifelse(x[,5]==2,-1,4))} +#' mu1 <- function(x) {1+g(x)+x[,1]*x[,3]} +#' mu2 <- function(x) {1+g(x)+6*abs(x[,3]-1)} +#' tau1 <- function(x) {rep(3,nrow(x))} +#' tau2 <- function(x) {1+2*x[,2]*x[,4]} +#' mu_x <- mu1(X) +#' tau_x <- tau2(X) +#' pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 +#' Z <- rbinom(n,1,pi_x) +#' E_XZ <- mu_x + Z*tau_x +#' snr <- 3 +#' group_ids <- rep(c(1,2), n %/% 2) +#' rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) +#' rfx_basis <- cbind(1, runif(n, -1, 1)) +#' rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) +#' y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) +#' X <- as.data.frame(X) +#' X$x4 <- factor(X$x4, ordered = TRUE) +#' X$x5 <- factor(X$x5, ordered = TRUE) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' pi_test <- pi_x[test_inds] +#' pi_train <- pi_x[train_inds] +#' Z_test <- Z[test_inds] +#' Z_train <- Z[train_inds] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' mu_test <- mu_x[test_inds] +#' mu_train <- mu_x[train_inds] +#' tau_test <- tau_x[test_inds] +#' tau_train <- tau_x[train_inds] +#' group_ids_test <- group_ids[test_inds] +#' group_ids_train <- group_ids[train_inds] +#' rfx_basis_test <- rfx_basis[test_inds,] +#' rfx_basis_train <- rfx_basis[train_inds,] +#' rfx_term_test <- rfx_term[test_inds] +#' rfx_term_train <- rfx_term[train_inds] +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' pi_train = pi_train, group_ids_train = group_ids_train, +#' rfx_basis_train = rfx_basis_train, X_test = X_test, +#' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, +#' rfx_basis_test = rfx_basis_test, +#' num_gfr = 100, num_burnin = 0, num_mcmc = 100, +#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +#' # bcf_json <- convertBCFModelToJson(bcf_model) +#' # bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) +createBCFModelFromJson <- function(json_object){ + # Initialize the BCF model + output <- list() + + # Unpack the forests + output[["forests_mu"]] <- loadForestContainerJson(json_object, "forest_0") + output[["forests_tau"]] <- loadForestContainerJson(json_object, "forest_1") + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar("num_numeric_vars") + train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar("num_ordered_cat_vars") + train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar("num_unordered_cat_vars") + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector("numeric_vars") + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[["ordered_cat_vars"]] <- json_object$get_string_vector("ordered_cat_vars") + train_set_metadata[["ordered_unique_levels"]] <- json_object$get_string_list("ordered_unique_levels", train_set_metadata[["ordered_cat_vars"]]) + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { + train_set_metadata[["unordered_cat_vars"]] <- json_object$get_string_vector("unordered_cat_vars") + train_set_metadata[["unordered_unique_levels"]] <- json_object$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]]) + } + output[["train_set_metadata"]] <- train_set_metadata + output[["keep_indices"]] <- json_object$get_vector("keep_indices") + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") + model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") + model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global") + model_params[["sample_sigma_leaf_mu"]] <- json_object$get_boolean("sample_sigma_leaf_mu") + model_params[["sample_sigma_leaf_tau"]] <- json_object$get_boolean("sample_sigma_leaf_tau") + model_params[["propensity_covariate"]] <- json_object$get_string("propensity_covariate") + model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") + model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") + model_params[["adaptive_coding"]] <- json_object$get_boolean("adaptive_coding") + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar("num_samples") + model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates") + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma_global"]]) { + output[["sigma2_samples"]] <- json_object$get_vector("sigma2_samples", "parameters") + } + if (model_params[["sample_sigma_leaf_mu"]]) { + output[["sigma_leaf_mu_samples"]] <- json_object$get_vector("sigma_leaf_mu_samples", "parameters") + } + if (model_params[["sample_sigma_leaf_tau"]]) { + output[["sigma_leaf_tau_samples"]] <- json_object$get_vector("sigma_leaf_tau_samples", "parameters") + } + if (model_params[["adaptive_coding"]]) { + output[["b_1_samples"]] <- json_object$get_vector("b_1_samples", "parameters") + output[["b_0_samples"]] <- json_object$get_vector("b_0_samples", "parameters") + } + + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[["rfx_unique_group_ids"]] <- json_object$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0) + } + + class(output) <- "bcf" + return(output) +} diff --git a/R/cpp11.R b/R/cpp11.R index 16d8b449..b9e710b7 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -464,10 +464,18 @@ json_add_rfx_groupids_cpp <- function(json_ptr, groupids) { .Call(`_stochtree_json_add_rfx_groupids_cpp`, json_ptr, groupids) } -json_save_cpp <- function(json_ptr, filename) { - invisible(.Call(`_stochtree_json_save_cpp`, json_ptr, filename)) +get_json_string_cpp <- function(json_ptr) { + .Call(`_stochtree_get_json_string_cpp`, json_ptr) } -json_load_cpp <- function(json_ptr, filename) { - invisible(.Call(`_stochtree_json_load_cpp`, json_ptr, filename)) +json_save_file_cpp <- function(json_ptr, filename) { + invisible(.Call(`_stochtree_json_save_file_cpp`, json_ptr, filename)) +} + +json_load_file_cpp <- function(json_ptr, filename) { + invisible(.Call(`_stochtree_json_load_file_cpp`, json_ptr, filename)) +} + +json_load_string_cpp <- function(json_ptr, json_string) { + invisible(.Call(`_stochtree_json_load_string_cpp`, json_ptr, json_string)) } diff --git a/R/serialization.R b/R/serialization.R index eda655f5..525f5abf 100644 --- a/R/serialization.R +++ b/R/serialization.R @@ -276,12 +276,19 @@ CppJson <- R6::R6Class( return(output) }, + #' @description + #' Convert a JSON object to in-memory string + #' @return JSON string + return_json_string = function() { + return(get_json_string_cpp(self$json_ptr)) + }, + #' @description #' Save a json object to file #' @param filename String of filepath, must end in ".json" #' @return NULL save_file = function(filename) { - json_save_cpp(self$json_ptr, filename) + json_save_file_cpp(self$json_ptr, filename) }, #' @description @@ -289,7 +296,15 @@ CppJson <- R6::R6Class( #' @param filename String of filepath, must end in ".json" #' @return NULL load_from_file = function(filename) { - json_load_cpp(self$json_ptr, filename) + json_load_file_cpp(self$json_ptr, filename) + }, + + #' @description + #' Load a json object from string + #' @param json_string JSON string dump + #' @return NULL + load_from_string = function(json_string) { + json_load_string_cpp(self$json_ptr, json_string) } ) ) @@ -379,3 +394,16 @@ createCppJsonFile <- function(json_filename) { output$load_from_file(json_filename) return(output) } + +#' Create a C++ Json object from a Json string +#' +#' @param json_string JSON string dump +#' @return `CppJson` object +#' @export +createCppJsonString <- function(json_string) { + invisible(( + output <- CppJson$new() + )) + output$load_from_string(json_string) + return(output) +} diff --git a/_pkgdown.yml b/_pkgdown.yml index bffe900a..8292609e 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -18,6 +18,7 @@ reference: - predict.bcf - saveBCFModelToJsonFile - createBCFModelFromJsonFile + - createBCFModelFromJsonString - convertBCFModelToJson - createBCFModelFromJson @@ -34,6 +35,7 @@ reference: - loadVectorJson - loadScalarJson - createCppJsonFile + - createCppJsonString - subtitle: Data desc: > diff --git a/man/CppJson.Rd b/man/CppJson.Rd index fa484513..a7f7e448 100644 --- a/man/CppJson.Rd +++ b/man/CppJson.Rd @@ -45,8 +45,10 @@ Wrapper around a C++ container of tree ensembles \item \href{#method-CppJson-get_string_vector}{\code{CppJson$get_string_vector()}} \item \href{#method-CppJson-get_numeric_list}{\code{CppJson$get_numeric_list()}} \item \href{#method-CppJson-get_string_list}{\code{CppJson$get_string_list()}} +\item \href{#method-CppJson-return_json_string}{\code{CppJson$return_json_string()}} \item \href{#method-CppJson-save_file}{\code{CppJson$save_file()}} \item \href{#method-CppJson-load_from_file}{\code{CppJson$load_from_file()}} +\item \href{#method-CppJson-load_from_string}{\code{CppJson$load_from_string()}} } } \if{html}{\out{
}} @@ -421,6 +423,19 @@ NULL } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-CppJson-return_json_string}{}}} +\subsection{Method \code{return_json_string()}}{ +Convert a JSON object to in-memory string +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{CppJson$return_json_string()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +JSON string +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-CppJson-save_file}{}}} \subsection{Method \code{save_file()}}{ @@ -460,4 +475,24 @@ Load a json object from file NULL } } +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-CppJson-load_from_string}{}}} +\subsection{Method \code{load_from_string()}}{ +Load a json object from string +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{CppJson$load_from_string(json_string)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{json_string}}{JSON string dump} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +NULL +} +} } diff --git a/man/createBCFModelFromJson.Rd b/man/createBCFModelFromJson.Rd index 76c3ebcb..28b21ef0 100644 --- a/man/createBCFModelFromJson.Rd +++ b/man/createBCFModelFromJson.Rd @@ -5,15 +5,22 @@ \title{Convert an (in-memory) JSON representation of a BCF model to a BCF model object which can be used for prediction, etc...} \usage{ +createBCFModelFromJson(json_object) + createBCFModelFromJson(json_object) } \arguments{ \item{json_object}{Object of type \code{CppJson} containing Json representation of a BCF model} } \value{ +Object of type \code{bcf} + Object of type \code{bcf} } \description{ +Convert an (in-memory) JSON representation of a BCF model to a BCF model object +which can be used for prediction, etc... + Convert an (in-memory) JSON representation of a BCF model to a BCF model object which can be used for prediction, etc... } @@ -77,4 +84,63 @@ bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) # bcf_json <- convertBCFModelToJson(bcf_model) # bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) +n <- 500 +x1 <- rnorm(n) +x2 <- rnorm(n) +x3 <- rnorm(n) +x4 <- as.numeric(rbinom(n,1,0.5)) +x5 <- as.numeric(sample(1:3,n,replace=TRUE)) +X <- cbind(x1,x2,x3,x4,x5) +p <- ncol(X) +g <- function(x) {ifelse(x[,5]==1,2,ifelse(x[,5]==2,-1,4))} +mu1 <- function(x) {1+g(x)+x[,1]*x[,3]} +mu2 <- function(x) {1+g(x)+6*abs(x[,3]-1)} +tau1 <- function(x) {rep(3,nrow(x))} +tau2 <- function(x) {1+2*x[,2]*x[,4]} +mu_x <- mu1(X) +tau_x <- tau2(X) +pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 +Z <- rbinom(n,1,pi_x) +E_XZ <- mu_x + Z*tau_x +snr <- 3 +group_ids <- rep(c(1,2), n \%/\% 2) +rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) +rfx_basis <- cbind(1, runif(n, -1, 1)) +rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) +y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) +X <- as.data.frame(X) +X$x4 <- factor(X$x4, ordered = TRUE) +X$x5 <- factor(X$x5, ordered = TRUE) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +group_ids_test <- group_ids[test_inds] +group_ids_train <- group_ids[train_inds] +rfx_basis_test <- rfx_basis[test_inds,] +rfx_basis_train <- rfx_basis[train_inds,] +rfx_term_test <- rfx_term[test_inds] +rfx_term_train <- rfx_term[train_inds] +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + pi_train = pi_train, group_ids_train = group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, + rfx_basis_test = rfx_basis_test, + num_gfr = 100, num_burnin = 0, num_mcmc = 100, + sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +# bcf_json <- convertBCFModelToJson(bcf_model) +# bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) } diff --git a/man/createBCFModelFromJsonString.Rd b/man/createBCFModelFromJsonString.Rd new file mode 100644 index 00000000..b25557ab --- /dev/null +++ b/man/createBCFModelFromJsonString.Rd @@ -0,0 +1,80 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bcf.R +\name{createBCFModelFromJsonString} +\alias{createBCFModelFromJsonString} +\title{Convert a JSON string containing sample information on a trained BCF model +to a BCF model object which can be used for prediction, etc...} +\usage{ +createBCFModelFromJsonString(json_string) +} +\arguments{ +\item{json_string}{JSON string dump} +} +\value{ +Object of type \code{bcf} +} +\description{ +Convert a JSON string containing sample information on a trained BCF model +to a BCF model object which can be used for prediction, etc... +} +\examples{ +n <- 500 +x1 <- rnorm(n) +x2 <- rnorm(n) +x3 <- rnorm(n) +x4 <- as.numeric(rbinom(n,1,0.5)) +x5 <- as.numeric(sample(1:3,n,replace=TRUE)) +X <- cbind(x1,x2,x3,x4,x5) +p <- ncol(X) +g <- function(x) {ifelse(x[,5]==1,2,ifelse(x[,5]==2,-1,4))} +mu1 <- function(x) {1+g(x)+x[,1]*x[,3]} +mu2 <- function(x) {1+g(x)+6*abs(x[,3]-1)} +tau1 <- function(x) {rep(3,nrow(x))} +tau2 <- function(x) {1+2*x[,2]*x[,4]} +mu_x <- mu1(X) +tau_x <- tau2(X) +pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 +Z <- rbinom(n,1,pi_x) +E_XZ <- mu_x + Z*tau_x +snr <- 3 +group_ids <- rep(c(1,2), n \%/\% 2) +rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) +rfx_basis <- cbind(1, runif(n, -1, 1)) +rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) +y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) +X <- as.data.frame(X) +X$x4 <- factor(X$x4, ordered = TRUE) +X$x5 <- factor(X$x5, ordered = TRUE) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +group_ids_test <- group_ids[test_inds] +group_ids_train <- group_ids[train_inds] +rfx_basis_test <- rfx_basis[test_inds,] +rfx_basis_train <- rfx_basis[train_inds,] +rfx_term_test <- rfx_term[test_inds] +rfx_term_train <- rfx_term[train_inds] +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + pi_train = pi_train, group_ids_train = group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, + rfx_basis_test = rfx_basis_test, + num_gfr = 100, num_burnin = 0, num_mcmc = 100, + sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +# bcf_json <- saveBCFModelToJsonString(bcf_model) +# bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json) +} diff --git a/man/createCppJsonString.Rd b/man/createCppJsonString.Rd new file mode 100644 index 00000000..a8215cc6 --- /dev/null +++ b/man/createCppJsonString.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/serialization.R +\name{createCppJsonString} +\alias{createCppJsonString} +\title{Create a C++ Json object from a Json string} +\usage{ +createCppJsonString(json_string) +} +\arguments{ +\item{json_string}{JSON string dump} +} +\value{ +\code{CppJson} object +} +\description{ +Create a C++ Json object from a Json string +} diff --git a/man/saveBCFModelToJsonString.Rd b/man/saveBCFModelToJsonString.Rd new file mode 100644 index 00000000..7dd31418 --- /dev/null +++ b/man/saveBCFModelToJsonString.Rd @@ -0,0 +1,77 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bcf.R +\name{saveBCFModelToJsonString} +\alias{saveBCFModelToJsonString} +\title{Convert the persistent aspects of a BCF model to (in-memory) JSON string} +\usage{ +saveBCFModelToJsonString(object) +} +\arguments{ +\item{object}{Object of type \code{bcf} containing draws of a Bayesian causal forest model and associated sampling outputs.} +} +\value{ +JSON string +} +\description{ +Convert the persistent aspects of a BCF model to (in-memory) JSON string +} +\examples{ +n <- 500 +x1 <- rnorm(n) +x2 <- rnorm(n) +x3 <- rnorm(n) +x4 <- as.numeric(rbinom(n,1,0.5)) +x5 <- as.numeric(sample(1:3,n,replace=TRUE)) +X <- cbind(x1,x2,x3,x4,x5) +p <- ncol(X) +g <- function(x) {ifelse(x[,5]==1,2,ifelse(x[,5]==2,-1,4))} +mu1 <- function(x) {1+g(x)+x[,1]*x[,3]} +mu2 <- function(x) {1+g(x)+6*abs(x[,3]-1)} +tau1 <- function(x) {rep(3,nrow(x))} +tau2 <- function(x) {1+2*x[,2]*x[,4]} +mu_x <- mu1(X) +tau_x <- tau2(X) +pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 +Z <- rbinom(n,1,pi_x) +E_XZ <- mu_x + Z*tau_x +snr <- 3 +group_ids <- rep(c(1,2), n \%/\% 2) +rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) +rfx_basis <- cbind(1, runif(n, -1, 1)) +rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) +y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) +X <- as.data.frame(X) +X$x4 <- factor(X$x4, ordered = TRUE) +X$x5 <- factor(X$x5, ordered = TRUE) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +group_ids_test <- group_ids[test_inds] +group_ids_train <- group_ids[train_inds] +rfx_basis_test <- rfx_basis[test_inds,] +rfx_basis_train <- rfx_basis[train_inds,] +rfx_term_test <- rfx_term[test_inds] +rfx_term_train <- rfx_term[train_inds] +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + pi_train = pi_train, group_ids_train = group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, + rfx_basis_test = rfx_basis_test, + num_gfr = 100, num_burnin = 0, num_mcmc = 100, + sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) +# saveBCFModelToJsonString(bcf_model) +} diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 53423c30..6a5b883f 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -858,18 +858,33 @@ extern "C" SEXP _stochtree_json_add_rfx_groupids_cpp(SEXP json_ptr, SEXP groupid END_CPP11 } // serialization.cpp -void json_save_cpp(cpp11::external_pointer json_ptr, std::string filename); -extern "C" SEXP _stochtree_json_save_cpp(SEXP json_ptr, SEXP filename) { +std::string get_json_string_cpp(cpp11::external_pointer json_ptr); +extern "C" SEXP _stochtree_get_json_string_cpp(SEXP json_ptr) { BEGIN_CPP11 - json_save_cpp(cpp11::as_cpp>>(json_ptr), cpp11::as_cpp>(filename)); + return cpp11::as_sexp(get_json_string_cpp(cpp11::as_cpp>>(json_ptr))); + END_CPP11 +} +// serialization.cpp +void json_save_file_cpp(cpp11::external_pointer json_ptr, std::string filename); +extern "C" SEXP _stochtree_json_save_file_cpp(SEXP json_ptr, SEXP filename) { + BEGIN_CPP11 + json_save_file_cpp(cpp11::as_cpp>>(json_ptr), cpp11::as_cpp>(filename)); + return R_NilValue; + END_CPP11 +} +// serialization.cpp +void json_load_file_cpp(cpp11::external_pointer json_ptr, std::string filename); +extern "C" SEXP _stochtree_json_load_file_cpp(SEXP json_ptr, SEXP filename) { + BEGIN_CPP11 + json_load_file_cpp(cpp11::as_cpp>>(json_ptr), cpp11::as_cpp>(filename)); return R_NilValue; END_CPP11 } // serialization.cpp -void json_load_cpp(cpp11::external_pointer json_ptr, std::string filename); -extern "C" SEXP _stochtree_json_load_cpp(SEXP json_ptr, SEXP filename) { +void json_load_string_cpp(cpp11::external_pointer json_ptr, std::string json_string); +extern "C" SEXP _stochtree_json_load_string_cpp(SEXP json_ptr, SEXP json_string) { BEGIN_CPP11 - json_load_cpp(cpp11::as_cpp>>(json_ptr), cpp11::as_cpp>(filename)); + json_load_string_cpp(cpp11::as_cpp>>(json_ptr), cpp11::as_cpp>(json_string)); return R_NilValue; END_CPP11 } @@ -910,6 +925,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4}, {"_stochtree_get_forest_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_forest_split_counts_forest_container_cpp, 3}, {"_stochtree_get_granular_split_count_array_forest_container_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_forest_container_cpp, 2}, + {"_stochtree_get_json_string_cpp", (DL_FUNC) &_stochtree_get_json_string_cpp, 1}, {"_stochtree_get_overall_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_forest_container_cpp, 2}, {"_stochtree_get_residual_cpp", (DL_FUNC) &_stochtree_get_residual_cpp, 1}, {"_stochtree_get_tree_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_get_tree_leaves_forest_container_cpp, 3}, @@ -943,9 +959,10 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_json_extract_vector_cpp", (DL_FUNC) &_stochtree_json_extract_vector_cpp, 2}, {"_stochtree_json_extract_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_vector_subfolder_cpp, 3}, {"_stochtree_json_increment_rfx_count_cpp", (DL_FUNC) &_stochtree_json_increment_rfx_count_cpp, 1}, - {"_stochtree_json_load_cpp", (DL_FUNC) &_stochtree_json_load_cpp, 2}, + {"_stochtree_json_load_file_cpp", (DL_FUNC) &_stochtree_json_load_file_cpp, 2}, {"_stochtree_json_load_forest_container_cpp", (DL_FUNC) &_stochtree_json_load_forest_container_cpp, 2}, - {"_stochtree_json_save_cpp", (DL_FUNC) &_stochtree_json_save_cpp, 2}, + {"_stochtree_json_load_string_cpp", (DL_FUNC) &_stochtree_json_load_string_cpp, 2}, + {"_stochtree_json_save_file_cpp", (DL_FUNC) &_stochtree_json_save_file_cpp, 2}, {"_stochtree_json_save_forest_container_cpp", (DL_FUNC) &_stochtree_json_save_forest_container_cpp, 2}, {"_stochtree_num_samples_forest_container_cpp", (DL_FUNC) &_stochtree_num_samples_forest_container_cpp, 1}, {"_stochtree_num_trees_forest_container_cpp", (DL_FUNC) &_stochtree_num_trees_forest_container_cpp, 1}, diff --git a/src/serialization.cpp b/src/serialization.cpp index ba734757..3593f1a5 100644 --- a/src/serialization.cpp +++ b/src/serialization.cpp @@ -305,15 +305,25 @@ std::string json_add_rfx_groupids_cpp(cpp11::external_pointer js } [[cpp11::register]] -void json_save_cpp(cpp11::external_pointer json_ptr, std::string filename) { +std::string get_json_string_cpp(cpp11::external_pointer json_ptr) { + return json_ptr->dump(); +} + +[[cpp11::register]] +void json_save_file_cpp(cpp11::external_pointer json_ptr, std::string filename) { std::ofstream output_file(filename); output_file << *json_ptr << std::endl; } [[cpp11::register]] -void json_load_cpp(cpp11::external_pointer json_ptr, std::string filename) { +void json_load_file_cpp(cpp11::external_pointer json_ptr, std::string filename) { std::ifstream f(filename); // nlohmann::json file_json = nlohmann::json::parse(f); *json_ptr = nlohmann::json::parse(f); // json_ptr.reset(&file_json); } + +[[cpp11::register]] +void json_load_string_cpp(cpp11::external_pointer json_ptr, std::string json_string) { + *json_ptr = nlohmann::json::parse(json_string); +} From bd060616ff6ebfc4efc25d1190dbc7801185d10f Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sun, 4 Aug 2024 08:50:01 -0500 Subject: [PATCH 02/41] Allow BART json serialization in R --- NAMESPACE | 6 + R/bart.R | 344 +++++++++++++++++++++++++++ R/bcf.R | 139 ----------- man/convertBARTModelToJson.Rd | 41 ++++ man/createBARTModelFromJson.Rd | 44 ++++ man/createBARTModelFromJsonFile.Rd | 44 ++++ man/createBARTModelFromJsonString.Rd | 46 ++++ man/createBCFModelFromJson.Rd | 66 ----- man/saveBARTModelToJsonFile.Rd | 40 ++++ man/saveBARTModelToJsonString.Rd | 41 ++++ 10 files changed, 606 insertions(+), 205 deletions(-) create mode 100644 man/convertBARTModelToJson.Rd create mode 100644 man/createBARTModelFromJson.Rd create mode 100644 man/createBARTModelFromJsonFile.Rd create mode 100644 man/createBARTModelFromJsonString.Rd create mode 100644 man/saveBARTModelToJsonFile.Rd create mode 100644 man/saveBARTModelToJsonString.Rd diff --git a/NAMESPACE b/NAMESPACE index 03e83f44..5c8c8869 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -8,7 +8,11 @@ export(bart) export(bcf) export(computeForestKernels) export(computeForestLeafIndices) +export(convertBARTModelToJson) export(convertBCFModelToJson) +export(createBARTModelFromJson) +export(createBARTModelFromJsonFile) +export(createBARTModelFromJsonString) export(createBCFModelFromJson) export(createBCFModelFromJsonFile) export(createBCFModelFromJsonString) @@ -44,6 +48,8 @@ export(preprocessTrainDataFrame) export(preprocessTrainMatrix) export(sample_sigma2_one_iteration) export(sample_tau_one_iteration) +export(saveBARTModelToJsonFile) +export(saveBARTModelToJsonString) export(saveBCFModelToJsonFile) export(saveBCFModelToJsonString) importFrom(R6,R6Class) diff --git a/R/bart.R b/R/bart.R index 08f79549..9eaeb47c 100644 --- a/R/bart.R +++ b/R/bart.R @@ -688,3 +688,347 @@ getRandomEffectSamples.bartmodel <- function(object, ...){ return(result) } + +#' Convert the persistent aspects of a BART model to (in-memory) JSON +#' +#' @param object Object of type `bartmodel` containing draws of a BART model and associated sampling outputs. +#' +#' @return Object of type `CppJson` +#' @export +#' +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' f_XW <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +#' ) +#' noise_sd <- 1 +#' y <- f_XW + rnorm(n, 0, noise_sd) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' bart_model <- bart(X_train = X_train, y_train = y_train) +#' # bart_json <- convertBARTModelToJson(bart_model) +convertBARTModelToJson <- function(object){ + jsonobj <- createCppJson() + + if (is.null(object$model_params)) { + stop("This BCF model has not yet been sampled") + } + + # Add the forests + jsonobj$add_forest(object$forests) + + # Add metadata + jsonobj$add_scalar("num_numeric_vars", object$train_set_metadata$num_numeric_vars) + jsonobj$add_scalar("num_ordered_cat_vars", object$train_set_metadata$num_ordered_cat_vars) + jsonobj$add_scalar("num_unordered_cat_vars", object$train_set_metadata$num_unordered_cat_vars) + if (object$train_set_metadata$num_numeric_vars > 0) { + jsonobj$add_string_vector("numeric_vars", object$train_set_metadata$numeric_vars) + } + if (object$train_set_metadata$num_ordered_cat_vars > 0) { + jsonobj$add_string_vector("ordered_cat_vars", object$train_set_metadata$ordered_cat_vars) + jsonobj$add_string_list("ordered_unique_levels", object$train_set_metadata$ordered_unique_levels) + } + if (object$train_set_metadata$num_unordered_cat_vars > 0) { + jsonobj$add_string_vector("unordered_cat_vars", object$train_set_metadata$unordered_cat_vars) + jsonobj$add_string_list("unordered_unique_levels", object$train_set_metadata$unordered_unique_levels) + } + + # Add global parameters + jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) + jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) + jsonobj$add_boolean("sample_sigma", object$model_params$sample_sigma) + jsonobj$add_boolean("sample_tau", object$model_params$sample_tau) + jsonobj$add_boolean("has_rfx", object$model_params$has_rfx) + jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis) + jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis) + jsonobj$add_scalar("num_gfr", object$model_params$num_gfr) + jsonobj$add_scalar("num_burnin", object$model_params$num_burnin) + jsonobj$add_scalar("num_mcmc", object$model_params$num_mcmc) + jsonobj$add_scalar("num_samples", object$model_params$num_samples) + jsonobj$add_scalar("num_covariates", object$model_params$num_covariates) + jsonobj$add_scalar("num_basis", object$model_params$num_basis) + jsonobj$add_boolean("requires_basis", object$model_params$requires_basis) + jsonobj$add_vector("keep_indices", object$keep_indices) + if (object$model_params$sample_sigma) { + jsonobj$add_vector("sigma2_samples", object$sigma2_samples, "parameters") + } + if (object$model_params$sample_tau) { + jsonobj$add_vector("tau_samples", object$tau_samples, "parameters") + } + + # Add random effects (if present) + if (object$model_params$has_rfx) { + jsonobj$add_random_effects(object$rfx_samples) + jsonobj$add_string_vector("rfx_unique_group_ids", object$rfx_unique_group_ids) + } + + return(jsonobj) +} + +#' Convert the persistent aspects of a BART model to (in-memory) JSON and save to a file +#' +#' @param object Object of type `bartmodel` containing draws of a BART model and associated sampling outputs. +#' @param filename String of filepath, must end in ".json" +#' +#' @return NULL +#' @export +#' +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' f_XW <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +#' ) +#' noise_sd <- 1 +#' y <- f_XW + rnorm(n, 0, noise_sd) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' bart_model <- bart(X_train = X_train, y_train = y_train) +#' # saveBARTModelToJsonFile(bart_model, "test.json") +saveBARTModelToJsonFile <- function(object, filename){ + # Convert to Json + jsonobj <- convertBARTModelToJson(object) + + # Save to file + jsonobj$save_file(filename) +} + +#' Convert the persistent aspects of a BART model to (in-memory) JSON string +#' +#' @param object Object of type `bartmodel` containing draws of a BART model and associated sampling outputs. +#' @return JSON string +#' @export +#' +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' f_XW <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +#' ) +#' noise_sd <- 1 +#' y <- f_XW + rnorm(n, 0, noise_sd) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' bart_model <- bart(X_train = X_train, y_train = y_train) +#' # saveBARTModelToJsonString(bart_model) +saveBARTModelToJsonString <- function(object){ + # Convert to Json + jsonobj <- convertBARTModelToJson(object) + + # Dump to string + return(jsonobj$return_json_string()) +} + +#' Convert an (in-memory) JSON representation of a BART model to a BART model object +#' which can be used for prediction, etc... +#' +#' @param json_object Object of type `CppJson` containing Json representation of a BART model +#' +#' @return Object of type `bartmodel` +#' @export +#' +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' f_XW <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +#' ) +#' noise_sd <- 1 +#' y <- f_XW + rnorm(n, 0, noise_sd) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' bart_model <- bart(X_train = X_train, y_train = y_train) +#' # bart_json <- convertBARTModelToJson(bart_model) +#' # bart_model_roundtrip <- createBARTModelFromJson(bart_json) +createBARTModelFromJson <- function(json_object){ + # Initialize the BCF model + output <- list() + + # Unpack the forests + output[["forests"]] <- loadForestContainerJson(json_object, "forest_0") + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar("num_numeric_vars") + train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar("num_ordered_cat_vars") + train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar("num_unordered_cat_vars") + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector("numeric_vars") + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[["ordered_cat_vars"]] <- json_object$get_string_vector("ordered_cat_vars") + train_set_metadata[["ordered_unique_levels"]] <- json_object$get_string_list("ordered_unique_levels", train_set_metadata[["ordered_cat_vars"]]) + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { + train_set_metadata[["unordered_cat_vars"]] <- json_object$get_string_vector("unordered_cat_vars") + train_set_metadata[["unordered_unique_levels"]] <- json_object$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]]) + } + output[["train_set_metadata"]] <- train_set_metadata + output[["keep_indices"]] <- json_object$get_vector("keep_indices") + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") + model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") + model_params[["sample_sigma"]] <- json_object$get_boolean("sample_sigma") + model_params[["sample_tau"]] <- json_object$get_boolean("sample_tau") + model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") + model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar("num_samples") + model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates") + model_params[["num_basis"]] <- json_object$get_scalar("num_basis") + model_params[["requires_basis"]] <- json_object$get_boolean("requires_basis") + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma"]]) { + output[["sigma2_samples"]] <- json_object$get_vector("sigma2_samples", "parameters") + } + if (model_params[["sample_tau"]]) { + output[["tau_samples"]] <- json_object$get_vector("tau_samples", "parameters") + } + + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[["rfx_unique_group_ids"]] <- json_object$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0) + } + + class(output) <- "bartmodel" + return(output) +} + +#' Convert a JSON file containing sample information on a trained BART model +#' to a BART model object which can be used for prediction, etc... +#' +#' @param json_filename String of filepath, must end in ".json" +#' +#' @return Object of type `bartmodel` +#' @export +#' +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' f_XW <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +#' ) +#' noise_sd <- 1 +#' y <- f_XW + rnorm(n, 0, noise_sd) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' bart_model <- bart(X_train = X_train, y_train = y_train) +#' # saveBARTModelToJsonFile(bart_model, "test.json") +#' # bart_model_roundtrip <- createBARTModelFromJsonFile("test.json") +createBARTModelFromJsonFile <- function(json_filename){ + # Load a `CppJson` object from file + bart_json <- createCppJsonFile(json_filename) + + # Create and return the BCF object + bart_object <- createBARTModelFromJson(bart_json) + + return(bart_object) +} + +#' Convert a JSON string containing sample information on a trained BART model +#' to a BART model object which can be used for prediction, etc... +#' +#' @param json_string JSON string dump +#' +#' @return Object of type `bartmodel` +#' @export +#' +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' f_XW <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +#' ) +#' noise_sd <- 1 +#' y <- f_XW + rnorm(n, 0, noise_sd) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' bart_model <- bart(X_train = X_train, y_train = y_train) +#' # bart_json <- saveBARTModelToJsonString(bart_model) +#' # bart_model_roundtrip <- createBARTModelFromJsonString(bart_json) +#' # y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat) +#' # plot(rowMeans(bart_model$y_hat_train), y_hat_mean_roundtrip) +createBARTModelFromJsonString <- function(json_string){ + # Load a `CppJson` object from string + bart_json <- createCppJsonString(json_string) + + # Create and return the BCF object + bart_object <- createBARTModelFromJson(bart_json) + + return(bart_object) +} diff --git a/R/bcf.R b/R/bcf.R index 1d554429..b959ef2d 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -1689,142 +1689,3 @@ createBCFModelFromJsonString <- function(json_string){ return(bcf_object) } - -#' Convert an (in-memory) JSON representation of a BCF model to a BCF model object -#' which can be used for prediction, etc... -#' -#' @param json_object Object of type `CppJson` containing Json representation of a BCF model -#' -#' @return Object of type `bcf` -#' @export -#' -#' @examples -#' n <- 500 -#' x1 <- rnorm(n) -#' x2 <- rnorm(n) -#' x3 <- rnorm(n) -#' x4 <- as.numeric(rbinom(n,1,0.5)) -#' x5 <- as.numeric(sample(1:3,n,replace=TRUE)) -#' X <- cbind(x1,x2,x3,x4,x5) -#' p <- ncol(X) -#' g <- function(x) {ifelse(x[,5]==1,2,ifelse(x[,5]==2,-1,4))} -#' mu1 <- function(x) {1+g(x)+x[,1]*x[,3]} -#' mu2 <- function(x) {1+g(x)+6*abs(x[,3]-1)} -#' tau1 <- function(x) {rep(3,nrow(x))} -#' tau2 <- function(x) {1+2*x[,2]*x[,4]} -#' mu_x <- mu1(X) -#' tau_x <- tau2(X) -#' pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 -#' Z <- rbinom(n,1,pi_x) -#' E_XZ <- mu_x + Z*tau_x -#' snr <- 3 -#' group_ids <- rep(c(1,2), n %/% 2) -#' rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) -#' rfx_basis <- cbind(1, runif(n, -1, 1)) -#' rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) -#' y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) -#' X <- as.data.frame(X) -#' X$x4 <- factor(X$x4, ordered = TRUE) -#' X$x5 <- factor(X$x5, ordered = TRUE) -#' test_set_pct <- 0.2 -#' n_test <- round(test_set_pct*n) -#' n_train <- n - n_test -#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) -#' train_inds <- (1:n)[!((1:n) %in% test_inds)] -#' X_test <- X[test_inds,] -#' X_train <- X[train_inds,] -#' pi_test <- pi_x[test_inds] -#' pi_train <- pi_x[train_inds] -#' Z_test <- Z[test_inds] -#' Z_train <- Z[train_inds] -#' y_test <- y[test_inds] -#' y_train <- y[train_inds] -#' mu_test <- mu_x[test_inds] -#' mu_train <- mu_x[train_inds] -#' tau_test <- tau_x[test_inds] -#' tau_train <- tau_x[train_inds] -#' group_ids_test <- group_ids[test_inds] -#' group_ids_train <- group_ids[train_inds] -#' rfx_basis_test <- rfx_basis[test_inds,] -#' rfx_basis_train <- rfx_basis[train_inds,] -#' rfx_term_test <- rfx_term[test_inds] -#' rfx_term_train <- rfx_term[train_inds] -#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' pi_train = pi_train, group_ids_train = group_ids_train, -#' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, -#' rfx_basis_test = rfx_basis_test, -#' num_gfr = 100, num_burnin = 0, num_mcmc = 100, -#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) -#' # bcf_json <- convertBCFModelToJson(bcf_model) -#' # bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) -createBCFModelFromJson <- function(json_object){ - # Initialize the BCF model - output <- list() - - # Unpack the forests - output[["forests_mu"]] <- loadForestContainerJson(json_object, "forest_0") - output[["forests_tau"]] <- loadForestContainerJson(json_object, "forest_1") - - # Unpack metadata - train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar("num_numeric_vars") - train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar("num_ordered_cat_vars") - train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar("num_unordered_cat_vars") - if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector("numeric_vars") - } - if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[["ordered_cat_vars"]] <- json_object$get_string_vector("ordered_cat_vars") - train_set_metadata[["ordered_unique_levels"]] <- json_object$get_string_list("ordered_unique_levels", train_set_metadata[["ordered_cat_vars"]]) - } - if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[["unordered_cat_vars"]] <- json_object$get_string_vector("unordered_cat_vars") - train_set_metadata[["unordered_unique_levels"]] <- json_object$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]]) - } - output[["train_set_metadata"]] <- train_set_metadata - output[["keep_indices"]] <- json_object$get_vector("keep_indices") - - # Unpack model params - model_params = list() - model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") - model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") - model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global") - model_params[["sample_sigma_leaf_mu"]] <- json_object$get_boolean("sample_sigma_leaf_mu") - model_params[["sample_sigma_leaf_tau"]] <- json_object$get_boolean("sample_sigma_leaf_tau") - model_params[["propensity_covariate"]] <- json_object$get_string("propensity_covariate") - model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") - model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") - model_params[["adaptive_coding"]] <- json_object$get_boolean("adaptive_coding") - model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar("num_samples") - model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates") - output[["model_params"]] <- model_params - - # Unpack sampled parameters - if (model_params[["sample_sigma_global"]]) { - output[["sigma2_samples"]] <- json_object$get_vector("sigma2_samples", "parameters") - } - if (model_params[["sample_sigma_leaf_mu"]]) { - output[["sigma_leaf_mu_samples"]] <- json_object$get_vector("sigma_leaf_mu_samples", "parameters") - } - if (model_params[["sample_sigma_leaf_tau"]]) { - output[["sigma_leaf_tau_samples"]] <- json_object$get_vector("sigma_leaf_tau_samples", "parameters") - } - if (model_params[["adaptive_coding"]]) { - output[["b_1_samples"]] <- json_object$get_vector("b_1_samples", "parameters") - output[["b_0_samples"]] <- json_object$get_vector("b_0_samples", "parameters") - } - - # Unpack random effects - if (model_params[["has_rfx"]]) { - output[["rfx_unique_group_ids"]] <- json_object$get_string_vector("rfx_unique_group_ids") - output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0) - } - - class(output) <- "bcf" - return(output) -} diff --git a/man/convertBARTModelToJson.Rd b/man/convertBARTModelToJson.Rd new file mode 100644 index 00000000..de28613a --- /dev/null +++ b/man/convertBARTModelToJson.Rd @@ -0,0 +1,41 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{convertBARTModelToJson} +\alias{convertBARTModelToJson} +\title{Convert the persistent aspects of a BART model to (in-memory) JSON} +\usage{ +convertBARTModelToJson(object) +} +\arguments{ +\item{object}{Object of type \code{bartmodel} containing draws of a BART model and associated sampling outputs.} +} +\value{ +Object of type \code{CppJson} +} +\description{ +Convert the persistent aspects of a BART model to (in-memory) JSON +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +bart_model <- bart(X_train = X_train, y_train = y_train) +# bart_json <- convertBARTModelToJson(bart_model) +} diff --git a/man/createBARTModelFromJson.Rd b/man/createBARTModelFromJson.Rd new file mode 100644 index 00000000..0ebea7ee --- /dev/null +++ b/man/createBARTModelFromJson.Rd @@ -0,0 +1,44 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{createBARTModelFromJson} +\alias{createBARTModelFromJson} +\title{Convert an (in-memory) JSON representation of a BART model to a BART model object +which can be used for prediction, etc...} +\usage{ +createBARTModelFromJson(json_object) +} +\arguments{ +\item{json_object}{Object of type \code{CppJson} containing Json representation of a BART model} +} +\value{ +Object of type \code{bartmodel} +} +\description{ +Convert an (in-memory) JSON representation of a BART model to a BART model object +which can be used for prediction, etc... +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +bart_model <- bart(X_train = X_train, y_train = y_train) +# bart_json <- convertBARTModelToJson(bart_model) +# bart_model_roundtrip <- createBARTModelFromJson(bart_json) +} diff --git a/man/createBARTModelFromJsonFile.Rd b/man/createBARTModelFromJsonFile.Rd new file mode 100644 index 00000000..e776bb6f --- /dev/null +++ b/man/createBARTModelFromJsonFile.Rd @@ -0,0 +1,44 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{createBARTModelFromJsonFile} +\alias{createBARTModelFromJsonFile} +\title{Convert a JSON file containing sample information on a trained BART model +to a BART model object which can be used for prediction, etc...} +\usage{ +createBARTModelFromJsonFile(json_filename) +} +\arguments{ +\item{json_filename}{String of filepath, must end in ".json"} +} +\value{ +Object of type \code{bartmodel} +} +\description{ +Convert a JSON file containing sample information on a trained BART model +to a BART model object which can be used for prediction, etc... +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +bart_model <- bart(X_train = X_train, y_train = y_train) +# saveBARTModelToJsonFile(bart_model, "test.json") +# bart_model_roundtrip <- createBARTModelFromJsonFile("test.json") +} diff --git a/man/createBARTModelFromJsonString.Rd b/man/createBARTModelFromJsonString.Rd new file mode 100644 index 00000000..f26b9089 --- /dev/null +++ b/man/createBARTModelFromJsonString.Rd @@ -0,0 +1,46 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{createBARTModelFromJsonString} +\alias{createBARTModelFromJsonString} +\title{Convert a JSON string containing sample information on a trained BART model +to a BART model object which can be used for prediction, etc...} +\usage{ +createBARTModelFromJsonString(json_string) +} +\arguments{ +\item{json_string}{JSON string dump} +} +\value{ +Object of type \code{bartmodel} +} +\description{ +Convert a JSON string containing sample information on a trained BART model +to a BART model object which can be used for prediction, etc... +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +bart_model <- bart(X_train = X_train, y_train = y_train) +# bart_json <- saveBARTModelToJsonString(bart_model) +# bart_model_roundtrip <- createBARTModelFromJsonString(bart_json) +# y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat) +# plot(rowMeans(bart_model$y_hat_train), y_hat_mean_roundtrip) +} diff --git a/man/createBCFModelFromJson.Rd b/man/createBCFModelFromJson.Rd index 28b21ef0..76c3ebcb 100644 --- a/man/createBCFModelFromJson.Rd +++ b/man/createBCFModelFromJson.Rd @@ -5,22 +5,15 @@ \title{Convert an (in-memory) JSON representation of a BCF model to a BCF model object which can be used for prediction, etc...} \usage{ -createBCFModelFromJson(json_object) - createBCFModelFromJson(json_object) } \arguments{ \item{json_object}{Object of type \code{CppJson} containing Json representation of a BCF model} } \value{ -Object of type \code{bcf} - Object of type \code{bcf} } \description{ -Convert an (in-memory) JSON representation of a BCF model to a BCF model object -which can be used for prediction, etc... - Convert an (in-memory) JSON representation of a BCF model to a BCF model object which can be used for prediction, etc... } @@ -84,63 +77,4 @@ bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) # bcf_json <- convertBCFModelToJson(bcf_model) # bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) -n <- 500 -x1 <- rnorm(n) -x2 <- rnorm(n) -x3 <- rnorm(n) -x4 <- as.numeric(rbinom(n,1,0.5)) -x5 <- as.numeric(sample(1:3,n,replace=TRUE)) -X <- cbind(x1,x2,x3,x4,x5) -p <- ncol(X) -g <- function(x) {ifelse(x[,5]==1,2,ifelse(x[,5]==2,-1,4))} -mu1 <- function(x) {1+g(x)+x[,1]*x[,3]} -mu2 <- function(x) {1+g(x)+6*abs(x[,3]-1)} -tau1 <- function(x) {rep(3,nrow(x))} -tau2 <- function(x) {1+2*x[,2]*x[,4]} -mu_x <- mu1(X) -tau_x <- tau2(X) -pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 -Z <- rbinom(n,1,pi_x) -E_XZ <- mu_x + Z*tau_x -snr <- 3 -group_ids <- rep(c(1,2), n \%/\% 2) -rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) -rfx_basis <- cbind(1, runif(n, -1, 1)) -rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) -y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) -X <- as.data.frame(X) -X$x4 <- factor(X$x4, ordered = TRUE) -X$x5 <- factor(X$x5, ordered = TRUE) -test_set_pct <- 0.2 -n_test <- round(test_set_pct*n) -n_train <- n - n_test -test_inds <- sort(sample(1:n, n_test, replace = FALSE)) -train_inds <- (1:n)[!((1:n) \%in\% test_inds)] -X_test <- X[test_inds,] -X_train <- X[train_inds,] -pi_test <- pi_x[test_inds] -pi_train <- pi_x[train_inds] -Z_test <- Z[test_inds] -Z_train <- Z[train_inds] -y_test <- y[test_inds] -y_train <- y[train_inds] -mu_test <- mu_x[test_inds] -mu_train <- mu_x[train_inds] -tau_test <- tau_x[test_inds] -tau_train <- tau_x[train_inds] -group_ids_test <- group_ids[test_inds] -group_ids_train <- group_ids[train_inds] -rfx_basis_test <- rfx_basis[test_inds,] -rfx_basis_train <- rfx_basis[train_inds,] -rfx_term_test <- rfx_term[test_inds] -rfx_term_train <- rfx_term[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - pi_train = pi_train, group_ids_train = group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 100, num_burnin = 0, num_mcmc = 100, - sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) -# bcf_json <- convertBCFModelToJson(bcf_model) -# bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) } diff --git a/man/saveBARTModelToJsonFile.Rd b/man/saveBARTModelToJsonFile.Rd new file mode 100644 index 00000000..29763e81 --- /dev/null +++ b/man/saveBARTModelToJsonFile.Rd @@ -0,0 +1,40 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{saveBARTModelToJsonFile} +\alias{saveBARTModelToJsonFile} +\title{Convert the persistent aspects of a BART model to (in-memory) JSON and save to a file} +\usage{ +saveBARTModelToJsonFile(object, filename) +} +\arguments{ +\item{object}{Object of type \code{bartmodel} containing draws of a BART model and associated sampling outputs.} + +\item{filename}{String of filepath, must end in ".json"} +} +\description{ +Convert the persistent aspects of a BART model to (in-memory) JSON and save to a file +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +bart_model <- bart(X_train = X_train, y_train = y_train) +# saveBARTModelToJsonFile(bart_model, "test.json") +} diff --git a/man/saveBARTModelToJsonString.Rd b/man/saveBARTModelToJsonString.Rd new file mode 100644 index 00000000..031b6d1e --- /dev/null +++ b/man/saveBARTModelToJsonString.Rd @@ -0,0 +1,41 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{saveBARTModelToJsonString} +\alias{saveBARTModelToJsonString} +\title{Convert the persistent aspects of a BART model to (in-memory) JSON string} +\usage{ +saveBARTModelToJsonString(object) +} +\arguments{ +\item{object}{Object of type \code{bartmodel} containing draws of a BART model and associated sampling outputs.} +} +\value{ +JSON string +} +\description{ +Convert the persistent aspects of a BART model to (in-memory) JSON string +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +bart_model <- bart(X_train = X_train, y_train = y_train) +# saveBARTModelToJsonString(bart_model) +} From 5db1257b43e765b13634361a90c1993d2d925b76 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sun, 4 Aug 2024 09:04:40 -0500 Subject: [PATCH 03/41] Updated example code --- R/bart.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/bart.R b/R/bart.R index 9eaeb47c..30dd0e70 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1022,7 +1022,8 @@ createBARTModelFromJsonFile <- function(json_filename){ #' # bart_json <- saveBARTModelToJsonString(bart_model) #' # bart_model_roundtrip <- createBARTModelFromJsonString(bart_json) #' # y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat) -#' # plot(rowMeans(bart_model$y_hat_train), y_hat_mean_roundtrip) +#' # plot(rowMeans(bart_model$y_hat_train), y_hat_mean_roundtrip, +#' # xlab = "original", ylab = "roundtrip") createBARTModelFromJsonString <- function(json_string){ # Load a `CppJson` object from string bart_json <- createCppJsonString(json_string) From 8bd7948016f74bd398c54b4414ebd7b7e8878db0 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 8 Aug 2024 12:36:55 -0700 Subject: [PATCH 04/41] Added code to combine multiple forests --- NAMESPACE | 2 + R/cpp11.R | 12 ++ R/forest.R | 29 +++- R/serialization.R | 40 ++++++ _pkgdown.yml | 7 +- include/stochtree/container.h | 2 + man/ForestSamples.Rd | 71 +++++++++- man/createBARTModelFromJsonString.Rd | 3 +- man/loadForestContainerCombinedJson.Rd | 19 +++ man/loadForestContainerCombinedJsonString.Rd | 19 +++ src/container.cpp | 22 +++ src/cpp11.cpp | 26 ++++ src/forest.cpp | 40 ++++++ tools/debug/multichain_seq.R | 47 +++++++ vignettes/MultiChain.Rmd | 140 +++++++++++++++++++ 15 files changed, 474 insertions(+), 5 deletions(-) create mode 100644 man/loadForestContainerCombinedJson.Rd create mode 100644 man/loadForestContainerCombinedJsonString.Rd create mode 100644 tools/debug/multichain_seq.R create mode 100644 vignettes/MultiChain.Rmd diff --git a/NAMESPACE b/NAMESPACE index 5c8c8869..ecefca33 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -32,6 +32,8 @@ export(createRandomEffectsDataset) export(createRandomEffectsModel) export(createRandomEffectsTracker) export(getRandomEffectSamples) +export(loadForestContainerCombinedJson) +export(loadForestContainerCombinedJsonString) export(loadForestContainerJson) export(loadRandomEffectSamplesJson) export(loadScalarJson) diff --git a/R/cpp11.R b/R/cpp11.R index b9e710b7..4debd218 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -188,6 +188,18 @@ forest_container_from_json_cpp <- function(json_ptr, forest_label) { .Call(`_stochtree_forest_container_from_json_cpp`, json_ptr, forest_label) } +forest_container_append_from_json_cpp <- function(forest_sample_ptr, json_ptr, forest_label) { + invisible(.Call(`_stochtree_forest_container_append_from_json_cpp`, forest_sample_ptr, json_ptr, forest_label)) +} + +forest_container_from_json_string_cpp <- function(json_string, forest_label) { + .Call(`_stochtree_forest_container_from_json_string_cpp`, json_string, forest_label) +} + +forest_container_append_from_json_string_cpp <- function(forest_sample_ptr, json_string, forest_label) { + invisible(.Call(`_stochtree_forest_container_append_from_json_string_cpp`, forest_sample_ptr, json_string, forest_label)) +} + num_samples_forest_container_cpp <- function(forest_samples) { .Call(`_stochtree_num_samples_forest_container_cpp`, forest_samples) } diff --git a/R/forest.R b/R/forest.R index 953aa585..2feec6af 100644 --- a/R/forest.R +++ b/R/forest.R @@ -22,7 +22,7 @@ ForestSamples <- R6::R6Class( }, #' @description - #' Create a new ForestContainer object from a json object + #' Create a new `ForestContainer` object from a json object #' @param json_object Object of class `CppJson` #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy #' @return A new `ForestContainer` object. @@ -30,6 +30,33 @@ ForestSamples <- R6::R6Class( self$forest_container_ptr <- forest_container_from_json_cpp(json_object$json_ptr, json_forest_label) }, + #' @description + #' Append to a `ForestContainer` object from a json object + #' @param json_object Object of class `CppJson` + #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy + #' @return NULL + append_from_json = function(json_object, json_forest_label) { + forest_container_append_from_json_cpp(self$forest_container_ptr, json_object$json_ptr, json_forest_label) + }, + + #' @description + #' Create a new `ForestContainer` object from a json object + #' @param json_string JSON string which parses into object of class `CppJson` + #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy + #' @return A new `ForestContainer` object. + load_from_json_string = function(json_string, json_forest_label) { + self$forest_container_ptr <- forest_container_from_json_string_cpp(json_string, json_forest_label) + }, + + #' @description + #' Append to a `ForestContainer` object from a json object + #' @param json_string JSON string which parses into object of class `CppJson` + #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy + #' @return NULL + append_from_json_string = function(json_string, json_forest_label) { + forest_container_append_from_json_string_cpp(self$forest_container_ptr, json_string, json_forest_label) + }, + #' @description #' Predict every tree ensemble on every sample in `forest_dataset` #' @param forest_dataset `ForestDataset` R class diff --git a/R/serialization.R b/R/serialization.R index 525f5abf..cf0ba267 100644 --- a/R/serialization.R +++ b/R/serialization.R @@ -322,6 +322,46 @@ loadForestContainerJson <- function(json_object, json_forest_label) { return(output) } +#' Combine multiple JSON model objects containing forests (with the same hierarchy / schema) into a single forest_container +#' +#' @param json_object_list List of objects of class `CppJson` +#' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy (must exist in every json object in the list) +#' +#' @return `ForestSamples` object +#' @export +loadForestContainerCombinedJson <- function(json_object_list, json_forest_label) { + invisible(output <- ForestSamples$new(0,1,T)) + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[i] + if (i == 1) { + output$load_from_json(json_object, json_forest_label) + } else { + output$append_from_json(json_object, json_forest_label) + } + } + return(output) +} + +#' Combine multiple JSON strings representing model objects containing forests (with the same hierarchy / schema) into a single forest_container +#' +#' @param json_string_list List of strings that parse into objects of type `CppJson` +#' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy (must exist in every json object in the list) +#' +#' @return `ForestSamples` object +#' @export +loadForestContainerCombinedJsonString <- function(json_string_list, json_forest_label) { + invisible(output <- ForestSamples$new(0,1,T)) + for (i in 1:length(json_string_list)) { + json_string <- json_string_list[[i]] + if (i == 1) { + output$load_from_json_string(json_string, json_forest_label) + } else { + output$append_from_json_string(json_string, json_forest_label) + } + } + return(output) +} + #' Load a container of random effect samples from json #' #' @param json_object Object of class `CppJson` diff --git a/_pkgdown.yml b/_pkgdown.yml index 8292609e..f919a864 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -30,12 +30,14 @@ reference: contents: - CppJson - createCppJson + - createCppJsonFile + - createCppJsonString - loadForestContainerJson + - loadForestContainerCombinedJson + - loadForestContainerCombinedJsonString - loadRandomEffectSamplesJson - loadVectorJson - loadScalarJson - - createCppJsonFile - - createCppJsonString - subtitle: Data desc: > @@ -104,6 +106,7 @@ articles: contents: - BayesianSupervisedLearning - CausalInference + - MultiChain - ModelSerialization - title: Prototype Interface diff --git a/include/stochtree/container.h b/include/stochtree/container.h index 78139bb3..e189957a 100644 --- a/include/stochtree/container.h +++ b/include/stochtree/container.h @@ -90,6 +90,8 @@ class ForestContainer { nlohmann::json to_json(); /*! \brief Load from JSON */ void from_json(const nlohmann::json& forest_container_json); + /*! \brief Append to a forest container from JSON, requires that the ensemble already contains a nonzero number of forests */ + void append_from_json(const nlohmann::json& forest_container_json); private: std::vector> forests_; diff --git a/man/ForestSamples.Rd b/man/ForestSamples.Rd index b629ca1a..4b945926 100644 --- a/man/ForestSamples.Rd +++ b/man/ForestSamples.Rd @@ -18,6 +18,9 @@ Wrapper around a C++ container of tree ensembles \itemize{ \item \href{#method-ForestSamples-new}{\code{ForestSamples$new()}} \item \href{#method-ForestSamples-load_from_json}{\code{ForestSamples$load_from_json()}} +\item \href{#method-ForestSamples-append_from_json}{\code{ForestSamples$append_from_json()}} +\item \href{#method-ForestSamples-load_from_json_string}{\code{ForestSamples$load_from_json_string()}} +\item \href{#method-ForestSamples-append_from_json_string}{\code{ForestSamples$append_from_json_string()}} \item \href{#method-ForestSamples-predict}{\code{ForestSamples$predict()}} \item \href{#method-ForestSamples-predict_raw}{\code{ForestSamples$predict_raw()}} \item \href{#method-ForestSamples-predict_raw_single_forest}{\code{ForestSamples$predict_raw_single_forest()}} @@ -69,7 +72,7 @@ A new \code{ForestContainer} object. \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ForestSamples-load_from_json}{}}} \subsection{Method \code{load_from_json()}}{ -Create a new ForestContainer object from a json object +Create a new \code{ForestContainer} object from a json object \subsection{Usage}{ \if{html}{\out{
}}\preformatted{ForestSamples$load_from_json(json_object, json_forest_label)}\if{html}{\out{
}} } @@ -88,6 +91,72 @@ A new \code{ForestContainer} object. } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestSamples-append_from_json}{}}} +\subsection{Method \code{append_from_json()}}{ +Append to a \code{ForestContainer} object from a json object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestSamples$append_from_json(json_object, json_forest_label)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{json_object}}{Object of class \code{CppJson}} + +\item{\code{json_forest_label}}{Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +NULL +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestSamples-load_from_json_string}{}}} +\subsection{Method \code{load_from_json_string()}}{ +Create a new \code{ForestContainer} object from a json object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestSamples$load_from_json_string(json_string, json_forest_label)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{json_string}}{JSON string which parses into object of class \code{CppJson}} + +\item{\code{json_forest_label}}{Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +A new \code{ForestContainer} object. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestSamples-append_from_json_string}{}}} +\subsection{Method \code{append_from_json_string()}}{ +Append to a \code{ForestContainer} object from a json object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestSamples$append_from_json_string(json_string, json_forest_label)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{json_string}}{JSON string which parses into object of class \code{CppJson}} + +\item{\code{json_forest_label}}{Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +NULL +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ForestSamples-predict}{}}} \subsection{Method \code{predict()}}{ diff --git a/man/createBARTModelFromJsonString.Rd b/man/createBARTModelFromJsonString.Rd index f26b9089..735fb48f 100644 --- a/man/createBARTModelFromJsonString.Rd +++ b/man/createBARTModelFromJsonString.Rd @@ -42,5 +42,6 @@ bart_model <- bart(X_train = X_train, y_train = y_train) # bart_json <- saveBARTModelToJsonString(bart_model) # bart_model_roundtrip <- createBARTModelFromJsonString(bart_json) # y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat) -# plot(rowMeans(bart_model$y_hat_train), y_hat_mean_roundtrip) +# plot(rowMeans(bart_model$y_hat_train), y_hat_mean_roundtrip, +# xlab = "original", ylab = "roundtrip") } diff --git a/man/loadForestContainerCombinedJson.Rd b/man/loadForestContainerCombinedJson.Rd new file mode 100644 index 00000000..90d4e051 --- /dev/null +++ b/man/loadForestContainerCombinedJson.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/serialization.R +\name{loadForestContainerCombinedJson} +\alias{loadForestContainerCombinedJson} +\title{Combine multiple JSON model objects containing forests (with the same hierarchy / schema) into a single forest_container} +\usage{ +loadForestContainerCombinedJson(json_object_list, json_forest_label) +} +\arguments{ +\item{json_object_list}{List of objects of class \code{CppJson}} + +\item{json_forest_label}{Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy (must exist in every json object in the list)} +} +\value{ +\code{ForestSamples} object +} +\description{ +Combine multiple JSON model objects containing forests (with the same hierarchy / schema) into a single forest_container +} diff --git a/man/loadForestContainerCombinedJsonString.Rd b/man/loadForestContainerCombinedJsonString.Rd new file mode 100644 index 00000000..7b9a4d82 --- /dev/null +++ b/man/loadForestContainerCombinedJsonString.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/serialization.R +\name{loadForestContainerCombinedJsonString} +\alias{loadForestContainerCombinedJsonString} +\title{Combine multiple JSON strings representing model objects containing forests (with the same hierarchy / schema) into a single forest_container} +\usage{ +loadForestContainerCombinedJsonString(json_string_list, json_forest_label) +} +\arguments{ +\item{json_string_list}{List of strings that parse into objects of type \code{CppJson}} + +\item{json_forest_label}{Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy (must exist in every json object in the list)} +} +\value{ +\code{ForestSamples} object +} +\description{ +Combine multiple JSON strings representing model objects containing forests (with the same hierarchy / schema) into a single forest_container +} diff --git a/src/container.cpp b/src/container.cpp index 79e940d6..fac4a0ed 100644 --- a/src/container.cpp +++ b/src/container.cpp @@ -157,4 +157,26 @@ void ForestContainer::from_json(const json& forest_container_json) { } } +/*! \brief Append forests to a container from a JSON forest specification */ +void ForestContainer::append_from_json(const json& forest_container_json) { + CHECK_GT(this->num_samples_, 0); + CHECK_EQ(this->num_trees_, forest_container_json.at("num_trees")); + CHECK_EQ(this->output_dimension_, forest_container_json.at("output_dimension")); + CHECK_EQ(this->is_leaf_constant_, forest_container_json.at("is_leaf_constant")); + CHECK_EQ(this->initialized_, forest_container_json.at("initialized")); + int new_num_samples = forest_container_json.at("num_samples"); + + std::string forest_label; + // forests_.resize(this->num_samples_); + int forest_ind; + for (int i = 0; i < forest_container_json.at("num_samples"); i++) { + forest_ind = this->num_samples_ + i; + forest_label = "forest_" + std::to_string(i); + // forests_[forest_ind] = std::make_unique(this->num_trees_, this->output_dimension_, this->is_leaf_constant_); + forests_.push_back(std::make_unique(this->num_trees_, this->output_dimension_, this->is_leaf_constant_)); + forests_[forest_ind]->from_json(forest_container_json.at(forest_label)); + } + this->num_samples_ += new_num_samples; +} + } // namespace StochTree \ No newline at end of file diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 6a5b883f..ca2f168e 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -349,6 +349,29 @@ extern "C" SEXP _stochtree_forest_container_from_json_cpp(SEXP json_ptr, SEXP fo END_CPP11 } // forest.cpp +void forest_container_append_from_json_cpp(cpp11::external_pointer forest_sample_ptr, cpp11::external_pointer json_ptr, std::string forest_label); +extern "C" SEXP _stochtree_forest_container_append_from_json_cpp(SEXP forest_sample_ptr, SEXP json_ptr, SEXP forest_label) { + BEGIN_CPP11 + forest_container_append_from_json_cpp(cpp11::as_cpp>>(forest_sample_ptr), cpp11::as_cpp>>(json_ptr), cpp11::as_cpp>(forest_label)); + return R_NilValue; + END_CPP11 +} +// forest.cpp +cpp11::external_pointer forest_container_from_json_string_cpp(std::string json_string, std::string forest_label); +extern "C" SEXP _stochtree_forest_container_from_json_string_cpp(SEXP json_string, SEXP forest_label) { + BEGIN_CPP11 + return cpp11::as_sexp(forest_container_from_json_string_cpp(cpp11::as_cpp>(json_string), cpp11::as_cpp>(forest_label))); + END_CPP11 +} +// forest.cpp +void forest_container_append_from_json_string_cpp(cpp11::external_pointer forest_sample_ptr, std::string json_string, std::string forest_label); +extern "C" SEXP _stochtree_forest_container_append_from_json_string_cpp(SEXP forest_sample_ptr, SEXP json_string, SEXP forest_label) { + BEGIN_CPP11 + forest_container_append_from_json_string_cpp(cpp11::as_cpp>>(forest_sample_ptr), cpp11::as_cpp>(json_string), cpp11::as_cpp>(forest_label)); + return R_NilValue; + END_CPP11 +} +// forest.cpp int num_samples_forest_container_cpp(cpp11::external_pointer forest_samples); extern "C" SEXP _stochtree_num_samples_forest_container_cpp(SEXP forest_samples) { BEGIN_CPP11 @@ -909,8 +932,11 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_dataset_num_rows_cpp, 1}, {"_stochtree_ensemble_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_average_max_depth_forest_container_cpp, 2}, {"_stochtree_ensemble_tree_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_forest_container_cpp, 3}, + {"_stochtree_forest_container_append_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_cpp, 3}, + {"_stochtree_forest_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_string_cpp, 3}, {"_stochtree_forest_container_cpp", (DL_FUNC) &_stochtree_forest_container_cpp, 3}, {"_stochtree_forest_container_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_cpp, 2}, + {"_stochtree_forest_container_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_string_cpp, 2}, {"_stochtree_forest_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_basis_cpp, 2}, {"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2}, {"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2}, diff --git a/src/forest.cpp b/src/forest.cpp index c1f22c4e..fc2dd574 100644 --- a/src/forest.cpp +++ b/src/forest.cpp @@ -36,6 +36,46 @@ cpp11::external_pointer forest_container_from_json_c return cpp11::external_pointer(forest_sample_ptr_.release()); } +[[cpp11::register]] +void forest_container_append_from_json_cpp(cpp11::external_pointer forest_sample_ptr, cpp11::external_pointer json_ptr, std::string forest_label) { + // Extract the forest's json + nlohmann::json forest_json = json_ptr->at("forests").at(forest_label); + + // Append to the forest sample container using the json + forest_sample_ptr->append_from_json(forest_json); +} + +[[cpp11::register]] +cpp11::external_pointer forest_container_from_json_string_cpp(std::string json_string, std::string forest_label) { + // Create smart pointer to newly allocated object + std::unique_ptr forest_sample_ptr_ = std::make_unique(0, 1, true); + + // Create a nlohmann::json object from the string + nlohmann::json json_object = nlohmann::json::parse(json_string); + + // Extract the forest's json + nlohmann::json forest_json = json_object.at("forests").at(forest_label); + + // Reset the forest sample container using the json + forest_sample_ptr_->Reset(); + forest_sample_ptr_->from_json(forest_json); + + // Release management of the pointer to R session + return cpp11::external_pointer(forest_sample_ptr_.release()); +} + +[[cpp11::register]] +void forest_container_append_from_json_string_cpp(cpp11::external_pointer forest_sample_ptr, std::string json_string, std::string forest_label) { + // Create a nlohmann::json object from the string + nlohmann::json json_object = nlohmann::json::parse(json_string); + + // Extract the forest's json + nlohmann::json forest_json = json_object.at("forests").at(forest_label); + + // Append to the forest sample container using the json + forest_sample_ptr->append_from_json(forest_json); +} + [[cpp11::register]] int num_samples_forest_container_cpp(cpp11::external_pointer forest_samples) { return forest_samples->NumSamples(); diff --git a/tools/debug/multichain_seq.R b/tools/debug/multichain_seq.R new file mode 100644 index 00000000..49b159c8 --- /dev/null +++ b/tools/debug/multichain_seq.R @@ -0,0 +1,47 @@ +library(stochtree) +n <- 500 +p_x <- 10 +p_w <- 1 +snr <- 3 +X <- matrix(runif(n*p_x), ncol = p_x) +W <- matrix(runif(n*p_w), ncol = p_w) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1]) +) +noise_sd <- sd(f_XW) / snr +y <- f_XW + rnorm(n, 0, 1)*noise_sd +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- as.data.frame(X[test_inds,]) +X_train <- as.data.frame(X[train_inds,]) +W_test <- W[test_inds,] +W_train <- W[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +num_chains <- 4 +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +num_trees <- 100 +bart_models <- list() +for (i in 1:num_chains) { + bart_models[[i]] <- stochtree::bart( + X_train = X_train, W_train = W_train, y_train = y_train, + X_test = X_test, W_test = W_test, num_trees = num_trees, + num_gfr = num_gfr, num_burnin = num_burnin, + num_mcmc = num_mcmc, sample_sigma = T, sample_tau = T + ) +} +json_string_list <- list() +for (i in 1:num_chains) { + json_string_list[[i]] <- saveBARTModelToJsonString(bart_models[[i]]) +} +combined_forests <- loadForestContainerCombinedJsonString(json_string_list, "forest_0") +test_dataset <- createForestDataset(as.matrix(X_test), W_test) +yhat_combined <- combined_forests$predict(test_dataset) \ No newline at end of file diff --git a/vignettes/MultiChain.Rmd b/vignettes/MultiChain.Rmd new file mode 100644 index 00000000..4c09a49b --- /dev/null +++ b/vignettes/MultiChain.Rmd @@ -0,0 +1,140 @@ +--- +title: "Running Multiple Chains (Sequentially or in Parallel) in StochTree" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{Prototype-Interface} + %\VignetteEncoding{UTF-8} + %\VignetteEngine{knitr::rmarkdown} +bibliography: vignettes.bib +editor_options: + markdown: + wrap: 72 +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) +``` + +# Motivation + +Mixing of an MCMC sampler is a perennial concern for complex Bayesian models, +and BART is no exception. On common way to address such concerns is to run +multiple independent "chains" of an MCMC sampler, so that if each chain gets +stuck in a different region of the posterior, their combined samples attain +better coverage of the full posterior. + +This idea works with the classic "from-root" MCMC sampler of @chipman2010bart, +but a key insight of @he2023stochastic is that the XBART algorithm may be used +to warm-start initialize multiple chains of the BART MCMC sampler. + +Operationally, the above two approaches have the same implementation (setting +`num_gfr` > 0 if warm-start initialization is desired), so this vignette will +demonstrate how to run a multi-chain sampler sequentially or in parallel. + +To begin, load the `stochtree` package + +```{r setup} +library(stochtree) +``` + +# Demo 1: Supervised Learning, Sequential Multi Chain Sampler + +## Simulation + +Simulate a simple partitioned linear model + +```{r} +# Generate the data +n <- 500 +p_x <- 10 +p_w <- 1 +snr <- 3 +X <- matrix(runif(n*p_x), ncol = p_x) +W <- matrix(runif(n*p_w), ncol = p_w) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1]) +) +noise_sd <- sd(f_XW) / snr +y <- f_XW + rnorm(n, 0, 1)*noise_sd + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- as.data.frame(X[test_inds,]) +X_train <- as.data.frame(X[train_inds,]) +W_test <- W[test_inds,] +W_train <- W[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +``` + +## Sampling + +Define some high-level parameters, including number of chains to run and number of +samples per chain. Here we run 4 independent chains with 5 warm-start iterations +and 100 MCMC iterations each. + +```{r} +num_chains <- 4 +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +num_trees <- 100 +``` + +Run the sampler, storing the resulting BART objects in a list + +```{r} +bart_models <- list() +for (i in 1:num_chains) { + bart_models[[i]] <- stochtree::bart( + X_train = X_train, W_train = W_train, y_train = y_train, + X_test = X_test, W_test = W_test, num_trees = num_trees, + num_gfr = num_gfr, num_burnin = num_burnin, + num_mcmc = num_mcmc, sample_sigma = T, sample_tau = T + ) +} +``` + +Now, if we want to combine the forests from each of these BART models into a +single forest, we can do so as follows + +```{r} +json_string_list <- list() +for (i in 1:num_chains) { + json_string_list[[i]] <- saveBARTModelToJsonString(bart_models[[i]]) +} +combined_forests <- loadForestContainerCombinedJsonString(json_string_list, "forest_0") +``` + +We can predict from this combined forest as follows + +```{r} +test_dataset <- createForestDataset(as.matrix(X_test), W_test) +yhat_combined <- combined_forests$predict(test_dataset) +``` + +Compare to the original $\hat{y}$ values + +```{r} +num_samples <- num_gfr+num_burnin+num_mcmc +for (i in 1:num_chains) { + offset <- (i-1)*num_samples + inds_start <- offset + 1 + num_burnin + num_gfr + inds_end <- offset + num_samples + plot(rowMeans(bart_models[[i]]$y_hat_test), + rowMeans(yhat_combined[,inds_start:inds_end])) +} +``` + + +# References From 1ab2828e9dac19bab80346d6657d7875a79e9d0c Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 8 Aug 2024 18:01:59 -0700 Subject: [PATCH 05/41] Updated multichain vignette --- vignettes/MultiChain.Rmd | 83 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 79 insertions(+), 4 deletions(-) diff --git a/vignettes/MultiChain.Rmd b/vignettes/MultiChain.Rmd index 4c09a49b..13275117 100644 --- a/vignettes/MultiChain.Rmd +++ b/vignettes/MultiChain.Rmd @@ -34,15 +34,17 @@ Operationally, the above two approaches have the same implementation (setting `num_gfr` > 0 if warm-start initialization is desired), so this vignette will demonstrate how to run a multi-chain sampler sequentially or in parallel. -To begin, load the `stochtree` package +To begin, load `stochtree` and other necessary packages ```{r setup} library(stochtree) +library(foreach) +library(doParallel) ``` -# Demo 1: Supervised Learning, Sequential Multi Chain Sampler +# Demo 1: Supervised Learning -## Simulation +## Data Simulation Simulate a simple partitioned linear model @@ -77,7 +79,7 @@ y_test <- y[test_inds] y_train <- y[train_inds] ``` -## Sampling +## Sampling Multiple Chains Sequentially Define some high-level parameters, including number of chains to run and number of samples per chain. Here we run 4 independent chains with 5 warm-start iterations @@ -136,5 +138,78 @@ for (i in 1:num_chains) { } ``` +## Sampling Multiple Chains in Parallel + +We use the same high-level parameters as in the sequential demo. + +```{r} +num_chains <- 4 +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +num_trees <- 100 +``` + +In order to run this sampler in parallel, a parallel backend must be registered in your R environment. +The code below will register a parallel backend with access to as many cores are available on your machine. +Note that we do not **evaluate** the code snippet below in order to interact nicely with CRAN / Github Actions environments. + +```{r, eval=FALSE} +ncores <- parallel::detectCores() +cl <- makeCluster(ncores) +registerDoParallel(cl) +``` + +Run the sampler, storing the resulting BART objects in a list + +```{r} +bart_models <- foreach (i = 1:num_chains) %dopar% { + random_seed <- i + stochtree::bart( + X_train = X_train, W_train = W_train, y_train = y_train, + X_test = X_test, W_test = W_test, num_trees = num_trees, + num_gfr = num_gfr, num_burnin = num_burnin, + num_mcmc = num_mcmc, sample_sigma = T, sample_tau = T, + random_seed = random_seed + ) +} +``` + +Close the parallel cluster (not evaluated here, as explained above). + +```{r, eval=FALSE} +stopCluster(cl) +``` + +Now, if we want to combine the forests from each of these BART models into a +single forest, we can do so as follows + +```{r} +json_string_list <- list() +for (i in 1:num_chains) { + json_string_list[[i]] <- saveBARTModelToJsonString(bart_models[[i]]) +} +combined_forests <- loadForestContainerCombinedJsonString(json_string_list, "forest_0") +``` + +We can predict from this combined forest as follows + +```{r} +test_dataset <- createForestDataset(as.matrix(X_test), W_test) +yhat_combined <- combined_forests$predict(test_dataset) +``` + +Compare to the original $\hat{y}$ values + +```{r} +num_samples <- num_gfr+num_burnin+num_mcmc +for (i in 1:num_chains) { + offset <- (i-1)*num_samples + inds_start <- offset + 1 + num_burnin + num_gfr + inds_end <- offset + num_samples + plot(rowMeans(bart_models[[i]]$y_hat_test), + rowMeans(yhat_combined[,inds_start:inds_end])) +} +``` # References From 35be06771d09ae4ce6717fa2da4fb8bec85c7dac Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 8 Aug 2024 18:02:22 -0700 Subject: [PATCH 06/41] Added functions to combine random effects samples from multiple JSON strings / objects --- NAMESPACE | 1 + R/cpp11.R | 20 +++++ R/random_effects.R | 35 ++++++++ R/serialization.R | 48 ++++++++++- include/stochtree/random_effects.h | 1 + man/RandomEffectSamples.Rd | 96 ++++++++++++++++++++++ man/loadRandomEffectSamplesCombinedJson.Rd | 27 ++++++ src/R_random_effects.cpp | 79 ++++++++++++++++++ src/cpp11.cpp | 42 ++++++++++ src/random_effects.cpp | 23 ++++++ 10 files changed, 371 insertions(+), 1 deletion(-) create mode 100644 man/loadRandomEffectSamplesCombinedJson.Rd diff --git a/NAMESPACE b/NAMESPACE index ecefca33..4029e17c 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -35,6 +35,7 @@ export(getRandomEffectSamples) export(loadForestContainerCombinedJson) export(loadForestContainerCombinedJsonString) export(loadForestContainerJson) +export(loadRandomEffectSamplesCombinedJson) export(loadRandomEffectSamplesJson) export(loadScalarJson) export(loadVectorJson) diff --git a/R/cpp11.R b/R/cpp11.R index 4debd218..a6d060fd 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -96,6 +96,26 @@ rfx_group_ids_from_json_cpp <- function(json_ptr, rfx_label) { .Call(`_stochtree_rfx_group_ids_from_json_cpp`, json_ptr, rfx_label) } +rfx_container_append_from_json_cpp <- function(rfx_container_ptr, json_ptr, rfx_label) { + invisible(.Call(`_stochtree_rfx_container_append_from_json_cpp`, rfx_container_ptr, json_ptr, rfx_label)) +} + +rfx_container_from_json_string_cpp <- function(json_string, rfx_label) { + .Call(`_stochtree_rfx_container_from_json_string_cpp`, json_string, rfx_label) +} + +rfx_label_mapper_from_json_string_cpp <- function(json_string, rfx_label) { + .Call(`_stochtree_rfx_label_mapper_from_json_string_cpp`, json_string, rfx_label) +} + +rfx_group_ids_from_json_string_cpp <- function(json_string, rfx_label) { + .Call(`_stochtree_rfx_group_ids_from_json_string_cpp`, json_string, rfx_label) +} + +rfx_container_append_from_json_string_cpp <- function(rfx_container_ptr, json_string, rfx_label) { + invisible(.Call(`_stochtree_rfx_container_append_from_json_string_cpp`, rfx_container_ptr, json_string, rfx_label)) +} + rfx_model_cpp <- function(num_components, num_groups) { .Call(`_stochtree_rfx_model_cpp`, num_components, num_groups) } diff --git a/R/random_effects.R b/R/random_effects.R index dbcdbbb4..f9d0eaf9 100644 --- a/R/random_effects.R +++ b/R/random_effects.R @@ -53,6 +53,41 @@ RandomEffectSamples <- R6::R6Class( self$training_group_ids <- rfx_group_ids_from_json_cpp(json_object$json_ptr, json_rfx_groupids_label) }, + #' @description + #' Append random effect draws to `RandomEffectSamples` object from a json object + #' @param json_object Object of class `CppJson` + #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy + #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy + #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy + #' @return NULL (updates object in-place) + append_from_json = function(json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) { + rfx_container_append_from_json_cpp(self$rfx_container_ptr, json_object$json_ptr, json_rfx_container_label) + }, + + #' @description + #' Construct RandomEffectSamples object from a json object + #' @param json_string JSON string which parses into object of class `CppJson` + #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy + #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy + #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy + #' @return A new `RandomEffectSamples` object. + load_from_json_string = function(json_string, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) { + self$rfx_container_ptr <- rfx_container_from_json_string_cpp(json_object$json_ptr, json_rfx_container_label) + self$label_mapper_ptr <- rfx_label_mapper_from_json_string_cpp(json_object$json_ptr, json_rfx_mapper_label) + self$training_group_ids <- rfx_group_ids_from_json_string_cpp(json_object$json_ptr, json_rfx_groupids_label) + }, + + #' @description + #' Append random effect draws to `RandomEffectSamples` object from a json object + #' @param json_string JSON string which parses into object of class `CppJson` + #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy + #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy + #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy + #' @return NULL (updates object in-place) + append_from_json_string = function(json_string, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) { + rfx_container_append_from_json_string_cpp(self$rfx_container_ptr, json_object$json_ptr, json_rfx_container_label) + }, + #' @description #' Predict random effects for each observation implied by `rfx_group_ids` and `rfx_basis`. #' If a random effects model is "intercept-only" the `rfx_basis` will be a vector of ones of size `length(rfx_group_ids)`. diff --git a/R/serialization.R b/R/serialization.R index cf0ba267..6d25e802 100644 --- a/R/serialization.R +++ b/R/serialization.R @@ -332,7 +332,7 @@ loadForestContainerJson <- function(json_object, json_forest_label) { loadForestContainerCombinedJson <- function(json_object_list, json_forest_label) { invisible(output <- ForestSamples$new(0,1,T)) for (i in 1:length(json_object_list)) { - json_object <- json_object_list[i] + json_object <- json_object_list[[i]] if (i == 1) { output$load_from_json(json_object, json_forest_label) } else { @@ -378,6 +378,52 @@ loadRandomEffectSamplesJson <- function(json_object, json_rfx_num) { return(output) } +#' Combine multiple JSON model objects containing random effects (with the same hierarchy / schema) into a single container +#' +#' @param json_object_list List of objects of class `CppJson` +#' @param json_rfx_num Integer index indicating the position of the random effects term to be unpacked +#' +#' @return `RandomEffectSamples` object +#' @export +loadRandomEffectSamplesCombinedJson <- function(json_object_list, json_rfx_num) { + json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) + json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) + json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) + invisible(output <- RandomEffectSamples$new()) + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output$load_from_json(json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) + } else { + output$append_from_json(json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) + } + } + return(output) +} + +#' Combine multiple JSON strings representing model objects containing random effects (with the same hierarchy / schema) into a single container +#' +#' @param json_string_list List of objects of class `CppJson` +#' @param json_rfx_num Integer index indicating the position of the random effects term to be unpacked +#' +#' @return `RandomEffectSamples` object +#' @export +loadRandomEffectSamplesCombinedJson <- function(json_string_list, json_rfx_num) { + json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) + json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) + json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) + invisible(output <- RandomEffectSamples$new()) + for (i in 1:length(json_object_list)) { + json_string <- json_string_list[[i]] + if (i == 1) { + output$load_from_json(json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) + } else { + output$append_from_json(json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) + } + } + return(output) +} + #' Load a vector from json #' #' @param json_object Object of class `CppJson` diff --git a/include/stochtree/random_effects.h b/include/stochtree/random_effects.h index 7d7a65c0..623a1103 100644 --- a/include/stochtree/random_effects.h +++ b/include/stochtree/random_effects.h @@ -279,6 +279,7 @@ class RandomEffectsContainer { std::vector& GetSigma() {return sigma_xi_;} nlohmann::json to_json(); void from_json(const nlohmann::json& rfx_container_json); + void append_from_json(const nlohmann::json& rfx_container_json); private: int num_samples_; int num_components_; diff --git a/man/RandomEffectSamples.Rd b/man/RandomEffectSamples.Rd index 55887c03..90981546 100644 --- a/man/RandomEffectSamples.Rd +++ b/man/RandomEffectSamples.Rd @@ -28,6 +28,9 @@ needed for prediction / serialization \item \href{#method-RandomEffectSamples-new}{\code{RandomEffectSamples$new()}} \item \href{#method-RandomEffectSamples-load_in_session}{\code{RandomEffectSamples$load_in_session()}} \item \href{#method-RandomEffectSamples-load_from_json}{\code{RandomEffectSamples$load_from_json()}} +\item \href{#method-RandomEffectSamples-append_from_json}{\code{RandomEffectSamples$append_from_json()}} +\item \href{#method-RandomEffectSamples-load_from_json_string}{\code{RandomEffectSamples$load_from_json_string()}} +\item \href{#method-RandomEffectSamples-append_from_json_string}{\code{RandomEffectSamples$append_from_json_string()}} \item \href{#method-RandomEffectSamples-predict}{\code{RandomEffectSamples$predict()}} \item \href{#method-RandomEffectSamples-extract_parameter_samples}{\code{RandomEffectSamples$extract_parameter_samples()}} \item \href{#method-RandomEffectSamples-extract_label_mapping}{\code{RandomEffectSamples$extract_label_mapping()}} @@ -106,6 +109,99 @@ A new \code{RandomEffectSamples} object. } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectSamples-append_from_json}{}}} +\subsection{Method \code{append_from_json()}}{ +Append random effect draws to \code{RandomEffectSamples} object from a json object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectSamples$append_from_json( + json_object, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{json_object}}{Object of class \code{CppJson}} + +\item{\code{json_rfx_container_label}}{Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy} + +\item{\code{json_rfx_mapper_label}}{Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy} + +\item{\code{json_rfx_groupids_label}}{Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +NULL (updates object in-place) +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectSamples-load_from_json_string}{}}} +\subsection{Method \code{load_from_json_string()}}{ +Construct RandomEffectSamples object from a json object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectSamples$load_from_json_string( + json_string, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{json_string}}{JSON string which parses into object of class \code{CppJson}} + +\item{\code{json_rfx_container_label}}{Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy} + +\item{\code{json_rfx_mapper_label}}{Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy} + +\item{\code{json_rfx_groupids_label}}{Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +A new \code{RandomEffectSamples} object. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectSamples-append_from_json_string}{}}} +\subsection{Method \code{append_from_json_string()}}{ +Append random effect draws to \code{RandomEffectSamples} object from a json object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectSamples$append_from_json_string( + json_string, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{json_string}}{JSON string which parses into object of class \code{CppJson}} + +\item{\code{json_rfx_container_label}}{Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy} + +\item{\code{json_rfx_mapper_label}}{Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy} + +\item{\code{json_rfx_groupids_label}}{Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +NULL (updates object in-place) +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-RandomEffectSamples-predict}{}}} \subsection{Method \code{predict()}}{ diff --git a/man/loadRandomEffectSamplesCombinedJson.Rd b/man/loadRandomEffectSamplesCombinedJson.Rd new file mode 100644 index 00000000..ac4c1723 --- /dev/null +++ b/man/loadRandomEffectSamplesCombinedJson.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/serialization.R +\name{loadRandomEffectSamplesCombinedJson} +\alias{loadRandomEffectSamplesCombinedJson} +\title{Combine multiple JSON model objects containing random effects (with the same hierarchy / schema) into a single container} +\usage{ +loadRandomEffectSamplesCombinedJson(json_string_list, json_rfx_num) + +loadRandomEffectSamplesCombinedJson(json_string_list, json_rfx_num) +} +\arguments{ +\item{json_string_list}{List of objects of class \code{CppJson}} + +\item{json_rfx_num}{Integer index indicating the position of the random effects term to be unpacked} + +\item{json_object_list}{List of objects of class \code{CppJson}} +} +\value{ +\code{RandomEffectSamples} object + +\code{RandomEffectSamples} object +} +\description{ +Combine multiple JSON model objects containing random effects (with the same hierarchy / schema) into a single container + +Combine multiple JSON strings representing model objects containing random effects (with the same hierarchy / schema) into a single container +} diff --git a/src/R_random_effects.cpp b/src/R_random_effects.cpp index 463b8e97..ccb3aa98 100644 --- a/src/R_random_effects.cpp +++ b/src/R_random_effects.cpp @@ -69,6 +69,85 @@ cpp11::writable::integers rfx_group_ids_from_json_cpp(cpp11::external_pointer rfx_container_ptr, cpp11::external_pointer json_ptr, std::string rfx_label) { + // Extract the random effect container's json + nlohmann::json rfx_json = json_ptr->at("random_effects").at(rfx_label); + + // Reset the forest sample container using the json + rfx_container_ptr->append_from_json(rfx_json); +} + +[[cpp11::register]] +cpp11::external_pointer rfx_container_from_json_string_cpp(std::string json_string, std::string rfx_label) { + // Create smart pointer to newly allocated object + std::unique_ptr rfx_container_ptr_ = std::make_unique(); + + // Create a nlohmann::json object from the string + nlohmann::json json_object = nlohmann::json::parse(json_string); + + // Extract the random effect container's json + nlohmann::json rfx_json = json_object.at("random_effects").at(rfx_label); + + // Reset the forest sample container using the json + rfx_container_ptr_->Reset(); + rfx_container_ptr_->from_json(rfx_json); + + // Release management of the pointer to R session + return cpp11::external_pointer(rfx_container_ptr_.release()); +} + +[[cpp11::register]] +cpp11::external_pointer rfx_label_mapper_from_json_string_cpp(std::string json_string, std::string rfx_label) { + // Create smart pointer to newly allocated object + std::unique_ptr label_mapper_ptr_ = std::make_unique(); + + // Create a nlohmann::json object from the string + nlohmann::json json_object = nlohmann::json::parse(json_string); + + // Extract the label mapper's json + nlohmann::json rfx_json = json_object.at("random_effects").at(rfx_label); + + // Reset the label mapper using the json + label_mapper_ptr_->Reset(); + label_mapper_ptr_->from_json(rfx_json); + + // Release management of the pointer to R session + return cpp11::external_pointer(label_mapper_ptr_.release()); +} + +[[cpp11::register]] +cpp11::writable::integers rfx_group_ids_from_json_string_cpp(std::string json_string, std::string rfx_label) { + // Create smart pointer to newly allocated object + cpp11::writable::integers output; + + // Create a nlohmann::json object from the string + nlohmann::json json_object = nlohmann::json::parse(json_string); + + // Extract the groupids' json + nlohmann::json rfx_json = json_object.at("random_effects").at(rfx_label); + + // Reset the forest sample container using the json + int num_groups = rfx_json.size(); + for (int i = 0; i < num_groups; i++) { + output.push_back(rfx_json.at(i)); + } + + return output; +} + +[[cpp11::register]] +void rfx_container_append_from_json_string_cpp(cpp11::external_pointer rfx_container_ptr, std::string json_string, std::string rfx_label) { + // Create a nlohmann::json object from the string + nlohmann::json json_object = nlohmann::json::parse(json_string); + + // Extract the random effect container's json + nlohmann::json rfx_json = json_object.at("random_effects").at(rfx_label); + + // Reset the forest sample container using the json + rfx_container_ptr->append_from_json(rfx_json); +} + [[cpp11::register]] cpp11::external_pointer rfx_model_cpp(int num_components, int num_groups) { // Create smart pointer to newly allocated object diff --git a/src/cpp11.cpp b/src/cpp11.cpp index ca2f168e..b4293c7c 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -181,6 +181,43 @@ extern "C" SEXP _stochtree_rfx_group_ids_from_json_cpp(SEXP json_ptr, SEXP rfx_l END_CPP11 } // R_random_effects.cpp +void rfx_container_append_from_json_cpp(cpp11::external_pointer rfx_container_ptr, cpp11::external_pointer json_ptr, std::string rfx_label); +extern "C" SEXP _stochtree_rfx_container_append_from_json_cpp(SEXP rfx_container_ptr, SEXP json_ptr, SEXP rfx_label) { + BEGIN_CPP11 + rfx_container_append_from_json_cpp(cpp11::as_cpp>>(rfx_container_ptr), cpp11::as_cpp>>(json_ptr), cpp11::as_cpp>(rfx_label)); + return R_NilValue; + END_CPP11 +} +// R_random_effects.cpp +cpp11::external_pointer rfx_container_from_json_string_cpp(std::string json_string, std::string rfx_label); +extern "C" SEXP _stochtree_rfx_container_from_json_string_cpp(SEXP json_string, SEXP rfx_label) { + BEGIN_CPP11 + return cpp11::as_sexp(rfx_container_from_json_string_cpp(cpp11::as_cpp>(json_string), cpp11::as_cpp>(rfx_label))); + END_CPP11 +} +// R_random_effects.cpp +cpp11::external_pointer rfx_label_mapper_from_json_string_cpp(std::string json_string, std::string rfx_label); +extern "C" SEXP _stochtree_rfx_label_mapper_from_json_string_cpp(SEXP json_string, SEXP rfx_label) { + BEGIN_CPP11 + return cpp11::as_sexp(rfx_label_mapper_from_json_string_cpp(cpp11::as_cpp>(json_string), cpp11::as_cpp>(rfx_label))); + END_CPP11 +} +// R_random_effects.cpp +cpp11::writable::integers rfx_group_ids_from_json_string_cpp(std::string json_string, std::string rfx_label); +extern "C" SEXP _stochtree_rfx_group_ids_from_json_string_cpp(SEXP json_string, SEXP rfx_label) { + BEGIN_CPP11 + return cpp11::as_sexp(rfx_group_ids_from_json_string_cpp(cpp11::as_cpp>(json_string), cpp11::as_cpp>(rfx_label))); + END_CPP11 +} +// R_random_effects.cpp +void rfx_container_append_from_json_string_cpp(cpp11::external_pointer rfx_container_ptr, std::string json_string, std::string rfx_label); +extern "C" SEXP _stochtree_rfx_container_append_from_json_string_cpp(SEXP rfx_container_ptr, SEXP json_string, SEXP rfx_label) { + BEGIN_CPP11 + rfx_container_append_from_json_string_cpp(cpp11::as_cpp>>(rfx_container_ptr), cpp11::as_cpp>(json_string), cpp11::as_cpp>(rfx_label)); + return R_NilValue; + END_CPP11 +} +// R_random_effects.cpp cpp11::external_pointer rfx_model_cpp(int num_components, int num_groups); extern "C" SEXP _stochtree_rfx_model_cpp(SEXP num_components, SEXP num_groups) { BEGIN_CPP11 @@ -996,8 +1033,11 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_predict_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_cpp, 2}, {"_stochtree_predict_forest_raw_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_cpp, 2}, {"_stochtree_predict_forest_raw_single_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_forest_cpp, 3}, + {"_stochtree_rfx_container_append_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_cpp, 3}, + {"_stochtree_rfx_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_string_cpp, 3}, {"_stochtree_rfx_container_cpp", (DL_FUNC) &_stochtree_rfx_container_cpp, 2}, {"_stochtree_rfx_container_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_from_json_cpp, 2}, + {"_stochtree_rfx_container_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_from_json_string_cpp, 2}, {"_stochtree_rfx_container_get_alpha_cpp", (DL_FUNC) &_stochtree_rfx_container_get_alpha_cpp, 1}, {"_stochtree_rfx_container_get_beta_cpp", (DL_FUNC) &_stochtree_rfx_container_get_beta_cpp, 1}, {"_stochtree_rfx_container_get_sigma_cpp", (DL_FUNC) &_stochtree_rfx_container_get_sigma_cpp, 1}, @@ -1014,8 +1054,10 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rfx_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_variance_weights_cpp, 1}, {"_stochtree_rfx_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_rows_cpp, 1}, {"_stochtree_rfx_group_ids_from_json_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_cpp, 2}, + {"_stochtree_rfx_group_ids_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_string_cpp, 2}, {"_stochtree_rfx_label_mapper_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_cpp, 1}, {"_stochtree_rfx_label_mapper_from_json_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_from_json_cpp, 2}, + {"_stochtree_rfx_label_mapper_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_from_json_string_cpp, 2}, {"_stochtree_rfx_label_mapper_to_list_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_to_list_cpp, 1}, {"_stochtree_rfx_model_cpp", (DL_FUNC) &_stochtree_rfx_model_cpp, 2}, {"_stochtree_rfx_model_predict_cpp", (DL_FUNC) &_stochtree_rfx_model_predict_cpp, 3}, diff --git a/src/random_effects.cpp b/src/random_effects.cpp index bc746e81..efb141cf 100644 --- a/src/random_effects.cpp +++ b/src/random_effects.cpp @@ -294,4 +294,27 @@ void RandomEffectsContainer::from_json(const nlohmann::json& rfx_container_json) } } +void RandomEffectsContainer::append_from_json(const nlohmann::json& rfx_container_json) { + CHECK_EQ(this->num_components_, rfx_container_json.at("num_components")); + CHECK_EQ(this->num_groups_, rfx_container_json.at("num_groups")); + + // Update internal sample count and extract size of parameter vectors + int new_num_samples = rfx_container_json.at("num_samples"); + this->num_samples_ += new_num_samples; + int beta_size = rfx_container_json.at("beta_size"); + int alpha_size = rfx_container_json.at("alpha_size"); + + // Unpack beta and xi + for (int i = 0; i < beta_size; i++) { + beta_.push_back(rfx_container_json.at("beta").at(i)); + xi_.push_back(rfx_container_json.at("xi").at(i)); + } + + // Unpack alpha and sigma_xi + for (int i = 0; i < alpha_size; i++) { + alpha_.push_back(rfx_container_json.at("alpha").at(i)); + sigma_xi_.push_back(rfx_container_json.at("sigma_xi").at(i)); + } +} + } // namespace StochTree From be80a1c32a12558a0ae5755da6087afc2e0f1dd0 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 9 Aug 2024 01:18:09 -0500 Subject: [PATCH 07/41] Updated multichain code and demos --- NAMESPACE | 3 + R/bart.R | 268 ++++++++++++++++++ R/serialization.R | 6 +- _pkgdown.yml | 11 + man/createBARTModelFromCombinedJson.Rd | 44 +++ man/createBARTModelFromCombinedJsonString.Rd | 44 +++ man/loadRandomEffectSamplesCombinedJson.Rd | 12 +- ...adRandomEffectSamplesCombinedJsonString.Rd | 19 ++ vignettes/MultiChain.Rmd | 63 ++-- 9 files changed, 432 insertions(+), 38 deletions(-) create mode 100644 man/createBARTModelFromCombinedJson.Rd create mode 100644 man/createBARTModelFromCombinedJsonString.Rd create mode 100644 man/loadRandomEffectSamplesCombinedJsonString.Rd diff --git a/NAMESPACE b/NAMESPACE index 4029e17c..eb7284d2 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -10,6 +10,8 @@ export(computeForestKernels) export(computeForestLeafIndices) export(convertBARTModelToJson) export(convertBCFModelToJson) +export(createBARTModelFromCombinedJson) +export(createBARTModelFromCombinedJsonString) export(createBARTModelFromJson) export(createBARTModelFromJsonFile) export(createBARTModelFromJsonString) @@ -36,6 +38,7 @@ export(loadForestContainerCombinedJson) export(loadForestContainerCombinedJsonString) export(loadForestContainerJson) export(loadRandomEffectSamplesCombinedJson) +export(loadRandomEffectSamplesCombinedJsonString) export(loadRandomEffectSamplesJson) export(loadScalarJson) export(loadVectorJson) diff --git a/R/bart.R b/R/bart.R index 30dd0e70..39ec45f6 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1033,3 +1033,271 @@ createBARTModelFromJsonString <- function(json_string){ return(bart_object) } + +#' Convert a list of (in-memory) JSON representations of a BART model to a single combined BART model object +#' which can be used for prediction, etc... +#' +#' @param json_object_list List of objects of type `CppJson` containing Json representation of a BART model +#' +#' @return Object of type `bartmodel` +#' @export +#' +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' f_XW <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +#' ) +#' noise_sd <- 1 +#' y <- f_XW + rnorm(n, 0, noise_sd) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' bart_model <- bart(X_train = X_train, y_train = y_train) +#' # bart_json <- list(convertBARTModelToJson(bart_model)) +#' # bart_model_roundtrip <- createBARTModelFromCombinedJson(bart_json) +createBARTModelFromCombinedJson <- function(json_object_list){ + # Initialize the BCF model + output <- list() + + # Unpack the forests + output[["forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0") + + # For scalar / preprocessing details which aren't sample-dependent, + # defer to the first json + json_object_default <- json_object_list[[1]] + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar("num_numeric_vars") + train_set_metadata[["num_ordered_cat_vars"]] <- json_object_default$get_scalar("num_ordered_cat_vars") + train_set_metadata[["num_unordered_cat_vars"]] <- json_object_default$get_scalar("num_unordered_cat_vars") + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[["numeric_vars"]] <- json_object_default$get_string_vector("numeric_vars") + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[["ordered_cat_vars"]] <- json_object_default$get_string_vector("ordered_cat_vars") + train_set_metadata[["ordered_unique_levels"]] <- json_object_default$get_string_list("ordered_unique_levels", train_set_metadata[["ordered_cat_vars"]]) + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { + train_set_metadata[["unordered_cat_vars"]] <- json_object_default$get_string_vector("unordered_cat_vars") + train_set_metadata[["unordered_unique_levels"]] <- json_object_default$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]]) + } + output[["train_set_metadata"]] <- train_set_metadata + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale") + model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") + model_params[["sample_sigma"]] <- json_object_default$get_boolean("sample_sigma") + model_params[["sample_tau"]] <- json_object_default$get_boolean("sample_tau") + model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean("has_rfx_basis") + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar("num_rfx_basis") + model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates") + model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") + model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis") + + # Combine values that are sample-specific + keep_index_offset <- 0 + keep_indices <- c() + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar("num_samples") + keep_indices <- c(keep_indices, keep_index_offset + json_object$get_vector("keep_indices")) + } else { + prev_json <- json_object_list[[i-1]] + model_params[["num_gfr"]] <- model_params[["num_gfr"]] + json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- model_params[["num_burnin"]] + json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- model_params[["num_samples"]] + json_object$get_scalar("num_samples") + keep_index_offset <- keep_index_offset + prev_json$get_scalar("num_samples") + keep_indices <- c(keep_indices, keep_index_offset + json_object$get_vector("keep_indices")) + } + } + output[["keep_indices"]] <- keep_indices + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_samples"]] <- json_object$get_vector("sigma2_samples", "parameters") + } else { + output[["sigma2_samples"]] <- c(output[["sigma2_samples"]], json_object$get_vector("sigma2_samples", "parameters")) + } + } + } + if (model_params[["sample_tau"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["tau_samples"]] <- json_object$get_vector("tau_samples", "parameters") + } else { + output[["tau_samples"]] <- c(output[["tau_samples"]], json_object$get_vector("tau_samples", "parameters")) + } + } + } + + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[["rfx_unique_group_ids"]] <- json_object_default$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(json_object_list, 0) + } + + class(output) <- "bartmodel" + return(output) +} + +#' Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object +#' which can be used for prediction, etc... +#' +#' @param json_string_list List of JSON strings which can be parsed to objects of type `CppJson` containing Json representation of a BART model +#' +#' @return Object of type `bartmodel` +#' @export +#' +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' f_XW <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +#' ) +#' noise_sd <- 1 +#' y <- f_XW + rnorm(n, 0, noise_sd) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' bart_model <- bart(X_train = X_train, y_train = y_train) +#' # bart_json_string_list <- list(saveBARTModelToJsonString(bart_model)) +#' # bart_model_roundtrip <- createBARTModelFromCombinedJsonString(bart_json_string_list) +createBARTModelFromCombinedJsonString <- function(json_string_list){ + # Initialize the BCF model + output <- list() + + # Convert JSON strings + json_object_list <- list() + for (i in 1:length(json_string_list)) { + json_string <- json_string_list[[i]] + json_object_list[[i]] <- createCppJsonString(json_string) + } + + # Unpack the forests + output[["forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0") + + # For scalar / preprocessing details which aren't sample-dependent, + # defer to the first json + json_object_default <- json_object_list[[1]] + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar("num_numeric_vars") + train_set_metadata[["num_ordered_cat_vars"]] <- json_object_default$get_scalar("num_ordered_cat_vars") + train_set_metadata[["num_unordered_cat_vars"]] <- json_object_default$get_scalar("num_unordered_cat_vars") + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[["numeric_vars"]] <- json_object_default$get_string_vector("numeric_vars") + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[["ordered_cat_vars"]] <- json_object_default$get_string_vector("ordered_cat_vars") + train_set_metadata[["ordered_unique_levels"]] <- json_object_default$get_string_list("ordered_unique_levels", train_set_metadata[["ordered_cat_vars"]]) + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { + train_set_metadata[["unordered_cat_vars"]] <- json_object_default$get_string_vector("unordered_cat_vars") + train_set_metadata[["unordered_unique_levels"]] <- json_object_default$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]]) + } + output[["train_set_metadata"]] <- train_set_metadata + output[["keep_indices"]] <- json_object_default$get_vector("keep_indices") + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale") + model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") + model_params[["sample_sigma"]] <- json_object_default$get_boolean("sample_sigma") + model_params[["sample_tau"]] <- json_object_default$get_boolean("sample_tau") + model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean("has_rfx_basis") + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar("num_rfx_basis") + model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates") + model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") + model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis") + + # Combine values that are sample-specific + keep_index_offset <- 0 + keep_indices <- c() + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar("num_samples") + keep_indices <- c(keep_indices, keep_index_offset + json_object$get_vector("keep_indices")) + } else { + prev_json <- json_object_list[[i-1]] + model_params[["num_gfr"]] <- model_params[["num_gfr"]] + json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- model_params[["num_burnin"]] + json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- model_params[["num_samples"]] + json_object$get_scalar("num_samples") + keep_index_offset <- keep_index_offset + prev_json$get_scalar("num_samples") + keep_indices <- c(keep_indices, keep_index_offset + json_object$get_vector("keep_indices")) + } + } + output[["keep_indices"]] <- keep_indices + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_samples"]] <- json_object$get_vector("sigma2_samples", "parameters") + } else { + output[["sigma2_samples"]] <- c(output[["sigma2_samples"]], json_object$get_vector("sigma2_samples", "parameters")) + } + } + } + if (model_params[["sample_tau"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["tau_samples"]] <- json_object$get_vector("tau_samples", "parameters") + } else { + output[["tau_samples"]] <- c(output[["tau_samples"]], json_object$get_vector("tau_samples", "parameters")) + } + } + } + + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[["rfx_unique_group_ids"]] <- json_object_default$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(json_object_list, 0) + } + + class(output) <- "bartmodel" + return(output) +} diff --git a/R/serialization.R b/R/serialization.R index 6d25e802..4c6ec0cb 100644 --- a/R/serialization.R +++ b/R/serialization.R @@ -408,7 +408,7 @@ loadRandomEffectSamplesCombinedJson <- function(json_object_list, json_rfx_num) #' #' @return `RandomEffectSamples` object #' @export -loadRandomEffectSamplesCombinedJson <- function(json_string_list, json_rfx_num) { +loadRandomEffectSamplesCombinedJsonString <- function(json_string_list, json_rfx_num) { json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) @@ -416,9 +416,9 @@ loadRandomEffectSamplesCombinedJson <- function(json_string_list, json_rfx_num) for (i in 1:length(json_object_list)) { json_string <- json_string_list[[i]] if (i == 1) { - output$load_from_json(json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) + output$load_from_json_string(json_string, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) } else { - output$append_from_json(json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) + output$append_from_json_string(json_string, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) } } return(output) diff --git a/_pkgdown.yml b/_pkgdown.yml index f919a864..ec8742fd 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -38,6 +38,17 @@ reference: - loadRandomEffectSamplesJson - loadVectorJson - loadScalarJson + - convertBARTModelToJson + - createBARTModelFromCombinedJson + - createBARTModelFromCombinedJsonString + - createBARTModelFromJson + - createBARTModelFromJsonFile + - createBARTModelFromJsonString + - loadRandomEffectSamplesCombinedJson + - loadRandomEffectSamplesCombinedJsonString + - saveBARTModelToJsonFile + - saveBARTModelToJsonString + - saveBCFModelToJsonString - subtitle: Data desc: > diff --git a/man/createBARTModelFromCombinedJson.Rd b/man/createBARTModelFromCombinedJson.Rd new file mode 100644 index 00000000..72c3e675 --- /dev/null +++ b/man/createBARTModelFromCombinedJson.Rd @@ -0,0 +1,44 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{createBARTModelFromCombinedJson} +\alias{createBARTModelFromCombinedJson} +\title{Convert a list of (in-memory) JSON representations of a BART model to a single combined BART model object +which can be used for prediction, etc...} +\usage{ +createBARTModelFromCombinedJson(json_object_list) +} +\arguments{ +\item{json_object_list}{List of objects of type \code{CppJson} containing Json representation of a BART model} +} +\value{ +Object of type \code{bartmodel} +} +\description{ +Convert a list of (in-memory) JSON representations of a BART model to a single combined BART model object +which can be used for prediction, etc... +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +bart_model <- bart(X_train = X_train, y_train = y_train) +# bart_json <- list(convertBARTModelToJson(bart_model)) +# bart_model_roundtrip <- createBARTModelFromCombinedJson(bart_json) +} diff --git a/man/createBARTModelFromCombinedJsonString.Rd b/man/createBARTModelFromCombinedJsonString.Rd new file mode 100644 index 00000000..99c248b7 --- /dev/null +++ b/man/createBARTModelFromCombinedJsonString.Rd @@ -0,0 +1,44 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{createBARTModelFromCombinedJsonString} +\alias{createBARTModelFromCombinedJsonString} +\title{Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object +which can be used for prediction, etc...} +\usage{ +createBARTModelFromCombinedJsonString(json_string_list) +} +\arguments{ +\item{json_string_list}{List of JSON strings which can be parsed to objects of type \code{CppJson} containing Json representation of a BART model} +} +\value{ +Object of type \code{bartmodel} +} +\description{ +Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object +which can be used for prediction, etc... +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +bart_model <- bart(X_train = X_train, y_train = y_train) +# bart_json_string_list <- list(saveBARTModelToJsonString(bart_model)) +# bart_model_roundtrip <- createBARTModelFromCombinedJsonString(bart_json_string_list) +} diff --git a/man/loadRandomEffectSamplesCombinedJson.Rd b/man/loadRandomEffectSamplesCombinedJson.Rd index ac4c1723..d7ef5705 100644 --- a/man/loadRandomEffectSamplesCombinedJson.Rd +++ b/man/loadRandomEffectSamplesCombinedJson.Rd @@ -4,24 +4,16 @@ \alias{loadRandomEffectSamplesCombinedJson} \title{Combine multiple JSON model objects containing random effects (with the same hierarchy / schema) into a single container} \usage{ -loadRandomEffectSamplesCombinedJson(json_string_list, json_rfx_num) - -loadRandomEffectSamplesCombinedJson(json_string_list, json_rfx_num) +loadRandomEffectSamplesCombinedJson(json_object_list, json_rfx_num) } \arguments{ -\item{json_string_list}{List of objects of class \code{CppJson}} +\item{json_object_list}{List of objects of class \code{CppJson}} \item{json_rfx_num}{Integer index indicating the position of the random effects term to be unpacked} - -\item{json_object_list}{List of objects of class \code{CppJson}} } \value{ -\code{RandomEffectSamples} object - \code{RandomEffectSamples} object } \description{ Combine multiple JSON model objects containing random effects (with the same hierarchy / schema) into a single container - -Combine multiple JSON strings representing model objects containing random effects (with the same hierarchy / schema) into a single container } diff --git a/man/loadRandomEffectSamplesCombinedJsonString.Rd b/man/loadRandomEffectSamplesCombinedJsonString.Rd new file mode 100644 index 00000000..3531b968 --- /dev/null +++ b/man/loadRandomEffectSamplesCombinedJsonString.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/serialization.R +\name{loadRandomEffectSamplesCombinedJsonString} +\alias{loadRandomEffectSamplesCombinedJsonString} +\title{Combine multiple JSON strings representing model objects containing random effects (with the same hierarchy / schema) into a single container} +\usage{ +loadRandomEffectSamplesCombinedJsonString(json_string_list, json_rfx_num) +} +\arguments{ +\item{json_string_list}{List of objects of class \code{CppJson}} + +\item{json_rfx_num}{Integer index indicating the position of the random effects term to be unpacked} +} +\value{ +\code{RandomEffectSamples} object +} +\description{ +Combine multiple JSON strings representing model objects containing random effects (with the same hierarchy / schema) into a single container +} diff --git a/vignettes/MultiChain.Rmd b/vignettes/MultiChain.Rmd index 13275117..9cecaadd 100644 --- a/vignettes/MultiChain.Rmd +++ b/vignettes/MultiChain.Rmd @@ -115,27 +115,30 @@ json_string_list <- list() for (i in 1:num_chains) { json_string_list[[i]] <- saveBARTModelToJsonString(bart_models[[i]]) } -combined_forests <- loadForestContainerCombinedJsonString(json_string_list, "forest_0") +combined_bart <- createBARTModelFromCombinedJsonString(json_string_list) ``` We can predict from this combined forest as follows ```{r} -test_dataset <- createForestDataset(as.matrix(X_test), W_test) -yhat_combined <- combined_forests$predict(test_dataset) +yhat_combined <- predict(combined_bart, X_test, W_test)$y_hat ``` Compare to the original $\hat{y}$ values ```{r} -num_samples <- num_gfr+num_burnin+num_mcmc +par(mfrow = c(1,2)) for (i in 1:num_chains) { - offset <- (i-1)*num_samples - inds_start <- offset + 1 + num_burnin + num_gfr - inds_end <- offset + num_samples + offset <- (i-1)*num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc plot(rowMeans(bart_models[[i]]$y_hat_test), - rowMeans(yhat_combined[,inds_start:inds_end])) + rowMeans(yhat_combined[,inds_start:inds_end]), + xlab = "original", ylab = "deserialized", + main = paste0("Chain ", i, "\nPredictions")) + abline(0,1,col="red",lty=3,lwd=3) } +par(mfrow = c(1,1)) ``` ## Sampling Multiple Chains in Parallel @@ -160,18 +163,29 @@ cl <- makeCluster(ncores) registerDoParallel(cl) ``` -Run the sampler, storing the resulting BART objects in a list +Note that the `bartmodel` object contains external pointers to forests created by +the `stochtree` shared object, and when `stochtree::bart()` is run in parallel +on independent subprocesses, these pointers are not generally accessible in the +session that kicked off the parallel run. + +To overcome this, you can return a JSON representation of a `bartmodel` in memory +and combine them into a single in-memory `bartmodel` object. + +The first step of this process is to run the sampler in parallel, +storing the resulting BART JSON strings in a list. ```{r} -bart_models <- foreach (i = 1:num_chains) %dopar% { +bart_model_strings <- foreach (i = 1:num_chains) %dopar% { random_seed <- i - stochtree::bart( + bart_model <- stochtree::bart( X_train = X_train, W_train = W_train, y_train = y_train, X_test = X_test, W_test = W_test, num_trees = num_trees, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, sample_sigma = T, sample_tau = T, random_seed = random_seed ) + bart_model_string <- stochtree::saveBARTModelToJsonString(bart_model) + bart_model_string } ``` @@ -185,31 +199,30 @@ Now, if we want to combine the forests from each of these BART models into a single forest, we can do so as follows ```{r} -json_string_list <- list() -for (i in 1:num_chains) { - json_string_list[[i]] <- saveBARTModelToJsonString(bart_models[[i]]) -} -combined_forests <- loadForestContainerCombinedJsonString(json_string_list, "forest_0") +combined_bart <- createBARTModelFromCombinedJsonString(bart_model_strings) ``` We can predict from this combined forest as follows ```{r} -test_dataset <- createForestDataset(as.matrix(X_test), W_test) -yhat_combined <- combined_forests$predict(test_dataset) +yhat_combined <- predict(combined_bart, X_test, W_test)$y_hat ``` -Compare to the original $\hat{y}$ values +Since we don't have access to the original $\hat{y}$ values, we instead +compare average predictions from each chain to the true $y$ values. ```{r} -num_samples <- num_gfr+num_burnin+num_mcmc +par(mfrow = c(1,2)) for (i in 1:num_chains) { - offset <- (i-1)*num_samples - inds_start <- offset + 1 + num_burnin + num_gfr - inds_end <- offset + num_samples - plot(rowMeans(bart_models[[i]]$y_hat_test), - rowMeans(yhat_combined[,inds_start:inds_end])) + offset <- (i-1)*num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot(rowMeans(yhat_combined[,inds_start:inds_end]), y_test, + xlab = "predicted", ylab = "actual", + main = paste0("Chain ", i, "\nPredictions")) + abline(0,1,col="red",lty=3,lwd=3) } +par(mfrow = c(1,1)) ``` # References From 20170539b11274aec6e6d8a7470e6b7d781d2788 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 23 Aug 2024 23:54:14 -0500 Subject: [PATCH 08/41] Refactored the sampler classes into stateless templated functions --- debug/api_debug.cpp | 18 +- include/stochtree/tree_sampler.h | 854 +++++++++++++++---------------- src/py_stochtree.cpp | 18 +- src/sampler.cpp | 20 +- 4 files changed, 436 insertions(+), 474 deletions(-) diff --git a/debug/api_debug.cpp b/debug/api_debug.cpp index d827d8cb..d7420d5f 100644 --- a/debug/api_debug.cpp +++ b/debug/api_debug.cpp @@ -270,16 +270,13 @@ void sampleGFR(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& f ForestLeafModel leaf_model_type, Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size) { if (leaf_model_type == ForestLeafModel::kConstant) { GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_scale); - GFRForestSampler sampler = GFRForestSampler(cutpoint_grid_size); - sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types); + GFRSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size); } else if (leaf_model_type == ForestLeafModel::kUnivariateRegression) { GaussianUnivariateRegressionLeafModel leaf_model = GaussianUnivariateRegressionLeafModel(leaf_scale); - GFRForestSampler sampler = GFRForestSampler(cutpoint_grid_size); - sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types); + GFRSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size); } else if (leaf_model_type == ForestLeafModel::kMultivariateRegression) { GaussianMultivariateRegressionLeafModel leaf_model = GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - GFRForestSampler sampler = GFRForestSampler(cutpoint_grid_size); - sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types); + GFRSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size); } } @@ -288,16 +285,13 @@ void sampleMCMC(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& ForestLeafModel leaf_model_type, Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size) { if (leaf_model_type == ForestLeafModel::kConstant) { GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_scale); - MCMCForestSampler sampler = MCMCForestSampler(); - sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); + MCMCSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); } else if (leaf_model_type == ForestLeafModel::kUnivariateRegression) { GaussianUnivariateRegressionLeafModel leaf_model = GaussianUnivariateRegressionLeafModel(leaf_scale); - MCMCForestSampler sampler = MCMCForestSampler(); - sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); + MCMCSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); } else if (leaf_model_type == ForestLeafModel::kMultivariateRegression) { GaussianMultivariateRegressionLeafModel leaf_model = GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - MCMCForestSampler sampler = MCMCForestSampler(); - sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); + MCMCSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); } } diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 9c9a854e..9db97f02 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -234,266 +234,183 @@ static inline void UpdateResidualNewBasis(ForestTracker& tracker, ForestDataset& } template -class MCMCForestSampler { - public: - MCMCForestSampler() {} - ~MCMCForestSampler() {} +static inline void MCMCSampleOneIter(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, + ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, + double global_variance, bool pre_initialized = false) { + // Previous number of samples + int prev_num_samples = forests.NumSamples(); - void SampleOneIter(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - double global_variance, bool pre_initialized = false) { - // Previous number of samples - int prev_num_samples = forests.NumSamples(); + if ((prev_num_samples == 0) && (!pre_initialized)) { + // Add new forest to the container + forests.AddSamples(1); - if ((prev_num_samples == 0) && (!pre_initialized)) { - // Add new forest to the container - forests.AddSamples(1); - - // Set initial value for each leaf in the forest - double root_pred = ComputeMeanOutcome(residual) / static_cast(forests.NumTrees()); - TreeEnsemble* ensemble = forests.GetEnsemble(0); - leaf_model.SetEnsembleRootPredictedValue(dataset, ensemble, root_pred); - } else if (prev_num_samples > 0) { - // Add new forest to the container - forests.AddSamples(1); - - // Copy previous forest - forests.CopyFromPreviousSample(prev_num_samples, prev_num_samples - 1); - } else { - forests.IncrementSampleCount(); - } + // Set initial value for each leaf in the forest + double root_pred = ComputeMeanOutcome(residual) / static_cast(forests.NumTrees()); + TreeEnsemble* ensemble = forests.GetEnsemble(0); + leaf_model.SetEnsembleRootPredictedValue(dataset, ensemble, root_pred); + } else if (prev_num_samples > 0) { + // Add new forest to the container + forests.AddSamples(1); - // Run the MCMC algorithm for each tree - TreeEnsemble* ensemble = forests.GetEnsemble(prev_num_samples); - Tree* tree; - int num_trees = forests.NumTrees(); - for (int i = 0; i < num_trees; i++) { - // Add tree i's predictions back to the residual (thus, training a model on the "partial residual") - tree = ensemble->GetTree(i); - UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), plus_op_, false); - - // Sample tree i - tree = ensemble->GetTree(i); - SampleTreeOneIter(tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, global_variance); - - // Sample leaf parameters for tree i - tree = ensemble->GetTree(i); - leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen); - - // Subtract tree i's predictions back out of the residual - tree = ensemble->GetTree(i); - UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), minus_op_, true); - } + // Copy previous forest + forests.CopyFromPreviousSample(prev_num_samples, prev_num_samples - 1); + } else { + forests.IncrementSampleCount(); } - - private: - // Function objects for element-wise addition and subtraction (used in the residual update function which takes std::function as an argument) - std::plus plus_op_; - std::minus minus_op_; - void SampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - int tree_num, double global_variance) { - // Determine whether it is possible to grow any of the leaves - bool grow_possible = false; - std::vector leaves = tree->GetLeaves(); - for (auto& leaf: leaves) { - if (tracker.UnsortedNodeSize(tree_num, leaf) > 2 * tree_prior.GetMinSamplesLeaf()) { - grow_possible = true; - break; - } - } + // Run the MCMC algorithm for each tree + TreeEnsemble* ensemble = forests.GetEnsemble(prev_num_samples); + Tree* tree; + int num_trees = forests.NumTrees(); + for (int i = 0; i < num_trees; i++) { + // Add tree i's predictions back to the residual (thus, training a model on the "partial residual") + tree = ensemble->GetTree(i); + UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), std::plus(), false); + + // Sample tree i + tree = ensemble->GetTree(i); + MCMCSampleTreeOneIter(tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, global_variance); + + // Sample leaf parameters for tree i + tree = ensemble->GetTree(i); + leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen); + + // Subtract tree i's predictions back out of the residual + tree = ensemble->GetTree(i); + UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), std::minus(), true); + } +} - // Determine whether it is possible to prune the tree - bool prune_possible = false; - if (tree->NumValidNodes() > 1) { - prune_possible = true; +template +static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, + ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, + int tree_num, double global_variance) { + // Determine whether it is possible to grow any of the leaves + bool grow_possible = false; + std::vector leaves = tree->GetLeaves(); + for (auto& leaf: leaves) { + if (tracker.UnsortedNodeSize(tree_num, leaf) > 2 * tree_prior.GetMinSamplesLeaf()) { + grow_possible = true; + break; } + } - // Determine the relative probability of grow vs prune (0 = grow, 1 = prune) - double prob_grow; - std::vector step_probs(2); - if (grow_possible && prune_possible) { - step_probs = {0.5, 0.5}; - prob_grow = 0.5; - } else if (!grow_possible && prune_possible) { - step_probs = {0.0, 1.0}; - prob_grow = 0.0; - } else if (grow_possible && !prune_possible) { - step_probs = {1.0, 0.0}; - prob_grow = 1.0; - } else { - Log::Fatal("In this tree, neither grow nor prune is possible"); - } - std::discrete_distribution<> step_dist(step_probs.begin(), step_probs.end()); + // Determine whether it is possible to prune the tree + bool prune_possible = false; + if (tree->NumValidNodes() > 1) { + prune_possible = true; + } - // Draw a split rule at random - data_size_t step_chosen = step_dist(gen); - bool accept; - - if (step_chosen == 0) { - GrowTreeOneIter(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow); - } else { - PruneTreeOneIter(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance); - } + // Determine the relative probability of grow vs prune (0 = grow, 1 = prune) + double prob_grow; + std::vector step_probs(2); + if (grow_possible && prune_possible) { + step_probs = {0.5, 0.5}; + prob_grow = 0.5; + } else if (!grow_possible && prune_possible) { + step_probs = {0.0, 1.0}; + prob_grow = 0.0; + } else if (grow_possible && !prune_possible) { + step_probs = {1.0, 0.0}; + prob_grow = 1.0; + } else { + Log::Fatal("In this tree, neither grow nor prune is possible"); } + std::discrete_distribution<> step_dist(step_probs.begin(), step_probs.end()); - void GrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, - TreePrior& tree_prior, std::mt19937& gen, int tree_num, std::vector& variable_weights, - double global_variance, double prob_grow_old) { - // Extract dataset information - data_size_t n = dataset.GetCovariates().rows(); - - // Choose a leaf node at random - int num_leaves = tree->NumLeaves(); - std::vector leaves = tree->GetLeaves(); - std::vector leaf_weights(num_leaves); - std::fill(leaf_weights.begin(), leaf_weights.end(), 1.0/num_leaves); - std::discrete_distribution<> leaf_dist(leaf_weights.begin(), leaf_weights.end()); - int leaf_chosen = leaves[leaf_dist(gen)]; - int leaf_depth = tree->GetDepth(leaf_chosen); - - // Maximum leaf depth - int32_t max_depth = tree_prior.GetMaxDepth(); - - // Terminate early if cannot be split - bool accept; - if ((leaf_depth >= max_depth) && (max_depth != -1)) { - accept = false; - } else { + // Draw a split rule at random + data_size_t step_chosen = step_dist(gen); + bool accept; + + if (step_chosen == 0) { + MCMCGrowTreeOneIter(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow); + } else { + MCMCPruneTreeOneIter(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance); + } +} - // Select a split variable at random - int p = dataset.GetCovariates().cols(); - CHECK_EQ(variable_weights.size(), p); - // std::vector var_weights(p); - // std::fill(var_weights.begin(), var_weights.end(), 1.0/p); - std::discrete_distribution<> var_dist(variable_weights.begin(), variable_weights.end()); - int var_chosen = var_dist(gen); - - // Determine the range of possible cutpoints - // TODO: specialize this for binary / ordered categorical / unordered categorical variables - double var_min, var_max; - VarSplitRange(tracker, dataset, tree_num, leaf_chosen, var_chosen, var_min, var_max); - if (var_max <= var_min) { - return; - } - - // Split based on var_min to var_max in a given node - std::uniform_real_distribution split_point_dist(var_min, var_max); - double split_point_chosen = split_point_dist(gen); - - // Create a split object - TreeSplit split = TreeSplit(split_point_chosen); - - // Compute the marginal likelihood of split and no split, given the leaf prior - std::tuple split_eval = leaf_model.EvaluateProposedSplit(dataset, tracker, residual, split, tree_num, leaf_chosen, var_chosen, global_variance); - double split_log_marginal_likelihood = std::get<0>(split_eval); - double no_split_log_marginal_likelihood = std::get<1>(split_eval); - int32_t left_n = std::get<2>(split_eval); - int32_t right_n = std::get<3>(split_eval); - - // Determine probability of growing the split node and its two new left and right nodes - double pg = tree_prior.GetAlpha() * std::pow(1+leaf_depth, -tree_prior.GetBeta()); - double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta()); - double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta()); - - // Determine whether a "grow" move is possible from the newly formed tree - // in order to compute the probability of choosing "prune" from the new tree - // (which is always possible by construction) - bool non_constant = NodesNonConstantAfterSplit(dataset, tracker, split, tree_num, leaf_chosen, var_chosen); - bool min_samples_left_check = left_n >= 2*tree_prior.GetMinSamplesLeaf(); - bool min_samples_right_check = right_n >= 2*tree_prior.GetMinSamplesLeaf(); - double prob_prune_new; - if (non_constant && (min_samples_left_check || min_samples_right_check)) { - prob_prune_new = 0.5; - } else { - prob_prune_new = 1.0; - } +template +static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, + TreePrior& tree_prior, std::mt19937& gen, int tree_num, std::vector& variable_weights, + double global_variance, double prob_grow_old) { + // Extract dataset information + data_size_t n = dataset.GetCovariates().rows(); - // Determine the number of leaves in the current tree and leaf parents in the proposed tree - int num_leaf_parents = tree->NumLeafParents(); - double p_leaf = 1/static_cast(num_leaves); - double p_leaf_parent = 1/static_cast(num_leaf_parents+1); - - // Compute the final MH ratio - double log_mh_ratio = ( - std::log(pg) + std::log(1-pgl) + std::log(1-pgr) - std::log(1-pg) + std::log(prob_prune_new) + - std::log(p_leaf_parent) - std::log(prob_grow_old) - std::log(p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood - ); - // Threshold at 0 - if (log_mh_ratio > 0) { - log_mh_ratio = 0; - } + // Choose a leaf node at random + int num_leaves = tree->NumLeaves(); + std::vector leaves = tree->GetLeaves(); + std::vector leaf_weights(num_leaves); + std::fill(leaf_weights.begin(), leaf_weights.end(), 1.0/num_leaves); + std::discrete_distribution<> leaf_dist(leaf_weights.begin(), leaf_weights.end()); + int leaf_chosen = leaves[leaf_dist(gen)]; + int leaf_depth = tree->GetDepth(leaf_chosen); + + // Maximum leaf depth + int32_t max_depth = tree_prior.GetMaxDepth(); + + // Terminate early if cannot be split + bool accept; + if ((leaf_depth >= max_depth) && (max_depth != -1)) { + accept = false; + } else { - // Draw a uniform random variable and accept/reject the proposal on this basis - std::uniform_real_distribution mh_accept(0.0, 1.0); - double log_acceptance_prob = std::log(mh_accept(gen)); - if (log_acceptance_prob <= log_mh_ratio) { - accept = true; - AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false); - } else { - accept = false; - } + // Select a split variable at random + int p = dataset.GetCovariates().cols(); + CHECK_EQ(variable_weights.size(), p); + // std::vector var_weights(p); + // std::fill(var_weights.begin(), var_weights.end(), 1.0/p); + std::discrete_distribution<> var_dist(variable_weights.begin(), variable_weights.end()); + int var_chosen = var_dist(gen); + + // Determine the range of possible cutpoints + // TODO: specialize this for binary / ordered categorical / unordered categorical variables + double var_min, var_max; + VarSplitRange(tracker, dataset, tree_num, leaf_chosen, var_chosen, var_min, var_max); + if (var_max <= var_min) { + return; } - } - - void PruneTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, - TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance) { - // Choose a "leaf parent" node at random - int num_leaves = tree->NumLeaves(); - int num_leaf_parents = tree->NumLeafParents(); - std::vector leaf_parents = tree->GetLeafParents(); - std::vector leaf_parent_weights(num_leaf_parents); - std::fill(leaf_parent_weights.begin(), leaf_parent_weights.end(), 1.0/num_leaf_parents); - std::discrete_distribution<> leaf_parent_dist(leaf_parent_weights.begin(), leaf_parent_weights.end()); - int leaf_parent_chosen = leaf_parents[leaf_parent_dist(gen)]; - int leaf_parent_depth = tree->GetDepth(leaf_parent_chosen); - int left_node = tree->LeftChild(leaf_parent_chosen); - int right_node = tree->RightChild(leaf_parent_chosen); - int feature_split = tree->SplitIndex(leaf_parent_chosen); - // Compute the marginal likelihood for the leaf parent and its left and right nodes - std::tuple split_eval = leaf_model.EvaluateExistingSplit(dataset, tracker, residual, global_variance, tree_num, leaf_parent_chosen, left_node, right_node); + // Split based on var_min to var_max in a given node + std::uniform_real_distribution split_point_dist(var_min, var_max); + double split_point_chosen = split_point_dist(gen); + + // Create a split object + TreeSplit split = TreeSplit(split_point_chosen); + + // Compute the marginal likelihood of split and no split, given the leaf prior + std::tuple split_eval = leaf_model.EvaluateProposedSplit(dataset, tracker, residual, split, tree_num, leaf_chosen, var_chosen, global_variance); double split_log_marginal_likelihood = std::get<0>(split_eval); double no_split_log_marginal_likelihood = std::get<1>(split_eval); int32_t left_n = std::get<2>(split_eval); int32_t right_n = std::get<3>(split_eval); // Determine probability of growing the split node and its two new left and right nodes - double pg = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth, -tree_prior.GetBeta()); - double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth+1, -tree_prior.GetBeta()); - double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth+1, -tree_prior.GetBeta()); + double pg = tree_prior.GetAlpha() * std::pow(1+leaf_depth, -tree_prior.GetBeta()); + double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta()); + double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta()); - // Determine whether a "prune" move is possible from the new tree, - // in order to compute the probability of choosing "grow" from the new tree + // Determine whether a "grow" move is possible from the newly formed tree + // in order to compute the probability of choosing "prune" from the new tree // (which is always possible by construction) - bool non_root_tree = tree->NumNodes() > 1; - double prob_grow_new; - if (non_root_tree) { - prob_grow_new = 0.5; - } else { - prob_grow_new = 1.0; - } - - // Determine whether a "grow" move was possible from the old tree, - // in order to compute the probability of choosing "prune" from the old tree - bool non_constant_left = NodeNonConstant(dataset, tracker, tree_num, left_node); - bool non_constant_right = NodeNonConstant(dataset, tracker, tree_num, right_node); - double prob_prune_old; - if (non_constant_left && non_constant_right) { - prob_prune_old = 0.5; + bool non_constant = NodesNonConstantAfterSplit(dataset, tracker, split, tree_num, leaf_chosen, var_chosen); + bool min_samples_left_check = left_n >= 2*tree_prior.GetMinSamplesLeaf(); + bool min_samples_right_check = right_n >= 2*tree_prior.GetMinSamplesLeaf(); + double prob_prune_new; + if (non_constant && (min_samples_left_check || min_samples_right_check)) { + prob_prune_new = 0.5; } else { - prob_prune_old = 1.0; + prob_prune_new = 1.0; } // Determine the number of leaves in the current tree and leaf parents in the proposed tree - double p_leaf = 1/static_cast(num_leaves-1); - double p_leaf_parent = 1/static_cast(num_leaf_parents); + int num_leaf_parents = tree->NumLeafParents(); + double p_leaf = 1/static_cast(num_leaves); + double p_leaf_parent = 1/static_cast(num_leaf_parents+1); // Compute the final MH ratio double log_mh_ratio = ( - std::log(1-pg) - std::log(pg) - std::log(1-pgl) - std::log(1-pgr) + std::log(prob_prune_old) + - std::log(p_leaf) - std::log(prob_grow_new) - std::log(p_leaf_parent) + no_split_log_marginal_likelihood - split_log_marginal_likelihood + std::log(pg) + std::log(1-pgl) + std::log(1-pgr) - std::log(1-pg) + std::log(prob_prune_new) + + std::log(p_leaf_parent) - std::log(prob_grow_old) - std::log(p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood ); // Threshold at 0 if (log_mh_ratio > 0) { @@ -501,242 +418,305 @@ class MCMCForestSampler { } // Draw a uniform random variable and accept/reject the proposal on this basis - bool accept; std::uniform_real_distribution mh_accept(0.0, 1.0); double log_acceptance_prob = std::log(mh_accept(gen)); if (log_acceptance_prob <= log_mh_ratio) { accept = true; - RemoveSplitFromModel(tracker, dataset, tree_prior, gen, tree, tree_num, leaf_parent_chosen, left_node, right_node, false); + AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false); } else { accept = false; } } -}; +} template -class GFRForestSampler { - public: - GFRForestSampler() {cutpoint_grid_size_ = 500;} - GFRForestSampler(int cutpoint_grid_size) {cutpoint_grid_size_ = cutpoint_grid_size;} - ~GFRForestSampler() {} - - void SampleOneIter(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - double global_variance, std::vector& feature_types, bool pre_initialized = false) { - // Previous number of samples - int prev_num_samples = forests.NumSamples(); - - if ((prev_num_samples == 0) && (!pre_initialized)) { - // Add new forest to the container - forests.AddSamples(1); - - // Set initial value for each leaf in the forest - double root_pred = ComputeMeanOutcome(residual) / static_cast(forests.NumTrees()); - TreeEnsemble* ensemble = forests.GetEnsemble(0); - leaf_model.SetEnsembleRootPredictedValue(dataset, ensemble, root_pred); - } else if (prev_num_samples > 0) { - // Add new forest to the container - forests.AddSamples(1); - - // NOTE: only doing this for the simplicity of the partial residual step - // We could alternatively "reach back" to the tree predictions from a previous - // sample (whenever there is more than one sample). This is cleaner / quicker - // to implement during this refactor. - forests.CopyFromPreviousSample(prev_num_samples, prev_num_samples - 1); - } else { - forests.IncrementSampleCount(); - } - - // Run the GFR algorithm for each tree - TreeEnsemble* ensemble = forests.GetEnsemble(prev_num_samples); - int num_trees = forests.NumTrees(); - for (int i = 0; i < num_trees; i++) { - // Add tree i's predictions back to the residual (thus, training a model on the "partial residual") - Tree* tree = ensemble->GetTree(i); - UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), plus_op_, false); - - // Reset the tree and sample trackers - ensemble->ResetInitTree(i); - tracker.ResetRoot(dataset.GetCovariates(), feature_types, i); - tree = ensemble->GetTree(i); - - // Sample tree i - SampleTreeOneIter(tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, global_variance, feature_types); - - // Sample leaf parameters for tree i - tree = ensemble->GetTree(i); - leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen); - - // Subtract tree i's predictions back out of the residual - UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), minus_op_, true); - } +static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, + TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance) { + // Choose a "leaf parent" node at random + int num_leaves = tree->NumLeaves(); + int num_leaf_parents = tree->NumLeafParents(); + std::vector leaf_parents = tree->GetLeafParents(); + std::vector leaf_parent_weights(num_leaf_parents); + std::fill(leaf_parent_weights.begin(), leaf_parent_weights.end(), 1.0/num_leaf_parents); + std::discrete_distribution<> leaf_parent_dist(leaf_parent_weights.begin(), leaf_parent_weights.end()); + int leaf_parent_chosen = leaf_parents[leaf_parent_dist(gen)]; + int leaf_parent_depth = tree->GetDepth(leaf_parent_chosen); + int left_node = tree->LeftChild(leaf_parent_chosen); + int right_node = tree->RightChild(leaf_parent_chosen); + int feature_split = tree->SplitIndex(leaf_parent_chosen); + + // Compute the marginal likelihood for the leaf parent and its left and right nodes + std::tuple split_eval = leaf_model.EvaluateExistingSplit(dataset, tracker, residual, global_variance, tree_num, leaf_parent_chosen, left_node, right_node); + double split_log_marginal_likelihood = std::get<0>(split_eval); + double no_split_log_marginal_likelihood = std::get<1>(split_eval); + int32_t left_n = std::get<2>(split_eval); + int32_t right_n = std::get<3>(split_eval); + + // Determine probability of growing the split node and its two new left and right nodes + double pg = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth, -tree_prior.GetBeta()); + double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth+1, -tree_prior.GetBeta()); + double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth+1, -tree_prior.GetBeta()); + + // Determine whether a "prune" move is possible from the new tree, + // in order to compute the probability of choosing "grow" from the new tree + // (which is always possible by construction) + bool non_root_tree = tree->NumNodes() > 1; + double prob_grow_new; + if (non_root_tree) { + prob_grow_new = 0.5; + } else { + prob_grow_new = 1.0; + } + + // Determine whether a "grow" move was possible from the old tree, + // in order to compute the probability of choosing "prune" from the old tree + bool non_constant_left = NodeNonConstant(dataset, tracker, tree_num, left_node); + bool non_constant_right = NodeNonConstant(dataset, tracker, tree_num, right_node); + double prob_prune_old; + if (non_constant_left && non_constant_right) { + prob_prune_old = 0.5; + } else { + prob_prune_old = 1.0; + } + + // Determine the number of leaves in the current tree and leaf parents in the proposed tree + double p_leaf = 1/static_cast(num_leaves-1); + double p_leaf_parent = 1/static_cast(num_leaf_parents); + + // Compute the final MH ratio + double log_mh_ratio = ( + std::log(1-pg) - std::log(pg) - std::log(1-pgl) - std::log(1-pgr) + std::log(prob_prune_old) + + std::log(p_leaf) - std::log(prob_grow_new) - std::log(p_leaf_parent) + no_split_log_marginal_likelihood - split_log_marginal_likelihood + ); + // Threshold at 0 + if (log_mh_ratio > 0) { + log_mh_ratio = 0; } - private: - // Maximum cutpoint grid size in the enumeration of possible splits - int cutpoint_grid_size_; + // Draw a uniform random variable and accept/reject the proposal on this basis + bool accept; + std::uniform_real_distribution mh_accept(0.0, 1.0); + double log_acceptance_prob = std::log(mh_accept(gen)); + if (log_acceptance_prob <= log_mh_ratio) { + accept = true; + RemoveSplitFromModel(tracker, dataset, tree_prior, gen, tree, tree_num, leaf_parent_chosen, left_node, right_node, false); + } else { + accept = false; + } +} + +template +static inline void GFRSampleOneIter(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, + ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, + double global_variance, std::vector& feature_types, int cutpoint_grid_size = 500, + bool pre_initialized = false) { + // Previous number of samples + int prev_num_samples = forests.NumSamples(); - // Function objects for element-wise addition and subtraction (used in the residual update function which takes std::function as an argument) - std::plus plus_op_; - std::minus minus_op_; + if ((prev_num_samples == 0) && (!pre_initialized)) { + // Add new forest to the container + forests.AddSamples(1); + + // Set initial value for each leaf in the forest + double root_pred = ComputeMeanOutcome(residual) / static_cast(forests.NumTrees()); + TreeEnsemble* ensemble = forests.GetEnsemble(0); + leaf_model.SetEnsembleRootPredictedValue(dataset, ensemble, root_pred); + } else if (prev_num_samples > 0) { + // Add new forest to the container + forests.AddSamples(1); + + // NOTE: only doing this for the simplicity of the partial residual step + // We could alternatively "reach back" to the tree predictions from a previous + // sample (whenever there is more than one sample). This is cleaner / quicker + // to implement during this refactor. + forests.CopyFromPreviousSample(prev_num_samples, prev_num_samples - 1); + } else { + forests.IncrementSampleCount(); + } - void SampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - int tree_num, double global_variance, std::vector& feature_types) { - int root_id = Tree::kRoot; - int curr_node_id; - data_size_t curr_node_begin; - data_size_t curr_node_end; - data_size_t n = dataset.GetCovariates().rows(); - // Mapping from node id to start and end points of sorted indices - std::unordered_map> node_index_map; - node_index_map.insert({root_id, std::make_pair(0, n)}); - std::pair begin_end; - // Add root node to the split queue - std::deque split_queue; - split_queue.push_back(Tree::kRoot); - // Run the "GrowFromRoot" procedure using a stack in place of recursion - while (!split_queue.empty()) { - // Remove the next node from the queue - curr_node_id = split_queue.front(); - split_queue.pop_front(); - // Determine the beginning and ending indices of the left and right nodes - begin_end = node_index_map[curr_node_id]; - curr_node_begin = begin_end.first; - curr_node_end = begin_end.second; - // Draw a split rule at random - SampleSplitRule(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size_, - node_index_map, split_queue, curr_node_id, curr_node_begin, curr_node_end, variable_weights, feature_types); - } + // Run the GFR algorithm for each tree + TreeEnsemble* ensemble = forests.GetEnsemble(prev_num_samples); + int num_trees = forests.NumTrees(); + for (int i = 0; i < num_trees; i++) { + // Add tree i's predictions back to the residual (thus, training a model on the "partial residual") + Tree* tree = ensemble->GetTree(i); + UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), std::plus(), false); + + // Reset the tree and sample trackers + ensemble->ResetInitTree(i); + tracker.ResetRoot(dataset.GetCovariates(), feature_types, i); + tree = ensemble->GetTree(i); + + // Sample tree i + GFRSampleTreeOneIter(tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, global_variance, feature_types, cutpoint_grid_size); + + // Sample leaf parameters for tree i + tree = ensemble->GetTree(i); + leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen); + + // Subtract tree i's predictions back out of the residual + UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), std::minus(), true); } +} - void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, - TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, - std::unordered_map>& node_index_map, std::deque& split_queue, - int node_id, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, - std::vector& feature_types) { - // Leaf depth - int leaf_depth = tree->GetDepth(node_id); - - // Maximum leaf depth - int32_t max_depth = tree_prior.GetMaxDepth(); +template +static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, + ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, + int tree_num, double global_variance, std::vector& feature_types, int cutpoint_grid_size) { + int root_id = Tree::kRoot; + int curr_node_id; + data_size_t curr_node_begin; + data_size_t curr_node_end; + data_size_t n = dataset.GetCovariates().rows(); + // Mapping from node id to start and end points of sorted indices + std::unordered_map> node_index_map; + node_index_map.insert({root_id, std::make_pair(0, n)}); + std::pair begin_end; + // Add root node to the split queue + std::deque split_queue; + split_queue.push_back(Tree::kRoot); + // Run the "GrowFromRoot" procedure using a stack in place of recursion + while (!split_queue.empty()) { + // Remove the next node from the queue + curr_node_id = split_queue.front(); + split_queue.pop_front(); + // Determine the beginning and ending indices of the left and right nodes + begin_end = node_index_map[curr_node_id]; + curr_node_begin = begin_end.first; + curr_node_end = begin_end.second; + // Draw a split rule at random + SampleSplitRule(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size, + node_index_map, split_queue, curr_node_id, curr_node_begin, curr_node_end, variable_weights, feature_types); + } +} - if ((max_depth == -1) || (leaf_depth < max_depth)) { +template +static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, + TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, + std::unordered_map>& node_index_map, std::deque& split_queue, + int node_id, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, + std::vector& feature_types) { + // Leaf depth + int leaf_depth = tree->GetDepth(node_id); + + // Maximum leaf depth + int32_t max_depth = tree_prior.GetMaxDepth(); + + if ((max_depth == -1) || (leaf_depth < max_depth)) { + + // Cutpoint enumeration + std::vector log_cutpoint_evaluations; + std::vector cutpoint_features; + std::vector cutpoint_values; + std::vector cutpoint_feature_types; + StochTree::data_size_t valid_cutpoint_count; + CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); + EvaluateCutpoints(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, + cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features, + cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types, + cutpoint_grid_container); + // TODO: maybe add some checks here? + + // Convert log marginal likelihood to marginal likelihood, normalizing by the maximum log-likelihood + double largest_mll = *std::max_element(log_cutpoint_evaluations.begin(), log_cutpoint_evaluations.end()); + std::vector cutpoint_evaluations(log_cutpoint_evaluations.size()); + for (data_size_t i = 0; i < log_cutpoint_evaluations.size(); i++){ + cutpoint_evaluations[i] = std::exp(log_cutpoint_evaluations[i] - largest_mll); + } - // Cutpoint enumeration - std::vector log_cutpoint_evaluations; - std::vector cutpoint_features; - std::vector cutpoint_values; - std::vector cutpoint_feature_types; - StochTree::data_size_t valid_cutpoint_count; - CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); - EvaluateCutpoints(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, - cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features, - cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types, - cutpoint_grid_container); - // TODO: maybe add some checks here? + // Sample the split (including a "no split" option) + std::discrete_distribution split_dist(cutpoint_evaluations.begin(), cutpoint_evaluations.end()); + data_size_t split_chosen = split_dist(gen); + + if (split_chosen == valid_cutpoint_count){ + // "No split" sampled, don't split or add any nodes to split queue + return; + } else { + // Split sampled + int feature_split = cutpoint_features[split_chosen]; + FeatureType feature_type = cutpoint_feature_types[split_chosen]; + double split_value = cutpoint_values[split_chosen]; + // Perform all of the relevant "split" operations in the model, tree and training dataset - // Convert log marginal likelihood to marginal likelihood, normalizing by the maximum log-likelihood - double largest_mll = *std::max_element(log_cutpoint_evaluations.begin(), log_cutpoint_evaluations.end()); - std::vector cutpoint_evaluations(log_cutpoint_evaluations.size()); - for (data_size_t i = 0; i < log_cutpoint_evaluations.size(); i++){ - cutpoint_evaluations[i] = std::exp(log_cutpoint_evaluations[i] - largest_mll); - } + // Compute node sample size + data_size_t node_n = node_end - node_begin; - // Sample the split (including a "no split" option) - std::discrete_distribution split_dist(cutpoint_evaluations.begin(), cutpoint_evaluations.end()); - data_size_t split_chosen = split_dist(gen); + // Actual numeric cutpoint used for ordered categorical and numeric features + double split_value_numeric; + TreeSplit tree_split; - if (split_chosen == valid_cutpoint_count){ - // "No split" sampled, don't split or add any nodes to split queue - return; + // We will use these later in the model expansion + data_size_t left_n = 0; + data_size_t right_n = 0; + data_size_t sort_idx; + double feature_value; + bool split_true; + + if (feature_type == FeatureType::kUnorderedCategorical) { + // Determine the number of categories available in a categorical split and the set of categories that route observations to the left node after split + int num_categories; + std::vector categories = cutpoint_grid_container.CutpointVector(static_cast(split_value), feature_split); + tree_split = TreeSplit(categories); + } else if (feature_type == FeatureType::kOrderedCategorical) { + // Convert the bin split to an actual split value + split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast(split_value), feature_split); + tree_split = TreeSplit(split_value_numeric); + } else if (feature_type == FeatureType::kNumeric) { + // Convert the bin split to an actual split value + split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast(split_value), feature_split); + tree_split = TreeSplit(split_value_numeric); } else { - // Split sampled - int feature_split = cutpoint_features[split_chosen]; - FeatureType feature_type = cutpoint_feature_types[split_chosen]; - double split_value = cutpoint_values[split_chosen]; - // Perform all of the relevant "split" operations in the model, tree and training dataset - - // Compute node sample size - data_size_t node_n = node_end - node_begin; - - // Actual numeric cutpoint used for ordered categorical and numeric features - double split_value_numeric; - TreeSplit tree_split; - - // We will use these later in the model expansion - data_size_t left_n = 0; - data_size_t right_n = 0; - data_size_t sort_idx; - double feature_value; - bool split_true; - - if (feature_type == FeatureType::kUnorderedCategorical) { - // Determine the number of categories available in a categorical split and the set of categories that route observations to the left node after split - int num_categories; - std::vector categories = cutpoint_grid_container.CutpointVector(static_cast(split_value), feature_split); - tree_split = TreeSplit(categories); - } else if (feature_type == FeatureType::kOrderedCategorical) { - // Convert the bin split to an actual split value - split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast(split_value), feature_split); - tree_split = TreeSplit(split_value_numeric); - } else if (feature_type == FeatureType::kNumeric) { - // Convert the bin split to an actual split value - split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast(split_value), feature_split); - tree_split = TreeSplit(split_value_numeric); - } else { - Log::Fatal("Invalid split type"); - } - - // Add split to tree and trackers - AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true); - - // Determine the number of observation in the newly created left node - int left_node = tree->LeftChild(node_id); - int right_node = tree->RightChild(node_id); - auto left_begin_iter = tracker.SortedNodeBeginIterator(left_node, feature_split); - auto left_end_iter = tracker.SortedNodeEndIterator(left_node, feature_split); - for (auto i = left_begin_iter; i < left_end_iter; i++) { - left_n += 1; - } + Log::Fatal("Invalid split type"); + } + + // Add split to tree and trackers + AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true); + + // Determine the number of observation in the newly created left node + int left_node = tree->LeftChild(node_id); + int right_node = tree->RightChild(node_id); + auto left_begin_iter = tracker.SortedNodeBeginIterator(left_node, feature_split); + auto left_end_iter = tracker.SortedNodeEndIterator(left_node, feature_split); + for (auto i = left_begin_iter; i < left_end_iter; i++) { + left_n += 1; + } - // Add the begin and end indices for the new left and right nodes to node_index_map - node_index_map.insert({left_node, std::make_pair(node_begin, node_begin + left_n)}); - node_index_map.insert({right_node, std::make_pair(node_begin + left_n, node_end)}); + // Add the begin and end indices for the new left and right nodes to node_index_map + node_index_map.insert({left_node, std::make_pair(node_begin, node_begin + left_n)}); + node_index_map.insert({right_node, std::make_pair(node_begin + left_n, node_end)}); - // Add the left and right nodes to the split tracker - split_queue.push_front(right_node); - split_queue.push_front(left_node); - } + // Add the left and right nodes to the split tracker + split_queue.push_front(right_node); + split_queue.push_front(left_node); } } +} - void EvaluateCutpoints(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, - std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, int node_id, data_size_t node_begin, data_size_t node_end, - std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, - std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector& variable_weights, - std::vector& feature_types, CutpointGridContainer& cutpoint_grid_container) { - // Evaluate all possible cutpoints according to the leaf node model, - // recording their log-likelihood and other split information in a series of vectors. - // The last element of these vectors concerns the "no-split" option. - leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, tree_num, node_id, log_cutpoint_evaluations, - cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, - cutpoint_grid_container, node_begin, node_end, variable_weights, feature_types); - - // Compute an adjustment to reflect the no split prior probability and the number of cutpoints - double bart_prior_no_split_adj; - double alpha = tree_prior.GetAlpha(); - double beta = tree_prior.GetBeta(); - int node_depth = tree->GetDepth(node_id); - if (valid_cutpoint_count == 0) { - bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0); - } else { - bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0) + std::log(valid_cutpoint_count); - } - log_cutpoint_evaluations[log_cutpoint_evaluations.size()-1] += bart_prior_no_split_adj; +template +static inline void EvaluateCutpoints(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, + std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, int node_id, data_size_t node_begin, data_size_t node_end, + std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, + std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector& variable_weights, + std::vector& feature_types, CutpointGridContainer& cutpoint_grid_container) { + // Evaluate all possible cutpoints according to the leaf node model, + // recording their log-likelihood and other split information in a series of vectors. + // The last element of these vectors concerns the "no-split" option. + leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, tree_num, node_id, log_cutpoint_evaluations, + cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, + cutpoint_grid_container, node_begin, node_end, variable_weights, feature_types); + + // Compute an adjustment to reflect the no split prior probability and the number of cutpoints + double bart_prior_no_split_adj; + double alpha = tree_prior.GetAlpha(); + double beta = tree_prior.GetBeta(); + int node_depth = tree->GetDepth(node_id); + if (valid_cutpoint_count == 0) { + bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0); + } else { + bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0) + std::log(valid_cutpoint_count); } - -}; + log_cutpoint_evaluations[log_cutpoint_evaluations.size()-1] += bart_prior_no_split_adj; +} } // namespace StochTree diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 3c8ca606..47e9e26b 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -512,16 +512,13 @@ class ForestSamplerCpp { Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size, bool pre_initialized) { if (leaf_model_enum == ForestLeafModel::kConstant) { StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(leaf_scale); - StochTree::GFRForestSampler sampler = StochTree::GFRForestSampler(cutpoint_grid_size); - sampler.SampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, pre_initialized); + GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); - StochTree::GFRForestSampler sampler = StochTree::GFRForestSampler(cutpoint_grid_size); - sampler.SampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, pre_initialized); + GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - StochTree::GFRForestSampler sampler = StochTree::GFRForestSampler(cutpoint_grid_size); - sampler.SampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, pre_initialized); + GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); } } @@ -530,16 +527,13 @@ class ForestSamplerCpp { Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size, bool pre_initialized) { if (leaf_model_enum == ForestLeafModel::kConstant) { StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(leaf_scale); - StochTree::MCMCForestSampler sampler = StochTree::MCMCForestSampler(); - sampler.SampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); + MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); - StochTree::MCMCForestSampler sampler = StochTree::MCMCForestSampler(); - sampler.SampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); + MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - StochTree::MCMCForestSampler sampler = StochTree::MCMCForestSampler(); - sampler.SampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); + MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); } } }; diff --git a/src/sampler.cpp b/src/sampler.cpp index 0edf6a7a..bfb0fe6e 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -61,16 +61,13 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer sampler = StochTree::GFRForestSampler(cutpoint_grid_size); - sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, pre_initialized); + GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); - StochTree::GFRForestSampler sampler = StochTree::GFRForestSampler(cutpoint_grid_size); - sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, pre_initialized); + GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - StochTree::GFRForestSampler sampler = StochTree::GFRForestSampler(cutpoint_grid_size); - sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, pre_initialized); + GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); } } @@ -125,16 +122,13 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer sampler = StochTree::MCMCForestSampler(); - sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); - StochTree::MCMCForestSampler sampler = StochTree::MCMCForestSampler(); - sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { - StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - StochTree::MCMCForestSampler sampler = StochTree::MCMCForestSampler(); - sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); + MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); } } From c3bfa66cf86bc1c23f3e011c437b9e111987b98c Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 24 Aug 2024 00:05:57 -0500 Subject: [PATCH 09/41] Fixed R package bug and rearranged tree_sampler header file --- include/stochtree/tree_sampler.h | 470 +++++++++++++++---------------- src/sampler.cpp | 4 +- 2 files changed, 237 insertions(+), 237 deletions(-) diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 9db97f02..302ae6bf 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -234,9 +234,169 @@ static inline void UpdateResidualNewBasis(ForestTracker& tracker, ForestDataset& } template -static inline void MCMCSampleOneIter(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - double global_variance, bool pre_initialized = false) { +static inline void EvaluateCutpoints(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, + std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, int node_id, data_size_t node_begin, data_size_t node_end, + std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, + std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector& variable_weights, + std::vector& feature_types, CutpointGridContainer& cutpoint_grid_container) { + // Evaluate all possible cutpoints according to the leaf node model, + // recording their log-likelihood and other split information in a series of vectors. + // The last element of these vectors concerns the "no-split" option. + leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, tree_num, node_id, log_cutpoint_evaluations, + cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, + cutpoint_grid_container, node_begin, node_end, variable_weights, feature_types); + + // Compute an adjustment to reflect the no split prior probability and the number of cutpoints + double bart_prior_no_split_adj; + double alpha = tree_prior.GetAlpha(); + double beta = tree_prior.GetBeta(); + int node_depth = tree->GetDepth(node_id); + if (valid_cutpoint_count == 0) { + bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0); + } else { + bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0) + std::log(valid_cutpoint_count); + } + log_cutpoint_evaluations[log_cutpoint_evaluations.size()-1] += bart_prior_no_split_adj; +} + +template +static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, + TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, + std::unordered_map>& node_index_map, std::deque& split_queue, + int node_id, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, + std::vector& feature_types) { + // Leaf depth + int leaf_depth = tree->GetDepth(node_id); + + // Maximum leaf depth + int32_t max_depth = tree_prior.GetMaxDepth(); + + if ((max_depth == -1) || (leaf_depth < max_depth)) { + + // Cutpoint enumeration + std::vector log_cutpoint_evaluations; + std::vector cutpoint_features; + std::vector cutpoint_values; + std::vector cutpoint_feature_types; + StochTree::data_size_t valid_cutpoint_count; + CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); + EvaluateCutpoints(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, + cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features, + cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types, + cutpoint_grid_container); + // TODO: maybe add some checks here? + + // Convert log marginal likelihood to marginal likelihood, normalizing by the maximum log-likelihood + double largest_mll = *std::max_element(log_cutpoint_evaluations.begin(), log_cutpoint_evaluations.end()); + std::vector cutpoint_evaluations(log_cutpoint_evaluations.size()); + for (data_size_t i = 0; i < log_cutpoint_evaluations.size(); i++){ + cutpoint_evaluations[i] = std::exp(log_cutpoint_evaluations[i] - largest_mll); + } + + // Sample the split (including a "no split" option) + std::discrete_distribution split_dist(cutpoint_evaluations.begin(), cutpoint_evaluations.end()); + data_size_t split_chosen = split_dist(gen); + + if (split_chosen == valid_cutpoint_count){ + // "No split" sampled, don't split or add any nodes to split queue + return; + } else { + // Split sampled + int feature_split = cutpoint_features[split_chosen]; + FeatureType feature_type = cutpoint_feature_types[split_chosen]; + double split_value = cutpoint_values[split_chosen]; + // Perform all of the relevant "split" operations in the model, tree and training dataset + + // Compute node sample size + data_size_t node_n = node_end - node_begin; + + // Actual numeric cutpoint used for ordered categorical and numeric features + double split_value_numeric; + TreeSplit tree_split; + + // We will use these later in the model expansion + data_size_t left_n = 0; + data_size_t right_n = 0; + data_size_t sort_idx; + double feature_value; + bool split_true; + + if (feature_type == FeatureType::kUnorderedCategorical) { + // Determine the number of categories available in a categorical split and the set of categories that route observations to the left node after split + int num_categories; + std::vector categories = cutpoint_grid_container.CutpointVector(static_cast(split_value), feature_split); + tree_split = TreeSplit(categories); + } else if (feature_type == FeatureType::kOrderedCategorical) { + // Convert the bin split to an actual split value + split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast(split_value), feature_split); + tree_split = TreeSplit(split_value_numeric); + } else if (feature_type == FeatureType::kNumeric) { + // Convert the bin split to an actual split value + split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast(split_value), feature_split); + tree_split = TreeSplit(split_value_numeric); + } else { + Log::Fatal("Invalid split type"); + } + + // Add split to tree and trackers + AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true); + + // Determine the number of observation in the newly created left node + int left_node = tree->LeftChild(node_id); + int right_node = tree->RightChild(node_id); + auto left_begin_iter = tracker.SortedNodeBeginIterator(left_node, feature_split); + auto left_end_iter = tracker.SortedNodeEndIterator(left_node, feature_split); + for (auto i = left_begin_iter; i < left_end_iter; i++) { + left_n += 1; + } + + // Add the begin and end indices for the new left and right nodes to node_index_map + node_index_map.insert({left_node, std::make_pair(node_begin, node_begin + left_n)}); + node_index_map.insert({right_node, std::make_pair(node_begin + left_n, node_end)}); + + // Add the left and right nodes to the split tracker + split_queue.push_front(right_node); + split_queue.push_front(left_node); + } + } +} + +template +static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, + ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, + int tree_num, double global_variance, std::vector& feature_types, int cutpoint_grid_size) { + int root_id = Tree::kRoot; + int curr_node_id; + data_size_t curr_node_begin; + data_size_t curr_node_end; + data_size_t n = dataset.GetCovariates().rows(); + // Mapping from node id to start and end points of sorted indices + std::unordered_map> node_index_map; + node_index_map.insert({root_id, std::make_pair(0, n)}); + std::pair begin_end; + // Add root node to the split queue + std::deque split_queue; + split_queue.push_back(Tree::kRoot); + // Run the "GrowFromRoot" procedure using a stack in place of recursion + while (!split_queue.empty()) { + // Remove the next node from the queue + curr_node_id = split_queue.front(); + split_queue.pop_front(); + // Determine the beginning and ending indices of the left and right nodes + begin_end = node_index_map[curr_node_id]; + curr_node_begin = begin_end.first; + curr_node_end = begin_end.second; + // Draw a split rule at random + SampleSplitRule(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size, + node_index_map, split_queue, curr_node_id, curr_node_begin, curr_node_end, variable_weights, feature_types); + } +} + +template +static inline void GFRSampleOneIter(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, + ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, + double global_variance, std::vector& feature_types, int cutpoint_grid_size = 500, + bool pre_initialized = false) { // Previous number of samples int prev_num_samples = forests.NumSamples(); @@ -251,84 +411,41 @@ static inline void MCMCSampleOneIter(ForestTracker& tracker, ForestContainer& fo } else if (prev_num_samples > 0) { // Add new forest to the container forests.AddSamples(1); - - // Copy previous forest + + // NOTE: only doing this for the simplicity of the partial residual step + // We could alternatively "reach back" to the tree predictions from a previous + // sample (whenever there is more than one sample). This is cleaner / quicker + // to implement during this refactor. forests.CopyFromPreviousSample(prev_num_samples, prev_num_samples - 1); } else { forests.IncrementSampleCount(); } - // Run the MCMC algorithm for each tree + // Run the GFR algorithm for each tree TreeEnsemble* ensemble = forests.GetEnsemble(prev_num_samples); - Tree* tree; int num_trees = forests.NumTrees(); for (int i = 0; i < num_trees; i++) { // Add tree i's predictions back to the residual (thus, training a model on the "partial residual") - tree = ensemble->GetTree(i); + Tree* tree = ensemble->GetTree(i); UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), std::plus(), false); - // Sample tree i + // Reset the tree and sample trackers + ensemble->ResetInitTree(i); + tracker.ResetRoot(dataset.GetCovariates(), feature_types, i); tree = ensemble->GetTree(i); - MCMCSampleTreeOneIter(tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, global_variance); + + // Sample tree i + GFRSampleTreeOneIter(tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, global_variance, feature_types, cutpoint_grid_size); // Sample leaf parameters for tree i tree = ensemble->GetTree(i); leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen); // Subtract tree i's predictions back out of the residual - tree = ensemble->GetTree(i); UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), std::minus(), true); } } -template -static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - int tree_num, double global_variance) { - // Determine whether it is possible to grow any of the leaves - bool grow_possible = false; - std::vector leaves = tree->GetLeaves(); - for (auto& leaf: leaves) { - if (tracker.UnsortedNodeSize(tree_num, leaf) > 2 * tree_prior.GetMinSamplesLeaf()) { - grow_possible = true; - break; - } - } - - // Determine whether it is possible to prune the tree - bool prune_possible = false; - if (tree->NumValidNodes() > 1) { - prune_possible = true; - } - - // Determine the relative probability of grow vs prune (0 = grow, 1 = prune) - double prob_grow; - std::vector step_probs(2); - if (grow_possible && prune_possible) { - step_probs = {0.5, 0.5}; - prob_grow = 0.5; - } else if (!grow_possible && prune_possible) { - step_probs = {0.0, 1.0}; - prob_grow = 0.0; - } else if (grow_possible && !prune_possible) { - step_probs = {1.0, 0.0}; - prob_grow = 1.0; - } else { - Log::Fatal("In this tree, neither grow nor prune is possible"); - } - std::discrete_distribution<> step_dist(step_probs.begin(), step_probs.end()); - - // Draw a split rule at random - data_size_t step_chosen = step_dist(gen); - bool accept; - - if (step_chosen == 0) { - MCMCGrowTreeOneIter(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow); - } else { - MCMCPruneTreeOneIter(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance); - } -} - template static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, int tree_num, std::vector& variable_weights, @@ -506,10 +623,57 @@ static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, Leaf } template -static inline void GFRSampleOneIter(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - double global_variance, std::vector& feature_types, int cutpoint_grid_size = 500, - bool pre_initialized = false) { +static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, + ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, + int tree_num, double global_variance) { + // Determine whether it is possible to grow any of the leaves + bool grow_possible = false; + std::vector leaves = tree->GetLeaves(); + for (auto& leaf: leaves) { + if (tracker.UnsortedNodeSize(tree_num, leaf) > 2 * tree_prior.GetMinSamplesLeaf()) { + grow_possible = true; + break; + } + } + + // Determine whether it is possible to prune the tree + bool prune_possible = false; + if (tree->NumValidNodes() > 1) { + prune_possible = true; + } + + // Determine the relative probability of grow vs prune (0 = grow, 1 = prune) + double prob_grow; + std::vector step_probs(2); + if (grow_possible && prune_possible) { + step_probs = {0.5, 0.5}; + prob_grow = 0.5; + } else if (!grow_possible && prune_possible) { + step_probs = {0.0, 1.0}; + prob_grow = 0.0; + } else if (grow_possible && !prune_possible) { + step_probs = {1.0, 0.0}; + prob_grow = 1.0; + } else { + Log::Fatal("In this tree, neither grow nor prune is possible"); + } + std::discrete_distribution<> step_dist(step_probs.begin(), step_probs.end()); + + // Draw a split rule at random + data_size_t step_chosen = step_dist(gen); + bool accept; + + if (step_chosen == 0) { + MCMCGrowTreeOneIter(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow); + } else { + MCMCPruneTreeOneIter(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance); + } +} + +template +static inline void MCMCSampleOneIter(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, + ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, + double global_variance, bool pre_initialized = false) { // Previous number of samples int prev_num_samples = forests.NumSamples(); @@ -524,200 +688,36 @@ static inline void GFRSampleOneIter(ForestTracker& tracker, ForestContainer& for } else if (prev_num_samples > 0) { // Add new forest to the container forests.AddSamples(1); - - // NOTE: only doing this for the simplicity of the partial residual step - // We could alternatively "reach back" to the tree predictions from a previous - // sample (whenever there is more than one sample). This is cleaner / quicker - // to implement during this refactor. + + // Copy previous forest forests.CopyFromPreviousSample(prev_num_samples, prev_num_samples - 1); } else { forests.IncrementSampleCount(); } - // Run the GFR algorithm for each tree + // Run the MCMC algorithm for each tree TreeEnsemble* ensemble = forests.GetEnsemble(prev_num_samples); + Tree* tree; int num_trees = forests.NumTrees(); for (int i = 0; i < num_trees; i++) { // Add tree i's predictions back to the residual (thus, training a model on the "partial residual") - Tree* tree = ensemble->GetTree(i); - UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), std::plus(), false); - - // Reset the tree and sample trackers - ensemble->ResetInitTree(i); - tracker.ResetRoot(dataset.GetCovariates(), feature_types, i); tree = ensemble->GetTree(i); + UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), std::plus(), false); // Sample tree i - GFRSampleTreeOneIter(tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, global_variance, feature_types, cutpoint_grid_size); + tree = ensemble->GetTree(i); + MCMCSampleTreeOneIter(tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, global_variance); // Sample leaf parameters for tree i tree = ensemble->GetTree(i); leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen); // Subtract tree i's predictions back out of the residual + tree = ensemble->GetTree(i); UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), std::minus(), true); } } -template -static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - int tree_num, double global_variance, std::vector& feature_types, int cutpoint_grid_size) { - int root_id = Tree::kRoot; - int curr_node_id; - data_size_t curr_node_begin; - data_size_t curr_node_end; - data_size_t n = dataset.GetCovariates().rows(); - // Mapping from node id to start and end points of sorted indices - std::unordered_map> node_index_map; - node_index_map.insert({root_id, std::make_pair(0, n)}); - std::pair begin_end; - // Add root node to the split queue - std::deque split_queue; - split_queue.push_back(Tree::kRoot); - // Run the "GrowFromRoot" procedure using a stack in place of recursion - while (!split_queue.empty()) { - // Remove the next node from the queue - curr_node_id = split_queue.front(); - split_queue.pop_front(); - // Determine the beginning and ending indices of the left and right nodes - begin_end = node_index_map[curr_node_id]; - curr_node_begin = begin_end.first; - curr_node_end = begin_end.second; - // Draw a split rule at random - SampleSplitRule(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size, - node_index_map, split_queue, curr_node_id, curr_node_begin, curr_node_end, variable_weights, feature_types); - } -} - -template -static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, - TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, - std::unordered_map>& node_index_map, std::deque& split_queue, - int node_id, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, - std::vector& feature_types) { - // Leaf depth - int leaf_depth = tree->GetDepth(node_id); - - // Maximum leaf depth - int32_t max_depth = tree_prior.GetMaxDepth(); - - if ((max_depth == -1) || (leaf_depth < max_depth)) { - - // Cutpoint enumeration - std::vector log_cutpoint_evaluations; - std::vector cutpoint_features; - std::vector cutpoint_values; - std::vector cutpoint_feature_types; - StochTree::data_size_t valid_cutpoint_count; - CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); - EvaluateCutpoints(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, - cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features, - cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types, - cutpoint_grid_container); - // TODO: maybe add some checks here? - - // Convert log marginal likelihood to marginal likelihood, normalizing by the maximum log-likelihood - double largest_mll = *std::max_element(log_cutpoint_evaluations.begin(), log_cutpoint_evaluations.end()); - std::vector cutpoint_evaluations(log_cutpoint_evaluations.size()); - for (data_size_t i = 0; i < log_cutpoint_evaluations.size(); i++){ - cutpoint_evaluations[i] = std::exp(log_cutpoint_evaluations[i] - largest_mll); - } - - // Sample the split (including a "no split" option) - std::discrete_distribution split_dist(cutpoint_evaluations.begin(), cutpoint_evaluations.end()); - data_size_t split_chosen = split_dist(gen); - - if (split_chosen == valid_cutpoint_count){ - // "No split" sampled, don't split or add any nodes to split queue - return; - } else { - // Split sampled - int feature_split = cutpoint_features[split_chosen]; - FeatureType feature_type = cutpoint_feature_types[split_chosen]; - double split_value = cutpoint_values[split_chosen]; - // Perform all of the relevant "split" operations in the model, tree and training dataset - - // Compute node sample size - data_size_t node_n = node_end - node_begin; - - // Actual numeric cutpoint used for ordered categorical and numeric features - double split_value_numeric; - TreeSplit tree_split; - - // We will use these later in the model expansion - data_size_t left_n = 0; - data_size_t right_n = 0; - data_size_t sort_idx; - double feature_value; - bool split_true; - - if (feature_type == FeatureType::kUnorderedCategorical) { - // Determine the number of categories available in a categorical split and the set of categories that route observations to the left node after split - int num_categories; - std::vector categories = cutpoint_grid_container.CutpointVector(static_cast(split_value), feature_split); - tree_split = TreeSplit(categories); - } else if (feature_type == FeatureType::kOrderedCategorical) { - // Convert the bin split to an actual split value - split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast(split_value), feature_split); - tree_split = TreeSplit(split_value_numeric); - } else if (feature_type == FeatureType::kNumeric) { - // Convert the bin split to an actual split value - split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast(split_value), feature_split); - tree_split = TreeSplit(split_value_numeric); - } else { - Log::Fatal("Invalid split type"); - } - - // Add split to tree and trackers - AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true); - - // Determine the number of observation in the newly created left node - int left_node = tree->LeftChild(node_id); - int right_node = tree->RightChild(node_id); - auto left_begin_iter = tracker.SortedNodeBeginIterator(left_node, feature_split); - auto left_end_iter = tracker.SortedNodeEndIterator(left_node, feature_split); - for (auto i = left_begin_iter; i < left_end_iter; i++) { - left_n += 1; - } - - // Add the begin and end indices for the new left and right nodes to node_index_map - node_index_map.insert({left_node, std::make_pair(node_begin, node_begin + left_n)}); - node_index_map.insert({right_node, std::make_pair(node_begin + left_n, node_end)}); - - // Add the left and right nodes to the split tracker - split_queue.push_front(right_node); - split_queue.push_front(left_node); - } - } -} - -template -static inline void EvaluateCutpoints(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, - std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, int node_id, data_size_t node_begin, data_size_t node_end, - std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, - std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector& variable_weights, - std::vector& feature_types, CutpointGridContainer& cutpoint_grid_container) { - // Evaluate all possible cutpoints according to the leaf node model, - // recording their log-likelihood and other split information in a series of vectors. - // The last element of these vectors concerns the "no-split" option. - leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, tree_num, node_id, log_cutpoint_evaluations, - cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, - cutpoint_grid_container, node_begin, node_end, variable_weights, feature_types); - - // Compute an adjustment to reflect the no split prior probability and the number of cutpoints - double bart_prior_no_split_adj; - double alpha = tree_prior.GetAlpha(); - double beta = tree_prior.GetBeta(); - int node_depth = tree->GetDepth(node_id); - if (valid_cutpoint_count == 0) { - bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0); - } else { - bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0) + std::log(valid_cutpoint_count); - } - log_cutpoint_evaluations[log_cutpoint_evaluations.size()-1] += bart_prior_no_split_adj; -} - } // namespace StochTree #endif // STOCHTREE_TREE_SAMPLER_H_ \ No newline at end of file diff --git a/src/sampler.cpp b/src/sampler.cpp index bfb0fe6e..2fc241e6 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -127,8 +127,8 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { - StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); + MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); } } From e934db26d2a52ae665a1e5ec2d2428c3739d2bad Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 24 Aug 2024 00:12:42 -0500 Subject: [PATCH 10/41] Added StochTree scope to sampler function calls --- src/py_stochtree.cpp | 12 ++++++------ src/sampler.cpp | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 47e9e26b..7025d8a9 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -512,13 +512,13 @@ class ForestSamplerCpp { Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size, bool pre_initialized) { if (leaf_model_enum == ForestLeafModel::kConstant) { StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(leaf_scale); - GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); + StochTree::GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); - GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); + StochTree::GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); + StochTree::GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); } } @@ -527,13 +527,13 @@ class ForestSamplerCpp { Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size, bool pre_initialized) { if (leaf_model_enum == ForestLeafModel::kConstant) { StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(leaf_scale); - MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); + StochTree::MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); - MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); + StochTree::MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); + StochTree::MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); } } }; diff --git a/src/sampler.cpp b/src/sampler.cpp index 2fc241e6..1229d6f0 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -61,13 +61,13 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); + StochTree::GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); - GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); + StochTree::GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); + StochTree::GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); } } @@ -122,13 +122,13 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + StochTree::MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); - MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + StochTree::MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + StochTree::MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); } } From 9646a0844488e2ef3db58284764ff720b09a593b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 27 Aug 2024 00:46:12 -0500 Subject: [PATCH 11/41] Refactor sampler iteration to avoid incremental object creation --- debug/README.md | 15 +- debug/api_debug.cpp | 121 ++++---- include/stochtree/leaf_model.h | 155 ++++++++-- include/stochtree/meta.h | 10 +- include/stochtree/tree_sampler.h | 248 ++++++++++++--- src/leaf_model.cpp | 501 ------------------------------- src/py_stochtree.cpp | 12 +- src/sampler.cpp | 12 +- 8 files changed, 444 insertions(+), 630 deletions(-) diff --git a/debug/README.md b/debug/README.md index 907a4ec0..e740b4ef 100644 --- a/debug/README.md +++ b/debug/README.md @@ -4,12 +4,19 @@ This subdirectory contains a debug program for the C++ codebase. The program takes several command line arguments (in order): 1. Which data-generating process (DGP) to run (integer-coded, see below for a detailed description) -2. Whether or not to include random effects (0 = no, 1 = yes) -3. Number of grow-from-root (GFR) samples -4. Number of MCMC samples -5. Seed for random number generator (-1 means we defer to C++ `std::random_device`) +1. Which leaf model to sample (integer-coded, see below for a detailed description) +3. Whether or not to include random effects (0 = no, 1 = yes) +4. Number of grow-from-root (GFR) samples +5. Number of MCMC samples +6. Seed for random number generator (-1 means we defer to C++ `std::random_device`) The DGPs are numbered as follows: 0. Simple leaf regression model with a univariate basis for the leaf model 1. Constant leaf model with a large number of deep interactions between features + +The models are numbered as follows: + +0. Constant leaf tree model (the "classic" BART / XBART model) +1. "Univariate basis" leaf regression model +2. "Multivariate basis" leaf regression model diff --git a/debug/api_debug.cpp b/debug/api_debug.cpp index d7420d5f..5d84c00b 100644 --- a/debug/api_debug.cpp +++ b/debug/api_debug.cpp @@ -16,16 +16,11 @@ #include #include #include +#include #include namespace StochTree{ -enum ForestLeafModel { - kConstant, - kUnivariateRegression, - kMultivariateRegression -}; - void GenerateDGP1(std::vector& covariates, std::vector& basis, std::vector& outcome, std::vector& rfx_basis, std::vector& rfx_groups, std::vector& feature_types, std::mt19937& gen, int& n, int& x_cols, int& omega_cols, int& y_cols, int& rfx_basis_cols, int& num_rfx_groups, bool rfx_included, int random_seed = -1) { // Data dimensions n = 1000; @@ -265,37 +260,37 @@ void OutcomeOffsetScale(ColumnVector& residual, double& outcome_offset, double& } } -void sampleGFR(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& forest_samples, ForestDataset& dataset, - ColumnVector& residual, std::mt19937& rng, std::vector& feature_types, std::vector& var_weights_vector, - ForestLeafModel leaf_model_type, Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size) { - if (leaf_model_type == ForestLeafModel::kConstant) { - GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_scale); - GFRSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size); - } else if (leaf_model_type == ForestLeafModel::kUnivariateRegression) { - GaussianUnivariateRegressionLeafModel leaf_model = GaussianUnivariateRegressionLeafModel(leaf_scale); - GFRSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size); - } else if (leaf_model_type == ForestLeafModel::kMultivariateRegression) { - GaussianMultivariateRegressionLeafModel leaf_model = GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - GFRSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size); - } -} - -void sampleMCMC(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& forest_samples, ForestDataset& dataset, - ColumnVector& residual, std::mt19937& rng, std::vector& feature_types, std::vector& var_weights_vector, - ForestLeafModel leaf_model_type, Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size) { - if (leaf_model_type == ForestLeafModel::kConstant) { - GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_scale); - MCMCSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); - } else if (leaf_model_type == ForestLeafModel::kUnivariateRegression) { - GaussianUnivariateRegressionLeafModel leaf_model = GaussianUnivariateRegressionLeafModel(leaf_scale); - MCMCSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); - } else if (leaf_model_type == ForestLeafModel::kMultivariateRegression) { - GaussianMultivariateRegressionLeafModel leaf_model = GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - MCMCSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); - } -} - -void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int num_mcmc = 100, int random_seed = -1) { +// void sampleGFR(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& forest_samples, ForestDataset& dataset, +// ColumnVector& residual, std::mt19937& rng, std::vector& feature_types, std::vector& var_weights_vector, +// ForestLeafModel leaf_model_type, Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size) { +// if (leaf_model_type == ForestLeafModel::kConstant) { +// GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_scale); +// GFRSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size); +// } else if (leaf_model_type == ForestLeafModel::kUnivariateRegression) { +// GaussianUnivariateRegressionLeafModel leaf_model = GaussianUnivariateRegressionLeafModel(leaf_scale); +// GFRSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size); +// } else if (leaf_model_type == ForestLeafModel::kMultivariateRegression) { +// GaussianMultivariateRegressionLeafModel leaf_model = GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); +// GFRSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size); +// } +// } + +// void sampleMCMC(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& forest_samples, ForestDataset& dataset, +// ColumnVector& residual, std::mt19937& rng, std::vector& feature_types, std::vector& var_weights_vector, +// ForestLeafModel leaf_model_type, Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size) { +// if (leaf_model_type == ForestLeafModel::kConstant) { +// GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_scale); +// MCMCSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); +// } else if (leaf_model_type == ForestLeafModel::kUnivariateRegression) { +// GaussianUnivariateRegressionLeafModel leaf_model = GaussianUnivariateRegressionLeafModel(leaf_scale); +// MCMCSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); +// } else if (leaf_model_type == ForestLeafModel::kMultivariateRegression) { +// GaussianMultivariateRegressionLeafModel leaf_model = GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); +// MCMCSampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); +// } +// } + +void RunDebug(int dgp_num = 0, ModelType model_type = kConstantLeafGaussian, bool rfx_included = false, int num_gfr = 10, int num_mcmc = 100, int random_seed = -1) { // Flag the data as row-major bool row_major = true; @@ -326,24 +321,26 @@ void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int std::vector rfx_groups; std::vector feature_types; + // Check for DGP : ModelType compatibility + if ((model_type != kConstantLeafGaussian) && (dgp_num == 1)) { + Log::Fatal("dgp 2 is only compatible with a constant leaf model"); + } + // Generate the data int output_dimension; bool is_leaf_constant; - ForestLeafModel leaf_model_type; if (dgp_num == 0) { GenerateDGP1(covariates_raw, basis_raw, outcome_raw, rfx_basis_raw, rfx_groups, feature_types, gen, n, x_cols, omega_cols, y_cols, rfx_basis_cols, num_rfx_groups, rfx_included, random_seed); dataset.AddCovariates(covariates_raw.data(), n, x_cols, row_major); dataset.AddBasis(basis_raw.data(), n, omega_cols, row_major); output_dimension = 1; is_leaf_constant = false; - leaf_model_type = ForestLeafModel::kUnivariateRegression; } else if (dgp_num == 1) { GenerateDGP2(covariates_raw, basis_raw, outcome_raw, rfx_basis_raw, rfx_groups, feature_types, gen, n, x_cols, omega_cols, y_cols, rfx_basis_cols, num_rfx_groups, rfx_included, random_seed); dataset.AddCovariates(covariates_raw.data(), n, x_cols, row_major); output_dimension = 1; is_leaf_constant = true; - leaf_model_type = ForestLeafModel::kConstant; } else { Log::Fatal("Invalid dgp_num"); @@ -441,6 +438,9 @@ void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int std::vector global_variance_samples{}; std::vector leaf_variance_samples{}; + // Prepare the samplers + LeafModelVariant leaf_model = leafModelFactory(model_type, leaf_scale, leaf_scale_matrix); + // Run the GFR sampler if (num_gfr > 0) { for (int i = 0; i < num_gfr; i++) { @@ -454,8 +454,13 @@ void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int } // Sample tree ensemble - sampleGFR(tracker, tree_prior, forest_samples, dataset, residual, gen, feature_types, variable_weights, - leaf_model_type, leaf_scale_matrix, global_variance, leaf_scale, cutpoint_grid_size); + if (model_type == ModelType::kConstantLeafGaussian) { + GFRSampleOneIter(tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, false); + } else if (model_type == ModelType::kUnivariateRegressionLeafGaussian) { + GFRSampleOneIter(tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, false); + } else if (model_type == ModelType::kMultivariateRegressionLeafGaussian) { + GFRSampleOneIter(tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, false, omega_cols); + } if (rfx_included) { // Sample random effects @@ -484,8 +489,13 @@ void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int } // Sample tree ensemble - sampleMCMC(tracker, tree_prior, forest_samples, dataset, residual, gen, feature_types, variable_weights, - leaf_model_type, leaf_scale_matrix, global_variance, leaf_scale, cutpoint_grid_size); + if (model_type == ModelType::kConstantLeafGaussian) { + MCMCSampleOneIter(tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, false); + } else if (model_type == ModelType::kUnivariateRegressionLeafGaussian) { + MCMCSampleOneIter(tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, false); + } else if (model_type == ModelType::kMultivariateRegressionLeafGaussian) { + MCMCSampleOneIter(tracker, forest_samples, std::get(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, false, omega_cols); + } if (rfx_included) { // Sample random effects @@ -531,24 +541,29 @@ int main(int argc, char* argv[]) { if ((dgp_num != 0) && (dgp_num != 1)) { StochTree::Log::Fatal("The first command line argument must be 0 or 1"); } - int rfx_int = std::stoi(argv[2]); + int model_type_int = static_cast(std::stoi(argv[2])); + if ((model_type_int != 0) && (model_type_int != 1) && (model_type_int != 2)) { + StochTree::Log::Fatal("The second command line argument must be 0, 1, or 2"); + } + StochTree::ModelType model_type = static_cast(model_type_int); + int rfx_int = std::stoi(argv[3]); if ((rfx_int != 0) && (rfx_int != 1)) { - StochTree::Log::Fatal("The second command line argument must be 0 or 1"); + StochTree::Log::Fatal("The third command line argument must be 0 or 1"); } bool rfx_included = static_cast(rfx_int); - int num_gfr = std::stoi(argv[3]); + int num_gfr = std::stoi(argv[4]); if (num_gfr < 0) { - StochTree::Log::Fatal("The third command line argument must be >= 0"); + StochTree::Log::Fatal("The fourth command line argument must be >= 0"); } - int num_mcmc = std::stoi(argv[4]); + int num_mcmc = std::stoi(argv[5]); if (num_mcmc < 0) { - StochTree::Log::Fatal("The fourth command line argument must be >= 0"); + StochTree::Log::Fatal("The fifth command line argument must be >= 0"); } - int random_seed = std::stoi(argv[5]); + int random_seed = std::stoi(argv[6]); if (random_seed < -1) { - StochTree::Log::Fatal("The fifth command line argument must be >= -0"); + StochTree::Log::Fatal("The sixth command line argument must be >= -0"); } // Run the debug program - StochTree::RunDebug(dgp_num, rfx_included, num_gfr, num_mcmc); + StochTree::RunDebug(dgp_num, model_type, rfx_included, num_gfr, num_mcmc); } diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index 3ea7a8bb..f006566f 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -20,6 +20,12 @@ namespace StochTree { +enum ModelType { + kConstantLeafGaussian, + kUnivariateRegressionLeafGaussian, + kMultivariateRegressionLeafGaussian +}; + /*! \brief Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model */ class GaussianConstantSuffStat { public: @@ -67,12 +73,6 @@ class GaussianConstantLeafModel { public: GaussianConstantLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();} ~GaussianConstantLeafModel() {} - std::tuple EvaluateProposedSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance); - std::tuple EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id); - void EvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int split_node_id, - std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, std::vector& cutpoint_feature_types, - data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, - std::vector& feature_types); double SplitLogMarginalLikelihood(GaussianConstantSuffStat& left_stat, GaussianConstantSuffStat& right_stat, double global_variance); double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat& suff_stat, double global_variance); double PosteriorParameterMean(GaussianConstantSuffStat& suff_stat, double global_variance); @@ -133,12 +133,6 @@ class GaussianUnivariateRegressionLeafModel { public: GaussianUnivariateRegressionLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();} ~GaussianUnivariateRegressionLeafModel() {} - std::tuple EvaluateProposedSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance); - std::tuple EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id); - void EvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int split_node_id, - std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, std::vector& cutpoint_feature_types, - data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, - std::vector& feature_types); double SplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& left_stat, GaussianUnivariateRegressionSuffStat& right_stat, double global_variance); double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance); double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance); @@ -201,12 +195,6 @@ class GaussianMultivariateRegressionLeafModel { public: GaussianMultivariateRegressionLeafModel(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0; multivariate_normal_sampler_ = MultivariateNormalSampler();} ~GaussianMultivariateRegressionLeafModel() {} - std::tuple EvaluateProposedSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance); - std::tuple EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id); - void EvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int split_node_id, - std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, std::vector& cutpoint_feature_types, - data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, - std::vector& feature_types); double SplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& left_stat, GaussianMultivariateRegressionSuffStat& right_stat, double global_variance); double NoSplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance); Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance); @@ -220,6 +208,137 @@ class GaussianMultivariateRegressionLeafModel { MultivariateNormalSampler multivariate_normal_sampler_; }; +using SuffStatVariant = std::variant; + +using LeafModelVariant = std::variant; + +template +static inline SuffStatVariant createSuffStat(SuffStatConstructorArgs... leaf_suff_stat_args) { + return SuffStatType(leaf_suff_stat_args...); +} + +template +static inline LeafModelVariant createLeafModel(LeafModelConstructorArgs... leaf_model_args) { + return LeafModelType(leaf_model_args...); +} + +static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_dim = 0) { + if (model_type == kConstantLeafGaussian) { + return createSuffStat(); + } else if (model_type == kUnivariateRegressionLeafGaussian) { + return createSuffStat(); + } else { + return createSuffStat(basis_dim); + } +} + +static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0) { + if (model_type == kConstantLeafGaussian) { + return createLeafModel(tau); + } else if (model_type == kUnivariateRegressionLeafGaussian) { + return createLeafModel(tau); + } else { + return createLeafModel(Sigma0); + } +} + +template +static inline void AccumulateSuffStatProposed(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, + ColumnVector& residual, double global_variance, TreeSplit& split, int tree_num, int leaf_num, int split_feature) { + // Acquire iterators + auto node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, leaf_num); + auto node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, leaf_num); + + // Accumulate sufficient statistics + for (auto i = node_begin_iter; i != node_end_iter; i++) { + auto idx = *i; + double feature_value = dataset.CovariateValue(idx, split_feature); + node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); + if (split.SplitTrue(feature_value)) { + left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); + } else { + right_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); + } + } +} + +template +static inline void AccumulateSuffStatExisting(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, + ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id) { + // Acquire iterators + auto left_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, left_node_id); + auto left_node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, left_node_id); + auto right_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, right_node_id); + auto right_node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, right_node_id); + + // Accumulate sufficient statistics for the left and split nodes + for (auto i = left_node_begin_iter; i != left_node_end_iter; i++) { + auto idx = *i; + left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); + node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); + } + + // Accumulate sufficient statistics for the right and split nodes + for (auto i = right_node_begin_iter; i != right_node_end_iter; i++) { + auto idx = *i; + right_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); + node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); + } +} + +template +static inline void AccumulateSingleNodeSuffStat(SuffStatType& node_suff_stat, ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, int tree_num, int node_id) { + // Acquire iterators + std::vector::iterator node_begin_iter; + std::vector::iterator node_end_iter; + if (sorted) { + // Default to the first feature if we're using the presort tracker + node_begin_iter = tracker.SortedNodeBeginIterator(node_id, 0); + node_end_iter = tracker.SortedNodeEndIterator(node_id, 0); + } else { + node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, node_id); + node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, node_id); + } + + // Accumulate sufficient statistics + for (auto i = node_begin_iter; i != node_end_iter; i++) { + auto idx = *i; + node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); + } +} + +template +static inline void AccumulateCutpointBinSuffStat(SuffStatType& left_suff_stat, ForestTracker& tracker, CutpointGridContainer& cutpoint_grid_container, + ForestDataset& dataset, ColumnVector& residual, double global_variance, int tree_num, int node_id, + int feature_num, int cutpoint_num) { + // Acquire iterators + auto node_begin_iter = tracker.SortedNodeBeginIterator(node_id, feature_num); + auto node_end_iter = tracker.SortedNodeEndIterator(node_id, feature_num); + + // Determine node start point + data_size_t node_begin = tracker.SortedNodeBegin(node_id, feature_num); + + // Determine cutpoint bin start and end points + data_size_t current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_num, feature_num); + data_size_t current_bin_size = cutpoint_grid_container.BinLength(cutpoint_num, feature_num); + data_size_t next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_num + 1, feature_num); + + // Cutpoint specific iterators + // TODO: fix the hack of having to subtract off node_begin, probably by cleaning up the CutpointGridContainer interface + auto cutpoint_begin_iter = node_begin_iter + (current_bin_begin - node_begin); + auto cutpoint_end_iter = node_begin_iter + (next_bin_begin - node_begin); + + // Accumulate sufficient statistics + for (auto i = cutpoint_begin_iter; i != cutpoint_end_iter; i++) { + auto idx = *i; + left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); + } +} + } // namespace StochTree #endif // STOCHTREE_LEAF_MODEL_H_ diff --git a/include/stochtree/meta.h b/include/stochtree/meta.h index b77179ec..f078777c 100644 --- a/include/stochtree/meta.h +++ b/include/stochtree/meta.h @@ -41,11 +41,11 @@ enum ForestLeafVarianceType { kFixed }; -enum ForestLeafPriorType { - kConstantLeafGaussian, - kUnivariateRegressionLeafGaussian, - kMultivariateRegressionLeafGaussian -}; +// enum ForestLeafPriorType { +// kConstantLeafGaussian, +// kUnivariateRegressionLeafGaussian, +// kMultivariateRegressionLeafGaussian +// }; enum ForestSampler { kMCMC, diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 302ae6bf..b75b6b00 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -111,7 +112,8 @@ static inline bool NodeNonConstant(ForestDataset& dataset, ForestTracker& tracke return false; } -static inline void AddSplitToModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, TreeSplit& split, std::mt19937& gen, Tree* tree, int tree_num, int leaf_node, int feature_split, bool keep_sorted = false) { +static inline void AddSplitToModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, TreeSplit& split, std::mt19937& gen, Tree* tree, + int tree_num, int leaf_node, int feature_split, bool keep_sorted = false) { // Use zeros as a "temporary" leaf values since we draw leaf parameters after tree sampling is complete if (tree->OutputDimension() > 1) { std::vector temp_leaf_values(tree->OutputDimension(), 0.); @@ -127,7 +129,8 @@ static inline void AddSplitToModel(ForestTracker& tracker, ForestDataset& datase tracker.AddSplit(dataset.GetCovariates(), split, feature_split, tree_num, leaf_node, left_node, right_node, keep_sorted); } -static inline void RemoveSplitFromModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, std::mt19937& gen, Tree* tree, int tree_num, int leaf_node, int left_node, int right_node, bool keep_sorted = false) { +static inline void RemoveSplitFromModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, std::mt19937& gen, Tree* tree, + int tree_num, int leaf_node, int left_node, int right_node, bool keep_sorted = false) { // Use zeros as a "temporary" leaf values since we draw leaf parameters after tree sampling is complete if (tree->OutputDimension() > 1) { std::vector temp_leaf_values(tree->OutputDimension(), 0.); @@ -150,7 +153,8 @@ static inline double ComputeMeanOutcome(ColumnVector& residual) { return total_outcome / static_cast(n); } -static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest, bool requires_basis, std::function op) { +static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest, + bool requires_basis, std::function op) { data_size_t n = dataset.GetCovariates().rows(); double tree_pred = 0.; double pred_value = 0.; @@ -175,7 +179,8 @@ static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestData } } -static inline void UpdateResidualTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, bool requires_basis, std::function op, bool tree_new) { +static inline void UpdateResidualTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, + bool requires_basis, std::function op, bool tree_new) { data_size_t n = dataset.GetCovariates().rows(); double pred_value; int32_t leaf_pred; @@ -233,19 +238,168 @@ static inline void UpdateResidualNewBasis(ForestTracker& tracker, ForestDataset& } } -template +template +static inline std::tuple EvaluateProposedSplit( + ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model, + TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance, + LeafSuffStatConstructorArgs&... leaf_suff_stat_args +) { + // Initialize sufficient statistics + LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + + // Accumulate sufficient statistics + AccumulateSuffStatProposed(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, + residual, global_variance, split, tree_num, leaf_num, split_feature); + data_size_t left_n = left_suff_stat.n; + data_size_t right_n = right_suff_stat.n; + + // Evaluate split + double split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); + double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); + + return std::tuple(split_log_ml, no_split_log_ml, left_n, right_n); +} + +template +static inline std::tuple EvaluateExistingSplit( + ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model, + double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id, + LeafSuffStatConstructorArgs&... leaf_suff_stat_args +) { + // Initialize sufficient statistics + LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + + // Accumulate sufficient statistics + AccumulateSuffStatExisting(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, + residual, global_variance, tree_num, split_node_id, left_node_id, right_node_id); + data_size_t left_n = left_suff_stat.n; + data_size_t right_n = right_suff_stat.n; + + // Evaluate split + double split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); + double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); + + return std::tuple(split_log_ml, no_split_log_ml, left_n, right_n); +} + +template +static inline void EvaluateAllPossibleSplits( + ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, LeafModel& leaf_model, double global_variance, int tree_num, int split_node_id, + std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, std::vector& cutpoint_feature_types, + data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, + std::vector& feature_types, LeafSuffStatConstructorArgs&... leaf_suff_stat_args +) { + // Initialize sufficient statistics + LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + + // Accumulate aggregate sufficient statistic for the node to be split + AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, split_node_id); + + // Compute the "no split" log marginal likelihood + double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); + + // Unpack data + Eigen::MatrixXd covariates = dataset.GetCovariates(); + Eigen::VectorXd outcome = residual.GetData(); + Eigen::VectorXd var_weights; + bool has_weights = dataset.HasVarWeights(); + if (has_weights) var_weights = dataset.GetVarWeights(); + + // Minimum size of newly created leaf nodes (used to rule out invalid splits) + int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf(); + + // Compute sufficient statistics for each possible split + data_size_t num_cutpoints = 0; + bool valid_split = false; + data_size_t node_row_iter; + data_size_t current_bin_begin, current_bin_size, next_bin_begin; + data_size_t feature_sort_idx; + data_size_t row_iter_idx; + double outcome_val, outcome_val_sq; + FeatureType feature_type; + double feature_value = 0.0; + double cutoff_value = 0.0; + double log_split_eval = 0.0; + double split_log_ml; + for (int j = 0; j < covariates.cols(); j++) { + + if (std::abs(variable_weights.at(j)) > kEpsilon) { + // Enumerate cutpoint strides + cutpoint_grid_container.CalculateStrides(covariates, outcome, tracker.GetSortedNodeSampleTracker(), split_node_id, node_begin, node_end, j, feature_types); + + // Reset sufficient statistics + left_suff_stat.ResetSuffStat(); + right_suff_stat.ResetSuffStat(); + + // Iterate through possible cutpoints + int32_t num_feature_cutpoints = cutpoint_grid_container.NumCutpoints(j); + feature_type = feature_types[j]; + // Since we partition an entire cutpoint bin to the left, we must stop one bin before the total number of cutpoint bins + for (data_size_t cutpoint_idx = 0; cutpoint_idx < (num_feature_cutpoints - 1); cutpoint_idx++) { + current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx, j); + current_bin_size = cutpoint_grid_container.BinLength(cutpoint_idx, j); + next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx + 1, j); + + // Accumulate sufficient statistics for the left node + AccumulateCutpointBinSuffStat(left_suff_stat, tracker, cutpoint_grid_container, dataset, residual, + global_variance, tree_num, split_node_id, j, cutpoint_idx); + + // Compute the corresponding right node sufficient statistics + right_suff_stat.SubtractSuffStat(node_suff_stat, left_suff_stat); + + // Store the bin index as the "cutpoint value" - we can use this to query the actual split + // value or the set of split categories later on once a split is chose + cutoff_value = cutpoint_idx; + + // Only include cutpoint for consideration if it defines a valid split in the training data + valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) && + right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf)); + if (valid_split) { + num_cutpoints++; + // Add to split rule vector + cutpoint_feature_types.push_back(feature_type); + cutpoint_features.push_back(j); + cutpoint_values.push_back(cutoff_value); + // Add the log marginal likelihood of the split to the split eval vector + split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); + log_cutpoint_evaluations.push_back(split_log_ml); + } + } + } + + } + + // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper) + cutpoint_features.push_back(-1); + cutpoint_values.push_back(std::numeric_limits::max()); + cutpoint_feature_types.push_back(FeatureType::kNumeric); + log_cutpoint_evaluations.push_back(no_split_log_ml); + + // Update valid cutpoint count + valid_cutpoint_count = num_cutpoints; +} + +template static inline void EvaluateCutpoints(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, int node_id, data_size_t node_begin, data_size_t node_end, std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector& variable_weights, - std::vector& feature_types, CutpointGridContainer& cutpoint_grid_container) { + std::vector& feature_types, CutpointGridContainer& cutpoint_grid_container, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Evaluate all possible cutpoints according to the leaf node model, // recording their log-likelihood and other split information in a series of vectors. // The last element of these vectors concerns the "no-split" option. - leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, tree_num, node_id, log_cutpoint_evaluations, - cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, - cutpoint_grid_container, node_begin, node_end, variable_weights, feature_types); - + EvaluateAllPossibleSplits( + dataset, tracker, residual, tree_prior, leaf_model, global_variance, tree_num, node_id, log_cutpoint_evaluations, + cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, + node_begin, node_end, variable_weights, feature_types, leaf_suff_stat_args... + ); + // Compute an adjustment to reflect the no split prior probability and the number of cutpoints double bart_prior_no_split_adj; double alpha = tree_prior.GetAlpha(); @@ -259,12 +413,12 @@ static inline void EvaluateCutpoints(Tree* tree, ForestTracker& tracker, LeafMod log_cutpoint_evaluations[log_cutpoint_evaluations.size()-1] += bart_prior_no_split_adj; } -template +template static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, std::unordered_map>& node_index_map, std::deque& split_queue, int node_id, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, - std::vector& feature_types) { + std::vector& feature_types, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Leaf depth int leaf_depth = tree->GetDepth(node_id); @@ -280,10 +434,12 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel std::vector cutpoint_feature_types; StochTree::data_size_t valid_cutpoint_count; CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); - EvaluateCutpoints(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, - cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features, - cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types, - cutpoint_grid_container); + EvaluateCutpoints( + tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, + cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features, + cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types, + cutpoint_grid_container, leaf_suff_stat_args... + ); // TODO: maybe add some checks here? // Convert log marginal likelihood to marginal likelihood, normalizing by the maximum log-likelihood @@ -361,10 +517,11 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel } } -template +template static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - int tree_num, double global_variance, std::vector& feature_types, int cutpoint_grid_size) { + int tree_num, double global_variance, std::vector& feature_types, int cutpoint_grid_size, + LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { int root_id = Tree::kRoot; int curr_node_id; data_size_t curr_node_begin; @@ -387,16 +544,18 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore curr_node_begin = begin_end.first; curr_node_end = begin_end.second; // Draw a split rule at random - SampleSplitRule(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size, - node_index_map, split_queue, curr_node_id, curr_node_begin, curr_node_end, variable_weights, feature_types); + SampleSplitRule( + tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size, + node_index_map, split_queue, curr_node_id, curr_node_begin, curr_node_end, variable_weights, feature_types, + leaf_suff_stat_args...); } } -template +template static inline void GFRSampleOneIter(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - double global_variance, std::vector& feature_types, int cutpoint_grid_size = 500, - bool pre_initialized = false) { + double global_variance, std::vector& feature_types, int cutpoint_grid_size, + bool pre_initialized, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Previous number of samples int prev_num_samples = forests.NumSamples(); @@ -435,7 +594,11 @@ static inline void GFRSampleOneIter(ForestTracker& tracker, ForestContainer& for tree = ensemble->GetTree(i); // Sample tree i - GFRSampleTreeOneIter(tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, global_variance, feature_types, cutpoint_grid_size); + GFRSampleTreeOneIter( + tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, + variable_weights, i, global_variance, feature_types, cutpoint_grid_size, + leaf_suff_stat_args... + ); // Sample leaf parameters for tree i tree = ensemble->GetTree(i); @@ -446,10 +609,10 @@ static inline void GFRSampleOneIter(ForestTracker& tracker, ForestContainer& for } } -template +template static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, int tree_num, std::vector& variable_weights, - double global_variance, double prob_grow_old) { + double global_variance, double prob_grow_old, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Extract dataset information data_size_t n = dataset.GetCovariates().rows(); @@ -495,7 +658,9 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM TreeSplit split = TreeSplit(split_point_chosen); // Compute the marginal likelihood of split and no split, given the leaf prior - std::tuple split_eval = leaf_model.EvaluateProposedSplit(dataset, tracker, residual, split, tree_num, leaf_chosen, var_chosen, global_variance); + std::tuple split_eval = EvaluateProposedSplit( + dataset, tracker, residual, leaf_model, split, tree_num, leaf_chosen, var_chosen, global_variance, leaf_suff_stat_args... + ); double split_log_marginal_likelihood = std::get<0>(split_eval); double no_split_log_marginal_likelihood = std::get<1>(split_eval); int32_t left_n = std::get<2>(split_eval); @@ -546,9 +711,9 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM } } -template +template static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, - TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance) { + TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Choose a "leaf parent" node at random int num_leaves = tree->NumLeaves(); int num_leaf_parents = tree->NumLeafParents(); @@ -563,7 +728,9 @@ static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, Leaf int feature_split = tree->SplitIndex(leaf_parent_chosen); // Compute the marginal likelihood for the leaf parent and its left and right nodes - std::tuple split_eval = leaf_model.EvaluateExistingSplit(dataset, tracker, residual, global_variance, tree_num, leaf_parent_chosen, left_node, right_node); + std::tuple split_eval = EvaluateExistingSplit( + dataset, tracker, residual, leaf_model, global_variance, tree_num, leaf_parent_chosen, left_node, right_node, leaf_suff_stat_args... + ); double split_log_marginal_likelihood = std::get<0>(split_eval); double no_split_log_marginal_likelihood = std::get<1>(split_eval); int32_t left_n = std::get<2>(split_eval); @@ -622,10 +789,10 @@ static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, Leaf } } -template +template static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - int tree_num, double global_variance) { + int tree_num, double global_variance, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Determine whether it is possible to grow any of the leaves bool grow_possible = false; std::vector leaves = tree->GetLeaves(); @@ -664,16 +831,20 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For bool accept; if (step_chosen == 0) { - MCMCGrowTreeOneIter(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow); + MCMCGrowTreeOneIter( + tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow, leaf_suff_stat_args... + ); } else { - MCMCPruneTreeOneIter(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance); + MCMCPruneTreeOneIter( + tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, leaf_suff_stat_args... + ); } } -template +template static inline void MCMCSampleOneIter(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - double global_variance, bool pre_initialized = false) { + double global_variance, bool pre_initialized, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Previous number of samples int prev_num_samples = forests.NumSamples(); @@ -706,7 +877,10 @@ static inline void MCMCSampleOneIter(ForestTracker& tracker, ForestContainer& fo // Sample tree i tree = ensemble->GetTree(i); - MCMCSampleTreeOneIter(tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, global_variance); + MCMCSampleTreeOneIter( + tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, + global_variance, leaf_suff_stat_args... + ); // Sample leaf parameters for tree i tree = ensemble->GetTree(i); diff --git a/src/leaf_model.cpp b/src/leaf_model.cpp index 797e3758..a3ae9b38 100644 --- a/src/leaf_model.cpp +++ b/src/leaf_model.cpp @@ -2,234 +2,6 @@ namespace StochTree { -template -void AccumulateSuffStatProposed(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, - ColumnVector& residual, double global_variance, TreeSplit& split, int tree_num, int leaf_num, int split_feature) { - // Acquire iterators - auto node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, leaf_num); - auto node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, leaf_num); - - // Accumulate sufficient statistics - for (auto i = node_begin_iter; i != node_end_iter; i++) { - auto idx = *i; - double feature_value = dataset.CovariateValue(idx, split_feature); - node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); - if (split.SplitTrue(feature_value)) { - left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); - } else { - right_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); - } - } -} - -template -void AccumulateSuffStatExisting(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, - ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id) { - // Acquire iterators - auto left_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, left_node_id); - auto left_node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, left_node_id); - auto right_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, right_node_id); - auto right_node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, right_node_id); - - // Accumulate sufficient statistics for the left and split nodes - for (auto i = left_node_begin_iter; i != left_node_end_iter; i++) { - auto idx = *i; - left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); - node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); - } - - // Accumulate sufficient statistics for the right and split nodes - for (auto i = right_node_begin_iter; i != right_node_end_iter; i++) { - auto idx = *i; - right_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); - node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); - } -} - -template -void AccumulateSingleNodeSuffStat(SuffStatType& node_suff_stat, ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, int tree_num, int node_id) { - // Acquire iterators - std::vector::iterator node_begin_iter; - std::vector::iterator node_end_iter; - if (sorted) { - // Default to the first feature if we're using the presort tracker - node_begin_iter = tracker.SortedNodeBeginIterator(node_id, 0); - node_end_iter = tracker.SortedNodeEndIterator(node_id, 0); - } else { - node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, node_id); - node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, node_id); - } - - // Accumulate sufficient statistics - for (auto i = node_begin_iter; i != node_end_iter; i++) { - auto idx = *i; - node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); - } -} - -template -void AccumulateCutpointBinSuffStat(SuffStatType& left_suff_stat, ForestTracker& tracker, CutpointGridContainer& cutpoint_grid_container, - ForestDataset& dataset, ColumnVector& residual, double global_variance, int tree_num, int node_id, - int feature_num, int cutpoint_num) { - // Acquire iterators - auto node_begin_iter = tracker.SortedNodeBeginIterator(node_id, feature_num); - auto node_end_iter = tracker.SortedNodeEndIterator(node_id, feature_num); - - // Determine node start point - data_size_t node_begin = tracker.SortedNodeBegin(node_id, feature_num); - - // Determine cutpoint bin start and end points - data_size_t current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_num, feature_num); - data_size_t current_bin_size = cutpoint_grid_container.BinLength(cutpoint_num, feature_num); - data_size_t next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_num + 1, feature_num); - - // Cutpoint specific iterators - // TODO: fix the hack of having to subtract off node_begin, probably by cleaning up the CutpointGridContainer interface - auto cutpoint_begin_iter = node_begin_iter + (current_bin_begin - node_begin); - auto cutpoint_end_iter = node_begin_iter + (next_bin_begin - node_begin); - - // Accumulate sufficient statistics - for (auto i = cutpoint_begin_iter; i != cutpoint_end_iter; i++) { - auto idx = *i; - left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), idx); - } -} - -std::tuple GaussianConstantLeafModel::EvaluateProposedSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, - TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance) { - // Initialize sufficient statistics - GaussianConstantSuffStat node_suff_stat = GaussianConstantSuffStat(); - GaussianConstantSuffStat left_suff_stat = GaussianConstantSuffStat(); - GaussianConstantSuffStat right_suff_stat = GaussianConstantSuffStat(); - - // Accumulate sufficient statistics - AccumulateSuffStatProposed(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, - residual, global_variance, split, tree_num, leaf_num, split_feature); - data_size_t left_n = left_suff_stat.n; - data_size_t right_n = right_suff_stat.n; - - // Evaluate split - double split_log_ml = SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); - double no_split_log_ml = NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); - - return std::tuple(split_log_ml, no_split_log_ml, left_n, right_n); -} - -std::tuple GaussianConstantLeafModel::EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, - int tree_num, int split_node_id, int left_node_id, int right_node_id) { - // Initialize sufficient statistics - GaussianConstantSuffStat node_suff_stat = GaussianConstantSuffStat(); - GaussianConstantSuffStat left_suff_stat = GaussianConstantSuffStat(); - GaussianConstantSuffStat right_suff_stat = GaussianConstantSuffStat(); - - // Accumulate sufficient statistics - AccumulateSuffStatExisting(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, - residual, global_variance, tree_num, split_node_id, left_node_id, right_node_id); - data_size_t left_n = left_suff_stat.n; - data_size_t right_n = right_suff_stat.n; - - // Evaluate split - double split_log_ml = SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); - double no_split_log_ml = NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); - - return std::tuple(split_log_ml, no_split_log_ml, left_n, right_n); -} - -void GaussianConstantLeafModel::EvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int node_id, std::vector& log_cutpoint_evaluations, - std::vector& cutpoint_features, std::vector& cutpoint_values, std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, - CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, std::vector& feature_types) { - // Initialize sufficient statistics - GaussianConstantSuffStat node_suff_stat = GaussianConstantSuffStat(); - GaussianConstantSuffStat left_suff_stat = GaussianConstantSuffStat(); - GaussianConstantSuffStat right_suff_stat = GaussianConstantSuffStat(); - - // Accumulate aggregate sufficient statistic for the node to be split - AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, node_id); - - // Compute the "no split" log marginal likelihood - double no_split_log_ml = NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); - - // Unpack data - Eigen::MatrixXd covariates = dataset.GetCovariates(); - Eigen::VectorXd outcome = residual.GetData(); - Eigen::VectorXd var_weights; - bool has_weights = dataset.HasVarWeights(); - if (has_weights) var_weights = dataset.GetVarWeights(); - - // Minimum size of newly created leaf nodes (used to rule out invalid splits) - int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf(); - - // Compute sufficient statistics for each possible split - data_size_t num_cutpoints = 0; - bool valid_split = false; - data_size_t node_row_iter; - data_size_t current_bin_begin, current_bin_size, next_bin_begin; - data_size_t feature_sort_idx; - data_size_t row_iter_idx; - double outcome_val, outcome_val_sq; - FeatureType feature_type; - double feature_value = 0.0; - double cutoff_value = 0.0; - double log_split_eval = 0.0; - double split_log_ml; - for (int j = 0; j < covariates.cols(); j++) { - - if (std::abs(variable_weights.at(j)) > kEpsilon) { - // Enumerate cutpoint strides - cutpoint_grid_container.CalculateStrides(covariates, outcome, tracker.GetSortedNodeSampleTracker(), node_id, node_begin, node_end, j, feature_types); - - // Reset sufficient statistics - left_suff_stat.ResetSuffStat(); - right_suff_stat.ResetSuffStat(); - - // Iterate through possible cutpoints - int32_t num_feature_cutpoints = cutpoint_grid_container.NumCutpoints(j); - feature_type = feature_types[j]; - // Since we partition an entire cutpoint bin to the left, we must stop one bin before the total number of cutpoint bins - for (data_size_t cutpoint_idx = 0; cutpoint_idx < (num_feature_cutpoints - 1); cutpoint_idx++) { - current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx, j); - current_bin_size = cutpoint_grid_container.BinLength(cutpoint_idx, j); - next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx + 1, j); - - // Accumulate sufficient statistics for the left node - AccumulateCutpointBinSuffStat(left_suff_stat, tracker, cutpoint_grid_container, dataset, residual, - global_variance, tree_num, node_id, j, cutpoint_idx); - - // Compute the corresponding right node sufficient statistics - right_suff_stat.SubtractSuffStat(node_suff_stat, left_suff_stat); - - // Store the bin index as the "cutpoint value" - we can use this to query the actual split - // value or the set of split categories later on once a split is chose - cutoff_value = cutpoint_idx; - - // Only include cutpoint for consideration if it defines a valid split in the training data - valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) && - right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf)); - if (valid_split) { - num_cutpoints++; - // Add to split rule vector - cutpoint_feature_types.push_back(feature_type); - cutpoint_features.push_back(j); - cutpoint_values.push_back(cutoff_value); - // Add the log marginal likelihood of the split to the split eval vector - split_log_ml = SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); - log_cutpoint_evaluations.push_back(split_log_ml); - } - } - } - - } - - // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper) - cutpoint_features.push_back(-1); - cutpoint_values.push_back(std::numeric_limits::max()); - cutpoint_feature_types.push_back(FeatureType::kNumeric); - log_cutpoint_evaluations.push_back(no_split_log_ml); - - // Update valid cutpoint count - valid_cutpoint_count = num_cutpoints; -} - double GaussianConstantLeafModel::SplitLogMarginalLikelihood(GaussianConstantSuffStat& left_stat, GaussianConstantSuffStat& right_stat, double global_variance) { double left_log_ml = ( -0.5*std::log(1 + tau_*(left_stat.sum_w/global_variance)) + ((tau_*left_stat.sum_yw*left_stat.sum_yw)/(2.0*global_variance*(tau_*left_stat.sum_w + global_variance))) @@ -294,141 +66,6 @@ void GaussianConstantLeafModel::SetEnsembleRootPredictedValue(ForestDataset& dat } } -std::tuple GaussianUnivariateRegressionLeafModel::EvaluateProposedSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, - TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance) { - // Initialize sufficient statistics - GaussianUnivariateRegressionSuffStat node_suff_stat = GaussianUnivariateRegressionSuffStat(); - GaussianUnivariateRegressionSuffStat left_suff_stat = GaussianUnivariateRegressionSuffStat(); - GaussianUnivariateRegressionSuffStat right_suff_stat = GaussianUnivariateRegressionSuffStat(); - - // Accumulate sufficient statistics - AccumulateSuffStatProposed(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, - residual, global_variance, split, tree_num, leaf_num, split_feature); - data_size_t left_n = left_suff_stat.n; - data_size_t right_n = right_suff_stat.n; - - // Evaluate split - double split_log_ml = SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); - double no_split_log_ml = NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); - - return std::tuple(split_log_ml, no_split_log_ml, left_n, right_n); -} - -std::tuple GaussianUnivariateRegressionLeafModel::EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, - int tree_num, int split_node_id, int left_node_id, int right_node_id) { - // Initialize sufficient statistics - GaussianUnivariateRegressionSuffStat node_suff_stat = GaussianUnivariateRegressionSuffStat(); - GaussianUnivariateRegressionSuffStat left_suff_stat = GaussianUnivariateRegressionSuffStat(); - GaussianUnivariateRegressionSuffStat right_suff_stat = GaussianUnivariateRegressionSuffStat(); - - // Accumulate sufficient statistics - AccumulateSuffStatExisting(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, - residual, global_variance, tree_num, split_node_id, left_node_id, right_node_id); - data_size_t left_n = left_suff_stat.n; - data_size_t right_n = right_suff_stat.n; - - // Evaluate split - double split_log_ml = SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); - double no_split_log_ml = NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); - - return std::tuple(split_log_ml, no_split_log_ml, left_n, right_n); -} - -void GaussianUnivariateRegressionLeafModel::EvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int node_id, std::vector& log_cutpoint_evaluations, - std::vector& cutpoint_features, std::vector& cutpoint_values, std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, - CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, std::vector& feature_types) { - // Initialize sufficient statistics - GaussianUnivariateRegressionSuffStat node_suff_stat = GaussianUnivariateRegressionSuffStat(); - GaussianUnivariateRegressionSuffStat left_suff_stat = GaussianUnivariateRegressionSuffStat(); - GaussianUnivariateRegressionSuffStat right_suff_stat = GaussianUnivariateRegressionSuffStat(); - - // Accumulate aggregate sufficient statistic for the node to be split - AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, node_id); - - // Compute the "no split" log marginal likelihood - double no_split_log_ml = NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); - - // Unpack data - Eigen::MatrixXd covariates = dataset.GetCovariates(); - Eigen::VectorXd outcome = residual.GetData(); - Eigen::VectorXd var_weights; - bool has_weights = dataset.HasVarWeights(); - if (has_weights) var_weights = dataset.GetVarWeights(); - - // Minimum size of newly created leaf nodes (used to rule out invalid splits) - int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf(); - - // Compute sufficient statistics for each possible split - data_size_t num_cutpoints = 0; - bool valid_split = false; - data_size_t node_row_iter; - data_size_t current_bin_begin, current_bin_size, next_bin_begin; - data_size_t feature_sort_idx; - data_size_t row_iter_idx; - double outcome_val, outcome_val_sq; - FeatureType feature_type; - double feature_value = 0.0; - double cutoff_value = 0.0; - double log_split_eval = 0.0; - double split_log_ml; - for (int j = 0; j < covariates.cols(); j++) { - - if (std::abs(variable_weights.at(j)) > kEpsilon) { - // Enumerate cutpoint strides - cutpoint_grid_container.CalculateStrides(covariates, outcome, tracker.GetSortedNodeSampleTracker(), node_id, node_begin, node_end, j, feature_types); - - // Reset sufficient statistics - left_suff_stat.ResetSuffStat(); - right_suff_stat.ResetSuffStat(); - - // Iterate through possible cutpoints - int32_t num_feature_cutpoints = cutpoint_grid_container.NumCutpoints(j); - feature_type = feature_types[j]; - // Since we partition an entire cutpoint bin to the left, we must stop one bin before the total number of cutpoint bins - for (data_size_t cutpoint_idx = 0; cutpoint_idx < (num_feature_cutpoints - 1); cutpoint_idx++) { - current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx, j); - current_bin_size = cutpoint_grid_container.BinLength(cutpoint_idx, j); - next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx + 1, j); - - // Accumulate sufficient statistics for the left node - AccumulateCutpointBinSuffStat(left_suff_stat, tracker, cutpoint_grid_container, dataset, residual, - global_variance, tree_num, node_id, j, cutpoint_idx); - - // Compute the corresponding right node sufficient statistics - right_suff_stat.SubtractSuffStat(node_suff_stat, left_suff_stat); - - // Store the bin index as the "cutpoint value" - we can use this to query the actual split - // value or the set of split categories later on once a split is chose - cutoff_value = cutpoint_idx; - - // Only include cutpoint for consideration if it defines a valid split in the training data - valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) && - right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf)); - if (valid_split) { - num_cutpoints++; - // Add to split rule vector - cutpoint_feature_types.push_back(feature_type); - cutpoint_features.push_back(j); - cutpoint_values.push_back(cutoff_value); - // Add the log marginal likelihood of the split to the split eval vector - split_log_ml = SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); - log_cutpoint_evaluations.push_back(split_log_ml); - } - } - } - - } - - // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper) - cutpoint_features.push_back(-1); - cutpoint_values.push_back(std::numeric_limits::max()); - cutpoint_feature_types.push_back(FeatureType::kNumeric); - log_cutpoint_evaluations.push_back(no_split_log_ml); - - // Update valid cutpoint count - valid_cutpoint_count = num_cutpoints; -} - double GaussianUnivariateRegressionLeafModel::SplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& left_stat, GaussianUnivariateRegressionSuffStat& right_stat, double global_variance) { double left_log_ml = ( -0.5*std::log(1 + tau_*(left_stat.sum_xxw/global_variance)) + ((tau_*left_stat.sum_yxw*left_stat.sum_yxw)/(2.0*global_variance*(tau_*left_stat.sum_xxw + global_variance))) @@ -493,144 +130,6 @@ void GaussianUnivariateRegressionLeafModel::SetEnsembleRootPredictedValue(Forest } } -std::tuple GaussianMultivariateRegressionLeafModel::EvaluateProposedSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, - TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance) { - // Initialize sufficient statistics - int num_basis = dataset.GetBasis().cols(); - GaussianMultivariateRegressionSuffStat node_suff_stat = GaussianMultivariateRegressionSuffStat(num_basis); - GaussianMultivariateRegressionSuffStat left_suff_stat = GaussianMultivariateRegressionSuffStat(num_basis); - GaussianMultivariateRegressionSuffStat right_suff_stat = GaussianMultivariateRegressionSuffStat(num_basis); - - // Accumulate sufficient statistics - AccumulateSuffStatProposed(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, - residual, global_variance, split, tree_num, leaf_num, split_feature); - data_size_t left_n = left_suff_stat.n; - data_size_t right_n = right_suff_stat.n; - - // Evaluate split - double split_log_ml = SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); - double no_split_log_ml = NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); - - return std::tuple(split_log_ml, no_split_log_ml, left_n, right_n); -} - -std::tuple GaussianMultivariateRegressionLeafModel::EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, - int tree_num, int split_node_id, int left_node_id, int right_node_id) { - // Initialize sufficient statistics - int num_basis = dataset.GetBasis().cols(); - GaussianMultivariateRegressionSuffStat node_suff_stat = GaussianMultivariateRegressionSuffStat(num_basis); - GaussianMultivariateRegressionSuffStat left_suff_stat = GaussianMultivariateRegressionSuffStat(num_basis); - GaussianMultivariateRegressionSuffStat right_suff_stat = GaussianMultivariateRegressionSuffStat(num_basis); - - // Accumulate sufficient statistics - AccumulateSuffStatExisting(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, - residual, global_variance, tree_num, split_node_id, left_node_id, right_node_id); - data_size_t left_n = left_suff_stat.n; - data_size_t right_n = right_suff_stat.n; - - // Evaluate split - double split_log_ml = SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); - double no_split_log_ml = NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); - - return std::tuple(split_log_ml, no_split_log_ml, left_n, right_n); -} - -void GaussianMultivariateRegressionLeafModel::EvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int node_id, std::vector& log_cutpoint_evaluations, - std::vector& cutpoint_features, std::vector& cutpoint_values, std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, - CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, std::vector& feature_types) { - // Initialize sufficient statistics - int basis_dim = dataset.GetBasis().cols(); - GaussianMultivariateRegressionSuffStat node_suff_stat = GaussianMultivariateRegressionSuffStat(basis_dim); - GaussianMultivariateRegressionSuffStat left_suff_stat = GaussianMultivariateRegressionSuffStat(basis_dim); - GaussianMultivariateRegressionSuffStat right_suff_stat = GaussianMultivariateRegressionSuffStat(basis_dim); - - // Accumulate aggregate sufficient statistic for the node to be split - AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, node_id); - - // Compute the "no split" log marginal likelihood - double no_split_log_ml = NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); - - // Unpack data - Eigen::MatrixXd covariates = dataset.GetCovariates(); - Eigen::VectorXd outcome = residual.GetData(); - Eigen::VectorXd var_weights; - bool has_weights = dataset.HasVarWeights(); - if (has_weights) var_weights = dataset.GetVarWeights(); - - // Minimum size of newly created leaf nodes (used to rule out invalid splits) - int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf(); - - // Compute sufficient statistics for each possible split - data_size_t num_cutpoints = 0; - bool valid_split = false; - data_size_t node_row_iter; - data_size_t current_bin_begin, current_bin_size, next_bin_begin; - data_size_t feature_sort_idx; - data_size_t row_iter_idx; - double outcome_val, outcome_val_sq; - FeatureType feature_type; - double feature_value = 0.0; - double cutoff_value = 0.0; - double log_split_eval = 0.0; - double split_log_ml; - for (int j = 0; j < covariates.cols(); j++) { - - if (std::abs(variable_weights.at(j)) > kEpsilon) { - // Enumerate cutpoint strides - cutpoint_grid_container.CalculateStrides(covariates, outcome, tracker.GetSortedNodeSampleTracker(), node_id, node_begin, node_end, j, feature_types); - - // Reset sufficient statistics - left_suff_stat.ResetSuffStat(); - right_suff_stat.ResetSuffStat(); - - // Iterate through possible cutpoints - int32_t num_feature_cutpoints = cutpoint_grid_container.NumCutpoints(j); - feature_type = feature_types[j]; - // Since we partition an entire cutpoint bin to the left, we must stop one bin before the total number of cutpoint bins - for (data_size_t cutpoint_idx = 0; cutpoint_idx < (num_feature_cutpoints - 1); cutpoint_idx++) { - current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx, j); - current_bin_size = cutpoint_grid_container.BinLength(cutpoint_idx, j); - next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx + 1, j); - - // Accumulate sufficient statistics for the left node - AccumulateCutpointBinSuffStat(left_suff_stat, tracker, cutpoint_grid_container, dataset, residual, - global_variance, tree_num, node_id, j, cutpoint_idx); - - // Compute the corresponding right node sufficient statistics - right_suff_stat.SubtractSuffStat(node_suff_stat, left_suff_stat); - - // Store the bin index as the "cutpoint value" - we can use this to query the actual split - // value or the set of split categories later on once a split is chose - cutoff_value = cutpoint_idx; - - // Only include cutpoint for consideration if it defines a valid split in the training data - valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) && - right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf)); - if (valid_split) { - num_cutpoints++; - // Add to split rule vector - cutpoint_feature_types.push_back(feature_type); - cutpoint_features.push_back(j); - cutpoint_values.push_back(cutoff_value); - // Add the log marginal likelihood of the split to the split eval vector - split_log_ml = SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); - log_cutpoint_evaluations.push_back(split_log_ml); - } - } - } - - } - - // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper) - cutpoint_features.push_back(-1); - cutpoint_values.push_back(std::numeric_limits::max()); - cutpoint_feature_types.push_back(FeatureType::kNumeric); - log_cutpoint_evaluations.push_back(no_split_log_ml); - - // Update valid cutpoint count - valid_cutpoint_count = num_cutpoints; -} - double GaussianMultivariateRegressionLeafModel::SplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& left_stat, GaussianMultivariateRegressionSuffStat& right_stat, double global_variance) { Eigen::MatrixXd I_p = Eigen::MatrixXd::Identity(left_stat.p, left_stat.p); double left_log_ml = ( diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 7025d8a9..207d1ce0 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -512,13 +512,13 @@ class ForestSamplerCpp { Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size, bool pre_initialized) { if (leaf_model_enum == ForestLeafModel::kConstant) { StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(leaf_scale); - StochTree::GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); + StochTree::GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); - StochTree::GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); + StochTree::GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - StochTree::GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); + StochTree::GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); } } @@ -527,13 +527,13 @@ class ForestSamplerCpp { Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size, bool pre_initialized) { if (leaf_model_enum == ForestLeafModel::kConstant) { StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(leaf_scale); - StochTree::MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); + StochTree::MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); - StochTree::MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); + StochTree::MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - StochTree::MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); + StochTree::MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); } } }; diff --git a/src/sampler.cpp b/src/sampler.cpp index 1229d6f0..7631bdd1 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -61,13 +61,13 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); + StochTree::GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); - StochTree::GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); + StochTree::GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - StochTree::GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); + StochTree::GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); } } @@ -122,13 +122,13 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + StochTree::MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); - StochTree::MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + StochTree::MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - StochTree::MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + StochTree::MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); } } From 894deb2a09c6ee25e056eccd3843c4dec3d71d5c Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 27 Aug 2024 01:31:57 -0500 Subject: [PATCH 12/41] Refactored R package C++ calls --- debug/api_debug.cpp | 106 ++++++++++++++++++++++++++++--- src/sampler.cpp | 79 +++++++++++++---------- tools/perf/bart_microbenchmark.R | 29 +++++++++ 3 files changed, 172 insertions(+), 42 deletions(-) create mode 100644 tools/perf/bart_microbenchmark.R diff --git a/debug/api_debug.cpp b/debug/api_debug.cpp index 5d84c00b..5f19af57 100644 --- a/debug/api_debug.cpp +++ b/debug/api_debug.cpp @@ -239,6 +239,86 @@ void GenerateDGP2(std::vector& covariates, std::vector& basis, s } } +void GenerateDGP3(std::vector& covariates, std::vector& basis, std::vector& outcome, std::vector& rfx_basis, std::vector& rfx_groups, std::vector& feature_types, std::mt19937& gen, int& n, int& x_cols, int& omega_cols, int& y_cols, int& rfx_basis_cols, int& num_rfx_groups, bool rfx_included, int random_seed = -1) { + // Data dimensions + n = 1000; + x_cols = 2; + omega_cols = 2; + y_cols = 1; + if (rfx_included) { + num_rfx_groups = 2; + rfx_basis_cols = 1; + } else { + num_rfx_groups = 0; + rfx_basis_cols = 0; + } + + // Resize data + covariates.resize(n * x_cols); + basis.resize(n * omega_cols); + rfx_basis.resize(n * rfx_basis_cols); + outcome.resize(n * y_cols); + rfx_groups.resize(n); + feature_types.resize(x_cols, FeatureType::kNumeric); + + // Random number generation + std::uniform_real_distribution uniform_dist{0.0,1.0}; + std::normal_distribution normal_dist(0.,1.); + + // DGP parameters + std::vector betas{-10, -5, 5, 10}; + int num_partitions = betas.size(); + double f_x_omega; + double rfx; + double error; + + for (int i = 0; i < n; i++) { + for (int j = 0; j < x_cols; j++) { + covariates[i*x_cols + j] = uniform_dist(gen); + } + + for (int j = 0; j < omega_cols; j++) { + basis[i*omega_cols + j] = uniform_dist(gen); + } + + if (rfx_included) { + for (int j = 0; j < rfx_basis_cols; j++) { + rfx_basis[i * rfx_basis_cols + j] = 1; + } + + if (i % 2 == 0) { + rfx_groups[i] = 1; + } + else { + rfx_groups[i] = 2; + } + } + + for (int j = 0; j < y_cols; j++) { + if ((covariates[i * x_cols + 0] >= 0.0) && covariates[i * x_cols + 0] < 0.25) { + f_x_omega = betas[0] * basis[i * omega_cols + 0]; + } else if ((covariates[i * x_cols + 0] >= 0.25) && covariates[i * x_cols + 0] < 0.5) { + f_x_omega = betas[1] * basis[i * omega_cols + 0]; + } else if ((covariates[i * x_cols + 0] >= 0.5) && covariates[i * x_cols + 0] < 0.75) { + f_x_omega = betas[2] * basis[i * omega_cols + 0]; + } else { + f_x_omega = betas[3] * basis[i * omega_cols + 0]; + } + error = 0.1 * normal_dist(gen); + outcome[i * y_cols + j] = f_x_omega + error; + if (rfx_included) { + if (rfx_groups[i] == 1) { + rfx = 5.; + } + else { + rfx = -5.; + } + outcome[i * y_cols + j] += rfx; + } + } + } +} + void OutcomeOffsetScale(ColumnVector& residual, double& outcome_offset, double& outcome_scale) { data_size_t n = residual.NumRows(); double outcome_val = 0.0; @@ -335,14 +415,18 @@ void RunDebug(int dgp_num = 0, ModelType model_type = kConstantLeafGaussian, boo dataset.AddBasis(basis_raw.data(), n, omega_cols, row_major); output_dimension = 1; is_leaf_constant = false; - } - else if (dgp_num == 1) { + } else if (dgp_num == 1) { GenerateDGP2(covariates_raw, basis_raw, outcome_raw, rfx_basis_raw, rfx_groups, feature_types, gen, n, x_cols, omega_cols, y_cols, rfx_basis_cols, num_rfx_groups, rfx_included, random_seed); dataset.AddCovariates(covariates_raw.data(), n, x_cols, row_major); output_dimension = 1; is_leaf_constant = true; - } - else { + } else if (dgp_num == 2) { + GenerateDGP3(covariates_raw, basis_raw, outcome_raw, rfx_basis_raw, rfx_groups, feature_types, gen, n, x_cols, omega_cols, y_cols, rfx_basis_cols, num_rfx_groups, rfx_included, random_seed); + dataset.AddCovariates(covariates_raw.data(), n, x_cols, row_major); + dataset.AddBasis(basis_raw.data(), n, omega_cols, row_major); + output_dimension = omega_cols; + is_leaf_constant = false; + } else { Log::Fatal("Invalid dgp_num"); } @@ -413,14 +497,16 @@ void RunDebug(int dgp_num = 0, ModelType model_type = kConstantLeafGaussian, boo double lamb = 0.5; // Set leaf model parameters - double leaf_scale_init = 1.; - Eigen::MatrixXd leaf_scale_matrix, leaf_scale_matrix_init; - // leaf_scale_matrix_init << 1.0, 0.0, 0.0, 1.0; double leaf_scale; + double leaf_scale_init = 1.; + Eigen::MatrixXd leaf_scale_matrix(omega_cols, omega_cols); + Eigen::MatrixXd leaf_scale_matrix_init(omega_cols, omega_cols); + leaf_scale_matrix_init << 1.0, 0.0, 0.0, 1.0; + leaf_scale_matrix = leaf_scale_matrix_init; // Set global variance - double global_variance_init = 1.0; double global_variance; + double global_variance_init = 1.0; // Set variable weights double const_var_wt = static_cast(1. / x_cols); @@ -538,8 +624,8 @@ void RunDebug(int dgp_num = 0, ModelType model_type = kConstantLeafGaussian, boo int main(int argc, char* argv[]) { // Unpack command line arguments int dgp_num = std::stoi(argv[1]); - if ((dgp_num != 0) && (dgp_num != 1)) { - StochTree::Log::Fatal("The first command line argument must be 0 or 1"); + if ((dgp_num != 0) && (dgp_num != 1) && (dgp_num != 2)) { + StochTree::Log::Fatal("The first command line argument must be 0, 1, or 2"); } int model_type_int = static_cast(std::stoi(argv[2])); if ((model_type_int != 0) && (model_type_int != 1) && (model_type_int != 2)) { diff --git a/src/sampler.cpp b/src/sampler.cpp index 7631bdd1..e9da2c81 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include [[cpp11::register]] @@ -30,18 +31,18 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointerNumBasis(); + // Run one iteration of the sampler - if (leaf_model_enum == ForestLeafModel::kConstant) { - StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(leaf_scale); - StochTree::GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); - } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { - StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); - StochTree::GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); - } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { - StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - StochTree::GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); + if (model_type == StochTree::ModelType::kConstantLeafGaussian) { + // StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(leaf_scale); + // StochTree::GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); + StochTree::GFRSampleOneIter(*tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); + } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { + // StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); + // StochTree::GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); + StochTree::GFRSampleOneIter(*tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); + } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { + // StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); + // StochTree::GFRSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); + StochTree::GFRSampleOneIter(*tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized, num_basis); } } @@ -91,18 +99,18 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointerNumBasis(); + // Run one iteration of the sampler - if (leaf_model_enum == ForestLeafModel::kConstant) { - StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(leaf_scale); - StochTree::MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); - } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { - StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); - StochTree::MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); - } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { - StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - StochTree::MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + if (model_type == StochTree::ModelType::kConstantLeafGaussian) { + // StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(leaf_scale); + // StochTree::MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + StochTree::MCMCSampleOneIter(*tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { + // StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); + // StochTree::MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + StochTree::MCMCSampleOneIter(*tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { + // StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); + // StochTree::MCMCSampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + StochTree::MCMCSampleOneIter(*tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized, num_basis); } } diff --git a/tools/perf/bart_microbenchmark.R b/tools/perf/bart_microbenchmark.R new file mode 100644 index 00000000..21e5e171 --- /dev/null +++ b/tools/perf/bart_microbenchmark.R @@ -0,0 +1,29 @@ +library(microbenchmark) +library(stochtree) + +# Generate data needed to train BART model +n <- 1000 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] + +# Run microbenchmark +microbenchmark( + bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 10, num_mcmc = 1000) +) From 7e2a110533db041446d5cff0fefd30d8dbec60c4 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 27 Aug 2024 02:22:57 -0500 Subject: [PATCH 13/41] Updated python library C++ code --- src/py_stochtree.cpp | 68 +++++++++++++++++--------------------------- 1 file changed, 26 insertions(+), 42 deletions(-) diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 207d1ce0..b576cb2d 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -461,18 +461,18 @@ class ForestSamplerCpp { } // Convert leaf model type to enum - ForestLeafModel leaf_model_enum; - if (leaf_model_int == 0) leaf_model_enum = ForestLeafModel::kConstant; - else if (leaf_model_int == 1) leaf_model_enum = ForestLeafModel::kUnivariateRegression; - else if (leaf_model_int == 2) leaf_model_enum = ForestLeafModel::kMultivariateRegression; + StochTree::ModelType model_type; + if (leaf_model_int == 0) model_type = StochTree::ModelType::kConstantLeafGaussian; + else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; + else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; // Unpack leaf model parameters double leaf_scale; Eigen::MatrixXd leaf_scale_matrix; - if ((leaf_model_enum == ForestLeafModel::kConstant) || - (leaf_model_enum == ForestLeafModel::kUnivariateRegression)) { + if ((model_type == StochTree::ModelType::kConstantLeafGaussian) || + (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian)) { leaf_scale = leaf_model_scale_input.at(0,0); - } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { + } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { int num_row = leaf_model_scale_input.shape(0); int num_col = leaf_model_scale_input.shape(1); leaf_scale_matrix.resize(num_row, num_col); @@ -482,60 +482,44 @@ class ForestSamplerCpp { } } } - + // Convert variable weights to std::vector std::vector var_weights_vector(variable_weights.size()); for (int i = 0; i < variable_weights.size(); i++) { var_weights_vector[i] = variable_weights.at(i); } + // Prepare the samplers + StochTree::LeafModelVariant leaf_model = StochTree::leafModelFactory(model_type, leaf_scale, leaf_scale_matrix); + // Run one iteration of the sampler StochTree::ForestContainer* forest_sample_ptr = forest_samples.GetContainer(); StochTree::ForestDataset* forest_data_ptr = dataset.GetDataset(); StochTree::ColumnVector* residual_data_ptr = residual.GetData(); + int num_basis = forest_data_ptr->NumBasis(); std::mt19937* rng_ptr = rng.GetRng(); if (gfr) { - InternalSampleGFR(*forest_sample_ptr, *forest_data_ptr, *residual_data_ptr, *rng_ptr, feature_types_, var_weights_vector, - leaf_model_enum, leaf_scale_matrix, global_variance, leaf_scale, cutpoint_grid_size, pre_initialized); + if (model_type == StochTree::ModelType::kConstantLeafGaussian) { + StochTree::GFRSampleOneIter(*(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); + } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { + StochTree::GFRSampleOneIter(*(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized); + } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { + StochTree::GFRSampleOneIter(*(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, feature_types_, cutpoint_grid_size, pre_initialized, num_basis); + } } else { - InternalSampleMCMC(*forest_sample_ptr, *forest_data_ptr, *residual_data_ptr, *rng_ptr, feature_types_, var_weights_vector, - leaf_model_enum, leaf_scale_matrix, global_variance, leaf_scale, cutpoint_grid_size, pre_initialized); + if (model_type == StochTree::ModelType::kConstantLeafGaussian) { + StochTree::MCMCSampleOneIter(*(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, pre_initialized); + } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { + StochTree::MCMCSampleOneIter(*(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, pre_initialized); + } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { + StochTree::MCMCSampleOneIter(*(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, global_variance, pre_initialized, num_basis); + } } } private: std::unique_ptr tracker_; std::unique_ptr split_prior_; - - void InternalSampleGFR(StochTree::ForestContainer& forest_samples, StochTree::ForestDataset& dataset, StochTree::ColumnVector& residual, std::mt19937& rng, - std::vector& feature_types, std::vector& var_weights_vector, ForestLeafModel leaf_model_enum, - Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size, bool pre_initialized) { - if (leaf_model_enum == ForestLeafModel::kConstant) { - StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(leaf_scale); - StochTree::GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); - } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { - StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); - StochTree::GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); - } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { - StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - StochTree::GFRSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, feature_types, cutpoint_grid_size, pre_initialized); - } - } - - void InternalSampleMCMC(StochTree::ForestContainer& forest_samples, StochTree::ForestDataset& dataset, StochTree::ColumnVector& residual, std::mt19937& rng, - std::vector& feature_types, std::vector& var_weights_vector, ForestLeafModel leaf_model_enum, - Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size, bool pre_initialized) { - if (leaf_model_enum == ForestLeafModel::kConstant) { - StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(leaf_scale); - StochTree::MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); - } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { - StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); - StochTree::MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); - } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { - StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - StochTree::MCMCSampleOneIter(*(tracker_.get()), forest_samples, leaf_model, dataset, residual, *(split_prior_.get()), rng, var_weights_vector, global_variance, pre_initialized); - } - } }; class GlobalVarianceModelCpp { From 64c19e8b0f4b012ee8ef02075fa0c23d4a603d37 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 27 Aug 2024 02:27:42 -0500 Subject: [PATCH 14/41] Added include --- include/stochtree/leaf_model.h | 1 + include/stochtree/tree_sampler.h | 1 + 2 files changed, 2 insertions(+) diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index f006566f..57ac1b69 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -17,6 +17,7 @@ #include #include +#include namespace StochTree { diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index b75b6b00..4aadb373 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -17,6 +17,7 @@ #include #include #include +#include #include namespace StochTree { From 06efbb5560b04f23cf4eadc3c03675de1eb5d9e1 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 27 Aug 2024 02:36:42 -0500 Subject: [PATCH 15/41] Updated unit tests --- test/cpp/test_model.cpp | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/test/cpp/test_model.cpp b/test/cpp/test_model.cpp index 23e2e929..0e729bef 100644 --- a/test/cpp/test_model.cpp +++ b/test/cpp/test_model.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -49,9 +50,10 @@ TEST(LeafConstantModel, FullEnumeration) { StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(tau); // Evaluate all possible cutpoints - leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, - cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, - feature_types); + StochTree::EvaluateAllPossibleSplits( + dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, + cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types + ); // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered ASSERT_EQ(log_cutpoint_evaluations.size(), (n - 2*min_samples_leaf + 1)*p + 1); @@ -107,9 +109,10 @@ TEST(LeafConstantModel, CutpointThinning) { StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(tau); // Evaluate all possible cutpoints - leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, - cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, - feature_types); + StochTree::EvaluateAllPossibleSplits( + dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, + cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types + ); // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered ASSERT_EQ(log_cutpoint_evaluations.size(), (cutpoint_grid_size - 1)*p + 1); @@ -165,9 +168,10 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(tau); // Evaluate all possible cutpoints - leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, - cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, - feature_types); + StochTree::EvaluateAllPossibleSplits( + dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, + cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types + ); // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered ASSERT_EQ(log_cutpoint_evaluations.size(), (n - 2*min_samples_leaf + 1)*p + 1); @@ -224,9 +228,11 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(tau); // Evaluate all possible cutpoints - leaf_model.EvaluateAllPossibleSplits(dataset, tracker, residual, tree_prior, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, - cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, - feature_types); + StochTree::EvaluateAllPossibleSplits( + dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, + cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types + ); + // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered ASSERT_EQ(log_cutpoint_evaluations.size(), (cutpoint_grid_size - 1)*p + 1); From ad03adb3ab3dc2dc06f1e2f2a97dc9378d2de7cc Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 31 Aug 2024 01:31:43 -0500 Subject: [PATCH 16/41] Initial setup for building and publishing C++ documentation --- .gitignore | 4 +- cpp_docs/Doxyfile | 2862 +++++++++++++++++++++++++++++++++++++ cpp_docs/Makefile | 20 + cpp_docs/README.md | 21 + cpp_docs/conf.py | 40 + cpp_docs/index.rst | 8 + cpp_docs/make.bat | 35 + cpp_docs/requirements.txt | 39 + python_docs/README.md | 2 +- 9 files changed, 3029 insertions(+), 2 deletions(-) create mode 100644 cpp_docs/Doxyfile create mode 100644 cpp_docs/Makefile create mode 100644 cpp_docs/README.md create mode 100644 cpp_docs/conf.py create mode 100644 cpp_docs/index.rst create mode 100644 cpp_docs/make.bat create mode 100644 cpp_docs/requirements.txt diff --git a/.gitignore b/.gitignore index 14f9b134..8d64c5b1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,6 @@ ## System and data files *.pdf *.csv -*.txt *.DS_Store lib/ build/ @@ -9,6 +8,9 @@ build/ xcode/ *.json .vs/ +cpp_docs/doxyoutput/html +cpp_docs/doxyoutput/xml +cpp_docs/doxyoutput/latex ## R gitignore diff --git a/cpp_docs/Doxyfile b/cpp_docs/Doxyfile new file mode 100644 index 00000000..2b178775 --- /dev/null +++ b/cpp_docs/Doxyfile @@ -0,0 +1,2862 @@ +# Doxyfile 1.12.0 + +# This file describes the settings to be used by the documentation system +# Doxygen (www.doxygen.org) for a project. +# +# All text after a double hash (##) is considered a comment and is placed in +# front of the TAG it is preceding. +# +# All text after a single hash (#) is considered a comment and will be ignored. +# The format is: +# TAG = value [value, ...] +# For lists, items can also be appended using: +# TAG += value [value, ...] +# Values that contain spaces should be placed between quotes (\" \"). +# +# Note: +# +# Use Doxygen to compare the used configuration file with the template +# configuration file: +# doxygen -x [configFile] +# Use Doxygen to compare the used configuration file with the template +# configuration file without replacing the environment variables or CMake type +# replacement variables: +# doxygen -x_noenv [configFile] + +#--------------------------------------------------------------------------- +# Project related configuration options +#--------------------------------------------------------------------------- + +# This tag specifies the encoding used for all characters in the configuration +# file that follow. The default is UTF-8 which is also the encoding used for all +# text before the first occurrence of this tag. Doxygen uses libiconv (or the +# iconv built into libc) for the transcoding. See +# https://www.gnu.org/software/libiconv/ for the list of possible encodings. +# The default value is: UTF-8. + +DOXYFILE_ENCODING = UTF-8 + +# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by +# double-quotes, unless you are using Doxywizard) that should identify the +# project for which the documentation is generated. This name is used in the +# title of most generated pages and in a few other places. +# The default value is: My Project. + +PROJECT_NAME = "StochTree" + +# The PROJECT_NUMBER tag can be used to enter a project or revision number. This +# could be handy for archiving the generated documentation or if some version +# control system is used. + +PROJECT_NUMBER = 0.0.1 + +# Using the PROJECT_BRIEF tag one can provide an optional one line description +# for a project that appears at the top of each page and should give viewer a +# quick idea about the purpose of the project. Keep the description short. + +PROJECT_BRIEF = + +# With the PROJECT_LOGO tag one can specify a logo or an icon that is included +# in the documentation. The maximum height of the logo should not exceed 55 +# pixels and the maximum width should not exceed 200 pixels. Doxygen will copy +# the logo to the output directory. + +PROJECT_LOGO = + +# With the PROJECT_ICON tag one can specify an icon that is included in the tabs +# when the HTML document is shown. Doxygen will copy the logo to the output +# directory. + +PROJECT_ICON = + +# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path +# into which the generated documentation will be written. If a relative path is +# entered, it will be relative to the location where Doxygen was started. If +# left blank the current directory will be used. + +OUTPUT_DIRECTORY = doxyoutput + +# If the CREATE_SUBDIRS tag is set to YES then Doxygen will create up to 4096 +# sub-directories (in 2 levels) under the output directory of each output format +# and will distribute the generated files over these directories. Enabling this +# option can be useful when feeding Doxygen a huge amount of source files, where +# putting all generated files in the same directory would otherwise causes +# performance problems for the file system. Adapt CREATE_SUBDIRS_LEVEL to +# control the number of sub-directories. +# The default value is: NO. + +CREATE_SUBDIRS = NO + +# Controls the number of sub-directories that will be created when +# CREATE_SUBDIRS tag is set to YES. Level 0 represents 16 directories, and every +# level increment doubles the number of directories, resulting in 4096 +# directories at level 8 which is the default and also the maximum value. The +# sub-directories are organized in 2 levels, the first level always has a fixed +# number of 16 directories. +# Minimum value: 0, maximum value: 8, default value: 8. +# This tag requires that the tag CREATE_SUBDIRS is set to YES. + +CREATE_SUBDIRS_LEVEL = 8 + +# If the ALLOW_UNICODE_NAMES tag is set to YES, Doxygen will allow non-ASCII +# characters to appear in the names of generated files. If set to NO, non-ASCII +# characters will be escaped, for example _xE3_x81_x84 will be used for Unicode +# U+3044. +# The default value is: NO. + +ALLOW_UNICODE_NAMES = NO + +# The OUTPUT_LANGUAGE tag is used to specify the language in which all +# documentation generated by Doxygen is written. Doxygen will use this +# information to generate all constant output in the proper language. +# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Bulgarian, +# Catalan, Chinese, Chinese-Traditional, Croatian, Czech, Danish, Dutch, English +# (United States), Esperanto, Farsi (Persian), Finnish, French, German, Greek, +# Hindi, Hungarian, Indonesian, Italian, Japanese, Japanese-en (Japanese with +# English messages), Korean, Korean-en (Korean with English messages), Latvian, +# Lithuanian, Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, +# Romanian, Russian, Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, +# Swedish, Turkish, Ukrainian and Vietnamese. +# The default value is: English. + +OUTPUT_LANGUAGE = English + +# If the BRIEF_MEMBER_DESC tag is set to YES, Doxygen will include brief member +# descriptions after the members that are listed in the file and class +# documentation (similar to Javadoc). Set to NO to disable this. +# The default value is: YES. + +BRIEF_MEMBER_DESC = YES + +# If the REPEAT_BRIEF tag is set to YES, Doxygen will prepend the brief +# description of a member or function before the detailed description +# +# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the +# brief descriptions will be completely suppressed. +# The default value is: YES. + +REPEAT_BRIEF = YES + +# This tag implements a quasi-intelligent brief description abbreviator that is +# used to form the text in various listings. Each string in this list, if found +# as the leading text of the brief description, will be stripped from the text +# and the result, after processing the whole list, is used as the annotated +# text. Otherwise, the brief description is used as-is. If left blank, the +# following values are used ($name is automatically replaced with the name of +# the entity):The $name class, The $name widget, The $name file, is, provides, +# specifies, contains, represents, a, an and the. + +ABBREVIATE_BRIEF = "The $name class" \ + "The $name widget" \ + "The $name file" \ + is \ + provides \ + specifies \ + contains \ + represents \ + a \ + an \ + the + +# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then +# Doxygen will generate a detailed section even if there is only a brief +# description. +# The default value is: NO. + +ALWAYS_DETAILED_SEC = NO + +# If the INLINE_INHERITED_MEMB tag is set to YES, Doxygen will show all +# inherited members of a class in the documentation of that class as if those +# members were ordinary class members. Constructors, destructors and assignment +# operators of the base classes will not be shown. +# The default value is: NO. + +INLINE_INHERITED_MEMB = NO + +# If the FULL_PATH_NAMES tag is set to YES, Doxygen will prepend the full path +# before files name in the file list and in the header files. If set to NO the +# shortest path that makes the file name unique will be used +# The default value is: YES. + +FULL_PATH_NAMES = YES + +# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path. +# Stripping is only done if one of the specified strings matches the left-hand +# part of the path. The tag can be used to show relative paths in the file list. +# If left blank the directory from which Doxygen is run is used as the path to +# strip. +# +# Note that you can specify absolute paths here, but also relative paths, which +# will be relative from the directory where Doxygen is started. +# This tag requires that the tag FULL_PATH_NAMES is set to YES. + +STRIP_FROM_PATH = + +# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the +# path mentioned in the documentation of a class, which tells the reader which +# header file to include in order to use a class. If left blank only the name of +# the header file containing the class definition is used. Otherwise one should +# specify the list of include paths that are normally passed to the compiler +# using the -I flag. + +STRIP_FROM_INC_PATH = + +# If the SHORT_NAMES tag is set to YES, Doxygen will generate much shorter (but +# less readable) file names. This can be useful is your file systems doesn't +# support long names like on DOS, Mac, or CD-ROM. +# The default value is: NO. + +SHORT_NAMES = NO + +# If the JAVADOC_AUTOBRIEF tag is set to YES then Doxygen will interpret the +# first line (until the first dot) of a Javadoc-style comment as the brief +# description. If set to NO, the Javadoc-style will behave just like regular Qt- +# style comments (thus requiring an explicit @brief command for a brief +# description.) +# The default value is: NO. + +JAVADOC_AUTOBRIEF = NO + +# If the JAVADOC_BANNER tag is set to YES then Doxygen will interpret a line +# such as +# /*************** +# as being the beginning of a Javadoc-style comment "banner". If set to NO, the +# Javadoc-style will behave just like regular comments and it will not be +# interpreted by Doxygen. +# The default value is: NO. + +JAVADOC_BANNER = NO + +# If the QT_AUTOBRIEF tag is set to YES then Doxygen will interpret the first +# line (until the first dot) of a Qt-style comment as the brief description. If +# set to NO, the Qt-style will behave just like regular Qt-style comments (thus +# requiring an explicit \brief command for a brief description.) +# The default value is: NO. + +QT_AUTOBRIEF = NO + +# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make Doxygen treat a +# multi-line C++ special comment block (i.e. a block of //! or /// comments) as +# a brief description. This used to be the default behavior. The new default is +# to treat a multi-line C++ comment block as a detailed description. Set this +# tag to YES if you prefer the old behavior instead. +# +# Note that setting this tag to YES also means that rational rose comments are +# not recognized any more. +# The default value is: NO. + +MULTILINE_CPP_IS_BRIEF = NO + +# By default Python docstrings are displayed as preformatted text and Doxygen's +# special commands cannot be used. By setting PYTHON_DOCSTRING to NO the +# Doxygen's special commands can be used and the contents of the docstring +# documentation blocks is shown as Doxygen documentation. +# The default value is: YES. + +PYTHON_DOCSTRING = YES + +# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the +# documentation from any documented member that it re-implements. +# The default value is: YES. + +INHERIT_DOCS = YES + +# If the SEPARATE_MEMBER_PAGES tag is set to YES then Doxygen will produce a new +# page for each member. If set to NO, the documentation of a member will be part +# of the file/class/namespace that contains it. +# The default value is: NO. + +SEPARATE_MEMBER_PAGES = NO + +# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen +# uses this value to replace tabs by spaces in code fragments. +# Minimum value: 1, maximum value: 16, default value: 4. + +TAB_SIZE = 4 + +# This tag can be used to specify a number of aliases that act as commands in +# the documentation. An alias has the form: +# name=value +# For example adding +# "sideeffect=@par Side Effects:^^" +# will allow you to put the command \sideeffect (or @sideeffect) in the +# documentation, which will result in a user-defined paragraph with heading +# "Side Effects:". Note that you cannot put \n's in the value part of an alias +# to insert newlines (in the resulting output). You can put ^^ in the value part +# of an alias to insert a newline as if a physical newline was in the original +# file. When you need a literal { or } or , in the value part of an alias you +# have to escape them by means of a backslash (\), this can lead to conflicts +# with the commands \{ and \} for these it is advised to use the version @{ and +# @} or use a double escape (\\{ and \\}) + +ALIASES = + +# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources +# only. Doxygen will then generate output that is more tailored for C. For +# instance, some of the names that are used will be different. The list of all +# members will be omitted, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_FOR_C = NO + +# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or +# Python sources only. Doxygen will then generate output that is more tailored +# for that language. For instance, namespaces will be presented as packages, +# qualified scopes will look different, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_JAVA = NO + +# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran +# sources. Doxygen will then generate output that is tailored for Fortran. +# The default value is: NO. + +OPTIMIZE_FOR_FORTRAN = NO + +# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL +# sources. Doxygen will then generate output that is tailored for VHDL. +# The default value is: NO. + +OPTIMIZE_OUTPUT_VHDL = NO + +# Set the OPTIMIZE_OUTPUT_SLICE tag to YES if your project consists of Slice +# sources only. Doxygen will then generate output that is more tailored for that +# language. For instance, namespaces will be presented as modules, types will be +# separated into more groups, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_SLICE = NO + +# Doxygen selects the parser to use depending on the extension of the files it +# parses. With this tag you can assign which parser to use for a given +# extension. Doxygen has a built-in mapping, but you can override or extend it +# using this tag. The format is ext=language, where ext is a file extension, and +# language is one of the parsers supported by Doxygen: IDL, Java, JavaScript, +# Csharp (C#), C, C++, Lex, D, PHP, md (Markdown), Objective-C, Python, Slice, +# VHDL, Fortran (fixed format Fortran: FortranFixed, free formatted Fortran: +# FortranFree, unknown formatted Fortran: Fortran. In the later case the parser +# tries to guess whether the code is fixed or free formatted code, this is the +# default for Fortran type files). For instance to make Doxygen treat .inc files +# as Fortran files (default is PHP), and .f files as C (default is Fortran), +# use: inc=Fortran f=C. +# +# Note: For files without extension you can use no_extension as a placeholder. +# +# Note that for custom extensions you also need to set FILE_PATTERNS otherwise +# the files are not read by Doxygen. When specifying no_extension you should add +# * to the FILE_PATTERNS. +# +# Note see also the list of default file extension mappings. + +EXTENSION_MAPPING = + +# If the MARKDOWN_SUPPORT tag is enabled then Doxygen pre-processes all comments +# according to the Markdown format, which allows for more readable +# documentation. See https://daringfireball.net/projects/markdown/ for details. +# The output of markdown processing is further processed by Doxygen, so you can +# mix Doxygen, HTML, and XML commands with Markdown formatting. Disable only in +# case of backward compatibilities issues. +# The default value is: YES. + +MARKDOWN_SUPPORT = YES + +# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up +# to that level are automatically included in the table of contents, even if +# they do not have an id attribute. +# Note: This feature currently applies only to Markdown headings. +# Minimum value: 0, maximum value: 99, default value: 6. +# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. + +TOC_INCLUDE_HEADINGS = 6 + +# The MARKDOWN_ID_STYLE tag can be used to specify the algorithm used to +# generate identifiers for the Markdown headings. Note: Every identifier is +# unique. +# Possible values are: DOXYGEN use a fixed 'autotoc_md' string followed by a +# sequence number starting at 0 and GITHUB use the lower case version of title +# with any whitespace replaced by '-' and punctuation characters removed. +# The default value is: DOXYGEN. +# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. + +MARKDOWN_ID_STYLE = DOXYGEN + +# When enabled Doxygen tries to link words that correspond to documented +# classes, or namespaces to their corresponding documentation. Such a link can +# be prevented in individual cases by putting a % sign in front of the word or +# globally by setting AUTOLINK_SUPPORT to NO. +# The default value is: YES. + +AUTOLINK_SUPPORT = YES + +# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want +# to include (a tag file for) the STL sources as input, then you should set this +# tag to YES in order to let Doxygen match functions declarations and +# definitions whose arguments contain STL classes (e.g. func(std::string); +# versus func(std::string) {}). This also makes the inheritance and +# collaboration diagrams that involve STL classes more complete and accurate. +# The default value is: NO. + +BUILTIN_STL_SUPPORT = NO + +# If you use Microsoft's C++/CLI language, you should set this option to YES to +# enable parsing support. +# The default value is: NO. + +CPP_CLI_SUPPORT = NO + +# Set the SIP_SUPPORT tag to YES if your project consists of sip (see: +# https://www.riverbankcomputing.com/software) sources only. Doxygen will parse +# them like normal C++ but will assume all classes use public instead of private +# inheritance when no explicit protection keyword is present. +# The default value is: NO. + +SIP_SUPPORT = NO + +# For Microsoft's IDL there are propget and propput attributes to indicate +# getter and setter methods for a property. Setting this option to YES will make +# Doxygen to replace the get and set methods by a property in the documentation. +# This will only work if the methods are indeed getting or setting a simple +# type. If this is not the case, or you want to show the methods anyway, you +# should set this option to NO. +# The default value is: YES. + +IDL_PROPERTY_SUPPORT = YES + +# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC +# tag is set to YES then Doxygen will reuse the documentation of the first +# member in the group (if any) for the other members of the group. By default +# all members of a group must be documented explicitly. +# The default value is: NO. + +DISTRIBUTE_GROUP_DOC = NO + +# If one adds a struct or class to a group and this option is enabled, then also +# any nested class or struct is added to the same group. By default this option +# is disabled and one has to add nested compounds explicitly via \ingroup. +# The default value is: NO. + +GROUP_NESTED_COMPOUNDS = NO + +# Set the SUBGROUPING tag to YES to allow class member groups of the same type +# (for instance a group of public functions) to be put as a subgroup of that +# type (e.g. under the Public Functions section). Set it to NO to prevent +# subgrouping. Alternatively, this can be done per class using the +# \nosubgrouping command. +# The default value is: YES. + +SUBGROUPING = YES + +# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions +# are shown inside the group in which they are included (e.g. using \ingroup) +# instead of on a separate page (for HTML and Man pages) or section (for LaTeX +# and RTF). +# +# Note that this feature does not work in combination with +# SEPARATE_MEMBER_PAGES. +# The default value is: NO. + +INLINE_GROUPED_CLASSES = NO + +# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions +# with only public data fields or simple typedef fields will be shown inline in +# the documentation of the scope in which they are defined (i.e. file, +# namespace, or group documentation), provided this scope is documented. If set +# to NO, structs, classes, and unions are shown on a separate page (for HTML and +# Man pages) or section (for LaTeX and RTF). +# The default value is: NO. + +INLINE_SIMPLE_STRUCTS = NO + +# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or +# enum is documented as struct, union, or enum with the name of the typedef. So +# typedef struct TypeS {} TypeT, will appear in the documentation as a struct +# with name TypeT. When disabled the typedef will appear as a member of a file, +# namespace, or class. And the struct will be named TypeS. This can typically be +# useful for C code in case the coding convention dictates that all compound +# types are typedef'ed and only the typedef is referenced, never the tag name. +# The default value is: NO. + +TYPEDEF_HIDES_STRUCT = NO + +# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This +# cache is used to resolve symbols given their name and scope. Since this can be +# an expensive process and often the same symbol appears multiple times in the +# code, Doxygen keeps a cache of pre-resolved symbols. If the cache is too small +# Doxygen will become slower. If the cache is too large, memory is wasted. The +# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range +# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536 +# symbols. At the end of a run Doxygen will report the cache usage and suggest +# the optimal cache size from a speed point of view. +# Minimum value: 0, maximum value: 9, default value: 0. + +LOOKUP_CACHE_SIZE = 0 + +# The NUM_PROC_THREADS specifies the number of threads Doxygen is allowed to use +# during processing. When set to 0 Doxygen will based this on the number of +# cores available in the system. You can set it explicitly to a value larger +# than 0 to get more control over the balance between CPU load and processing +# speed. At this moment only the input processing can be done using multiple +# threads. Since this is still an experimental feature the default is set to 1, +# which effectively disables parallel processing. Please report any issues you +# encounter. Generating dot graphs in parallel is controlled by the +# DOT_NUM_THREADS setting. +# Minimum value: 0, maximum value: 32, default value: 1. + +NUM_PROC_THREADS = 1 + +# If the TIMESTAMP tag is set different from NO then each generated page will +# contain the date or date and time when the page was generated. Setting this to +# NO can help when comparing the output of multiple runs. +# Possible values are: YES, NO, DATETIME and DATE. +# The default value is: NO. + +TIMESTAMP = NO + +#--------------------------------------------------------------------------- +# Build related configuration options +#--------------------------------------------------------------------------- + +# If the EXTRACT_ALL tag is set to YES, Doxygen will assume all entities in +# documentation are documented, even if no documentation was available. Private +# class members and static file members will be hidden unless the +# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES. +# Note: This will also disable the warnings about undocumented members that are +# normally produced when WARNINGS is set to YES. +# The default value is: NO. + +EXTRACT_ALL = NO + +# If the EXTRACT_PRIVATE tag is set to YES, all private members of a class will +# be included in the documentation. +# The default value is: NO. + +EXTRACT_PRIVATE = NO + +# If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual +# methods of a class will be included in the documentation. +# The default value is: NO. + +EXTRACT_PRIV_VIRTUAL = NO + +# If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal +# scope will be included in the documentation. +# The default value is: NO. + +EXTRACT_PACKAGE = NO + +# If the EXTRACT_STATIC tag is set to YES, all static members of a file will be +# included in the documentation. +# The default value is: NO. + +EXTRACT_STATIC = NO + +# If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined +# locally in source files will be included in the documentation. If set to NO, +# only classes defined in header files are included. Does not have any effect +# for Java sources. +# The default value is: YES. + +EXTRACT_LOCAL_CLASSES = YES + +# This flag is only useful for Objective-C code. If set to YES, local methods, +# which are defined in the implementation section but not in the interface are +# included in the documentation. If set to NO, only methods in the interface are +# included. +# The default value is: NO. + +EXTRACT_LOCAL_METHODS = NO + +# If this flag is set to YES, the members of anonymous namespaces will be +# extracted and appear in the documentation as a namespace called +# 'anonymous_namespace{file}', where file will be replaced with the base name of +# the file that contains the anonymous namespace. By default anonymous namespace +# are hidden. +# The default value is: NO. + +EXTRACT_ANON_NSPACES = NO + +# If this flag is set to YES, the name of an unnamed parameter in a declaration +# will be determined by the corresponding definition. By default unnamed +# parameters remain unnamed in the output. +# The default value is: YES. + +RESOLVE_UNNAMED_PARAMS = YES + +# If the HIDE_UNDOC_MEMBERS tag is set to YES, Doxygen will hide all +# undocumented members inside documented classes or files. If set to NO these +# members will be included in the various overviews, but no documentation +# section is generated. This option has no effect if EXTRACT_ALL is enabled. +# The default value is: NO. + +HIDE_UNDOC_MEMBERS = NO + +# If the HIDE_UNDOC_CLASSES tag is set to YES, Doxygen will hide all +# undocumented classes that are normally visible in the class hierarchy. If set +# to NO, these classes will be included in the various overviews. This option +# will also hide undocumented C++ concepts if enabled. This option has no effect +# if EXTRACT_ALL is enabled. +# The default value is: NO. + +HIDE_UNDOC_CLASSES = NO + +# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, Doxygen will hide all friend +# declarations. If set to NO, these declarations will be included in the +# documentation. +# The default value is: NO. + +HIDE_FRIEND_COMPOUNDS = NO + +# If the HIDE_IN_BODY_DOCS tag is set to YES, Doxygen will hide any +# documentation blocks found inside the body of a function. If set to NO, these +# blocks will be appended to the function's detailed documentation block. +# The default value is: NO. + +HIDE_IN_BODY_DOCS = NO + +# The INTERNAL_DOCS tag determines if documentation that is typed after a +# \internal command is included. If the tag is set to NO then the documentation +# will be excluded. Set it to YES to include the internal documentation. +# The default value is: NO. + +INTERNAL_DOCS = NO + +# With the correct setting of option CASE_SENSE_NAMES Doxygen will better be +# able to match the capabilities of the underlying filesystem. In case the +# filesystem is case sensitive (i.e. it supports files in the same directory +# whose names only differ in casing), the option must be set to YES to properly +# deal with such files in case they appear in the input. For filesystems that +# are not case sensitive the option should be set to NO to properly deal with +# output files written for symbols that only differ in casing, such as for two +# classes, one named CLASS and the other named Class, and to also support +# references to files without having to specify the exact matching casing. On +# Windows (including Cygwin) and macOS, users should typically set this option +# to NO, whereas on Linux or other Unix flavors it should typically be set to +# YES. +# Possible values are: SYSTEM, NO and YES. +# The default value is: SYSTEM. + +CASE_SENSE_NAMES = SYSTEM + +# If the HIDE_SCOPE_NAMES tag is set to NO then Doxygen will show members with +# their full class and namespace scopes in the documentation. If set to YES, the +# scope will be hidden. +# The default value is: NO. + +HIDE_SCOPE_NAMES = NO + +# If the HIDE_COMPOUND_REFERENCE tag is set to NO (default) then Doxygen will +# append additional text to a page's title, such as Class Reference. If set to +# YES the compound reference will be hidden. +# The default value is: NO. + +HIDE_COMPOUND_REFERENCE= NO + +# If the SHOW_HEADERFILE tag is set to YES then the documentation for a class +# will show which file needs to be included to use the class. +# The default value is: YES. + +SHOW_HEADERFILE = YES + +# If the SHOW_INCLUDE_FILES tag is set to YES then Doxygen will put a list of +# the files that are included by a file in the documentation of that file. +# The default value is: YES. + +SHOW_INCLUDE_FILES = YES + +# If the SHOW_GROUPED_MEMB_INC tag is set to YES then Doxygen will add for each +# grouped member an include statement to the documentation, telling the reader +# which file to include in order to use the member. +# The default value is: NO. + +SHOW_GROUPED_MEMB_INC = NO + +# If the FORCE_LOCAL_INCLUDES tag is set to YES then Doxygen will list include +# files with double quotes in the documentation rather than with sharp brackets. +# The default value is: NO. + +FORCE_LOCAL_INCLUDES = NO + +# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the +# documentation for inline members. +# The default value is: YES. + +INLINE_INFO = YES + +# If the SORT_MEMBER_DOCS tag is set to YES then Doxygen will sort the +# (detailed) documentation of file and class members alphabetically by member +# name. If set to NO, the members will appear in declaration order. +# The default value is: YES. + +SORT_MEMBER_DOCS = YES + +# If the SORT_BRIEF_DOCS tag is set to YES then Doxygen will sort the brief +# descriptions of file, namespace and class members alphabetically by member +# name. If set to NO, the members will appear in declaration order. Note that +# this will also influence the order of the classes in the class list. +# The default value is: NO. + +SORT_BRIEF_DOCS = NO + +# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then Doxygen will sort the +# (brief and detailed) documentation of class members so that constructors and +# destructors are listed first. If set to NO the constructors will appear in the +# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS. +# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief +# member documentation. +# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting +# detailed member documentation. +# The default value is: NO. + +SORT_MEMBERS_CTORS_1ST = NO + +# If the SORT_GROUP_NAMES tag is set to YES then Doxygen will sort the hierarchy +# of group names into alphabetical order. If set to NO the group names will +# appear in their defined order. +# The default value is: NO. + +SORT_GROUP_NAMES = NO + +# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by +# fully-qualified names, including namespaces. If set to NO, the class list will +# be sorted only by class name, not including the namespace part. +# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES. +# Note: This option applies only to the class list, not to the alphabetical +# list. +# The default value is: NO. + +SORT_BY_SCOPE_NAME = NO + +# If the STRICT_PROTO_MATCHING option is enabled and Doxygen fails to do proper +# type resolution of all parameters of a function it will reject a match between +# the prototype and the implementation of a member function even if there is +# only one candidate or it is obvious which candidate to choose by doing a +# simple string match. By disabling STRICT_PROTO_MATCHING Doxygen will still +# accept a match between prototype and implementation in such cases. +# The default value is: NO. + +STRICT_PROTO_MATCHING = NO + +# The GENERATE_TODOLIST tag can be used to enable (YES) or disable (NO) the todo +# list. This list is created by putting \todo commands in the documentation. +# The default value is: YES. + +GENERATE_TODOLIST = YES + +# The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test +# list. This list is created by putting \test commands in the documentation. +# The default value is: YES. + +GENERATE_TESTLIST = YES + +# The GENERATE_BUGLIST tag can be used to enable (YES) or disable (NO) the bug +# list. This list is created by putting \bug commands in the documentation. +# The default value is: YES. + +GENERATE_BUGLIST = YES + +# The GENERATE_DEPRECATEDLIST tag can be used to enable (YES) or disable (NO) +# the deprecated list. This list is created by putting \deprecated commands in +# the documentation. +# The default value is: YES. + +GENERATE_DEPRECATEDLIST= YES + +# The ENABLED_SECTIONS tag can be used to enable conditional documentation +# sections, marked by \if ... \endif and \cond +# ... \endcond blocks. + +ENABLED_SECTIONS = + +# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the +# initial value of a variable or macro / define can have for it to appear in the +# documentation. If the initializer consists of more lines than specified here +# it will be hidden. Use a value of 0 to hide initializers completely. The +# appearance of the value of individual variables and macros / defines can be +# controlled using \showinitializer or \hideinitializer command in the +# documentation regardless of this setting. +# Minimum value: 0, maximum value: 10000, default value: 30. + +MAX_INITIALIZER_LINES = 30 + +# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at +# the bottom of the documentation of classes and structs. If set to YES, the +# list will mention the files that were used to generate the documentation. +# The default value is: YES. + +SHOW_USED_FILES = YES + +# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This +# will remove the Files entry from the Quick Index and from the Folder Tree View +# (if specified). +# The default value is: YES. + +SHOW_FILES = YES + +# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces +# page. This will remove the Namespaces entry from the Quick Index and from the +# Folder Tree View (if specified). +# The default value is: YES. + +SHOW_NAMESPACES = YES + +# The FILE_VERSION_FILTER tag can be used to specify a program or script that +# Doxygen should invoke to get the current version for each file (typically from +# the version control system). Doxygen will invoke the program by executing (via +# popen()) the command command input-file, where command is the value of the +# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided +# by Doxygen. Whatever the program writes to standard output is used as the file +# version. For an example see the documentation. + +FILE_VERSION_FILTER = + +# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed +# by Doxygen. The layout file controls the global structure of the generated +# output files in an output format independent way. To create the layout file +# that represents Doxygen's defaults, run Doxygen with the -l option. You can +# optionally specify a file name after the option, if omitted DoxygenLayout.xml +# will be used as the name of the layout file. See also section "Changing the +# layout of pages" for information. +# +# Note that if you run Doxygen from a directory containing a file called +# DoxygenLayout.xml, Doxygen will parse it automatically even if the LAYOUT_FILE +# tag is left empty. + +LAYOUT_FILE = + +# The CITE_BIB_FILES tag can be used to specify one or more bib files containing +# the reference definitions. This must be a list of .bib files. The .bib +# extension is automatically appended if omitted. This requires the bibtex tool +# to be installed. See also https://en.wikipedia.org/wiki/BibTeX for more info. +# For LaTeX the style of the bibliography can be controlled using +# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the +# search path. See also \cite for info how to create references. + +CITE_BIB_FILES = + +# The EXTERNAL_TOOL_PATH tag can be used to extend the search path (PATH +# environment variable) so that external tools such as latex and gs can be +# found. +# Note: Directories specified with EXTERNAL_TOOL_PATH are added in front of the +# path already specified by the PATH variable, and are added in the order +# specified. +# Note: This option is particularly useful for macOS version 14 (Sonoma) and +# higher, when running Doxygen from Doxywizard, because in this case any user- +# defined changes to the PATH are ignored. A typical example on macOS is to set +# EXTERNAL_TOOL_PATH = /Library/TeX/texbin /usr/local/bin +# together with the standard path, the full search path used by doxygen when +# launching external tools will then become +# PATH=/Library/TeX/texbin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin + +EXTERNAL_TOOL_PATH = + +#--------------------------------------------------------------------------- +# Configuration options related to warning and progress messages +#--------------------------------------------------------------------------- + +# The QUIET tag can be used to turn on/off the messages that are generated to +# standard output by Doxygen. If QUIET is set to YES this implies that the +# messages are off. +# The default value is: NO. + +QUIET = NO + +# The WARNINGS tag can be used to turn on/off the warning messages that are +# generated to standard error (stderr) by Doxygen. If WARNINGS is set to YES +# this implies that the warnings are on. +# +# Tip: Turn warnings on while writing the documentation. +# The default value is: YES. + +WARNINGS = YES + +# If the WARN_IF_UNDOCUMENTED tag is set to YES then Doxygen will generate +# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag +# will automatically be disabled. +# The default value is: YES. + +WARN_IF_UNDOCUMENTED = YES + +# If the WARN_IF_DOC_ERROR tag is set to YES, Doxygen will generate warnings for +# potential errors in the documentation, such as documenting some parameters in +# a documented function twice, or documenting parameters that don't exist or +# using markup commands wrongly. +# The default value is: YES. + +WARN_IF_DOC_ERROR = YES + +# If WARN_IF_INCOMPLETE_DOC is set to YES, Doxygen will warn about incomplete +# function parameter documentation. If set to NO, Doxygen will accept that some +# parameters have no documentation without warning. +# The default value is: YES. + +WARN_IF_INCOMPLETE_DOC = YES + +# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that +# are documented, but have no documentation for their parameters or return +# value. If set to NO, Doxygen will only warn about wrong parameter +# documentation, but not about the absence of documentation. If EXTRACT_ALL is +# set to YES then this flag will automatically be disabled. See also +# WARN_IF_INCOMPLETE_DOC +# The default value is: NO. + +WARN_NO_PARAMDOC = NO + +# If WARN_IF_UNDOC_ENUM_VAL option is set to YES, Doxygen will warn about +# undocumented enumeration values. If set to NO, Doxygen will accept +# undocumented enumeration values. If EXTRACT_ALL is set to YES then this flag +# will automatically be disabled. +# The default value is: NO. + +WARN_IF_UNDOC_ENUM_VAL = NO + +# If the WARN_AS_ERROR tag is set to YES then Doxygen will immediately stop when +# a warning is encountered. If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS +# then Doxygen will continue running as if WARN_AS_ERROR tag is set to NO, but +# at the end of the Doxygen process Doxygen will return with a non-zero status. +# If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS_PRINT then Doxygen behaves +# like FAIL_ON_WARNINGS but in case no WARN_LOGFILE is defined Doxygen will not +# write the warning messages in between other messages but write them at the end +# of a run, in case a WARN_LOGFILE is defined the warning messages will be +# besides being in the defined file also be shown at the end of a run, unless +# the WARN_LOGFILE is defined as - i.e. standard output (stdout) in that case +# the behavior will remain as with the setting FAIL_ON_WARNINGS. +# Possible values are: NO, YES, FAIL_ON_WARNINGS and FAIL_ON_WARNINGS_PRINT. +# The default value is: NO. + +WARN_AS_ERROR = NO + +# The WARN_FORMAT tag determines the format of the warning messages that Doxygen +# can produce. The string should contain the $file, $line, and $text tags, which +# will be replaced by the file and line number from which the warning originated +# and the warning text. Optionally the format may contain $version, which will +# be replaced by the version of the file (if it could be obtained via +# FILE_VERSION_FILTER) +# See also: WARN_LINE_FORMAT +# The default value is: $file:$line: $text. + +WARN_FORMAT = "$file:$line: $text" + +# In the $text part of the WARN_FORMAT command it is possible that a reference +# to a more specific place is given. To make it easier to jump to this place +# (outside of Doxygen) the user can define a custom "cut" / "paste" string. +# Example: +# WARN_LINE_FORMAT = "'vi $file +$line'" +# See also: WARN_FORMAT +# The default value is: at line $line of file $file. + +WARN_LINE_FORMAT = "at line $line of file $file" + +# The WARN_LOGFILE tag can be used to specify a file to which warning and error +# messages should be written. If left blank the output is written to standard +# error (stderr). In case the file specified cannot be opened for writing the +# warning and error messages are written to standard error. When as file - is +# specified the warning and error messages are written to standard output +# (stdout). + +WARN_LOGFILE = + +#--------------------------------------------------------------------------- +# Configuration options related to the input files +#--------------------------------------------------------------------------- + +# The INPUT tag is used to specify the files and/or directories that contain +# documented source files. You may enter file names like myfile.cpp or +# directories like /usr/src/myproject. Separate the files or directories with +# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING +# Note: If this tag is empty the current directory is searched. + +INPUT = ../ + +# This tag can be used to specify the character encoding of the source files +# that Doxygen parses. Internally Doxygen uses the UTF-8 encoding. Doxygen uses +# libiconv (or the iconv built into libc) for the transcoding. See the libiconv +# documentation (see: +# https://www.gnu.org/software/libiconv/) for the list of possible encodings. +# See also: INPUT_FILE_ENCODING +# The default value is: UTF-8. + +INPUT_ENCODING = UTF-8 + +# This tag can be used to specify the character encoding of the source files +# that Doxygen parses The INPUT_FILE_ENCODING tag can be used to specify +# character encoding on a per file pattern basis. Doxygen will compare the file +# name with each pattern and apply the encoding instead of the default +# INPUT_ENCODING) if there is a match. The character encodings are a list of the +# form: pattern=encoding (like *.php=ISO-8859-1). +# See also: INPUT_ENCODING for further information on supported encodings. + +INPUT_FILE_ENCODING = + +# If the value of the INPUT tag contains directories, you can use the +# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and +# *.h) to filter out the source-files in the directories. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# read by Doxygen. +# +# Note the list of default checked file patterns might differ from the list of +# default file extension mappings. +# +# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cxxm, +# *.cpp, *.cppm, *.ccm, *.c++, *.c++m, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, +# *.idl, *.ddl, *.odl, *.h, *.hh, *.hxx, *.hpp, *.h++, *.ixx, *.l, *.cs, *.d, +# *.php, *.php4, *.php5, *.phtml, *.inc, *.m, *.markdown, *.md, *.mm, *.dox (to +# be provided as Doxygen C comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, +# *.f18, *.f, *.for, *.vhd, *.vhdl, *.ucf, *.qsf and *.ice. + +FILE_PATTERNS = *.c \ + *.cc \ + *.cxx \ + *.cxxm \ + *.cpp \ + *.cppm \ + *.ccm \ + *.c++ \ + *.c++m \ + *.h \ + *.hh \ + *.hxx \ + *.hpp \ + *.h++ \ + +# The RECURSIVE tag can be used to specify whether or not subdirectories should +# be searched for input files as well. +# The default value is: NO. + +RECURSIVE = YES + +# The EXCLUDE tag can be used to specify files and/or directories that should be +# excluded from the INPUT source files. This way you can easily exclude a +# subdirectory from a directory tree whose root is specified with the INPUT tag. +# +# Note that relative paths are relative to the directory from which Doxygen is +# run. + +EXCLUDE = ../src/cpp11.cpp \ + ../src/py_stochtree.cpp \ + ../src/R_data.cpp \ + ../src/R_random_effects.cpp \ + ../src/sampler.cpp \ + ../src/serialization.cpp \ + ../src/stochtree_types.h + +# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or +# directories that are symbolic links (a Unix file system feature) are excluded +# from the input. +# The default value is: NO. + +EXCLUDE_SYMLINKS = NO + +# If the value of the INPUT tag contains directories, you can use the +# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude +# certain files from those directories. +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories for example use the pattern */test/* + +EXCLUDE_PATTERNS = */test/* \ + */tools/* \ + */vignettes/* \ + */R/* \ + */nlohmann/* \ + */debug/* \ + */demo/* \ + */deps/* \ + */venv/* \ + */xcode/* + +# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names +# (namespaces, classes, functions, etc.) that should be excluded from the +# output. The symbol name can be a fully qualified name, a word, or if the +# wildcard * is used, a substring. Examples: ANamespace, AClass, +# ANamespace::AClass, ANamespace::*Test + +EXCLUDE_SYMBOLS = + +# The EXAMPLE_PATH tag can be used to specify one or more files or directories +# that contain example code fragments that are included (see the \include +# command). + +EXAMPLE_PATH = + +# If the value of the EXAMPLE_PATH tag contains directories, you can use the +# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and +# *.h) to filter out the source-files in the directories. If left blank all +# files are included. + +EXAMPLE_PATTERNS = * + +# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be +# searched for input files to be used with the \include or \dontinclude commands +# irrespective of the value of the RECURSIVE tag. +# The default value is: NO. + +EXAMPLE_RECURSIVE = NO + +# The IMAGE_PATH tag can be used to specify one or more files or directories +# that contain images that are to be included in the documentation (see the +# \image command). + +IMAGE_PATH = + +# The INPUT_FILTER tag can be used to specify a program that Doxygen should +# invoke to filter for each input file. Doxygen will invoke the filter program +# by executing (via popen()) the command: +# +# +# +# where is the value of the INPUT_FILTER tag, and is the +# name of an input file. Doxygen will then use the output that the filter +# program writes to standard output. If FILTER_PATTERNS is specified, this tag +# will be ignored. +# +# Note that the filter must not add or remove lines; it is applied before the +# code is scanned, but not when the output code is generated. If lines are added +# or removed, the anchors will not be placed correctly. +# +# Note that Doxygen will use the data processed and written to standard output +# for further processing, therefore nothing else, like debug statements or used +# commands (so in case of a Windows batch file always use @echo OFF), should be +# written to standard output. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by Doxygen. + +INPUT_FILTER = + +# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern +# basis. Doxygen will compare the file name with each pattern and apply the +# filter if there is a match. The filters are a list of the form: pattern=filter +# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how +# filters are used. If the FILTER_PATTERNS tag is empty or if none of the +# patterns match the file name, INPUT_FILTER is applied. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by Doxygen. + +FILTER_PATTERNS = + +# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using +# INPUT_FILTER) will also be used to filter the input files that are used for +# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES). +# The default value is: NO. + +FILTER_SOURCE_FILES = NO + +# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file +# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and +# it is also possible to disable source filtering for a specific pattern using +# *.ext= (so without naming a filter). +# This tag requires that the tag FILTER_SOURCE_FILES is set to YES. + +FILTER_SOURCE_PATTERNS = + +# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that +# is part of the input, its contents will be placed on the main page +# (index.html). This can be useful if you have a project on for instance GitHub +# and want to reuse the introduction page also for the Doxygen output. + +USE_MDFILE_AS_MAINPAGE = + +# The Fortran standard specifies that for fixed formatted Fortran code all +# characters from position 72 are to be considered as comment. A common +# extension is to allow longer lines before the automatic comment starts. The +# setting FORTRAN_COMMENT_AFTER will also make it possible that longer lines can +# be processed before the automatic comment starts. +# Minimum value: 7, maximum value: 10000, default value: 72. + +FORTRAN_COMMENT_AFTER = 72 + +#--------------------------------------------------------------------------- +# Configuration options related to source browsing +#--------------------------------------------------------------------------- + +# If the SOURCE_BROWSER tag is set to YES then a list of source files will be +# generated. Documented entities will be cross-referenced with these sources. +# +# Note: To get rid of all source code in the generated output, make sure that +# also VERBATIM_HEADERS is set to NO. +# The default value is: NO. + +SOURCE_BROWSER = NO + +# Setting the INLINE_SOURCES tag to YES will include the body of functions, +# multi-line macros, enums or list initialized variables directly into the +# documentation. +# The default value is: NO. + +INLINE_SOURCES = NO + +# Setting the STRIP_CODE_COMMENTS tag to YES will instruct Doxygen to hide any +# special comment blocks from generated source code fragments. Normal C, C++ and +# Fortran comments will always remain visible. +# The default value is: YES. + +STRIP_CODE_COMMENTS = YES + +# If the REFERENCED_BY_RELATION tag is set to YES then for each documented +# entity all documented functions referencing it will be listed. +# The default value is: NO. + +REFERENCED_BY_RELATION = NO + +# If the REFERENCES_RELATION tag is set to YES then for each documented function +# all documented entities called/used by that function will be listed. +# The default value is: NO. + +REFERENCES_RELATION = NO + +# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set +# to YES then the hyperlinks from functions in REFERENCES_RELATION and +# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will +# link to the documentation. +# The default value is: YES. + +REFERENCES_LINK_SOURCE = YES + +# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the +# source code will show a tooltip with additional information such as prototype, +# brief description and links to the definition and documentation. Since this +# will make the HTML file larger and loading of large files a bit slower, you +# can opt to disable this feature. +# The default value is: YES. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + +SOURCE_TOOLTIPS = YES + +# If the USE_HTAGS tag is set to YES then the references to source code will +# point to the HTML generated by the htags(1) tool instead of Doxygen built-in +# source browser. The htags tool is part of GNU's global source tagging system +# (see https://www.gnu.org/software/global/global.html). You will need version +# 4.8.6 or higher. +# +# To use it do the following: +# - Install the latest version of global +# - Enable SOURCE_BROWSER and USE_HTAGS in the configuration file +# - Make sure the INPUT points to the root of the source tree +# - Run doxygen as normal +# +# Doxygen will invoke htags (and that will in turn invoke gtags), so these +# tools must be available from the command line (i.e. in the search path). +# +# The result: instead of the source browser generated by Doxygen, the links to +# source code will now point to the output of htags. +# The default value is: NO. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + +USE_HTAGS = NO + +# If the VERBATIM_HEADERS tag is set the YES then Doxygen will generate a +# verbatim copy of the header file for each class for which an include is +# specified. Set to NO to disable this. +# See also: Section \class. +# The default value is: YES. + +VERBATIM_HEADERS = YES + +#--------------------------------------------------------------------------- +# Configuration options related to the alphabetical class index +#--------------------------------------------------------------------------- + +# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all +# compounds will be generated. Enable this if the project contains a lot of +# classes, structs, unions or interfaces. +# The default value is: YES. + +ALPHABETICAL_INDEX = YES + +# The IGNORE_PREFIX tag can be used to specify a prefix (or a list of prefixes) +# that should be ignored while generating the index headers. The IGNORE_PREFIX +# tag works for classes, function and member names. The entity will be placed in +# the alphabetical list under the first letter of the entity name that remains +# after removing the prefix. +# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. + +IGNORE_PREFIX = + +#--------------------------------------------------------------------------- +# Configuration options related to the HTML output +#--------------------------------------------------------------------------- + +# If the GENERATE_HTML tag is set to YES, Doxygen will generate HTML output +# The default value is: YES. + +GENERATE_HTML = NO + +# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a +# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of +# it. +# The default directory is: html. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_OUTPUT = html + +# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each +# generated HTML page (for example: .htm, .php, .asp). +# The default value is: .html. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FILE_EXTENSION = .html + +# The HTML_HEADER tag can be used to specify a user-defined HTML header file for +# each generated HTML page. If the tag is left blank Doxygen will generate a +# standard header. +# +# To get valid HTML the header file that includes any scripts and style sheets +# that Doxygen needs, which is dependent on the configuration options used (e.g. +# the setting GENERATE_TREEVIEW). It is highly recommended to start with a +# default header using +# doxygen -w html new_header.html new_footer.html new_stylesheet.css +# YourConfigFile +# and then modify the file new_header.html. See also section "Doxygen usage" +# for information on how to generate the default header that Doxygen normally +# uses. +# Note: The header is subject to change so you typically have to regenerate the +# default header when upgrading to a newer version of Doxygen. For a description +# of the possible markers and block names see the documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_HEADER = + +# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each +# generated HTML page. If the tag is left blank Doxygen will generate a standard +# footer. See HTML_HEADER for more information on how to generate a default +# footer and what special commands can be used inside the footer. See also +# section "Doxygen usage" for information on how to generate the default footer +# that Doxygen normally uses. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FOOTER = + +# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style +# sheet that is used by each HTML page. It can be used to fine-tune the look of +# the HTML output. If left blank Doxygen will generate a default style sheet. +# See also section "Doxygen usage" for information on how to generate the style +# sheet that Doxygen normally uses. +# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as +# it is more robust and this tag (HTML_STYLESHEET) will in the future become +# obsolete. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_STYLESHEET = + +# The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined +# cascading style sheets that are included after the standard style sheets +# created by Doxygen. Using this option one can overrule certain style aspects. +# This is preferred over using HTML_STYLESHEET since it does not replace the +# standard style sheet and is therefore more robust against future updates. +# Doxygen will copy the style sheet files to the output directory. +# Note: The order of the extra style sheet files is of importance (e.g. the last +# style sheet in the list overrules the setting of the previous ones in the +# list). +# Note: Since the styling of scrollbars can currently not be overruled in +# Webkit/Chromium, the styling will be left out of the default doxygen.css if +# one or more extra stylesheets have been specified. So if scrollbar +# customization is desired it has to be added explicitly. For an example see the +# documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_EXTRA_STYLESHEET = + +# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or +# other source files which should be copied to the HTML output directory. Note +# that these files will be copied to the base HTML output directory. Use the +# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these +# files. In the HTML_STYLESHEET file, use the file name only. Also note that the +# files will be copied as-is; there are no commands or markers available. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_EXTRA_FILES = + +# The HTML_COLORSTYLE tag can be used to specify if the generated HTML output +# should be rendered with a dark or light theme. +# Possible values are: LIGHT always generates light mode output, DARK always +# generates dark mode output, AUTO_LIGHT automatically sets the mode according +# to the user preference, uses light mode if no preference is set (the default), +# AUTO_DARK automatically sets the mode according to the user preference, uses +# dark mode if no preference is set and TOGGLE allows a user to switch between +# light and dark mode via a button. +# The default value is: AUTO_LIGHT. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE = AUTO_LIGHT + +# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen +# will adjust the colors in the style sheet and background images according to +# this color. Hue is specified as an angle on a color-wheel, see +# https://en.wikipedia.org/wiki/Hue for more information. For instance the value +# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 +# purple, and 360 is red again. +# Minimum value: 0, maximum value: 359, default value: 220. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_HUE = 220 + +# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors +# in the HTML output. For a value of 0 the output will use gray-scales only. A +# value of 255 will produce the most vivid colors. +# Minimum value: 0, maximum value: 255, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_SAT = 100 + +# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the +# luminance component of the colors in the HTML output. Values below 100 +# gradually make the output lighter, whereas values above 100 make the output +# darker. The value divided by 100 is the actual gamma applied, so 80 represents +# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not +# change the gamma. +# Minimum value: 40, maximum value: 240, default value: 80. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_GAMMA = 80 + +# If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML +# documentation will contain a main index with vertical navigation menus that +# are dynamically created via JavaScript. If disabled, the navigation index will +# consists of multiple levels of tabs that are statically embedded in every HTML +# page. Disable this option to support browsers that do not have JavaScript, +# like the Qt help browser. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_DYNAMIC_MENUS = YES + +# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML +# documentation will contain sections that can be hidden and shown after the +# page has loaded. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_DYNAMIC_SECTIONS = NO + +# If the HTML_CODE_FOLDING tag is set to YES then classes and functions can be +# dynamically folded and expanded in the generated HTML source code. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_CODE_FOLDING = YES + +# If the HTML_COPY_CLIPBOARD tag is set to YES then Doxygen will show an icon in +# the top right corner of code and text fragments that allows the user to copy +# its content to the clipboard. Note this only works if supported by the browser +# and the web page is served via a secure context (see: +# https://www.w3.org/TR/secure-contexts/), i.e. using the https: or file: +# protocol. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COPY_CLIPBOARD = YES + +# Doxygen stores a couple of settings persistently in the browser (via e.g. +# cookies). By default these settings apply to all HTML pages generated by +# Doxygen across all projects. The HTML_PROJECT_COOKIE tag can be used to store +# the settings under a project specific key, such that the user preferences will +# be stored separately. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_PROJECT_COOKIE = + +# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries +# shown in the various tree structured indices initially; the user can expand +# and collapse entries dynamically later on. Doxygen will expand the tree to +# such a level that at most the specified number of entries are visible (unless +# a fully collapsed tree already exceeds this amount). So setting the number of +# entries 1 will produce a full collapsed tree by default. 0 is a special value +# representing an infinite number of entries and will result in a full expanded +# tree by default. +# Minimum value: 0, maximum value: 9999, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_INDEX_NUM_ENTRIES = 100 + +# If the GENERATE_DOCSET tag is set to YES, additional index files will be +# generated that can be used as input for Apple's Xcode 3 integrated development +# environment (see: +# https://developer.apple.com/xcode/), introduced with OSX 10.5 (Leopard). To +# create a documentation set, Doxygen will generate a Makefile in the HTML +# output directory. Running make will produce the docset in that directory and +# running make install will install the docset in +# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at +# startup. See https://developer.apple.com/library/archive/featuredarticles/Doxy +# genXcode/_index.html for more information. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_DOCSET = NO + +# This tag determines the name of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# The default value is: Doxygen generated docs. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDNAME = "Doxygen generated docs" + +# This tag determines the URL of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDURL = + +# This tag specifies a string that should uniquely identify the documentation +# set bundle. This should be a reverse domain-name style string, e.g. +# com.mycompany.MyDocSet. Doxygen will append .docset to the name. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_BUNDLE_ID = org.doxygen.Project + +# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify +# the documentation publisher. This should be a reverse domain-name style +# string, e.g. com.mycompany.MyDocSet.documentation. +# The default value is: org.doxygen.Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_PUBLISHER_ID = org.doxygen.Publisher + +# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher. +# The default value is: Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_PUBLISHER_NAME = Publisher + +# If the GENERATE_HTMLHELP tag is set to YES then Doxygen generates three +# additional HTML index files: index.hhp, index.hhc, and index.hhk. The +# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop +# on Windows. In the beginning of 2021 Microsoft took the original page, with +# a.o. the download links, offline the HTML help workshop was already many years +# in maintenance mode). You can download the HTML help workshop from the web +# archives at Installation executable (see: +# http://web.archive.org/web/20160201063255/http://download.microsoft.com/downlo +# ad/0/A/9/0A939EF6-E31C-430F-A3DF-DFAE7960D564/htmlhelp.exe). +# +# The HTML Help Workshop contains a compiler that can convert all HTML output +# generated by Doxygen into a single compiled HTML file (.chm). Compiled HTML +# files are now used as the Windows 98 help format, and will replace the old +# Windows help format (.hlp) on all Windows platforms in the future. Compressed +# HTML files also contain an index, a table of contents, and you can search for +# words in the documentation. The HTML workshop also contains a viewer for +# compressed HTML files. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_HTMLHELP = NO + +# The CHM_FILE tag can be used to specify the file name of the resulting .chm +# file. You can add a path in front of the file if the result should not be +# written to the html output directory. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +CHM_FILE = + +# The HHC_LOCATION tag can be used to specify the location (absolute path +# including file name) of the HTML help compiler (hhc.exe). If non-empty, +# Doxygen will try to run the HTML help compiler on the generated index.hhp. +# The file has to be specified with full path. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +HHC_LOCATION = + +# The GENERATE_CHI flag controls if a separate .chi index file is generated +# (YES) or that it should be included in the main .chm file (NO). +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +GENERATE_CHI = NO + +# The CHM_INDEX_ENCODING is used to encode HtmlHelp index (hhk), content (hhc) +# and project file content. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +CHM_INDEX_ENCODING = + +# The BINARY_TOC flag controls whether a binary table of contents is generated +# (YES) or a normal table of contents (NO) in the .chm file. Furthermore it +# enables the Previous and Next buttons. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +BINARY_TOC = NO + +# The TOC_EXPAND flag can be set to YES to add extra items for group members to +# the table of contents of the HTML help documentation and to the tree view. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +TOC_EXPAND = NO + +# The SITEMAP_URL tag is used to specify the full URL of the place where the +# generated documentation will be placed on the server by the user during the +# deployment of the documentation. The generated sitemap is called sitemap.xml +# and placed on the directory specified by HTML_OUTPUT. In case no SITEMAP_URL +# is specified no sitemap is generated. For information about the sitemap +# protocol see https://www.sitemaps.org +# This tag requires that the tag GENERATE_HTML is set to YES. + +SITEMAP_URL = + +# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and +# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that +# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help +# (.qch) of the generated HTML documentation. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_QHP = NO + +# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify +# the file name of the resulting .qch file. The path specified is relative to +# the HTML output folder. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QCH_FILE = + +# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help +# Project output. For more information please see Qt Help Project / Namespace +# (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#namespace). +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_NAMESPACE = org.doxygen.Project + +# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt +# Help Project output. For more information please see Qt Help Project / Virtual +# Folders (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#virtual-folders). +# The default value is: doc. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_VIRTUAL_FOLDER = doc + +# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom +# filter to add. For more information please see Qt Help Project / Custom +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_CUST_FILTER_NAME = + +# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the +# custom filter to add. For more information please see Qt Help Project / Custom +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_CUST_FILTER_ATTRS = + +# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this +# project's filter section matches. Qt Help Project / Filter Attributes (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#filter-attributes). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_SECT_FILTER_ATTRS = + +# The QHG_LOCATION tag can be used to specify the location (absolute path +# including file name) of Qt's qhelpgenerator. If non-empty Doxygen will try to +# run qhelpgenerator on the generated .qhp file. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHG_LOCATION = + +# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be +# generated, together with the HTML files, they form an Eclipse help plugin. To +# install this plugin and make it available under the help contents menu in +# Eclipse, the contents of the directory containing the HTML and XML files needs +# to be copied into the plugins directory of eclipse. The name of the directory +# within the plugins directory should be the same as the ECLIPSE_DOC_ID value. +# After copying Eclipse needs to be restarted before the help appears. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_ECLIPSEHELP = NO + +# A unique identifier for the Eclipse help plugin. When installing the plugin +# the directory name containing the HTML and XML files should also have this +# name. Each documentation set should have its own identifier. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES. + +ECLIPSE_DOC_ID = org.doxygen.Project + +# If you want full control over the layout of the generated HTML pages it might +# be necessary to disable the index and replace it with your own. The +# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top +# of each HTML page. A value of NO enables the index and the value YES disables +# it. Since the tabs in the index contain the same information as the navigation +# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +DISABLE_INDEX = NO + +# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index +# structure should be generated to display hierarchical information. If the tag +# value is set to YES, a side panel will be generated containing a tree-like +# index structure (just like the one that is generated for HTML Help). For this +# to work a browser that supports JavaScript, DHTML, CSS and frames is required +# (i.e. any modern browser). Windows users are probably better off using the +# HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can +# further fine tune the look of the index (see "Fine-tuning the output"). As an +# example, the default style sheet generated by Doxygen has an example that +# shows how to put an image at the root of the tree instead of the PROJECT_NAME. +# Since the tree basically has the same information as the tab index, you could +# consider setting DISABLE_INDEX to YES when enabling this option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_TREEVIEW = NO + +# When both GENERATE_TREEVIEW and DISABLE_INDEX are set to YES, then the +# FULL_SIDEBAR option determines if the side bar is limited to only the treeview +# area (value NO) or if it should extend to the full height of the window (value +# YES). Setting this to YES gives a layout similar to +# https://docs.readthedocs.io with more room for contents, but less room for the +# project logo, title, and description. If either GENERATE_TREEVIEW or +# DISABLE_INDEX is set to NO, this option has no effect. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FULL_SIDEBAR = NO + +# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that +# Doxygen will group on one line in the generated HTML documentation. +# +# Note that a value of 0 will completely suppress the enum values from appearing +# in the overview section. +# Minimum value: 0, maximum value: 20, default value: 4. +# This tag requires that the tag GENERATE_HTML is set to YES. + +ENUM_VALUES_PER_LINE = 4 + +# When the SHOW_ENUM_VALUES tag is set doxygen will show the specified +# enumeration values besides the enumeration mnemonics. +# The default value is: NO. + +SHOW_ENUM_VALUES = NO + +# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used +# to set the initial width (in pixels) of the frame in which the tree is shown. +# Minimum value: 0, maximum value: 1500, default value: 250. +# This tag requires that the tag GENERATE_HTML is set to YES. + +TREEVIEW_WIDTH = 250 + +# If the EXT_LINKS_IN_WINDOW option is set to YES, Doxygen will open links to +# external symbols imported via tag files in a separate window. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +EXT_LINKS_IN_WINDOW = NO + +# If the OBFUSCATE_EMAILS tag is set to YES, Doxygen will obfuscate email +# addresses. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +OBFUSCATE_EMAILS = YES + +# If the HTML_FORMULA_FORMAT option is set to svg, Doxygen will use the pdf2svg +# tool (see https://github.com/dawbarton/pdf2svg) or inkscape (see +# https://inkscape.org) to generate formulas as SVG images instead of PNGs for +# the HTML output. These images will generally look nicer at scaled resolutions. +# Possible values are: png (the default) and svg (looks nicer but requires the +# pdf2svg or inkscape tool). +# The default value is: png. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FORMULA_FORMAT = png + +# Use this tag to change the font size of LaTeX formulas included as images in +# the HTML documentation. When you change the font size after a successful +# Doxygen run you need to manually remove any form_*.png images from the HTML +# output directory to force them to be regenerated. +# Minimum value: 8, maximum value: 50, default value: 10. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FORMULA_FONTSIZE = 10 + +# The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands +# to create new LaTeX commands to be used in formulas as building blocks. See +# the section "Including formulas" for details. + +FORMULA_MACROFILE = + +# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see +# https://www.mathjax.org) which uses client side JavaScript for the rendering +# instead of using pre-rendered bitmaps. Use this if you do not have LaTeX +# installed or if you want to formulas look prettier in the HTML output. When +# enabled you may also need to install MathJax separately and configure the path +# to it using the MATHJAX_RELPATH option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +USE_MATHJAX = NO + +# With MATHJAX_VERSION it is possible to specify the MathJax version to be used. +# Note that the different versions of MathJax have different requirements with +# regards to the different settings, so it is possible that also other MathJax +# settings have to be changed when switching between the different MathJax +# versions. +# Possible values are: MathJax_2 and MathJax_3. +# The default value is: MathJax_2. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_VERSION = MathJax_2 + +# When MathJax is enabled you can set the default output format to be used for +# the MathJax output. For more details about the output format see MathJax +# version 2 (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) and MathJax version 3 +# (see: +# http://docs.mathjax.org/en/latest/web/components/output.html). +# Possible values are: HTML-CSS (which is slower, but has the best +# compatibility. This is the name for Mathjax version 2, for MathJax version 3 +# this will be translated into chtml), NativeMML (i.e. MathML. Only supported +# for MathJax 2. For MathJax version 3 chtml will be used instead.), chtml (This +# is the name for Mathjax version 3, for MathJax version 2 this will be +# translated into HTML-CSS) and SVG. +# The default value is: HTML-CSS. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_FORMAT = HTML-CSS + +# When MathJax is enabled you need to specify the location relative to the HTML +# output directory using the MATHJAX_RELPATH option. The destination directory +# should contain the MathJax.js script. For instance, if the mathjax directory +# is located at the same level as the HTML output directory, then +# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax +# Content Delivery Network so you can quickly see the result without installing +# MathJax. However, it is strongly recommended to install a local copy of +# MathJax from https://www.mathjax.org before deployment. The default value is: +# - in case of MathJax version 2: https://cdn.jsdelivr.net/npm/mathjax@2 +# - in case of MathJax version 3: https://cdn.jsdelivr.net/npm/mathjax@3 +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_RELPATH = + +# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax +# extension names that should be enabled during MathJax rendering. For example +# for MathJax version 2 (see +# https://docs.mathjax.org/en/v2.7-latest/tex.html#tex-and-latex-extensions): +# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols +# For example for MathJax version 3 (see +# http://docs.mathjax.org/en/latest/input/tex/extensions/index.html): +# MATHJAX_EXTENSIONS = ams +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_EXTENSIONS = + +# The MATHJAX_CODEFILE tag can be used to specify a file with JavaScript pieces +# of code that will be used on startup of the MathJax code. See the MathJax site +# (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) for more details. For an +# example see the documentation. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_CODEFILE = + +# When the SEARCHENGINE tag is enabled Doxygen will generate a search box for +# the HTML output. The underlying search engine uses JavaScript and DHTML and +# should work on any modern browser. Note that when using HTML help +# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET) +# there is already a search function so this one should typically be disabled. +# For large projects the JavaScript based search engine can be slow, then +# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to +# search using the keyboard; to jump to the search box use + S +# (what the is depends on the OS and browser, but it is typically +# , /