From 9aca63e981aab1fdcfd6255976b5e9d261e5a031 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 7 Oct 2024 18:19:08 -0500 Subject: [PATCH 1/2] Updated R serialization interface --- R/bart.R | 7 +++++-- vignettes/MultiChain.Rmd | 30 ++++++++++++++++++++++++++---- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/R/bart.R b/R/bart.R index 1c452bb0..8c96c3a5 100644 --- a/R/bart.R +++ b/R/bart.R @@ -915,6 +915,7 @@ convertBARTModelToJson <- function(object){ } # Add global parameters + jsonobj$add_scalar("variance_scale", object$model_params$variance_scale) jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) jsonobj$add_boolean("sample_sigma_global", object$model_params$sample_sigma_global) @@ -1093,6 +1094,7 @@ createBARTModelFromJson <- function(json_object){ # Unpack model params model_params = list() + model_params[["variance_scale"]] <- json_object$get_scalar("variance_scale") model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global") @@ -1437,10 +1439,11 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){ # Unpack model params model_params = list() + model_params[["variance_scale"]] <- json_object_default$get_scalar("variance_scale") model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale") model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") - model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global") - model_params[["sample_sigma_leaf"]] <- json_object$get_boolean("sample_sigma_leaf") + model_params[["sample_sigma_global"]] <- json_object_default$get_boolean("sample_sigma_global") + model_params[["sample_sigma_leaf"]] <- json_object_default$get_boolean("sample_sigma_leaf") model_params[["include_mean_forest"]] <- include_mean_forest model_params[["include_variance_forest"]] <- include_variance_forest model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") diff --git a/vignettes/MultiChain.Rmd b/vignettes/MultiChain.Rmd index 65752ae4..2e679b6b 100644 --- a/vignettes/MultiChain.Rmd +++ b/vignettes/MultiChain.Rmd @@ -176,7 +176,7 @@ The first step of this process is to run the sampler in parallel, storing the resulting BART JSON strings in a list. ```{r} -bart_model_strings <- foreach (i = 1:num_chains) %dopar% { +bart_model_outputs <- foreach (i = 1:num_chains) %dopar% { random_seed <- i bart_model <- stochtree::bart( X_train = X_train, W_train = W_train, y_train = y_train, @@ -186,7 +186,8 @@ bart_model_strings <- foreach (i = 1:num_chains) %dopar% { sample_sigma_leaf = T, random_seed = random_seed ) bart_model_string <- stochtree::saveBARTModelToJsonString(bart_model) - bart_model_string + y_hat_test <- bart_model$y_hat_test + list(model=bart_model_string, yhat=y_hat_test) } ``` @@ -200,6 +201,12 @@ Now, if we want to combine the forests from each of these BART models into a single forest, we can do so as follows ```{r} +bart_model_strings <- list() +bart_model_yhats <- matrix(NA, nrow = length(y_test), ncol = num_chains) +for (i in 1:length(bart_model_outputs)) { + bart_model_strings[[i]] <- bart_model_outputs[[i]]$model + bart_model_yhats[,i] <- rowMeans(bart_model_outputs[[i]]$yhat) +} combined_bart <- createBARTModelFromCombinedJsonString(bart_model_strings) ``` @@ -209,8 +216,23 @@ We can predict from this combined forest as follows yhat_combined <- predict(combined_bart, X_test, W_test)$y_hat ``` -Since we don't have access to the original $\hat{y}$ values, we instead -compare average predictions from each chain to the true $y$ values. +Compare average predictions from each chain to the original predictions. + +```{r} +par(mfrow = c(1,2)) +for (i in 1:num_chains) { + offset <- (i-1)*num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot(rowMeans(yhat_combined[,inds_start:inds_end]), bart_model_yhats[,i], + xlab = "deserialized", ylab = "original", + main = paste0("Chain ", i, "\nPredictions")) + abline(0,1,col="red",lty=3,lwd=3) +} +par(mfrow = c(1,1)) +``` + +And to the true $y$ values. ```{r} par(mfrow = c(1,2)) From 7abc3d2637b10eefb131a4d86307d512c1f21e82 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 8 Oct 2024 01:32:36 -0500 Subject: [PATCH 2/2] Updated python serialization interface --- demo/debug/serialization.py | 95 ++++++++++++++++++++++++++++++++++ include/stochtree/container.h | 11 ++++ src/py_stochtree.cpp | 10 ++++ stochtree/forest.py | 6 +++ test/python/test_json.py | 97 ++++++++++++++++++++++++++++++++++- 5 files changed, 217 insertions(+), 2 deletions(-) create mode 100644 demo/debug/serialization.py diff --git a/demo/debug/serialization.py b/demo/debug/serialization.py new file mode 100644 index 00000000..7aaae1b6 --- /dev/null +++ b/demo/debug/serialization.py @@ -0,0 +1,95 @@ +import numpy as np +from stochtree import ( + BARTModel, JSONSerializer, ForestContainer, Dataset, Residual, + RNG, ForestSampler, ForestContainer, GlobalVarianceModel +) + +# RNG +random_seed = 1234 +rng = np.random.default_rng(random_seed) + +# Generate covariates and basis +n = 1000 +p_X = 10 +p_W = 1 +X = rng.uniform(0, 1, (n, p_X)) +W = rng.uniform(0, 1, (n, p_W)) + +# Define the outcome mean function +def outcome_mean(X, W): + return np.where( + (X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * W[:,0], + np.where( + (X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * W[:,0], + np.where( + (X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * W[:,0], + 7.5 * W[:,0] + ) + ) + ) + +# Generate outcome +epsilon = rng.normal(0, 1, n) +y = outcome_mean(X, W) + epsilon + +# Standardize outcome +y_bar = np.mean(y) +y_std = np.std(y) +resid = (y-y_bar)/y_std + +# Sampler parameters +alpha = 0.9 +beta = 1.25 +min_samples_leaf = 1 +num_trees = 100 +cutpoint_grid_size = 100 +global_variance_init = 1. +tau_init = 0.5 +leaf_prior_scale = np.array([[tau_init]], order='C') +a_global = 4. +b_global = 2. +a_leaf = 2. +b_leaf = 0.5 +leaf_regression = True +feature_types = np.repeat(0, p_X).astype(int) # 0 = numeric +var_weights = np.repeat(1/p_X, p_X) + +# Dataset (covariates and basis) +dataset = Dataset() +dataset.add_covariates(X) +dataset.add_basis(W) + +# Residual +residual = Residual(resid) + +# Forest samplers and temporary tracking data structures +forest_container = ForestContainer(num_trees, W.shape[1], False, False) +forest_sampler = ForestSampler(dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf) +cpp_rng = RNG(random_seed) +global_var_model = GlobalVarianceModel() + +# Prepare to run sampler +num_warmstart = 10 +num_mcmc = 100 +num_samples = num_warmstart + num_mcmc +global_var_samples = np.concatenate((np.array([global_variance_init]), np.repeat(0, num_samples))) + +# Run "grow-from-root" sampler +for i in range(num_warmstart): + forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, 1., 1., global_var_samples[i], 1, True, False) + global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global) + +# Run MCMC sampler +for i in range(num_warmstart, num_samples): + forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, 1., 1., global_var_samples[i], 1, False, False) + global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global) + +# Extract predictions from the sampler +y_hat_orig = forest_container.predict(dataset) + +# "Round-trip" the forest to JSON string and back and check that the predictions agree +forest_json_string = forest_container.dump_json_string() +forest_container_reloaded = ForestContainer(num_trees, W.shape[1], False, False) +forest_container_reloaded.load_from_json_string(forest_json_string) +y_hat_reloaded = forest_container_reloaded.predict(dataset) +np.testing.assert_approx_equal(y_hat_orig, y_hat_reloaded) diff --git a/include/stochtree/container.h b/include/stochtree/container.h index 874f77ff..7b7719da 100644 --- a/include/stochtree/container.h +++ b/include/stochtree/container.h @@ -77,6 +77,17 @@ class ForestContainer { this->from_json(file_tree_json); } + std::string DumpJsonString() { + nlohmann::json model_json = this->to_json(); + return model_json.dump(); + } + + void LoadFromJsonString(std::string& json_string) { + nlohmann::json file_tree_json = nlohmann::json::parse(json_string); + this->Reset(); + this->from_json(file_tree_json); + } + void Reset() { forests_.clear(); num_samples_ = 0; diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 35c86b57..add3ed8a 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -260,6 +260,14 @@ class ForestContainerCpp { void LoadFromJson(JsonCpp& json, std::string forest_label); + std::string DumpJsonString() { + return forest_samples_->DumpJsonString(); + } + + void LoadFromJsonString(std::string& json_string) { + forest_samples_->LoadFromJsonString(json_string); + } + StochTree::ForestContainer* GetContainer() { return forest_samples_.get(); } @@ -973,6 +981,8 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("SaveToJsonFile", &ForestContainerCpp::SaveToJsonFile) .def("LoadFromJsonFile", &ForestContainerCpp::LoadFromJsonFile) .def("LoadFromJson", &ForestContainerCpp::LoadFromJson) + .def("DumpJsonString", &ForestContainerCpp::DumpJsonString) + .def("LoadFromJsonString", &ForestContainerCpp::LoadFromJsonString) .def("AddSampleValue", &ForestContainerCpp::AddSampleValue) .def("AddSampleVector", &ForestContainerCpp::AddSampleVector) .def("AddNumericSplitValue", &ForestContainerCpp::AddNumericSplitValue) diff --git a/stochtree/forest.py b/stochtree/forest.py index eb844d64..1a921330 100644 --- a/stochtree/forest.py +++ b/stochtree/forest.py @@ -45,6 +45,12 @@ def save_to_json_file(self, json_filename: str) -> None: def load_from_json_file(self, json_filename: str) -> None: self.forest_container_cpp.LoadFromJsonFile(json_filename) + + def dump_json_string(self) -> str: + return self.forest_container_cpp.DumpJsonString() + + def load_from_json_string(self, json_string: str) -> None: + self.forest_container_cpp.LoadFromJsonString(json_string) def add_sample(self, leaf_value: Union[float, np.array]) -> None: """ diff --git a/test/python/test_json.py b/test/python/test_json.py index 8a187709..71cdc563 100644 --- a/test/python/test_json.py +++ b/test/python/test_json.py @@ -1,5 +1,8 @@ import numpy as np -from stochtree import BARTModel, JSONSerializer, ForestContainer, Dataset +from stochtree import ( + BARTModel, JSONSerializer, ForestContainer, Dataset, Residual, + RNG, ForestSampler, ForestContainer, GlobalVarianceModel +) class TestJson: def test_value(self): @@ -68,4 +71,94 @@ def outcome_mean(X): # Check the predictions np.testing.assert_almost_equal(forest_preds_y_mcmc_cached, forest_preds_json_reload) np.testing.assert_almost_equal(forest_preds_y_mcmc_retrieved, forest_preds_json_reload) - \ No newline at end of file + + def test_forest_string(self): + # RNG + random_seed = 1234 + rng = np.random.default_rng(random_seed) + + # Generate covariates and basis + n = 1000 + p_X = 10 + p_W = 1 + X = rng.uniform(0, 1, (n, p_X)) + W = rng.uniform(0, 1, (n, p_W)) + + # Define the outcome mean function + def outcome_mean(X, W): + return np.where( + (X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * W[:,0], + np.where( + (X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * W[:,0], + np.where( + (X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * W[:,0], + 7.5 * W[:,0] + ) + ) + ) + + # Generate outcome + epsilon = rng.normal(0, 1, n) + y = outcome_mean(X, W) + epsilon + + # Standardize outcome + y_bar = np.mean(y) + y_std = np.std(y) + resid = (y-y_bar)/y_std + + # Sampler parameters + alpha = 0.9 + beta = 1.25 + min_samples_leaf = 1 + num_trees = 100 + cutpoint_grid_size = 100 + global_variance_init = 1. + tau_init = 0.5 + leaf_prior_scale = np.array([[tau_init]], order='C') + a_global = 4. + b_global = 2. + a_leaf = 2. + b_leaf = 0.5 + leaf_regression = True + feature_types = np.repeat(0, p_X).astype(int) # 0 = numeric + var_weights = np.repeat(1/p_X, p_X) + + # Dataset (covariates and basis) + dataset = Dataset() + dataset.add_covariates(X) + dataset.add_basis(W) + + # Residual + residual = Residual(resid) + + # Forest samplers and temporary tracking data structures + forest_container = ForestContainer(num_trees, W.shape[1], False, False) + forest_sampler = ForestSampler(dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf) + cpp_rng = RNG(random_seed) + global_var_model = GlobalVarianceModel() + + # Prepare to run sampler + num_warmstart = 10 + num_mcmc = 100 + num_samples = num_warmstart + num_mcmc + global_var_samples = np.concatenate((np.array([global_variance_init]), np.repeat(0, num_samples))) + + # Run "grow-from-root" sampler + for i in range(num_warmstart): + forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, 1., 1., global_var_samples[i], 1, True, False) + global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global) + + # Run MCMC sampler + for i in range(num_warmstart, num_samples): + forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, 1., 1., global_var_samples[i], 1, False, False) + global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global) + + # Extract predictions from the sampler + y_hat_orig = forest_container.predict(dataset) + + # "Round-trip" the forest to JSON string and back and check that the predictions agree + forest_json_string = forest_container.dump_json_string() + forest_container_reloaded = ForestContainer(num_trees, W.shape[1], False, False) + forest_container_reloaded.load_from_json_string(forest_json_string) + y_hat_reloaded = forest_container_reloaded.predict(dataset) + np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded)