Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,7 @@ convertBARTModelToJson <- function(object){
}

# Add global parameters
jsonobj$add_scalar("variance_scale", object$model_params$variance_scale)
jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale)
jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean)
jsonobj$add_boolean("sample_sigma_global", object$model_params$sample_sigma_global)
Expand Down Expand Up @@ -1093,6 +1094,7 @@ createBARTModelFromJson <- function(json_object){

# Unpack model params
model_params = list()
model_params[["variance_scale"]] <- json_object$get_scalar("variance_scale")
model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale")
model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean")
model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global")
Expand Down Expand Up @@ -1437,10 +1439,11 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){

# Unpack model params
model_params = list()
model_params[["variance_scale"]] <- json_object_default$get_scalar("variance_scale")
model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale")
model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean")
model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global")
model_params[["sample_sigma_leaf"]] <- json_object$get_boolean("sample_sigma_leaf")
model_params[["sample_sigma_global"]] <- json_object_default$get_boolean("sample_sigma_global")
model_params[["sample_sigma_leaf"]] <- json_object_default$get_boolean("sample_sigma_leaf")
model_params[["include_mean_forest"]] <- include_mean_forest
model_params[["include_variance_forest"]] <- include_variance_forest
model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx")
Expand Down
95 changes: 95 additions & 0 deletions demo/debug/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import numpy as np
from stochtree import (
BARTModel, JSONSerializer, ForestContainer, Dataset, Residual,
RNG, ForestSampler, ForestContainer, GlobalVarianceModel
)

# RNG
random_seed = 1234
rng = np.random.default_rng(random_seed)

# Generate covariates and basis
n = 1000
p_X = 10
p_W = 1
X = rng.uniform(0, 1, (n, p_X))
W = rng.uniform(0, 1, (n, p_W))

# Define the outcome mean function
def outcome_mean(X, W):
return np.where(
(X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * W[:,0],
np.where(
(X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * W[:,0],
np.where(
(X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * W[:,0],
7.5 * W[:,0]
)
)
)

# Generate outcome
epsilon = rng.normal(0, 1, n)
y = outcome_mean(X, W) + epsilon

# Standardize outcome
y_bar = np.mean(y)
y_std = np.std(y)
resid = (y-y_bar)/y_std

# Sampler parameters
alpha = 0.9
beta = 1.25
min_samples_leaf = 1
num_trees = 100
cutpoint_grid_size = 100
global_variance_init = 1.
tau_init = 0.5
leaf_prior_scale = np.array([[tau_init]], order='C')
a_global = 4.
b_global = 2.
a_leaf = 2.
b_leaf = 0.5
leaf_regression = True
feature_types = np.repeat(0, p_X).astype(int) # 0 = numeric
var_weights = np.repeat(1/p_X, p_X)

# Dataset (covariates and basis)
dataset = Dataset()
dataset.add_covariates(X)
dataset.add_basis(W)

# Residual
residual = Residual(resid)

# Forest samplers and temporary tracking data structures
forest_container = ForestContainer(num_trees, W.shape[1], False, False)
forest_sampler = ForestSampler(dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf)
cpp_rng = RNG(random_seed)
global_var_model = GlobalVarianceModel()

# Prepare to run sampler
num_warmstart = 10
num_mcmc = 100
num_samples = num_warmstart + num_mcmc
global_var_samples = np.concatenate((np.array([global_variance_init]), np.repeat(0, num_samples)))

# Run "grow-from-root" sampler
for i in range(num_warmstart):
forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, 1., 1., global_var_samples[i], 1, True, False)
global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global)

# Run MCMC sampler
for i in range(num_warmstart, num_samples):
forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, 1., 1., global_var_samples[i], 1, False, False)
global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global)

# Extract predictions from the sampler
y_hat_orig = forest_container.predict(dataset)

# "Round-trip" the forest to JSON string and back and check that the predictions agree
forest_json_string = forest_container.dump_json_string()
forest_container_reloaded = ForestContainer(num_trees, W.shape[1], False, False)
forest_container_reloaded.load_from_json_string(forest_json_string)
y_hat_reloaded = forest_container_reloaded.predict(dataset)
np.testing.assert_approx_equal(y_hat_orig, y_hat_reloaded)
11 changes: 11 additions & 0 deletions include/stochtree/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,17 @@ class ForestContainer {
this->from_json(file_tree_json);
}

std::string DumpJsonString() {
nlohmann::json model_json = this->to_json();
return model_json.dump();
}

void LoadFromJsonString(std::string& json_string) {
nlohmann::json file_tree_json = nlohmann::json::parse(json_string);
this->Reset();
this->from_json(file_tree_json);
}

void Reset() {
forests_.clear();
num_samples_ = 0;
Expand Down
10 changes: 10 additions & 0 deletions src/py_stochtree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,14 @@ class ForestContainerCpp {

void LoadFromJson(JsonCpp& json, std::string forest_label);

std::string DumpJsonString() {
return forest_samples_->DumpJsonString();
}

void LoadFromJsonString(std::string& json_string) {
forest_samples_->LoadFromJsonString(json_string);
}

StochTree::ForestContainer* GetContainer() {
return forest_samples_.get();
}
Expand Down Expand Up @@ -973,6 +981,8 @@ PYBIND11_MODULE(stochtree_cpp, m) {
.def("SaveToJsonFile", &ForestContainerCpp::SaveToJsonFile)
.def("LoadFromJsonFile", &ForestContainerCpp::LoadFromJsonFile)
.def("LoadFromJson", &ForestContainerCpp::LoadFromJson)
.def("DumpJsonString", &ForestContainerCpp::DumpJsonString)
.def("LoadFromJsonString", &ForestContainerCpp::LoadFromJsonString)
.def("AddSampleValue", &ForestContainerCpp::AddSampleValue)
.def("AddSampleVector", &ForestContainerCpp::AddSampleVector)
.def("AddNumericSplitValue", &ForestContainerCpp::AddNumericSplitValue)
Expand Down
6 changes: 6 additions & 0 deletions stochtree/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def save_to_json_file(self, json_filename: str) -> None:

def load_from_json_file(self, json_filename: str) -> None:
self.forest_container_cpp.LoadFromJsonFile(json_filename)

def dump_json_string(self) -> str:
return self.forest_container_cpp.DumpJsonString()

def load_from_json_string(self, json_string: str) -> None:
self.forest_container_cpp.LoadFromJsonString(json_string)

def add_sample(self, leaf_value: Union[float, np.array]) -> None:
"""
Expand Down
97 changes: 95 additions & 2 deletions test/python/test_json.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
from stochtree import BARTModel, JSONSerializer, ForestContainer, Dataset
from stochtree import (
BARTModel, JSONSerializer, ForestContainer, Dataset, Residual,
RNG, ForestSampler, ForestContainer, GlobalVarianceModel
)

class TestJson:
def test_value(self):
Expand Down Expand Up @@ -68,4 +71,94 @@ def outcome_mean(X):
# Check the predictions
np.testing.assert_almost_equal(forest_preds_y_mcmc_cached, forest_preds_json_reload)
np.testing.assert_almost_equal(forest_preds_y_mcmc_retrieved, forest_preds_json_reload)


def test_forest_string(self):
# RNG
random_seed = 1234
rng = np.random.default_rng(random_seed)

# Generate covariates and basis
n = 1000
p_X = 10
p_W = 1
X = rng.uniform(0, 1, (n, p_X))
W = rng.uniform(0, 1, (n, p_W))

# Define the outcome mean function
def outcome_mean(X, W):
return np.where(
(X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * W[:,0],
np.where(
(X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * W[:,0],
np.where(
(X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * W[:,0],
7.5 * W[:,0]
)
)
)

# Generate outcome
epsilon = rng.normal(0, 1, n)
y = outcome_mean(X, W) + epsilon

# Standardize outcome
y_bar = np.mean(y)
y_std = np.std(y)
resid = (y-y_bar)/y_std

# Sampler parameters
alpha = 0.9
beta = 1.25
min_samples_leaf = 1
num_trees = 100
cutpoint_grid_size = 100
global_variance_init = 1.
tau_init = 0.5
leaf_prior_scale = np.array([[tau_init]], order='C')
a_global = 4.
b_global = 2.
a_leaf = 2.
b_leaf = 0.5
leaf_regression = True
feature_types = np.repeat(0, p_X).astype(int) # 0 = numeric
var_weights = np.repeat(1/p_X, p_X)

# Dataset (covariates and basis)
dataset = Dataset()
dataset.add_covariates(X)
dataset.add_basis(W)

# Residual
residual = Residual(resid)

# Forest samplers and temporary tracking data structures
forest_container = ForestContainer(num_trees, W.shape[1], False, False)
forest_sampler = ForestSampler(dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf)
cpp_rng = RNG(random_seed)
global_var_model = GlobalVarianceModel()

# Prepare to run sampler
num_warmstart = 10
num_mcmc = 100
num_samples = num_warmstart + num_mcmc
global_var_samples = np.concatenate((np.array([global_variance_init]), np.repeat(0, num_samples)))

# Run "grow-from-root" sampler
for i in range(num_warmstart):
forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, 1., 1., global_var_samples[i], 1, True, False)
global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global)

# Run MCMC sampler
for i in range(num_warmstart, num_samples):
forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, 1., 1., global_var_samples[i], 1, False, False)
global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global)

# Extract predictions from the sampler
y_hat_orig = forest_container.predict(dataset)

# "Round-trip" the forest to JSON string and back and check that the predictions agree
forest_json_string = forest_container.dump_json_string()
forest_container_reloaded = ForestContainer(num_trees, W.shape[1], False, False)
forest_container_reloaded.load_from_json_string(forest_json_string)
y_hat_reloaded = forest_container_reloaded.predict(dataset)
np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded)
30 changes: 26 additions & 4 deletions vignettes/MultiChain.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ The first step of this process is to run the sampler in parallel,
storing the resulting BART JSON strings in a list.

```{r}
bart_model_strings <- foreach (i = 1:num_chains) %dopar% {
bart_model_outputs <- foreach (i = 1:num_chains) %dopar% {
random_seed <- i
bart_model <- stochtree::bart(
X_train = X_train, W_train = W_train, y_train = y_train,
Expand All @@ -186,7 +186,8 @@ bart_model_strings <- foreach (i = 1:num_chains) %dopar% {
sample_sigma_leaf = T, random_seed = random_seed
)
bart_model_string <- stochtree::saveBARTModelToJsonString(bart_model)
bart_model_string
y_hat_test <- bart_model$y_hat_test
list(model=bart_model_string, yhat=y_hat_test)
}
```

Expand All @@ -200,6 +201,12 @@ Now, if we want to combine the forests from each of these BART models into a
single forest, we can do so as follows

```{r}
bart_model_strings <- list()
bart_model_yhats <- matrix(NA, nrow = length(y_test), ncol = num_chains)
for (i in 1:length(bart_model_outputs)) {
bart_model_strings[[i]] <- bart_model_outputs[[i]]$model
bart_model_yhats[,i] <- rowMeans(bart_model_outputs[[i]]$yhat)
}
combined_bart <- createBARTModelFromCombinedJsonString(bart_model_strings)
```

Expand All @@ -209,8 +216,23 @@ We can predict from this combined forest as follows
yhat_combined <- predict(combined_bart, X_test, W_test)$y_hat
```

Since we don't have access to the original $\hat{y}$ values, we instead
compare average predictions from each chain to the true $y$ values.
Compare average predictions from each chain to the original predictions.

```{r}
par(mfrow = c(1,2))
for (i in 1:num_chains) {
offset <- (i-1)*num_mcmc
inds_start <- offset + 1
inds_end <- offset + num_mcmc
plot(rowMeans(yhat_combined[,inds_start:inds_end]), bart_model_yhats[,i],
xlab = "deserialized", ylab = "original",
main = paste0("Chain ", i, "\nPredictions"))
abline(0,1,col="red",lty=3,lwd=3)
}
par(mfrow = c(1,1))
```

And to the true $y$ values.

```{r}
par(mfrow = c(1,2))
Expand Down
Loading