diff --git a/R/bart.R b/R/bart.R
index 1b28ee83..f14b7501 100644
--- a/R/bart.R
+++ b/R/bart.R
@@ -707,6 +707,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
num_retained_samples <- num_gfr + ifelse(keep_burnin, num_burnin, 0) + num_mcmc * num_chains
if (sample_sigma2_global) global_var_samples <- rep(NA, num_retained_samples)
if (sample_sigma2_leaf) leaf_scale_samples <- rep(NA, num_retained_samples)
+ if (include_mean_forest) mean_forest_pred_train <- matrix(NA_real_, nrow(X_train), num_retained_samples)
+ if (include_variance_forest) variance_forest_pred_train <- matrix(NA_real_, nrow(X_train), num_retained_samples)
sample_counter <- 0
# Initialize the leaves of each tree in the mean forest
@@ -757,6 +759,11 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
)
+
+ # Cache train set predictions since they are already computed during sampling
+ if (keep_sample) {
+ mean_forest_pred_train[,sample_counter] <- forest_model_mean$get_cached_forest_predictions()
+ }
}
if (include_variance_forest) {
forest_model_variance$sample_one_iteration(
@@ -764,6 +771,11 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
)
+
+ # Cache train set predictions since they are already computed during sampling
+ if (keep_sample) {
+ variance_forest_pred_train[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
+ }
}
if (sample_sigma2_global) {
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -910,6 +922,11 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
)
+
+ # Cache train set predictions since they are already computed during sampling
+ if (keep_sample) {
+ mean_forest_pred_train[,sample_counter] <- forest_model_mean$get_cached_forest_predictions()
+ }
}
if (include_variance_forest) {
forest_model_variance$sample_one_iteration(
@@ -917,6 +934,11 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
)
+
+ # Cache train set predictions since they are already computed during sampling
+ if (keep_sample) {
+ variance_forest_pred_train[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
+ }
}
if (sample_sigma2_global) {
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -949,6 +971,12 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
rfx_samples$delete_sample(0)
}
}
+ if (include_mean_forest) {
+ mean_forest_pred_train <- mean_forest_pred_train[,(num_gfr+1):ncol(mean_forest_pred_train)]
+ }
+ if (include_variance_forest) {
+ variance_forest_pred_train <- variance_forest_pred_train[,(num_gfr+1):ncol(variance_forest_pred_train)]
+ }
if (sample_sigma2_global) {
global_var_samples <- global_var_samples[(num_gfr+1):length(global_var_samples)]
}
@@ -960,13 +988,15 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
# Mean forest predictions
if (include_mean_forest) {
- y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train
+ # y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train
+ y_hat_train <- mean_forest_pred_train*y_std_train + y_bar_train
if (has_test) y_hat_test <- forest_samples_mean$predict(forest_dataset_test)*y_std_train + y_bar_train
}
# Variance forest predictions
if (include_variance_forest) {
- sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train)
+ # sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train)
+ sigma2_x_hat_train <- exp(variance_forest_pred_train)
if (has_test) sigma2_x_hat_test <- forest_samples_variance$predict(forest_dataset_test)
}
diff --git a/R/bcf.R b/R/bcf.R
index 11696f79..85cbcb20 100644
--- a/R/bcf.R
+++ b/R/bcf.R
@@ -885,6 +885,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
if (sample_sigma2_global) global_var_samples <- rep(NA, num_retained_samples)
if (sample_sigma2_leaf_mu) leaf_scale_mu_samples <- rep(NA, num_retained_samples)
if (sample_sigma2_leaf_tau) leaf_scale_tau_samples <- rep(NA, num_retained_samples)
+ muhat_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples)
+ if (include_variance_forest) sigma2_x_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples)
sample_counter <- 0
# Prepare adaptive coding structure
@@ -997,6 +999,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
)
+ # Cache train set predictions since they are already computed during sampling
+ if (keep_sample) {
+ muhat_train_raw[,sample_counter] <- forest_model_mu$get_cached_forest_predictions()
+ }
+
# Sample variance parameters (if requested)
if (sample_sigma2_global) {
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -1016,6 +1023,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
)
+ # Cannot cache train set predictions for tau because the cached predictions in the
+ # tracking data structures are pre-multiplied by the basis (treatment)
+ # ...
+
# Sample coding parameters (if requested)
if (adaptive_coding) {
# Estimate mu(X) and tau(X) and compute y - mu(X)
@@ -1060,6 +1071,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
)
+
+ # Cache train set predictions since they are already computed during sampling
+ if (keep_sample) {
+ sigma2_x_train_raw[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
+ }
}
if (sample_sigma2_global) {
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -1263,6 +1279,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
)
+ # Cache train set predictions since they are already computed during sampling
+ if (keep_sample) {
+ muhat_train_raw[,sample_counter] <- forest_model_mu$get_cached_forest_predictions()
+ }
+
# Sample variance parameters (if requested)
if (sample_sigma2_global) {
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -1282,6 +1303,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
)
+ # Cannot cache train set predictions for tau because the cached predictions in the
+ # tracking data structures are pre-multiplied by the basis (treatment)
+ # ...
+
# Sample coding parameters (if requested)
if (adaptive_coding) {
# Estimate mu(X) and tau(X) and compute y - mu(X)
@@ -1326,6 +1351,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
)
+
+ # Cache train set predictions since they are already computed during sampling
+ if (keep_sample) {
+ sigma2_x_train_raw[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
+ }
}
if (sample_sigma2_global) {
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -1372,11 +1402,15 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
b_1_samples <- b_1_samples[(num_gfr+1):length(b_1_samples)]
b_0_samples <- b_0_samples[(num_gfr+1):length(b_0_samples)]
}
+ muhat_train_raw <- muhat_train_raw[,(num_gfr+1):ncol(muhat_train_raw)]
+ if (include_variance_forest) {
+ sigma2_x_train_raw <- sigma2_x_train_raw[,(num_gfr+1):ncol(sigma2_x_train_raw)]
+ }
num_retained_samples <- num_retained_samples - num_gfr
}
# Forest predictions
- mu_hat_train <- forest_samples_mu$predict(forest_dataset_train)*y_std_train + y_bar_train
+ mu_hat_train <- muhat_train_raw*y_std_train + y_bar_train
if (adaptive_coding) {
tau_hat_train_raw <- forest_samples_tau$predict_raw(forest_dataset_train)
tau_hat_train <- t(t(tau_hat_train_raw) * (b_1_samples - b_0_samples))*y_std_train
@@ -1395,7 +1429,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test)
}
if (include_variance_forest) {
- sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train)
+ sigma2_x_hat_train <- exp(sigma2_x_train_raw)
if (has_test) sigma2_x_hat_test <- forest_samples_variance$predict(forest_dataset_test)
}
diff --git a/R/cpp11.R b/R/cpp11.R
index 943206f1..39802efe 100644
--- a/R/cpp11.R
+++ b/R/cpp11.R
@@ -640,6 +640,10 @@ forest_tracker_cpp <- function(data, feature_types, num_trees, n) {
.Call(`_stochtree_forest_tracker_cpp`, data, feature_types, num_trees, n)
}
+get_cached_forest_predictions_cpp <- function(tracker_ptr) {
+ .Call(`_stochtree_get_cached_forest_predictions_cpp`, tracker_ptr)
+}
+
sample_without_replacement_integer_cpp <- function(population_vector, sampling_probs, sample_size) {
.Call(`_stochtree_sample_without_replacement_integer_cpp`, population_vector, sampling_probs, sample_size)
}
diff --git a/R/model.R b/R/model.R
index 955037b0..5b003055 100644
--- a/R/model.R
+++ b/R/model.R
@@ -126,6 +126,13 @@ ForestModel <- R6::R6Class(
}
},
+ #' @description
+ #' Extract an internally-cached prediction of a forest on the training dataset in a sampler.
+ #' @return Vector with as many elements as observations in the training dataset
+ get_cached_forest_predictions = function() {
+ get_cached_forest_predictions_cpp(self$tracker_ptr)
+ },
+
#' @description
#' Propagates basis update through to the (full/partial) residual by iteratively
#' (a) adding back in the previous prediction of each tree, (b) recomputing predictions
diff --git a/include/stochtree/partition_tracker.h b/include/stochtree/partition_tracker.h
index 56b6c2e6..6546b593 100644
--- a/include/stochtree/partition_tracker.h
+++ b/include/stochtree/partition_tracker.h
@@ -91,6 +91,10 @@ class ForestTracker {
SampleNodeMapper* GetSampleNodeMapper() {return sample_node_mapper_.get();}
UnsortedNodeSampleTracker* GetUnsortedNodeSampleTracker() {return unsorted_node_sample_tracker_.get();}
SortedNodeSampleTracker* GetSortedNodeSampleTracker() {return sorted_node_sample_tracker_.get();}
+ int GetNumObservations() {return num_observations_;}
+ int GetNumTrees() {return num_trees_;}
+ int GetNumFeatures() {return num_features_;}
+ bool Initialized() {return initialized_;}
private:
/*! \brief Mapper from observations to predicted values summed over every tree in a forest */
diff --git a/man/ForestModel.Rd b/man/ForestModel.Rd
index ad1181d5..3bb7a1db 100644
--- a/man/ForestModel.Rd
+++ b/man/ForestModel.Rd
@@ -22,6 +22,7 @@ trees, and exposes functionality to run a forest sampler
\itemize{
\item \href{#method-ForestModel-new}{\code{ForestModel$new()}}
\item \href{#method-ForestModel-sample_one_iteration}{\code{ForestModel$sample_one_iteration()}}
+\item \href{#method-ForestModel-get_cached_forest_predictions}{\code{ForestModel$get_cached_forest_predictions()}}
\item \href{#method-ForestModel-propagate_basis_update}{\code{ForestModel$propagate_basis_update()}}
\item \href{#method-ForestModel-propagate_residual_update}{\code{ForestModel$propagate_residual_update()}}
\item \href{#method-ForestModel-update_alpha}{\code{ForestModel$update_alpha()}}
@@ -121,6 +122,19 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR)
}
}
\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-ForestModel-get_cached_forest_predictions}{}}}
+\subsection{Method \code{get_cached_forest_predictions()}}{
+Extract an internally-cached prediction of a forest on the training dataset in a sampler.
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{ForestModel$get_cached_forest_predictions()}\if{html}{\out{
}}
+}
+
+\subsection{Returns}{
+Vector with as many elements as observations in the training dataset
+}
+}
+\if{html}{\out{
}}
\if{html}{\out{}}
\if{latex}{\out{\hypertarget{method-ForestModel-propagate_basis_update}{}}}
\subsection{Method \code{propagate_basis_update()}}{
diff --git a/src/container.cpp b/src/container.cpp
index db10e53b..0d7d3548 100644
--- a/src/container.cpp
+++ b/src/container.cpp
@@ -206,6 +206,7 @@ void ForestContainer::append_from_json(const json& forest_container_json) {
CHECK_EQ(this->num_trees_, forest_container_json.at("num_trees"));
CHECK_EQ(this->output_dimension_, forest_container_json.at("output_dimension"));
CHECK_EQ(this->is_leaf_constant_, forest_container_json.at("is_leaf_constant"));
+ CHECK_EQ(this->is_exponentiated_, forest_container_json.at("is_exponentiated"));
CHECK_EQ(this->initialized_, forest_container_json.at("initialized"));
int new_num_samples = forest_container_json.at("num_samples");
@@ -215,8 +216,7 @@ void ForestContainer::append_from_json(const json& forest_container_json) {
for (int i = 0; i < forest_container_json.at("num_samples"); i++) {
forest_ind = this->num_samples_ + i;
forest_label = "forest_" + std::to_string(i);
- // forests_[forest_ind] = std::make_unique(this->num_trees_, this->output_dimension_, this->is_leaf_constant_);
- forests_.push_back(std::make_unique(this->num_trees_, this->output_dimension_, this->is_leaf_constant_));
+ forests_.push_back(std::make_unique(this->num_trees_, this->output_dimension_, this->is_leaf_constant_, this->is_exponentiated_));
forests_[forest_ind]->from_json(forest_container_json.at(forest_label));
}
this->num_samples_ += new_num_samples;
diff --git a/src/cpp11.cpp b/src/cpp11.cpp
index aea80bc6..873b0c25 100644
--- a/src/cpp11.cpp
+++ b/src/cpp11.cpp
@@ -1187,6 +1187,13 @@ extern "C" SEXP _stochtree_forest_tracker_cpp(SEXP data, SEXP feature_types, SEX
END_CPP11
}
// sampler.cpp
+cpp11::writable::doubles get_cached_forest_predictions_cpp(cpp11::external_pointer tracker_ptr);
+extern "C" SEXP _stochtree_get_cached_forest_predictions_cpp(SEXP tracker_ptr) {
+ BEGIN_CPP11
+ return cpp11::as_sexp(get_cached_forest_predictions_cpp(cpp11::as_cpp>>(tracker_ptr)));
+ END_CPP11
+}
+// sampler.cpp
cpp11::writable::integers sample_without_replacement_integer_cpp(cpp11::integers population_vector, cpp11::doubles sampling_probs, int sample_size);
extern "C" SEXP _stochtree_sample_without_replacement_integer_cpp(SEXP population_vector, SEXP sampling_probs, SEXP sample_size) {
BEGIN_CPP11
@@ -1539,6 +1546,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4},
{"_stochtree_get_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_get_alpha_tree_prior_cpp, 1},
{"_stochtree_get_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_get_beta_tree_prior_cpp, 1},
+ {"_stochtree_get_cached_forest_predictions_cpp", (DL_FUNC) &_stochtree_get_cached_forest_predictions_cpp, 1},
{"_stochtree_get_forest_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_forest_split_counts_forest_container_cpp, 3},
{"_stochtree_get_granular_split_count_array_active_forest_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_active_forest_cpp, 2},
{"_stochtree_get_granular_split_count_array_forest_container_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_forest_container_cpp, 2},
diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp
index 5ff25e72..f90f5cc6 100644
--- a/src/py_stochtree.cpp
+++ b/src/py_stochtree.cpp
@@ -1166,6 +1166,16 @@ class ForestSamplerCpp {
}
}
+ py::array_t GetCachedForestPredictions() {
+ int n_train = tracker_->GetNumObservations();
+ auto output = py::array_t(py::detail::any_container({n_train}));
+ auto accessor = output.mutable_unchecked<1>();
+ for (size_t i = 0; i < n_train; i++) {
+ accessor(i) = tracker_->GetSamplePrediction(i);
+ }
+ return output;
+ }
+
void PropagateBasisUpdate(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestCpp& forest) {
// Perform the update operation
StochTree::UpdateResidualNewBasis(*tracker_, *(dataset.GetDataset()), *(residual.GetData()), forest.GetEnsemble());
@@ -2147,6 +2157,7 @@ PYBIND11_MODULE(stochtree_cpp, m) {
.def("ReconstituteTrackerFromForest", &ForestSamplerCpp::ReconstituteTrackerFromForest)
.def("SampleOneIteration", &ForestSamplerCpp::SampleOneIteration)
.def("InitializeForestModel", &ForestSamplerCpp::InitializeForestModel)
+ .def("GetCachedForestPredictions", &ForestSamplerCpp::GetCachedForestPredictions)
.def("PropagateBasisUpdate", &ForestSamplerCpp::PropagateBasisUpdate)
.def("PropagateResidualUpdate", &ForestSamplerCpp::PropagateResidualUpdate)
.def("UpdateAlpha", &ForestSamplerCpp::UpdateAlpha)
diff --git a/src/sampler.cpp b/src/sampler.cpp
index 1a5a5bb5..af45d6d6 100644
--- a/src/sampler.cpp
+++ b/src/sampler.cpp
@@ -284,6 +284,16 @@ cpp11::external_pointer forest_tracker_cpp(cpp11::exte
return cpp11::external_pointer(tracker_ptr_.release());
}
+[[cpp11::register]]
+cpp11::writable::doubles get_cached_forest_predictions_cpp(cpp11::external_pointer tracker_ptr) {
+ int n_train = tracker_ptr->GetNumObservations();
+ cpp11::writable::doubles output(n_train);
+ for (int i = 0; i < n_train; i++) {
+ output[i] = tracker_ptr->GetSamplePrediction(i);
+ }
+ return output;
+}
+
[[cpp11::register]]
cpp11::writable::integers sample_without_replacement_integer_cpp(
cpp11::integers population_vector,
diff --git a/stochtree/bart.py b/stochtree/bart.py
index 1b51491e..b9040eda 100644
--- a/stochtree/bart.py
+++ b/stochtree/bart.py
@@ -1005,6 +1005,10 @@ def sample(
self.global_var_samples = np.empty(self.num_samples, dtype=np.float64)
if sample_sigma2_leaf:
self.leaf_scale_samples = np.empty(self.num_samples, dtype=np.float64)
+ if self.include_mean_forest:
+ yhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64)
+ if self.include_variance_forest:
+ sigma2_x_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64)
sample_counter = -1
# Forest Dataset (covariates and optional basis)
@@ -1187,6 +1191,10 @@ def sample(
True,
)
+ # Cache train set predictions since they are already computed during sampling
+ if keep_sample:
+ yhat_train_raw[:,sample_counter] = forest_sampler_mean.get_cached_forest_predictions()
+
# Sample the variance forest
if self.include_variance_forest:
forest_sampler_variance.sample_one_iteration(
@@ -1201,6 +1209,10 @@ def sample(
True,
)
+ # Cache train set predictions since they are already computed during sampling
+ if keep_sample:
+ sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions()
+
# Sample variance parameters (if requested)
if self.sample_sigma2_global:
current_sigma2 = global_var_model.sample_one_iteration(
@@ -1379,6 +1391,9 @@ def sample(
False,
)
+ if keep_sample:
+ yhat_train_raw[:,sample_counter] = forest_sampler_mean.get_cached_forest_predictions()
+
# Sample the variance forest
if self.include_variance_forest:
forest_sampler_variance.sample_one_iteration(
@@ -1393,6 +1408,9 @@ def sample(
False,
)
+ if keep_sample:
+ sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions()
+
# Sample variance parameters (if requested)
if self.sample_sigma2_global:
current_sigma2 = global_var_model.sample_one_iteration(
@@ -1441,6 +1459,10 @@ def sample(
self.global_var_samples = self.global_var_samples[num_gfr:]
if self.sample_sigma2_leaf:
self.leaf_scale_samples = self.leaf_scale_samples[num_gfr:]
+ if self.include_mean_forest:
+ yhat_train_raw = yhat_train_raw[:,num_gfr:]
+ if self.include_variance_forest:
+ sigma2_x_train_raw = sigma2_x_train_raw[:,num_gfr:]
self.num_samples -= num_gfr
# Store predictions
@@ -1451,9 +1473,6 @@ def sample(
self.leaf_scale_samples = self.leaf_scale_samples
if self.include_mean_forest:
- yhat_train_raw = self.forest_container_mean.forest_container_cpp.Predict(
- forest_dataset_train.dataset_cpp
- )
self.y_hat_train = yhat_train_raw * self.y_std + self.y_bar
if self.has_test:
yhat_test_raw = self.forest_container_mean.forest_container_cpp.Predict(
@@ -1482,20 +1501,15 @@ def sample(
self.y_hat_test = rfx_preds_test
if self.include_variance_forest:
- sigma2_x_train_raw = (
- self.forest_container_variance.forest_container_cpp.Predict(
- forest_dataset_train.dataset_cpp
- )
- )
if self.sample_sigma2_global:
- self.sigma2_x_train = sigma2_x_train_raw
+ self.sigma2_x_train = np.empty_like(sigma2_x_train_raw)
for i in range(self.num_samples):
self.sigma2_x_train[:, i] = (
- sigma2_x_train_raw[:, i] * self.global_var_samples[i]
+ np.exp(sigma2_x_train_raw[:, i]) * self.global_var_samples[i]
)
else:
self.sigma2_x_train = (
- sigma2_x_train_raw * self.sigma2_init * self.y_std * self.y_std
+ np.exp(sigma2_x_train_raw) * self.sigma2_init * self.y_std * self.y_std
)
if self.has_test:
sigma2_x_test_raw = (
@@ -1621,14 +1635,14 @@ def predict(
)
)
if self.sample_sigma2_global:
- variance_pred = variance_pred_raw
+ variance_pred = np.empty_like(variance_pred_raw)
for i in range(self.num_samples):
- variance_pred[:, i] = np.sqrt(
+ variance_pred[:, i] = (
variance_pred_raw[:, i] * self.global_var_samples[i]
)
else:
variance_pred = (
- np.sqrt(variance_pred_raw * self.sigma2_init) * self.y_std
+ variance_pred_raw * self.sigma2_init * self.y_std * self.y_std
)
has_mean_predictions = self.include_mean_forest or self.has_rfx
@@ -1810,7 +1824,7 @@ def predict_variance(self, covariates: np.array) -> np.array:
pred_dataset.dataset_cpp
)
if self.sample_sigma2_global:
- variance_pred = variance_pred_raw
+ variance_pred = np.empty_like(variance_pred_raw)
for i in range(self.num_samples):
variance_pred[:, i] = (
variance_pred_raw[:, i] * self.global_var_samples[i]
@@ -2017,11 +2031,11 @@ def from_json_string_list(self, json_string_list: list[str]) -> None:
for i in range(len(json_object_list)):
if i == 0:
self.forest_container_variance.forest_container_cpp.LoadFromJson(
- json_object_list[i].json_cpp, "forest_1"
+ json_object_list[i].json_cpp, "forest_0"
)
else:
self.forest_container_variance.forest_container_cpp.AppendFromJson(
- json_object_list[i].json_cpp, "forest_1"
+ json_object_list[i].json_cpp, "forest_0"
)
# Unpack random effects
@@ -2046,13 +2060,19 @@ def from_json_string_list(self, json_string_list: list[str]) -> None:
self.num_gfr = json_object_default.get_integer("num_gfr")
self.num_burnin = json_object_default.get_integer("num_burnin")
self.num_mcmc = json_object_default.get_integer("num_mcmc")
- self.num_samples = json_object_default.get_integer("num_samples")
self.num_basis = json_object_default.get_integer("num_basis")
self.has_basis = json_object_default.get_boolean("requires_basis")
self.probit_outcome_model = json_object_default.get_boolean(
"probit_outcome_model"
)
+ # Unpack number of samples
+ for i in range(len(json_object_list)):
+ if i == 0:
+ self.num_samples = json_object_list[i].get_integer("num_samples")
+ else:
+ self.num_samples += json_object_list[i].get_integer("num_samples")
+
# Unpack parameter samples
if self.sample_sigma2_global:
for i in range(len(json_object_list)):
diff --git a/stochtree/bcf.py b/stochtree/bcf.py
index c4b67232..ce2c5531 100644
--- a/stochtree/bcf.py
+++ b/stochtree/bcf.py
@@ -1480,6 +1480,10 @@ def sample(
self.leaf_scale_mu_samples = np.empty(self.num_samples, dtype=np.float64)
if sample_sigma2_leaf_tau:
self.leaf_scale_tau_samples = np.empty(self.num_samples, dtype=np.float64)
+ muhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64)
+ tauhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64)
+ if self.include_variance_forest:
+ sigma2_x_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64)
sample_counter = -1
# Prepare adaptive coding structure
@@ -1692,6 +1696,10 @@ def sample(
True,
)
+ # Cache train set predictions since they are already computed during sampling
+ if keep_sample:
+ muhat_train_raw[:,sample_counter] = forest_sampler_mu.get_cached_forest_predictions()
+
# Sample variance parameters (if requested)
if self.sample_sigma2_global:
current_sigma2 = global_var_model.sample_one_iteration(
@@ -1725,6 +1733,10 @@ def sample(
True,
)
+ # Cannot cache train set predictions for tau because the cached predictions in the
+ # tracking data structures are pre-multiplied by the basis (treatment)
+ # ...
+
# Sample coding parameters (if requested)
if self.adaptive_coding:
mu_x = active_forest_mu.predict_raw(forest_dataset_train)
@@ -1782,6 +1794,10 @@ def sample(
True,
)
+ # Cache train set predictions since they are already computed during sampling
+ if keep_sample:
+ sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions()
+
# Sample variance parameters (if requested)
if self.sample_sigma2_global:
current_sigma2 = global_var_model.sample_one_iteration(
@@ -1873,6 +1889,10 @@ def sample(
False,
)
+ # Cache train set predictions since they are already computed during sampling
+ if keep_sample:
+ muhat_train_raw[:,sample_counter] = forest_sampler_mu.get_cached_forest_predictions()
+
# Sample variance parameters (if requested)
if self.sample_sigma2_global:
current_sigma2 = global_var_model.sample_one_iteration(
@@ -1906,6 +1926,10 @@ def sample(
False,
)
+ # Cannot cache train set predictions for tau because the cached predictions in the
+ # tracking data structures are pre-multiplied by the basis (treatment)
+ # ...
+
# Sample coding parameters (if requested)
if self.adaptive_coding:
mu_x = active_forest_mu.predict_raw(forest_dataset_train)
@@ -1963,6 +1987,10 @@ def sample(
True,
)
+ # Cache train set predictions since they are already computed during sampling
+ if keep_sample:
+ sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions()
+
# Sample variance parameters (if requested)
if self.sample_sigma2_global:
current_sigma2 = global_var_model.sample_one_iteration(
@@ -2018,13 +2046,13 @@ def sample(
self.leaf_scale_mu_samples = self.leaf_scale_mu_samples[num_gfr:]
if self.sample_sigma2_leaf_tau:
self.leaf_scale_tau_samples = self.leaf_scale_tau_samples[num_gfr:]
+ muhat_train_raw = muhat_train_raw[:,num_gfr:]
+ if self.include_variance_forest:
+ sigma2_x_train_raw = sigma2_x_train_raw[:,num_gfr:]
self.num_samples -= num_gfr
# Store predictions
- mu_raw = self.forest_container_mu.forest_container_cpp.Predict(
- forest_dataset_train.dataset_cpp
- )
- self.mu_hat_train = mu_raw * self.y_std + self.y_bar
+ self.mu_hat_train = muhat_train_raw * self.y_std + self.y_bar
tau_raw_train = self.forest_container_tau.forest_container_cpp.PredictRaw(
forest_dataset_train.dataset_cpp
)
@@ -2080,21 +2108,29 @@ def sample(
if self.has_test:
self.y_hat_test = self.y_hat_test + rfx_preds_test
+ if self.sample_sigma2_global:
+ self.global_var_samples = self.global_var_samples * self.y_std * self.y_std
+
+ if self.sample_sigma2_leaf_mu:
+ self.leaf_scale_mu_samples = self.leaf_scale_mu_samples
+
+ if self.sample_sigma2_leaf_tau:
+ self.leaf_scale_tau_samples = self.leaf_scale_tau_samples
+
+ if self.adaptive_coding:
+ self.b0_samples = self.b0_samples
+ self.b1_samples = self.b1_samples
+
if self.include_variance_forest:
- sigma2_x_train_raw = (
- self.forest_container_variance.forest_container_cpp.Predict(
- forest_dataset_train.dataset_cpp
- )
- )
if self.sample_sigma2_global:
- self.sigma2_x_train = sigma2_x_train_raw
+ self.sigma2_x_train = np.empty_like(sigma2_x_train_raw)
for i in range(self.num_samples):
self.sigma2_x_train[:, i] = (
- sigma2_x_train_raw[:, i] * self.global_var_samples[i]
+ np.exp(sigma2_x_train_raw[:, i]) * self.global_var_samples[i]
)
else:
self.sigma2_x_train = (
- sigma2_x_train_raw * self.sigma2_init * self.y_std * self.y_std
+ np.exp(sigma2_x_train_raw) * self.sigma2_init * self.y_std * self.y_std
)
if self.has_test:
sigma2_x_test_raw = (
@@ -2103,7 +2139,7 @@ def sample(
)
)
if self.sample_sigma2_global:
- self.sigma2_x_test = sigma2_x_test_raw
+ self.sigma2_x_test = np.empty_like(sigma2_x_test_raw)
for i in range(self.num_samples):
self.sigma2_x_test[:, i] = (
sigma2_x_test_raw[:, i] * self.global_var_samples[i]
@@ -2113,19 +2149,6 @@ def sample(
sigma2_x_test_raw * self.sigma2_init * self.y_std * self.y_std
)
- if self.sample_sigma2_global:
- self.global_var_samples = self.global_var_samples * self.y_std * self.y_std
-
- if self.sample_sigma2_leaf_mu:
- self.leaf_scale_mu_samples = self.leaf_scale_mu_samples
-
- if self.sample_sigma2_leaf_tau:
- self.leaf_scale_tau_samples = self.leaf_scale_tau_samples
-
- if self.adaptive_coding:
- self.b0_samples = self.b0_samples
- self.b1_samples = self.b1_samples
-
def predict_tau(
self, X: np.array, Z: np.array, propensity: np.array = None
) -> np.array:
@@ -2311,7 +2334,7 @@ def predict_variance(
pred_dataset.dataset_cpp
)
if self.sample_sigma2_global:
- variance_pred = variance_pred_raw
+ variance_pred = np.empty_like(variance_pred_raw)
for i in range(self.num_samples):
variance_pred[:, i] = (
variance_pred_raw[:, i] * self.global_var_samples[i]
@@ -2463,7 +2486,7 @@ def predict(
forest_dataset_test.dataset_cpp
)
if self.sample_sigma2_global:
- sigma2_x = sigma2_x_raw
+ sigma2_x = np.empty_like(sigma2_x_raw)
for i in range(self.num_samples):
sigma2_x[:, i] = sigma2_x_raw[:, i] * self.global_var_samples[i]
else:
@@ -2736,7 +2759,6 @@ def from_json_string_list(self, json_string_list: list[str]) -> None:
self.num_gfr = json_object_default.get_scalar("num_gfr")
self.num_burnin = json_object_default.get_scalar("num_burnin")
self.num_mcmc = json_object_default.get_scalar("num_mcmc")
- self.num_samples = json_object_default.get_scalar("num_samples")
self.adaptive_coding = json_object_default.get_boolean("adaptive_coding")
self.propensity_covariate = json_object_default.get_string(
"propensity_covariate"
@@ -2744,6 +2766,13 @@ def from_json_string_list(self, json_string_list: list[str]) -> None:
self.internal_propensity_model = json_object_default.get_boolean(
"internal_propensity_model"
)
+
+ # Unpack number of samples
+ for i in range(len(json_object_list)):
+ if i == 0:
+ self.num_samples = json_object_list[i].get_integer("num_samples")
+ else:
+ self.num_samples += json_object_list[i].get_integer("num_samples")
# Unpack parameter samples
if self.sample_sigma2_global:
diff --git a/stochtree/sampler.py b/stochtree/sampler.py
index be55286a..8ac4f013 100644
--- a/stochtree/sampler.py
+++ b/stochtree/sampler.py
@@ -266,6 +266,17 @@ def propagate_basis_update(
self.forest_sampler_cpp.PropagateBasisUpdate(
dataset.dataset_cpp, residual.residual_cpp, forest.forest_cpp
)
+
+ def get_cached_forest_predictions(self) -> np.array:
+ """
+ Extract an internally-cached prediction of a forest on the training dataset in a sampler.
+
+ Returns
+ ----------
+ np.array
+ Numpy 1D array with as many elements as observations in the training dataset
+ """
+ return self.forest_sampler_cpp.GetCachedForestPredictions()
def update_alpha(self, alpha: float) -> None:
"""
diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R
index 325bdbcf..88cbcd6a 100644
--- a/test/R/testthat/test-bart.R
+++ b/test/R/testthat/test-bart.R
@@ -291,3 +291,48 @@ test_that("Warmstart BART", {
general_params = general_param_list)
)
})
+
+test_that("BART Predictions", {
+ skip_on_cran()
+
+ # Generate simulated data
+ n <- 100
+ p <- 5
+ X <- matrix(runif(n*p), ncol = p)
+ f_XW <- (
+ ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
+ ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
+ ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
+ ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
+ )
+ noise_sd <- 1
+ y <- f_XW + rnorm(n, 0, noise_sd)
+ test_set_pct <- 0.2
+ n_test <- round(test_set_pct*n)
+ n_train <- n - n_test
+ test_inds <- sort(sample(1:n, n_test, replace = FALSE))
+ train_inds <- (1:n)[!((1:n) %in% test_inds)]
+ X_test <- X[test_inds,]
+ X_train <- X[train_inds,]
+ y_test <- y[test_inds]
+ y_train <- y[train_inds]
+
+ # Run a BART model with only GFR
+ general_params <- list(num_chains = 1)
+ variance_forest_params <- list(num_trees = 50)
+ bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
+ num_gfr = 10, num_burnin = 0, num_mcmc = 10,
+ general_params = general_params,
+ variance_forest_params = variance_forest_params)
+
+ # Check that cached predictions agree with results of predict() function
+ train_preds <- predict(bart_model, X = X_train)
+ train_preds_mean_cached <- bart_model$y_hat_train
+ train_preds_mean_recomputed <- train_preds$mean_forest_predictions
+ train_preds_variance_cached <- bart_model$sigma2_x_hat_train
+ train_preds_variance_recomputed <- train_preds$variance_forest_predictions
+
+ # Assertion
+ expect_equal(train_preds_mean_cached, train_preds_mean_recomputed)
+ expect_equal(train_preds_variance_cached, train_preds_variance_recomputed)
+})
diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R
index 24fabcd1..6f0a9ce8 100644
--- a/test/R/testthat/test-bcf.R
+++ b/test/R/testthat/test-bcf.R
@@ -426,4 +426,71 @@ test_that("Multivariate Treatment MCMC BCF", {
propensity_test = pi_test, num_gfr = 0, num_burnin = 10,
num_mcmc = 10, general_params = general_param_list)
)
-})
\ No newline at end of file
+})
+
+test_that("BCF Predictions", {
+ skip_on_cran()
+
+ # Generate simulated data
+ n <- 100
+ p <- 5
+ X <- matrix(runif(n*p), ncol = p)
+ mu_X <- (
+ ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
+ ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
+ ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
+ ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
+ )
+ pi_X <- (
+ ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
+ ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
+ ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
+ ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
+ )
+ tau_X <- (
+ ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
+ ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
+ ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
+ ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
+ )
+ Z <- rbinom(n, 1, pi_X)
+ noise_sd <- 1
+ y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd)
+ test_set_pct <- 0.2
+ n_test <- round(test_set_pct*n)
+ n_train <- n - n_test
+ test_inds <- sort(sample(1:n, n_test, replace = FALSE))
+ train_inds <- (1:n)[!((1:n) %in% test_inds)]
+ X_test <- X[test_inds,]
+ X_train <- X[train_inds,]
+ Z_test <- Z[test_inds]
+ Z_train <- Z[train_inds]
+ pi_test <- pi_X[test_inds]
+ pi_train <- pi_X[train_inds]
+ mu_test <- mu_X[test_inds]
+ mu_train <- mu_X[train_inds]
+ tau_test <- tau_X[test_inds]
+ tau_train <- tau_X[train_inds]
+ y_test <- y[test_inds]
+ y_train <- y[train_inds]
+
+ # Run a BCF model with only GFR
+ general_params <- list(num_chains = 1, keep_every = 1)
+ variance_forest_params <- list(num_trees = 50)
+ bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train,
+ propensity_train = pi_train, X_test = X_test, Z_test = Z_test,
+ propensity_test = pi_test, num_gfr = 10, num_burnin = 0,
+ num_mcmc = 10, general_params = general_params,
+ variance_forest_params = variance_forest_params)
+
+ # Check that cached predictions agree with results of predict() function
+ train_preds <- predict(bcf_model, X = X_train, Z = Z_train, propensity = pi_train)
+ train_preds_mean_cached <- bcf_model$y_hat_train
+ train_preds_mean_recomputed <- train_preds$y_hat
+ train_preds_variance_cached <- bcf_model$sigma2_x_hat_train
+ train_preds_variance_recomputed <- train_preds$variance_forest_predictions
+
+ # Assertion
+ expect_equal(train_preds_mean_cached, train_preds_mean_recomputed)
+ expect_equal(train_preds_variance_cached, train_preds_variance_recomputed)
+})
diff --git a/test/python/test_bart.py b/test/python/test_bart.py
index a2f2e64c..31962891 100644
--- a/test/python/test_bart.py
+++ b/test/python/test_bart.py
@@ -406,14 +406,21 @@ def conditional_stddev(X):
bart_model_3.from_json_string_list(bart_models_json)
# Assertions
- y_hat_train_combined, _ = bart_model_3.predict(covariates=X_train)
+ y_hat_train_combined, sigma2_x_train_combined = bart_model_3.predict(covariates=X_train)
assert y_hat_train_combined.shape == (n_train, num_mcmc * 2)
+ assert sigma2_x_train_combined.shape == (n_train, num_mcmc * 2)
np.testing.assert_allclose(
y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train
)
np.testing.assert_allclose(
y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train
)
+ np.testing.assert_allclose(
+ sigma2_x_train_combined[:, 0:num_mcmc], bart_model.sigma2_x_train
+ )
+ np.testing.assert_allclose(
+ sigma2_x_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.sigma2_x_train
+ )
np.testing.assert_allclose(
bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples
)
diff --git a/test/python/test_bcf.py b/test/python/test_bcf.py
index 96f25a34..65f39390 100644
--- a/test/python/test_bcf.py
+++ b/test/python/test_bcf.py
@@ -577,3 +577,194 @@ def test_multivariate_bcf(self):
num_mcmc=num_mcmc,
variance_forest_params=variance_forest_params,
)
+
+ def test_binary_bcf_heteroskedastic(self):
+ # RNG
+ random_seed = 101
+ rng = np.random.default_rng(random_seed)
+
+ # Generate covariates and basis
+ n = 100
+ p_X = 5
+ X = rng.uniform(0, 1, (n, p_X))
+ pi_X = 0.25 + 0.5 * X[:, 0]
+ Z = rng.binomial(1, pi_X, n).astype(float)
+
+ # Define the outcome mean functions (prognostic and treatment effects)
+ mu_X = pi_X * 5
+ tau_X = X[:, 1] * 2
+
+ # Generate outcome
+ epsilon = rng.normal(0, 1, n)
+ y = mu_X + tau_X * Z + epsilon
+
+ # Test-train split
+ sample_inds = np.arange(n)
+ train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)
+ X_train = X[train_inds, :]
+ X_test = X[test_inds, :]
+ Z_train = Z[train_inds]
+ Z_test = Z[test_inds]
+ y_train = y[train_inds]
+ pi_train = pi_X[train_inds]
+ pi_test = pi_X[test_inds]
+ n_train = X_train.shape[0]
+ n_test = X_test.shape[0]
+
+ # BCF settings
+ num_gfr = 10
+ num_burnin = 0
+ num_mcmc = 10
+
+ # Run BCF with test set and propensity score
+ bcf_model = BCFModel()
+ variance_forest_params = {"num_trees": 50}
+ bcf_model.sample(
+ X_train=X_train,
+ Z_train=Z_train,
+ y_train=y_train,
+ pi_train=pi_train,
+ X_test=X_test,
+ Z_test=Z_test,
+ pi_test=pi_test,
+ num_gfr=num_gfr,
+ num_burnin=num_burnin,
+ num_mcmc=num_mcmc,
+ variance_forest_params=variance_forest_params,
+ )
+
+ # Assertions
+ assert bcf_model.y_hat_train.shape == (n_train, num_mcmc)
+ assert bcf_model.mu_hat_train.shape == (n_train, num_mcmc)
+ assert bcf_model.tau_hat_train.shape == (n_train, num_mcmc)
+ assert bcf_model.sigma2_x_train.shape == (n_train, num_mcmc)
+ assert bcf_model.y_hat_test.shape == (n_test, num_mcmc)
+ assert bcf_model.mu_hat_test.shape == (n_test, num_mcmc)
+ assert bcf_model.tau_hat_test.shape == (n_test, num_mcmc)
+ assert bcf_model.sigma2_x_test.shape == (n_train, num_mcmc)
+
+ # Check overall prediction method
+ tau_hat, mu_hat, y_hat, sigma2_hat = bcf_model.predict(X_test, Z_test, pi_test)
+ assert tau_hat.shape == (n_test, num_mcmc)
+ assert mu_hat.shape == (n_test, num_mcmc)
+ assert y_hat.shape == (n_test, num_mcmc)
+ assert sigma2_hat.shape == (n_test, num_mcmc)
+
+ # Check treatment effect prediction method
+ tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test)
+ assert tau_hat.shape == (n_test, num_mcmc)
+
+ # Run BCF without test set and with propensity score
+ bcf_model = BCFModel()
+ variance_forest_params = {"num_trees": 50}
+ bcf_model.sample(
+ X_train=X_train,
+ Z_train=Z_train,
+ y_train=y_train,
+ pi_train=pi_train,
+ num_gfr=num_gfr,
+ num_burnin=num_burnin,
+ num_mcmc=num_mcmc,
+ variance_forest_params=variance_forest_params,
+ )
+
+ # Assertions
+ assert bcf_model.y_hat_train.shape == (n_train, num_mcmc)
+ assert bcf_model.mu_hat_train.shape == (n_train, num_mcmc)
+ assert bcf_model.tau_hat_train.shape == (n_train, num_mcmc)
+ assert bcf_model.sigma2_x_train.shape == (n_train, num_mcmc)
+
+ # Check overall prediction method
+ tau_hat, mu_hat, y_hat, sigma2_hat = bcf_model.predict(X_test, Z_test, pi_test)
+ assert tau_hat.shape == (n_test, num_mcmc)
+ assert mu_hat.shape == (n_test, num_mcmc)
+ assert y_hat.shape == (n_test, num_mcmc)
+ assert sigma2_hat.shape == (n_test, num_mcmc)
+
+ # Check predictions match
+ tau_hat, mu_hat, y_hat, sigma2_hat = bcf_model.predict(X_train, Z_train, pi_train)
+ assert tau_hat.shape == (n_train, num_mcmc)
+ assert mu_hat.shape == (n_train, num_mcmc)
+ assert y_hat.shape == (n_train, num_mcmc)
+ assert sigma2_hat.shape == (n_train, num_mcmc)
+ np.testing.assert_allclose(
+ y_hat, bcf_model.y_hat_train
+ )
+ np.testing.assert_allclose(
+ mu_hat, bcf_model.mu_hat_train
+ )
+ np.testing.assert_allclose(
+ tau_hat, bcf_model.tau_hat_train
+ )
+ np.testing.assert_allclose(
+ sigma2_hat, bcf_model.sigma2_x_train
+ )
+
+ # Check treatment effect prediction method
+ tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test)
+ assert tau_hat.shape == (n_test, num_mcmc)
+
+ # Run BCF with test set and without propensity score
+ bcf_model = BCFModel()
+ variance_forest_params = {"num_trees": 50}
+ bcf_model.sample(
+ X_train=X_train,
+ Z_train=Z_train,
+ y_train=y_train,
+ X_test=X_test,
+ Z_test=Z_test,
+ num_gfr=num_gfr,
+ num_burnin=num_burnin,
+ num_mcmc=num_mcmc,
+ variance_forest_params=variance_forest_params,
+ )
+
+ # Assertions
+ assert bcf_model.y_hat_train.shape == (n_train, num_mcmc)
+ assert bcf_model.mu_hat_train.shape == (n_train, num_mcmc)
+ assert bcf_model.tau_hat_train.shape == (n_train, num_mcmc)
+ assert bcf_model.tau_hat_train.shape == (n_train, num_mcmc)
+ assert bcf_model.bart_propensity_model.y_hat_train.shape == (n_train, 10)
+ assert bcf_model.y_hat_test.shape == (n_test, num_mcmc)
+ assert bcf_model.mu_hat_test.shape == (n_test, num_mcmc)
+ assert bcf_model.tau_hat_test.shape == (n_test, num_mcmc)
+ assert bcf_model.bart_propensity_model.y_hat_test.shape == (n_test, 10)
+
+ # Check overall prediction method
+ tau_hat, mu_hat, y_hat, sigma2_hat = bcf_model.predict(X_test, Z_test)
+ assert tau_hat.shape == (n_test, num_mcmc)
+ assert mu_hat.shape == (n_test, num_mcmc)
+ assert y_hat.shape == (n_test, num_mcmc)
+ assert sigma2_hat.shape == (n_test, num_mcmc)
+
+ # Check treatment effect prediction method
+ tau_hat = bcf_model.predict_tau(X_test, Z_test)
+ assert tau_hat.shape == (n_test, num_mcmc)
+
+ # Run BCF without test set and without propensity score
+ bcf_model = BCFModel()
+ variance_forest_params = {"num_trees": 0}
+ bcf_model.sample(
+ X_train=X_train,
+ Z_train=Z_train,
+ y_train=y_train,
+ num_gfr=num_gfr,
+ num_burnin=num_burnin,
+ num_mcmc=num_mcmc,
+ variance_forest_params=variance_forest_params,
+ )
+
+ # Assertions
+ assert bcf_model.y_hat_train.shape == (n_train, num_mcmc)
+ assert bcf_model.mu_hat_train.shape == (n_train, num_mcmc)
+ assert bcf_model.tau_hat_train.shape == (n_train, num_mcmc)
+ assert bcf_model.bart_propensity_model.y_hat_train.shape == (n_train, 10)
+
+ # Check overall prediction method
+ tau_hat, mu_hat, y_hat = bcf_model.predict(X_test, Z_test)
+ assert tau_hat.shape == (n_test, num_mcmc)
+ assert mu_hat.shape == (n_test, num_mcmc)
+ assert y_hat.shape == (n_test, num_mcmc)
+
+ # Check treatment effect prediction method
+ tau_hat = bcf_model.predict_tau(X_test, Z_test)
diff --git a/tools/perf/bart_profiling_script.R b/tools/perf/bart_profiling_script.R
new file mode 100644
index 00000000..7a60eed2
--- /dev/null
+++ b/tools/perf/bart_profiling_script.R
@@ -0,0 +1,57 @@
+# Load libraries
+library(stochtree)
+
+# Capture command line arguments
+args <- commandArgs(trailingOnly = T)
+if (length(args) > 0){
+ n <- as.integer(args[1])
+ p <- as.integer(args[2])
+ num_gfr <- as.integer(args[3])
+ num_mcmc <- as.integer(args[4])
+ snr <- as.numeric(args[5])
+} else{
+ # Default arguments
+ n <- 1000
+ p <- 5
+ num_gfr <- 10
+ num_mcmc <- 100
+ snr <- 3.0
+}
+cat("n = ", n, "\np = ", p, "\nnum_gfr = ", num_gfr,
+ "\nnum_mcmc = ", num_mcmc, "\nsnr = ", snr, "\n", sep = "")
+
+# Generate data needed to train BART model
+X <- matrix(runif(n*p), ncol = p)
+plm_term <- (
+ ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*X[,2]) +
+ ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*X[,2]) +
+ ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*X[,2]) +
+ ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*X[,2])
+)
+trig_term <- (
+ 2*sin(X[,3]*2*pi) -
+ 1.5*cos(X[,4]*2*pi)
+)
+f_XW <- plm_term + trig_term
+noise_sd <- sd(f_XW)/snr
+y <- f_XW + rnorm(n, 0, noise_sd)
+
+# Split into train and test sets
+test_set_pct <- 0.2
+n_test <- round(test_set_pct*n)
+n_train <- n - n_test
+test_inds <- sort(sample(1:n, n_test, replace = FALSE))
+train_inds <- (1:n)[!((1:n) %in% test_inds)]
+X_test <- X[test_inds,]
+X_train <- X[train_inds,]
+y_test <- y[test_inds]
+y_train <- y[train_inds]
+
+system.time({
+ # Sample BART model
+ bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
+ num_gfr = num_gfr, num_mcmc = num_mcmc)
+
+ # Predict on the test set
+ test_preds <- predict(bart_model, X = X_test)
+})
\ No newline at end of file