diff --git a/R/bart.R b/R/bart.R index 1b28ee83..f14b7501 100644 --- a/R/bart.R +++ b/R/bart.R @@ -707,6 +707,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train num_retained_samples <- num_gfr + ifelse(keep_burnin, num_burnin, 0) + num_mcmc * num_chains if (sample_sigma2_global) global_var_samples <- rep(NA, num_retained_samples) if (sample_sigma2_leaf) leaf_scale_samples <- rep(NA, num_retained_samples) + if (include_mean_forest) mean_forest_pred_train <- matrix(NA_real_, nrow(X_train), num_retained_samples) + if (include_variance_forest) variance_forest_pred_train <- matrix(NA_real_, nrow(X_train), num_retained_samples) sample_counter <- 0 # Initialize the leaves of each tree in the mean forest @@ -757,6 +759,11 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean, global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + mean_forest_pred_train[,sample_counter] <- forest_model_mean$get_cached_forest_predictions() + } } if (include_variance_forest) { forest_model_variance$sample_one_iteration( @@ -764,6 +771,11 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + variance_forest_pred_train[,sample_counter] <- forest_model_variance$get_cached_forest_predictions() + } } if (sample_sigma2_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) @@ -910,6 +922,11 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean, global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + mean_forest_pred_train[,sample_counter] <- forest_model_mean$get_cached_forest_predictions() + } } if (include_variance_forest) { forest_model_variance$sample_one_iteration( @@ -917,6 +934,11 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + variance_forest_pred_train[,sample_counter] <- forest_model_variance$get_cached_forest_predictions() + } } if (sample_sigma2_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) @@ -949,6 +971,12 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train rfx_samples$delete_sample(0) } } + if (include_mean_forest) { + mean_forest_pred_train <- mean_forest_pred_train[,(num_gfr+1):ncol(mean_forest_pred_train)] + } + if (include_variance_forest) { + variance_forest_pred_train <- variance_forest_pred_train[,(num_gfr+1):ncol(variance_forest_pred_train)] + } if (sample_sigma2_global) { global_var_samples <- global_var_samples[(num_gfr+1):length(global_var_samples)] } @@ -960,13 +988,15 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train # Mean forest predictions if (include_mean_forest) { - y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train + # y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train + y_hat_train <- mean_forest_pred_train*y_std_train + y_bar_train if (has_test) y_hat_test <- forest_samples_mean$predict(forest_dataset_test)*y_std_train + y_bar_train } # Variance forest predictions if (include_variance_forest) { - sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train) + # sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train) + sigma2_x_hat_train <- exp(variance_forest_pred_train) if (has_test) sigma2_x_hat_test <- forest_samples_variance$predict(forest_dataset_test) } diff --git a/R/bcf.R b/R/bcf.R index 11696f79..85cbcb20 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -885,6 +885,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (sample_sigma2_global) global_var_samples <- rep(NA, num_retained_samples) if (sample_sigma2_leaf_mu) leaf_scale_mu_samples <- rep(NA, num_retained_samples) if (sample_sigma2_leaf_tau) leaf_scale_tau_samples <- rep(NA, num_retained_samples) + muhat_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples) + if (include_variance_forest) sigma2_x_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples) sample_counter <- 0 # Prepare adaptive coding structure @@ -997,6 +999,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE ) + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + muhat_train_raw[,sample_counter] <- forest_model_mu$get_cached_forest_predictions() + } + # Sample variance parameters (if requested) if (sample_sigma2_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) @@ -1016,6 +1023,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE ) + # Cannot cache train set predictions for tau because the cached predictions in the + # tracking data structures are pre-multiplied by the basis (treatment) + # ... + # Sample coding parameters (if requested) if (adaptive_coding) { # Estimate mu(X) and tau(X) and compute y - mu(X) @@ -1060,6 +1071,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + sigma2_x_train_raw[,sample_counter] <- forest_model_variance$get_cached_forest_predictions() + } } if (sample_sigma2_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) @@ -1263,6 +1279,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE ) + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + muhat_train_raw[,sample_counter] <- forest_model_mu$get_cached_forest_predictions() + } + # Sample variance parameters (if requested) if (sample_sigma2_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) @@ -1282,6 +1303,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE ) + # Cannot cache train set predictions for tau because the cached predictions in the + # tracking data structures are pre-multiplied by the basis (treatment) + # ... + # Sample coding parameters (if requested) if (adaptive_coding) { # Estimate mu(X) and tau(X) and compute y - mu(X) @@ -1326,6 +1351,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + sigma2_x_train_raw[,sample_counter] <- forest_model_variance$get_cached_forest_predictions() + } } if (sample_sigma2_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) @@ -1372,11 +1402,15 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id b_1_samples <- b_1_samples[(num_gfr+1):length(b_1_samples)] b_0_samples <- b_0_samples[(num_gfr+1):length(b_0_samples)] } + muhat_train_raw <- muhat_train_raw[,(num_gfr+1):ncol(muhat_train_raw)] + if (include_variance_forest) { + sigma2_x_train_raw <- sigma2_x_train_raw[,(num_gfr+1):ncol(sigma2_x_train_raw)] + } num_retained_samples <- num_retained_samples - num_gfr } # Forest predictions - mu_hat_train <- forest_samples_mu$predict(forest_dataset_train)*y_std_train + y_bar_train + mu_hat_train <- muhat_train_raw*y_std_train + y_bar_train if (adaptive_coding) { tau_hat_train_raw <- forest_samples_tau$predict_raw(forest_dataset_train) tau_hat_train <- t(t(tau_hat_train_raw) * (b_1_samples - b_0_samples))*y_std_train @@ -1395,7 +1429,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test) } if (include_variance_forest) { - sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train) + sigma2_x_hat_train <- exp(sigma2_x_train_raw) if (has_test) sigma2_x_hat_test <- forest_samples_variance$predict(forest_dataset_test) } diff --git a/R/cpp11.R b/R/cpp11.R index 943206f1..39802efe 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -640,6 +640,10 @@ forest_tracker_cpp <- function(data, feature_types, num_trees, n) { .Call(`_stochtree_forest_tracker_cpp`, data, feature_types, num_trees, n) } +get_cached_forest_predictions_cpp <- function(tracker_ptr) { + .Call(`_stochtree_get_cached_forest_predictions_cpp`, tracker_ptr) +} + sample_without_replacement_integer_cpp <- function(population_vector, sampling_probs, sample_size) { .Call(`_stochtree_sample_without_replacement_integer_cpp`, population_vector, sampling_probs, sample_size) } diff --git a/R/model.R b/R/model.R index 955037b0..5b003055 100644 --- a/R/model.R +++ b/R/model.R @@ -126,6 +126,13 @@ ForestModel <- R6::R6Class( } }, + #' @description + #' Extract an internally-cached prediction of a forest on the training dataset in a sampler. + #' @return Vector with as many elements as observations in the training dataset + get_cached_forest_predictions = function() { + get_cached_forest_predictions_cpp(self$tracker_ptr) + }, + #' @description #' Propagates basis update through to the (full/partial) residual by iteratively #' (a) adding back in the previous prediction of each tree, (b) recomputing predictions diff --git a/include/stochtree/partition_tracker.h b/include/stochtree/partition_tracker.h index 56b6c2e6..6546b593 100644 --- a/include/stochtree/partition_tracker.h +++ b/include/stochtree/partition_tracker.h @@ -91,6 +91,10 @@ class ForestTracker { SampleNodeMapper* GetSampleNodeMapper() {return sample_node_mapper_.get();} UnsortedNodeSampleTracker* GetUnsortedNodeSampleTracker() {return unsorted_node_sample_tracker_.get();} SortedNodeSampleTracker* GetSortedNodeSampleTracker() {return sorted_node_sample_tracker_.get();} + int GetNumObservations() {return num_observations_;} + int GetNumTrees() {return num_trees_;} + int GetNumFeatures() {return num_features_;} + bool Initialized() {return initialized_;} private: /*! \brief Mapper from observations to predicted values summed over every tree in a forest */ diff --git a/man/ForestModel.Rd b/man/ForestModel.Rd index ad1181d5..3bb7a1db 100644 --- a/man/ForestModel.Rd +++ b/man/ForestModel.Rd @@ -22,6 +22,7 @@ trees, and exposes functionality to run a forest sampler \itemize{ \item \href{#method-ForestModel-new}{\code{ForestModel$new()}} \item \href{#method-ForestModel-sample_one_iteration}{\code{ForestModel$sample_one_iteration()}} +\item \href{#method-ForestModel-get_cached_forest_predictions}{\code{ForestModel$get_cached_forest_predictions()}} \item \href{#method-ForestModel-propagate_basis_update}{\code{ForestModel$propagate_basis_update()}} \item \href{#method-ForestModel-propagate_residual_update}{\code{ForestModel$propagate_residual_update()}} \item \href{#method-ForestModel-update_alpha}{\code{ForestModel$update_alpha()}} @@ -121,6 +122,19 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR) } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModel-get_cached_forest_predictions}{}}} +\subsection{Method \code{get_cached_forest_predictions()}}{ +Extract an internally-cached prediction of a forest on the training dataset in a sampler. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModel$get_cached_forest_predictions()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Vector with as many elements as observations in the training dataset +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ForestModel-propagate_basis_update}{}}} \subsection{Method \code{propagate_basis_update()}}{ diff --git a/src/container.cpp b/src/container.cpp index db10e53b..0d7d3548 100644 --- a/src/container.cpp +++ b/src/container.cpp @@ -206,6 +206,7 @@ void ForestContainer::append_from_json(const json& forest_container_json) { CHECK_EQ(this->num_trees_, forest_container_json.at("num_trees")); CHECK_EQ(this->output_dimension_, forest_container_json.at("output_dimension")); CHECK_EQ(this->is_leaf_constant_, forest_container_json.at("is_leaf_constant")); + CHECK_EQ(this->is_exponentiated_, forest_container_json.at("is_exponentiated")); CHECK_EQ(this->initialized_, forest_container_json.at("initialized")); int new_num_samples = forest_container_json.at("num_samples"); @@ -215,8 +216,7 @@ void ForestContainer::append_from_json(const json& forest_container_json) { for (int i = 0; i < forest_container_json.at("num_samples"); i++) { forest_ind = this->num_samples_ + i; forest_label = "forest_" + std::to_string(i); - // forests_[forest_ind] = std::make_unique(this->num_trees_, this->output_dimension_, this->is_leaf_constant_); - forests_.push_back(std::make_unique(this->num_trees_, this->output_dimension_, this->is_leaf_constant_)); + forests_.push_back(std::make_unique(this->num_trees_, this->output_dimension_, this->is_leaf_constant_, this->is_exponentiated_)); forests_[forest_ind]->from_json(forest_container_json.at(forest_label)); } this->num_samples_ += new_num_samples; diff --git a/src/cpp11.cpp b/src/cpp11.cpp index aea80bc6..873b0c25 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -1187,6 +1187,13 @@ extern "C" SEXP _stochtree_forest_tracker_cpp(SEXP data, SEXP feature_types, SEX END_CPP11 } // sampler.cpp +cpp11::writable::doubles get_cached_forest_predictions_cpp(cpp11::external_pointer tracker_ptr); +extern "C" SEXP _stochtree_get_cached_forest_predictions_cpp(SEXP tracker_ptr) { + BEGIN_CPP11 + return cpp11::as_sexp(get_cached_forest_predictions_cpp(cpp11::as_cpp>>(tracker_ptr))); + END_CPP11 +} +// sampler.cpp cpp11::writable::integers sample_without_replacement_integer_cpp(cpp11::integers population_vector, cpp11::doubles sampling_probs, int sample_size); extern "C" SEXP _stochtree_sample_without_replacement_integer_cpp(SEXP population_vector, SEXP sampling_probs, SEXP sample_size) { BEGIN_CPP11 @@ -1539,6 +1546,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4}, {"_stochtree_get_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_get_alpha_tree_prior_cpp, 1}, {"_stochtree_get_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_get_beta_tree_prior_cpp, 1}, + {"_stochtree_get_cached_forest_predictions_cpp", (DL_FUNC) &_stochtree_get_cached_forest_predictions_cpp, 1}, {"_stochtree_get_forest_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_forest_split_counts_forest_container_cpp, 3}, {"_stochtree_get_granular_split_count_array_active_forest_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_active_forest_cpp, 2}, {"_stochtree_get_granular_split_count_array_forest_container_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_forest_container_cpp, 2}, diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 5ff25e72..f90f5cc6 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -1166,6 +1166,16 @@ class ForestSamplerCpp { } } + py::array_t GetCachedForestPredictions() { + int n_train = tracker_->GetNumObservations(); + auto output = py::array_t(py::detail::any_container({n_train})); + auto accessor = output.mutable_unchecked<1>(); + for (size_t i = 0; i < n_train; i++) { + accessor(i) = tracker_->GetSamplePrediction(i); + } + return output; + } + void PropagateBasisUpdate(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestCpp& forest) { // Perform the update operation StochTree::UpdateResidualNewBasis(*tracker_, *(dataset.GetDataset()), *(residual.GetData()), forest.GetEnsemble()); @@ -2147,6 +2157,7 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("ReconstituteTrackerFromForest", &ForestSamplerCpp::ReconstituteTrackerFromForest) .def("SampleOneIteration", &ForestSamplerCpp::SampleOneIteration) .def("InitializeForestModel", &ForestSamplerCpp::InitializeForestModel) + .def("GetCachedForestPredictions", &ForestSamplerCpp::GetCachedForestPredictions) .def("PropagateBasisUpdate", &ForestSamplerCpp::PropagateBasisUpdate) .def("PropagateResidualUpdate", &ForestSamplerCpp::PropagateResidualUpdate) .def("UpdateAlpha", &ForestSamplerCpp::UpdateAlpha) diff --git a/src/sampler.cpp b/src/sampler.cpp index 1a5a5bb5..af45d6d6 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -284,6 +284,16 @@ cpp11::external_pointer forest_tracker_cpp(cpp11::exte return cpp11::external_pointer(tracker_ptr_.release()); } +[[cpp11::register]] +cpp11::writable::doubles get_cached_forest_predictions_cpp(cpp11::external_pointer tracker_ptr) { + int n_train = tracker_ptr->GetNumObservations(); + cpp11::writable::doubles output(n_train); + for (int i = 0; i < n_train; i++) { + output[i] = tracker_ptr->GetSamplePrediction(i); + } + return output; +} + [[cpp11::register]] cpp11::writable::integers sample_without_replacement_integer_cpp( cpp11::integers population_vector, diff --git a/stochtree/bart.py b/stochtree/bart.py index 1b51491e..b9040eda 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1005,6 +1005,10 @@ def sample( self.global_var_samples = np.empty(self.num_samples, dtype=np.float64) if sample_sigma2_leaf: self.leaf_scale_samples = np.empty(self.num_samples, dtype=np.float64) + if self.include_mean_forest: + yhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64) + if self.include_variance_forest: + sigma2_x_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64) sample_counter = -1 # Forest Dataset (covariates and optional basis) @@ -1187,6 +1191,10 @@ def sample( True, ) + # Cache train set predictions since they are already computed during sampling + if keep_sample: + yhat_train_raw[:,sample_counter] = forest_sampler_mean.get_cached_forest_predictions() + # Sample the variance forest if self.include_variance_forest: forest_sampler_variance.sample_one_iteration( @@ -1201,6 +1209,10 @@ def sample( True, ) + # Cache train set predictions since they are already computed during sampling + if keep_sample: + sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions() + # Sample variance parameters (if requested) if self.sample_sigma2_global: current_sigma2 = global_var_model.sample_one_iteration( @@ -1379,6 +1391,9 @@ def sample( False, ) + if keep_sample: + yhat_train_raw[:,sample_counter] = forest_sampler_mean.get_cached_forest_predictions() + # Sample the variance forest if self.include_variance_forest: forest_sampler_variance.sample_one_iteration( @@ -1393,6 +1408,9 @@ def sample( False, ) + if keep_sample: + sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions() + # Sample variance parameters (if requested) if self.sample_sigma2_global: current_sigma2 = global_var_model.sample_one_iteration( @@ -1441,6 +1459,10 @@ def sample( self.global_var_samples = self.global_var_samples[num_gfr:] if self.sample_sigma2_leaf: self.leaf_scale_samples = self.leaf_scale_samples[num_gfr:] + if self.include_mean_forest: + yhat_train_raw = yhat_train_raw[:,num_gfr:] + if self.include_variance_forest: + sigma2_x_train_raw = sigma2_x_train_raw[:,num_gfr:] self.num_samples -= num_gfr # Store predictions @@ -1451,9 +1473,6 @@ def sample( self.leaf_scale_samples = self.leaf_scale_samples if self.include_mean_forest: - yhat_train_raw = self.forest_container_mean.forest_container_cpp.Predict( - forest_dataset_train.dataset_cpp - ) self.y_hat_train = yhat_train_raw * self.y_std + self.y_bar if self.has_test: yhat_test_raw = self.forest_container_mean.forest_container_cpp.Predict( @@ -1482,20 +1501,15 @@ def sample( self.y_hat_test = rfx_preds_test if self.include_variance_forest: - sigma2_x_train_raw = ( - self.forest_container_variance.forest_container_cpp.Predict( - forest_dataset_train.dataset_cpp - ) - ) if self.sample_sigma2_global: - self.sigma2_x_train = sigma2_x_train_raw + self.sigma2_x_train = np.empty_like(sigma2_x_train_raw) for i in range(self.num_samples): self.sigma2_x_train[:, i] = ( - sigma2_x_train_raw[:, i] * self.global_var_samples[i] + np.exp(sigma2_x_train_raw[:, i]) * self.global_var_samples[i] ) else: self.sigma2_x_train = ( - sigma2_x_train_raw * self.sigma2_init * self.y_std * self.y_std + np.exp(sigma2_x_train_raw) * self.sigma2_init * self.y_std * self.y_std ) if self.has_test: sigma2_x_test_raw = ( @@ -1621,14 +1635,14 @@ def predict( ) ) if self.sample_sigma2_global: - variance_pred = variance_pred_raw + variance_pred = np.empty_like(variance_pred_raw) for i in range(self.num_samples): - variance_pred[:, i] = np.sqrt( + variance_pred[:, i] = ( variance_pred_raw[:, i] * self.global_var_samples[i] ) else: variance_pred = ( - np.sqrt(variance_pred_raw * self.sigma2_init) * self.y_std + variance_pred_raw * self.sigma2_init * self.y_std * self.y_std ) has_mean_predictions = self.include_mean_forest or self.has_rfx @@ -1810,7 +1824,7 @@ def predict_variance(self, covariates: np.array) -> np.array: pred_dataset.dataset_cpp ) if self.sample_sigma2_global: - variance_pred = variance_pred_raw + variance_pred = np.empty_like(variance_pred_raw) for i in range(self.num_samples): variance_pred[:, i] = ( variance_pred_raw[:, i] * self.global_var_samples[i] @@ -2017,11 +2031,11 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: for i in range(len(json_object_list)): if i == 0: self.forest_container_variance.forest_container_cpp.LoadFromJson( - json_object_list[i].json_cpp, "forest_1" + json_object_list[i].json_cpp, "forest_0" ) else: self.forest_container_variance.forest_container_cpp.AppendFromJson( - json_object_list[i].json_cpp, "forest_1" + json_object_list[i].json_cpp, "forest_0" ) # Unpack random effects @@ -2046,13 +2060,19 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.num_gfr = json_object_default.get_integer("num_gfr") self.num_burnin = json_object_default.get_integer("num_burnin") self.num_mcmc = json_object_default.get_integer("num_mcmc") - self.num_samples = json_object_default.get_integer("num_samples") self.num_basis = json_object_default.get_integer("num_basis") self.has_basis = json_object_default.get_boolean("requires_basis") self.probit_outcome_model = json_object_default.get_boolean( "probit_outcome_model" ) + # Unpack number of samples + for i in range(len(json_object_list)): + if i == 0: + self.num_samples = json_object_list[i].get_integer("num_samples") + else: + self.num_samples += json_object_list[i].get_integer("num_samples") + # Unpack parameter samples if self.sample_sigma2_global: for i in range(len(json_object_list)): diff --git a/stochtree/bcf.py b/stochtree/bcf.py index c4b67232..ce2c5531 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -1480,6 +1480,10 @@ def sample( self.leaf_scale_mu_samples = np.empty(self.num_samples, dtype=np.float64) if sample_sigma2_leaf_tau: self.leaf_scale_tau_samples = np.empty(self.num_samples, dtype=np.float64) + muhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64) + tauhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64) + if self.include_variance_forest: + sigma2_x_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64) sample_counter = -1 # Prepare adaptive coding structure @@ -1692,6 +1696,10 @@ def sample( True, ) + # Cache train set predictions since they are already computed during sampling + if keep_sample: + muhat_train_raw[:,sample_counter] = forest_sampler_mu.get_cached_forest_predictions() + # Sample variance parameters (if requested) if self.sample_sigma2_global: current_sigma2 = global_var_model.sample_one_iteration( @@ -1725,6 +1733,10 @@ def sample( True, ) + # Cannot cache train set predictions for tau because the cached predictions in the + # tracking data structures are pre-multiplied by the basis (treatment) + # ... + # Sample coding parameters (if requested) if self.adaptive_coding: mu_x = active_forest_mu.predict_raw(forest_dataset_train) @@ -1782,6 +1794,10 @@ def sample( True, ) + # Cache train set predictions since they are already computed during sampling + if keep_sample: + sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions() + # Sample variance parameters (if requested) if self.sample_sigma2_global: current_sigma2 = global_var_model.sample_one_iteration( @@ -1873,6 +1889,10 @@ def sample( False, ) + # Cache train set predictions since they are already computed during sampling + if keep_sample: + muhat_train_raw[:,sample_counter] = forest_sampler_mu.get_cached_forest_predictions() + # Sample variance parameters (if requested) if self.sample_sigma2_global: current_sigma2 = global_var_model.sample_one_iteration( @@ -1906,6 +1926,10 @@ def sample( False, ) + # Cannot cache train set predictions for tau because the cached predictions in the + # tracking data structures are pre-multiplied by the basis (treatment) + # ... + # Sample coding parameters (if requested) if self.adaptive_coding: mu_x = active_forest_mu.predict_raw(forest_dataset_train) @@ -1963,6 +1987,10 @@ def sample( True, ) + # Cache train set predictions since they are already computed during sampling + if keep_sample: + sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions() + # Sample variance parameters (if requested) if self.sample_sigma2_global: current_sigma2 = global_var_model.sample_one_iteration( @@ -2018,13 +2046,13 @@ def sample( self.leaf_scale_mu_samples = self.leaf_scale_mu_samples[num_gfr:] if self.sample_sigma2_leaf_tau: self.leaf_scale_tau_samples = self.leaf_scale_tau_samples[num_gfr:] + muhat_train_raw = muhat_train_raw[:,num_gfr:] + if self.include_variance_forest: + sigma2_x_train_raw = sigma2_x_train_raw[:,num_gfr:] self.num_samples -= num_gfr # Store predictions - mu_raw = self.forest_container_mu.forest_container_cpp.Predict( - forest_dataset_train.dataset_cpp - ) - self.mu_hat_train = mu_raw * self.y_std + self.y_bar + self.mu_hat_train = muhat_train_raw * self.y_std + self.y_bar tau_raw_train = self.forest_container_tau.forest_container_cpp.PredictRaw( forest_dataset_train.dataset_cpp ) @@ -2080,21 +2108,29 @@ def sample( if self.has_test: self.y_hat_test = self.y_hat_test + rfx_preds_test + if self.sample_sigma2_global: + self.global_var_samples = self.global_var_samples * self.y_std * self.y_std + + if self.sample_sigma2_leaf_mu: + self.leaf_scale_mu_samples = self.leaf_scale_mu_samples + + if self.sample_sigma2_leaf_tau: + self.leaf_scale_tau_samples = self.leaf_scale_tau_samples + + if self.adaptive_coding: + self.b0_samples = self.b0_samples + self.b1_samples = self.b1_samples + if self.include_variance_forest: - sigma2_x_train_raw = ( - self.forest_container_variance.forest_container_cpp.Predict( - forest_dataset_train.dataset_cpp - ) - ) if self.sample_sigma2_global: - self.sigma2_x_train = sigma2_x_train_raw + self.sigma2_x_train = np.empty_like(sigma2_x_train_raw) for i in range(self.num_samples): self.sigma2_x_train[:, i] = ( - sigma2_x_train_raw[:, i] * self.global_var_samples[i] + np.exp(sigma2_x_train_raw[:, i]) * self.global_var_samples[i] ) else: self.sigma2_x_train = ( - sigma2_x_train_raw * self.sigma2_init * self.y_std * self.y_std + np.exp(sigma2_x_train_raw) * self.sigma2_init * self.y_std * self.y_std ) if self.has_test: sigma2_x_test_raw = ( @@ -2103,7 +2139,7 @@ def sample( ) ) if self.sample_sigma2_global: - self.sigma2_x_test = sigma2_x_test_raw + self.sigma2_x_test = np.empty_like(sigma2_x_test_raw) for i in range(self.num_samples): self.sigma2_x_test[:, i] = ( sigma2_x_test_raw[:, i] * self.global_var_samples[i] @@ -2113,19 +2149,6 @@ def sample( sigma2_x_test_raw * self.sigma2_init * self.y_std * self.y_std ) - if self.sample_sigma2_global: - self.global_var_samples = self.global_var_samples * self.y_std * self.y_std - - if self.sample_sigma2_leaf_mu: - self.leaf_scale_mu_samples = self.leaf_scale_mu_samples - - if self.sample_sigma2_leaf_tau: - self.leaf_scale_tau_samples = self.leaf_scale_tau_samples - - if self.adaptive_coding: - self.b0_samples = self.b0_samples - self.b1_samples = self.b1_samples - def predict_tau( self, X: np.array, Z: np.array, propensity: np.array = None ) -> np.array: @@ -2311,7 +2334,7 @@ def predict_variance( pred_dataset.dataset_cpp ) if self.sample_sigma2_global: - variance_pred = variance_pred_raw + variance_pred = np.empty_like(variance_pred_raw) for i in range(self.num_samples): variance_pred[:, i] = ( variance_pred_raw[:, i] * self.global_var_samples[i] @@ -2463,7 +2486,7 @@ def predict( forest_dataset_test.dataset_cpp ) if self.sample_sigma2_global: - sigma2_x = sigma2_x_raw + sigma2_x = np.empty_like(sigma2_x_raw) for i in range(self.num_samples): sigma2_x[:, i] = sigma2_x_raw[:, i] * self.global_var_samples[i] else: @@ -2736,7 +2759,6 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.num_gfr = json_object_default.get_scalar("num_gfr") self.num_burnin = json_object_default.get_scalar("num_burnin") self.num_mcmc = json_object_default.get_scalar("num_mcmc") - self.num_samples = json_object_default.get_scalar("num_samples") self.adaptive_coding = json_object_default.get_boolean("adaptive_coding") self.propensity_covariate = json_object_default.get_string( "propensity_covariate" @@ -2744,6 +2766,13 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.internal_propensity_model = json_object_default.get_boolean( "internal_propensity_model" ) + + # Unpack number of samples + for i in range(len(json_object_list)): + if i == 0: + self.num_samples = json_object_list[i].get_integer("num_samples") + else: + self.num_samples += json_object_list[i].get_integer("num_samples") # Unpack parameter samples if self.sample_sigma2_global: diff --git a/stochtree/sampler.py b/stochtree/sampler.py index be55286a..8ac4f013 100644 --- a/stochtree/sampler.py +++ b/stochtree/sampler.py @@ -266,6 +266,17 @@ def propagate_basis_update( self.forest_sampler_cpp.PropagateBasisUpdate( dataset.dataset_cpp, residual.residual_cpp, forest.forest_cpp ) + + def get_cached_forest_predictions(self) -> np.array: + """ + Extract an internally-cached prediction of a forest on the training dataset in a sampler. + + Returns + ---------- + np.array + Numpy 1D array with as many elements as observations in the training dataset + """ + return self.forest_sampler_cpp.GetCachedForestPredictions() def update_alpha(self, alpha: float) -> None: """ diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R index 325bdbcf..88cbcd6a 100644 --- a/test/R/testthat/test-bart.R +++ b/test/R/testthat/test-bart.R @@ -291,3 +291,48 @@ test_that("Warmstart BART", { general_params = general_param_list) ) }) + +test_that("BART Predictions", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BART model with only GFR + general_params <- list(num_chains = 1) + variance_forest_params <- list(num_trees = 50) + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + general_params = general_params, + variance_forest_params = variance_forest_params) + + # Check that cached predictions agree with results of predict() function + train_preds <- predict(bart_model, X = X_train) + train_preds_mean_cached <- bart_model$y_hat_train + train_preds_mean_recomputed <- train_preds$mean_forest_predictions + train_preds_variance_cached <- bart_model$sigma2_x_hat_train + train_preds_variance_recomputed <- train_preds$variance_forest_predictions + + # Assertion + expect_equal(train_preds_mean_cached, train_preds_mean_recomputed) + expect_equal(train_preds_variance_cached, train_preds_variance_recomputed) +}) diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R index 24fabcd1..6f0a9ce8 100644 --- a/test/R/testthat/test-bcf.R +++ b/test/R/testthat/test-bcf.R @@ -426,4 +426,71 @@ test_that("Multivariate Treatment MCMC BCF", { propensity_test = pi_test, num_gfr = 0, num_burnin = 10, num_mcmc = 10, general_params = general_param_list) ) -}) \ No newline at end of file +}) + +test_that("BCF Predictions", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + mu_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + pi_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) + ) + tau_X <- ( + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) + ) + Z <- rbinom(n, 1, pi_X) + noise_sd <- 1 + y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BCF model with only GFR + general_params <- list(num_chains = 1, keep_every = 1) + variance_forest_params <- list(num_trees = 50) + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 0, + num_mcmc = 10, general_params = general_params, + variance_forest_params = variance_forest_params) + + # Check that cached predictions agree with results of predict() function + train_preds <- predict(bcf_model, X = X_train, Z = Z_train, propensity = pi_train) + train_preds_mean_cached <- bcf_model$y_hat_train + train_preds_mean_recomputed <- train_preds$y_hat + train_preds_variance_cached <- bcf_model$sigma2_x_hat_train + train_preds_variance_recomputed <- train_preds$variance_forest_predictions + + # Assertion + expect_equal(train_preds_mean_cached, train_preds_mean_recomputed) + expect_equal(train_preds_variance_cached, train_preds_variance_recomputed) +}) diff --git a/test/python/test_bart.py b/test/python/test_bart.py index a2f2e64c..31962891 100644 --- a/test/python/test_bart.py +++ b/test/python/test_bart.py @@ -406,14 +406,21 @@ def conditional_stddev(X): bart_model_3.from_json_string_list(bart_models_json) # Assertions - y_hat_train_combined, _ = bart_model_3.predict(covariates=X_train) + y_hat_train_combined, sigma2_x_train_combined = bart_model_3.predict(covariates=X_train) assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + assert sigma2_x_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train ) np.testing.assert_allclose( y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train ) + np.testing.assert_allclose( + sigma2_x_train_combined[:, 0:num_mcmc], bart_model.sigma2_x_train + ) + np.testing.assert_allclose( + sigma2_x_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.sigma2_x_train + ) np.testing.assert_allclose( bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples ) diff --git a/test/python/test_bcf.py b/test/python/test_bcf.py index 96f25a34..65f39390 100644 --- a/test/python/test_bcf.py +++ b/test/python/test_bcf.py @@ -577,3 +577,194 @@ def test_multivariate_bcf(self): num_mcmc=num_mcmc, variance_forest_params=variance_forest_params, ) + + def test_binary_bcf_heteroskedastic(self): + # RNG + random_seed = 101 + rng = np.random.default_rng(random_seed) + + # Generate covariates and basis + n = 100 + p_X = 5 + X = rng.uniform(0, 1, (n, p_X)) + pi_X = 0.25 + 0.5 * X[:, 0] + Z = rng.binomial(1, pi_X, n).astype(float) + + # Define the outcome mean functions (prognostic and treatment effects) + mu_X = pi_X * 5 + tau_X = X[:, 1] * 2 + + # Generate outcome + epsilon = rng.normal(0, 1, n) + y = mu_X + tau_X * Z + epsilon + + # Test-train split + sample_inds = np.arange(n) + train_inds, test_inds = train_test_split(sample_inds, test_size=0.5) + X_train = X[train_inds, :] + X_test = X[test_inds, :] + Z_train = Z[train_inds] + Z_test = Z[test_inds] + y_train = y[train_inds] + pi_train = pi_X[train_inds] + pi_test = pi_X[test_inds] + n_train = X_train.shape[0] + n_test = X_test.shape[0] + + # BCF settings + num_gfr = 10 + num_burnin = 0 + num_mcmc = 10 + + # Run BCF with test set and propensity score + bcf_model = BCFModel() + variance_forest_params = {"num_trees": 50} + bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + pi_train=pi_train, + X_test=X_test, + Z_test=Z_test, + pi_test=pi_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params, + ) + + # Assertions + assert bcf_model.y_hat_train.shape == (n_train, num_mcmc) + assert bcf_model.mu_hat_train.shape == (n_train, num_mcmc) + assert bcf_model.tau_hat_train.shape == (n_train, num_mcmc) + assert bcf_model.sigma2_x_train.shape == (n_train, num_mcmc) + assert bcf_model.y_hat_test.shape == (n_test, num_mcmc) + assert bcf_model.mu_hat_test.shape == (n_test, num_mcmc) + assert bcf_model.tau_hat_test.shape == (n_test, num_mcmc) + assert bcf_model.sigma2_x_test.shape == (n_train, num_mcmc) + + # Check overall prediction method + tau_hat, mu_hat, y_hat, sigma2_hat = bcf_model.predict(X_test, Z_test, pi_test) + assert tau_hat.shape == (n_test, num_mcmc) + assert mu_hat.shape == (n_test, num_mcmc) + assert y_hat.shape == (n_test, num_mcmc) + assert sigma2_hat.shape == (n_test, num_mcmc) + + # Check treatment effect prediction method + tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + assert tau_hat.shape == (n_test, num_mcmc) + + # Run BCF without test set and with propensity score + bcf_model = BCFModel() + variance_forest_params = {"num_trees": 50} + bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + pi_train=pi_train, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params, + ) + + # Assertions + assert bcf_model.y_hat_train.shape == (n_train, num_mcmc) + assert bcf_model.mu_hat_train.shape == (n_train, num_mcmc) + assert bcf_model.tau_hat_train.shape == (n_train, num_mcmc) + assert bcf_model.sigma2_x_train.shape == (n_train, num_mcmc) + + # Check overall prediction method + tau_hat, mu_hat, y_hat, sigma2_hat = bcf_model.predict(X_test, Z_test, pi_test) + assert tau_hat.shape == (n_test, num_mcmc) + assert mu_hat.shape == (n_test, num_mcmc) + assert y_hat.shape == (n_test, num_mcmc) + assert sigma2_hat.shape == (n_test, num_mcmc) + + # Check predictions match + tau_hat, mu_hat, y_hat, sigma2_hat = bcf_model.predict(X_train, Z_train, pi_train) + assert tau_hat.shape == (n_train, num_mcmc) + assert mu_hat.shape == (n_train, num_mcmc) + assert y_hat.shape == (n_train, num_mcmc) + assert sigma2_hat.shape == (n_train, num_mcmc) + np.testing.assert_allclose( + y_hat, bcf_model.y_hat_train + ) + np.testing.assert_allclose( + mu_hat, bcf_model.mu_hat_train + ) + np.testing.assert_allclose( + tau_hat, bcf_model.tau_hat_train + ) + np.testing.assert_allclose( + sigma2_hat, bcf_model.sigma2_x_train + ) + + # Check treatment effect prediction method + tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + assert tau_hat.shape == (n_test, num_mcmc) + + # Run BCF with test set and without propensity score + bcf_model = BCFModel() + variance_forest_params = {"num_trees": 50} + bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + X_test=X_test, + Z_test=Z_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params, + ) + + # Assertions + assert bcf_model.y_hat_train.shape == (n_train, num_mcmc) + assert bcf_model.mu_hat_train.shape == (n_train, num_mcmc) + assert bcf_model.tau_hat_train.shape == (n_train, num_mcmc) + assert bcf_model.tau_hat_train.shape == (n_train, num_mcmc) + assert bcf_model.bart_propensity_model.y_hat_train.shape == (n_train, 10) + assert bcf_model.y_hat_test.shape == (n_test, num_mcmc) + assert bcf_model.mu_hat_test.shape == (n_test, num_mcmc) + assert bcf_model.tau_hat_test.shape == (n_test, num_mcmc) + assert bcf_model.bart_propensity_model.y_hat_test.shape == (n_test, 10) + + # Check overall prediction method + tau_hat, mu_hat, y_hat, sigma2_hat = bcf_model.predict(X_test, Z_test) + assert tau_hat.shape == (n_test, num_mcmc) + assert mu_hat.shape == (n_test, num_mcmc) + assert y_hat.shape == (n_test, num_mcmc) + assert sigma2_hat.shape == (n_test, num_mcmc) + + # Check treatment effect prediction method + tau_hat = bcf_model.predict_tau(X_test, Z_test) + assert tau_hat.shape == (n_test, num_mcmc) + + # Run BCF without test set and without propensity score + bcf_model = BCFModel() + variance_forest_params = {"num_trees": 0} + bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params, + ) + + # Assertions + assert bcf_model.y_hat_train.shape == (n_train, num_mcmc) + assert bcf_model.mu_hat_train.shape == (n_train, num_mcmc) + assert bcf_model.tau_hat_train.shape == (n_train, num_mcmc) + assert bcf_model.bart_propensity_model.y_hat_train.shape == (n_train, 10) + + # Check overall prediction method + tau_hat, mu_hat, y_hat = bcf_model.predict(X_test, Z_test) + assert tau_hat.shape == (n_test, num_mcmc) + assert mu_hat.shape == (n_test, num_mcmc) + assert y_hat.shape == (n_test, num_mcmc) + + # Check treatment effect prediction method + tau_hat = bcf_model.predict_tau(X_test, Z_test) diff --git a/tools/perf/bart_profiling_script.R b/tools/perf/bart_profiling_script.R new file mode 100644 index 00000000..7a60eed2 --- /dev/null +++ b/tools/perf/bart_profiling_script.R @@ -0,0 +1,57 @@ +# Load libraries +library(stochtree) + +# Capture command line arguments +args <- commandArgs(trailingOnly = T) +if (length(args) > 0){ + n <- as.integer(args[1]) + p <- as.integer(args[2]) + num_gfr <- as.integer(args[3]) + num_mcmc <- as.integer(args[4]) + snr <- as.numeric(args[5]) +} else{ + # Default arguments + n <- 1000 + p <- 5 + num_gfr <- 10 + num_mcmc <- 100 + snr <- 3.0 +} +cat("n = ", n, "\np = ", p, "\nnum_gfr = ", num_gfr, + "\nnum_mcmc = ", num_mcmc, "\nsnr = ", snr, "\n", sep = "") + +# Generate data needed to train BART model +X <- matrix(runif(n*p), ncol = p) +plm_term <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*X[,2]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*X[,2]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*X[,2]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*X[,2]) +) +trig_term <- ( + 2*sin(X[,3]*2*pi) - + 1.5*cos(X[,4]*2*pi) +) +f_XW <- plm_term + trig_term +noise_sd <- sd(f_XW)/snr +y <- f_XW + rnorm(n, 0, noise_sd) + +# Split into train and test sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] + +system.time({ + # Sample BART model + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = num_gfr, num_mcmc = num_mcmc) + + # Predict on the test set + test_preds <- predict(bart_model, X = X_test) +}) \ No newline at end of file