From 51c7f40551bb21b914b306f1323e7d809631a5bf Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 19 Nov 2025 23:50:55 -0600 Subject: [PATCH 01/11] Switched order of python arguments to match R interface --- stochtree/bart.py | 6 +++--- stochtree/bcf.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index 656bcba7..fe1a851e 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -82,12 +82,12 @@ def sample( num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, + previous_model_json: Optional[str] = None, + previous_model_warmstart_sample_num: Optional[int] = None, general_params: Optional[Dict[str, Any]] = None, mean_forest_params: Optional[Dict[str, Any]] = None, variance_forest_params: Optional[Dict[str, Any]] = None, random_effects_params: Optional[Dict[str, Any]] = None, - previous_model_json: Optional[str] = None, - previous_model_warmstart_sample_num: Optional[int] = None, ) -> None: """Runs a BART sampler on provided training set. Predictions will be cached for the training set and (if provided) the test set. Does not require a leaf regression basis. @@ -2194,8 +2194,8 @@ def compute_contrast( def compute_posterior_interval( self, terms: Union[list[str], str] = "all", - scale: str = "linear", level: float = 0.95, + scale: str = "linear", covariates: np.array = None, basis: np.array = None, rfx_group_ids: np.array = None, diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 7f017c29..98442965 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -95,13 +95,13 @@ def sample( num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, + previous_model_json: Optional[str] = None, + previous_model_warmstart_sample_num: Optional[int] = None, general_params: Optional[Dict[str, Any]] = None, prognostic_forest_params: Optional[Dict[str, Any]] = None, treatment_effect_forest_params: Optional[Dict[str, Any]] = None, variance_forest_params: Optional[Dict[str, Any]] = None, random_effects_params: Optional[Dict[str, Any]] = None, - previous_model_json: Optional[str] = None, - previous_model_warmstart_sample_num: Optional[int] = None, ) -> None: """Runs a BCF sampler on provided training set. Outcome predictions and estimates of the prognostic and treatment effect functions will be cached for the training set and (if provided) the test set. @@ -3261,8 +3261,8 @@ def compute_contrast( def compute_posterior_interval( self, terms: Union[list[str], str] = "all", - scale: str = "linear", level: float = 0.95, + scale: str = "linear", covariates: np.array = None, treatment: np.array = None, propensity: np.array = None, From 13de68fc3203677bb0eff8d97899c293ff928948 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 20 Nov 2025 00:12:50 -0600 Subject: [PATCH 02/11] Harmonized BART sampler arguments between R and Python (and wherever they are called) --- R/bart.R | 26 ++++++------ demo/debug/bart_contrast_debug.py | 16 ++++---- demo/debug/bart_predict_debug.py | 8 ++-- demo/debug/gfr_ties_debug.py | 8 ++-- demo/debug/multi_chain.py | 4 +- demo/debug/multiple_initializations.py | 6 +-- demo/debug/parallel_multi_chain.py | 2 +- demo/debug/probit_bart_rfx_debug.py | 12 +++--- demo/debug/rfx_serialization.py | 4 +- demo/notebooks/multi_chain.ipynb | 8 ++-- man/predict.bartmodel.Rd | 4 +- stochtree/bart.py | 56 +++++++++++++------------- test/R/testthat/test-bart.R | 2 +- tools/debug/acic_bcf_surrogate_debug.R | 2 +- tools/debug/bart_contrast_debug.R | 8 ++-- tools/debug/bart_predict_debug.R | 4 +- tools/debug/gfr_ties_debug.R | 8 ++-- vignettes/MultiChain.Rmd | 8 ++-- 18 files changed, 93 insertions(+), 93 deletions(-) diff --git a/R/bart.R b/R/bart.R index e12eb041..6d6e2e2e 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1924,7 +1924,7 @@ bart <- function( #' Predict from a sampled BART model on new data #' #' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs. -#' @param covariates Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe. +#' @param X Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe. #' @param leaf_basis (Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: `NULL`. #' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model. #' We do not currently support (but plan to in the near future), test set evaluation for group labels @@ -1964,7 +1964,7 @@ bart <- function( #' y_hat_test <- predict(bart_model, X_test)$y_hat predict.bartmodel <- function( object, - covariates, + X, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, @@ -2047,8 +2047,8 @@ predict.bartmodel <- function( } # Check that covariates are matrix or data frame - if ((!is.data.frame(covariates)) && (!is.matrix(covariates))) { - stop("covariates must be a matrix or dataframe") + if ((!is.data.frame(X)) && (!is.matrix(X))) { + stop("X must be a matrix or dataframe") } # Convert all input data to matrices if not already converted @@ -2063,12 +2063,12 @@ predict.bartmodel <- function( if ((object$model_params$requires_basis) && (is.null(leaf_basis))) { stop("Basis (leaf_basis) must be provided for this model") } - if ((!is.null(leaf_basis)) && (nrow(covariates) != nrow(leaf_basis))) { - stop("covariates and leaf_basis must have the same number of rows") + if ((!is.null(leaf_basis)) && (nrow(X) != nrow(leaf_basis))) { + stop("X and leaf_basis must have the same number of rows") } - if (object$model_params$num_covariates != ncol(covariates)) { + if (object$model_params$num_covariates != ncol(X)) { stop( - "covariates must contain the same number of columns as the BART model's training dataset" + "X must contain the same number of columns as the BART model's training dataset" ) } if ((predict_rfx) && (is.null(rfx_group_ids))) { @@ -2089,7 +2089,7 @@ predict.bartmodel <- function( # Preprocess covariates train_set_metadata <- object$train_set_metadata - covariates <- preprocessPredictionData(covariates, train_set_metadata) + X <- preprocessPredictionData(X, train_set_metadata) # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) has_rfx <- FALSE @@ -2119,8 +2119,8 @@ predict.bartmodel <- function( # Only construct a basis if user-provided basis missing if (is.null(rfx_basis)) { rfx_basis <- matrix( - rep(1, nrow(covariates)), - nrow = nrow(covariates), + rep(1, nrow(X)), + nrow = nrow(X), ncol = 1 ) } @@ -2129,9 +2129,9 @@ predict.bartmodel <- function( # Create prediction dataset if (!is.null(leaf_basis)) { - prediction_dataset <- createForestDataset(covariates, leaf_basis) + prediction_dataset <- createForestDataset(X, leaf_basis) } else { - prediction_dataset <- createForestDataset(covariates) + prediction_dataset <- createForestDataset(X) } # Compute variance forest predictions diff --git a/demo/debug/bart_contrast_debug.py b/demo/debug/bart_contrast_debug.py index 15ce5705..bbfd82d4 100644 --- a/demo/debug/bart_contrast_debug.py +++ b/demo/debug/bart_contrast_debug.py @@ -65,15 +65,15 @@ # Compute the same quantity via two predict calls y_hat_posterior_test_0 = bart_model.predict( - covariates=X_test, - basis=np.zeros((n_test, 1)), + X=X_test, + leaf_basis=np.zeros((n_test, 1)), type="posterior", terms="y_hat", scale="linear", ) y_hat_posterior_test_1 = bart_model.predict( - covariates=X_test, - basis=np.ones((n_test, 1)), + X=X_test, + leaf_basis=np.ones((n_test, 1)), type="posterior", terms="y_hat", scale="linear", @@ -157,8 +157,8 @@ # Compute the same quantity via two predict calls y_hat_posterior_test_0 = bart_model.predict( - covariates=X_test, - basis=np.zeros((n_test, 1)), + X=X_test, + leaf_basis=np.zeros((n_test, 1)), rfx_group_ids=group_ids_test, rfx_basis=rfx_basis_test, type="posterior", @@ -166,8 +166,8 @@ scale="linear", ) y_hat_posterior_test_1 = bart_model.predict( - covariates=X_test, - basis=np.ones((n_test, 1)), + X=X_test, + leaf_basis=np.ones((n_test, 1)), rfx_group_ids=group_ids_test, rfx_basis=rfx_basis_test, type="posterior", diff --git a/demo/debug/bart_predict_debug.py b/demo/debug/bart_predict_debug.py index d66b1110..d58c5ef1 100644 --- a/demo/debug/bart_predict_debug.py +++ b/demo/debug/bart_predict_debug.py @@ -46,11 +46,11 @@ ) # # Check several predict approaches -bart_preds = bart_model.predict(covariates=X_test) -y_hat_posterior_test = bart_model.predict(covariates=X_test)["y_hat"] -y_hat_mean_test = bart_model.predict(covariates=X_test, type="mean", terms=["y_hat"]) +bart_preds = bart_model.predict(X=X_test) +y_hat_posterior_test = bart_model.predict(X=X_test)["y_hat"] +y_hat_mean_test = bart_model.predict(X=X_test, type="mean", terms=["y_hat"]) y_hat_test = bart_model.predict( - covariates=X_test, type="mean", terms=["rfx", "variance"] + X=X_test, type="mean", terms=["rfx", "variance"] ) # Plot predicted versus actual diff --git a/demo/debug/gfr_ties_debug.py b/demo/debug/gfr_ties_debug.py index fabc70b5..c69e3d67 100644 --- a/demo/debug/gfr_ties_debug.py +++ b/demo/debug/gfr_ties_debug.py @@ -38,7 +38,7 @@ ) # Inspect the model fit -y_hat_test = xbart_model.predict(X_test, type="mean", terms="y_hat") +y_hat_test = xbart_model.predict(X=X_test, type="mean", terms="y_hat") plt.scatter(y_hat_test, y_test) plt.axline((0, 0), slope=1, color="red", linestyle="dashed", linewidth=1.5) plt.xlabel("Predicted Outcome Mean") @@ -54,7 +54,7 @@ ) # Inspect the model fit -y_hat_test = bart_model.predict(X_test, type="mean", terms="y_hat") +y_hat_test = bart_model.predict(X=X_test, type="mean", terms="y_hat") plt.clf() plt.scatter(y_hat_test, y_test) plt.axline((0, 0), slope=1, color="red", linestyle="dashed", linewidth=1.5) @@ -95,7 +95,7 @@ ) # Inspect the model fit -y_hat_test = xbart_model.predict(X_test, type="mean", terms="y_hat") +y_hat_test = xbart_model.predict(X=X_test, type="mean", terms="y_hat") plt.clf() plt.scatter(y_hat_test, y_test) plt.axline((0, 0), slope=1, color="red", linestyle="dashed", linewidth=1.5) @@ -112,7 +112,7 @@ ) # Inspect the model fit -y_hat_test = bart_model.predict(X_test, type="mean", terms="y_hat") +y_hat_test = bart_model.predict(X=X_test, type="mean", terms="y_hat") plt.clf() plt.scatter(y_hat_test, y_test) plt.axline((0, 0), slope=1, color="red", linestyle="dashed", linewidth=1.5) diff --git a/demo/debug/multi_chain.py b/demo/debug/multi_chain.py index bb35ee9a..59d6e11d 100644 --- a/demo/debug/multi_chain.py +++ b/demo/debug/multi_chain.py @@ -89,8 +89,8 @@ def outcome_mean(X, W): # Analyze model predictions collectively across all chains y_hat_test = bart_model.predict( - covariates = X_test, - basis = basis_test, + X = X_test, + leaf_basis = basis_test, type = "mean", terms = "y_hat" ) diff --git a/demo/debug/multiple_initializations.py b/demo/debug/multiple_initializations.py index c499f45b..b489ee80 100644 --- a/demo/debug/multiple_initializations.py +++ b/demo/debug/multiple_initializations.py @@ -118,14 +118,14 @@ def outcome_mean(X, W): ) # Inspect the model outputs -bart_preds_2 = bart_model_2.predict(X_test, basis_test) +bart_preds_2 = bart_model_2.predict(X=X_test, basis_test) y_hat_mcmc_2 = bart_preds_2['y_hat'] y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True) y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True) -bart_preds_3 = bart_model_3.predict(X_test, basis_test) +bart_preds_3 = bart_model_3.predict(X=X_test, basis_test) y_hat_mcmc_3 = bart_preds_3['y_hat'] y_avg_mcmc_3 = np.squeeze(y_hat_mcmc_3).mean(axis=1, keepdims=True) -bart_preds_4 = bart_model_4.predict(X_test, basis_test) +bart_preds_4 = bart_model_4.predict(X=X_test, basis_test) y_hat_mcmc_4 = bart_preds_4['y_hat'] y_avg_mcmc_4 = np.squeeze(y_hat_mcmc_4).mean(axis=1, keepdims=True) y_df = pd.DataFrame( diff --git a/demo/debug/parallel_multi_chain.py b/demo/debug/parallel_multi_chain.py index ee618df5..e3148e5b 100644 --- a/demo/debug/parallel_multi_chain.py +++ b/demo/debug/parallel_multi_chain.py @@ -145,7 +145,7 @@ def outcome_mean(X, W): ) # Inspect the model outputs - bart_preds = combined_bart.predict(X_test, basis_test) + bart_preds = combined_bart.predict(X=X_test, leaf_basis=basis_test) y_hat_mcmc = bart_preds['y_hat'] y_avg_mcmc = np.squeeze(y_hat_mcmc).mean(axis=1, keepdims=True) y_df = pd.DataFrame( diff --git a/demo/debug/probit_bart_rfx_debug.py b/demo/debug/probit_bart_rfx_debug.py index ae2e8c10..bfe0be6c 100644 --- a/demo/debug/probit_bart_rfx_debug.py +++ b/demo/debug/probit_bart_rfx_debug.py @@ -86,8 +86,8 @@ # Compute the same quantity via two predict calls y_hat_posterior_test_0 = bart_model.predict( - covariates=X_test, - basis=np.zeros((n_test, 1)), + X=X_test, + leaf_basis=np.zeros((n_test, 1)), rfx_group_ids=group_ids_test, rfx_basis=rfx_basis_test, type="posterior", @@ -95,8 +95,8 @@ scale="linear", ) y_hat_posterior_test_1 = bart_model.predict( - covariates=X_test, - basis=np.ones((n_test, 1)), + X=X_test, + leaf_basis=np.ones((n_test, 1)), rfx_group_ids=group_ids_test, rfx_basis=rfx_basis_test, type="posterior", @@ -111,8 +111,8 @@ # Plot predicted versus actual outcome Z_hat_test = bart_model.predict( - covariates=X_test, - basis=W_test, + X=X_test, + leaf_basis=W_test, rfx_group_ids=group_ids_test, rfx_basis=rfx_basis_test, type="mean", diff --git a/demo/debug/rfx_serialization.py b/demo/debug/rfx_serialization.py index fec857b6..b6fc3d97 100644 --- a/demo/debug/rfx_serialization.py +++ b/demo/debug/rfx_serialization.py @@ -60,13 +60,13 @@ def rfx_mean(group_labels, basis): rfx_basis_train=basis, num_gfr=10, num_mcmc=10) # Extract predictions from the sampler -bart_preds_orig = bart_orig.predict(X, W, group_labels, basis) +bart_preds_orig = bart_orig.predict(X=X, leaf_basis=W, rfx_group_ids=group_labels, rfx_basis=basis) y_hat_orig = bart_preds_orig['y_hat'] # "Round-trip" the model to JSON string and back and check that the predictions agree bart_json_string = bart_orig.to_json() bart_reloaded = BARTModel() bart_reloaded.from_json(bart_json_string) -bart_preds_reloaded = bart_reloaded.predict(X, W, group_labels, basis) +bart_preds_reloaded = bart_reloaded.predict(X=X, leaf_basis=W, rfx_group_ids=group_labels, rfx_basis=basis) y_hat_reloaded = bart_preds_reloaded['y_hat'] np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded) \ No newline at end of file diff --git a/demo/notebooks/multi_chain.ipynb b/demo/notebooks/multi_chain.ipynb index 85aebd8e..afe4c741 100644 --- a/demo/notebooks/multi_chain.ipynb +++ b/demo/notebooks/multi_chain.ipynb @@ -161,8 +161,8 @@ "outputs": [], "source": [ "y_hat_test = bart_model.predict(\n", - " covariates = X_test,\n", - " basis = leaf_basis_test, \n", + " X = X_test,\n", + " leaf_basis = leaf_basis_test, \n", " type = \"mean\", \n", " terms = \"y_hat\"\n", ")\n", @@ -321,8 +321,8 @@ "outputs": [], "source": [ "y_hat_test = bart_model.predict(\n", - " covariates = X_test,\n", - " basis = leaf_basis_test, \n", + " X = X_test,\n", + " leaf_basis = leaf_basis_test, \n", " type = \"mean\", \n", " terms = \"y_hat\"\n", ")\n", diff --git a/man/predict.bartmodel.Rd b/man/predict.bartmodel.Rd index 0cb82678..1c520be8 100644 --- a/man/predict.bartmodel.Rd +++ b/man/predict.bartmodel.Rd @@ -6,7 +6,7 @@ \usage{ \method{predict}{bartmodel}( object, - covariates, + X, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, @@ -19,7 +19,7 @@ \arguments{ \item{object}{Object of type \code{bart} containing draws of a regression forest and associated sampling outputs.} -\item{covariates}{Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.} +\item{X}{Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.} \item{leaf_basis}{(Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: \code{NULL}.} diff --git a/stochtree/bart.py b/stochtree/bart.py index fe1a851e..2988a5fd 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1743,8 +1743,8 @@ def sample( def predict( self, - covariates: Union[np.array, pd.DataFrame], - basis: np.array = None, + X: Union[np.array, pd.DataFrame], + leaf_basis: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, type: str = "posterior", @@ -1757,9 +1757,9 @@ def predict( Parameters ---------- - covariates : np.array + X : np.array Test set covariates. - basis : np.array, optional + leaf_basis : np.array, optional Optional test set basis vector, must be provided if the model was trained with a leaf regression basis. rfx_group_ids : np.array, optional Optional group labels used for an additive random effects model. @@ -1861,29 +1861,29 @@ def predict( raise NotSampledError(msg) # Data checks - if not isinstance(covariates, pd.DataFrame) and not isinstance( - covariates, np.ndarray + if not isinstance(X, pd.DataFrame) and not isinstance( + X, np.ndarray ): - raise ValueError("covariates must be a pandas dataframe or numpy array") - if basis is not None: - if not isinstance(basis, np.ndarray): - raise ValueError("basis must be a numpy array") - if basis.shape[0] != covariates.shape[0]: + raise ValueError("X must be a pandas dataframe or numpy array") + if leaf_basis is not None: + if not isinstance(leaf_basis, np.ndarray): + raise ValueError("leaf_basis must be a numpy array") + if leaf_basis.shape[0] != X.shape[0]: raise ValueError( - "covariates and basis must have the same number of rows" + "X and leaf_basis must have the same number of rows" ) # Convert everything to standard shape (2-dimensional) - if isinstance(covariates, np.ndarray): - if covariates.ndim == 1: - covariates = np.expand_dims(covariates, 1) - if basis is not None: - if basis.ndim == 1: - basis = np.expand_dims(basis, 1) + if isinstance(X, np.ndarray): + if X.ndim == 1: + X = np.expand_dims(X, 1) + if leaf_basis is not None: + if leaf_basis.ndim == 1: + leaf_basis = np.expand_dims(leaf_basis, 1) # Covariate preprocessing if not self._covariate_preprocessor._check_is_fitted(): - if not isinstance(covariates, np.ndarray): + if not isinstance(X, np.ndarray): raise ValueError( "Prediction cannot proceed on a pandas dataframe, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe." ) @@ -1893,20 +1893,20 @@ def predict( RuntimeWarning, ) if not np.issubdtype( - covariates.dtype, np.floating - ) and not np.issubdtype(covariates.dtype, np.integer): + X.dtype, np.floating + ) and not np.issubdtype(X.dtype, np.integer): raise ValueError( "Prediction cannot proceed on a non-numeric numpy array, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe." ) - covariates_processed = covariates + X_processed = X else: - covariates_processed = self._covariate_preprocessor.transform(covariates) + X_processed = self._covariate_preprocessor.transform(X) # Dataset construction pred_dataset = Dataset() - pred_dataset.add_covariates(covariates_processed) - if basis is not None: - pred_dataset.add_basis(basis) + pred_dataset.add_covariates(X_processed) + if leaf_basis is not None: + pred_dataset.add_basis(leaf_basis) # Variance forest predictions if predict_variance_forest: @@ -1946,7 +1946,7 @@ def predict( if rfx_basis is not None: if rfx_basis.ndim == 1: rfx_basis = np.expand_dims(rfx_basis, 1) - if rfx_basis.shape[0] != covariates.shape[0]: + if rfx_basis.shape[0] != X.shape[0]: raise ValueError("X and rfx_basis must have the same number of rows") if rfx_basis.shape[1] != self.num_rfx_basis: raise ValueError( @@ -1971,7 +1971,7 @@ def predict( rfx_beta_draws = rfx_samples_raw["beta_samples"] * self.y_std # Construct an array with the appropriate group random effects arranged for each observation - n_train = covariates.shape[0] + n_train = X.shape[0] if rfx_beta_draws.ndim != 2: raise ValueError( "BART models fit with random intercept models should only yield 2 dimensional random effect sample matrices" diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R index 23013ec2..7009ad3a 100644 --- a/test/R/testthat/test-bart.R +++ b/test/R/testthat/test-bart.R @@ -433,7 +433,7 @@ test_that("BART Predictions", { ) # Check that cached predictions agree with results of predict() function - train_preds <- predict(bart_model, covariates = X_train) + train_preds <- predict(bart_model, X = X_train) train_preds_mean_cached <- bart_model$y_hat_train train_preds_mean_recomputed <- train_preds$mean_forest_predictions train_preds_variance_cached <- bart_model$sigma2_x_hat_train diff --git a/tools/debug/acic_bcf_surrogate_debug.R b/tools/debug/acic_bcf_surrogate_debug.R index 1678a8cb..32fece42 100644 --- a/tools/debug/acic_bcf_surrogate_debug.R +++ b/tools/debug/acic_bcf_surrogate_debug.R @@ -66,7 +66,7 @@ propensity_model <- stochtree::bart( ) propensity <- predict( propensity_model, - covariates = covariate_df, + X = covariate_df, type = "mean", terms = "y_hat" ) diff --git a/tools/debug/bart_contrast_debug.R b/tools/debug/bart_contrast_debug.R index 647d12b0..79f6fe10 100644 --- a/tools/debug/bart_contrast_debug.R +++ b/tools/debug/bart_contrast_debug.R @@ -56,7 +56,7 @@ contrast_posterior_test <- compute_contrast_bart_model( # Compute the same quantity via two predict calls y_hat_posterior_test_0 <- predict( bart_model, - covariates = X_test, + X = X_test, leaf_basis = matrix(0, nrow = n_test, ncol = 1), type = "posterior", term = "y_hat", @@ -64,7 +64,7 @@ y_hat_posterior_test_0 <- predict( ) y_hat_posterior_test_1 <- predict( bart_model, - covariates = X_test, + X = X_test, leaf_basis = matrix(1, nrow = n_test, ncol = 1), type = "posterior", term = "y_hat", @@ -143,7 +143,7 @@ contrast_posterior_test <- compute_contrast_bart_model( # Compute the same quantity via two predict calls y_hat_posterior_test_0 <- predict( bart_model, - covariates = X_test, + X = X_test, leaf_basis = matrix(0, nrow = n_test, ncol = 1), rfx_group_ids = group_ids_test, rfx_basis = rfx_basis_test, @@ -153,7 +153,7 @@ y_hat_posterior_test_0 <- predict( ) y_hat_posterior_test_1 <- predict( bart_model, - covariates = X_test, + X = X_test, leaf_basis = matrix(1, nrow = n_test, ncol = 1), rfx_group_ids = group_ids_test, rfx_basis = rfx_basis_test, diff --git a/tools/debug/bart_predict_debug.R b/tools/debug/bart_predict_debug.R index 89766a74..bb4c5c1e 100644 --- a/tools/debug/bart_predict_debug.R +++ b/tools/debug/bart_predict_debug.R @@ -119,7 +119,7 @@ y_hat_post <- predict( object = bart_model, type = "posterior", terms = c("y_hat"), - covariates = X_test, + X = X_test, scale = "linear" ) @@ -128,7 +128,7 @@ y_hat_post_prob <- predict( object = bart_model, type = "posterior", terms = c("y_hat"), - covariates = X_test, + X = X_test, scale = "probability" ) diff --git a/tools/debug/gfr_ties_debug.R b/tools/debug/gfr_ties_debug.R index 833bf533..757faaa6 100644 --- a/tools/debug/gfr_ties_debug.R +++ b/tools/debug/gfr_ties_debug.R @@ -38,7 +38,7 @@ xbart_model <- bart( # Inspect the model fit y_hat_test <- predict( xbart_model, - covariates = X_test, + X = X_test, type = "mean", terms = "y_hat" ) @@ -57,7 +57,7 @@ bart_model <- bart( # Inspect the model fit y_hat_test <- predict( bart_model, - covariates = X_test, + X = X_test, type = "mean", terms = "y_hat" ) @@ -100,7 +100,7 @@ xbart_model <- bart( # Inspect the model fit y_hat_test <- predict( xbart_model, - covariates = X_test, + X = X_test, type = "mean", terms = "y_hat" ) @@ -119,7 +119,7 @@ bart_model <- bart( # Inspect the model fit y_hat_test <- predict( bart_model, - covariates = X_test, + X = X_test, type = "mean", terms = "y_hat" ) diff --git a/vignettes/MultiChain.Rmd b/vignettes/MultiChain.Rmd index bedf5220..83c60e6c 100644 --- a/vignettes/MultiChain.Rmd +++ b/vignettes/MultiChain.Rmd @@ -129,7 +129,7 @@ predictions. ```{r} y_hat_test <- predict( bart_model, - covariates = X_test, + X = X_test, leaf_basis = leaf_basis_test, type = "mean", terms = "y_hat" @@ -217,7 +217,7 @@ abs_test_set_resid <- abs(y_test - y_hat_test) top5_resids <- order(abs_test_set_resid, decreasing = T)[1:5] y_hat_test_posterior <- predict( bart_model, - covariates = X_test[top5_resids, ], + X = X_test[top5_resids, ], leaf_basis = leaf_basis_test[top5_resids], type = "posterior", terms = "y_hat" @@ -345,7 +345,7 @@ predictions. ```{r} y_hat_test <- predict( bart_model, - covariates = X_test, + X = X_test, leaf_basis = leaf_basis_test, type = "mean", terms = "y_hat" @@ -433,7 +433,7 @@ abs_test_set_resid <- abs(y_test - y_hat_test) top5_resids <- order(abs_test_set_resid, decreasing = T)[1:5] y_hat_test_posterior <- predict( bart_model, - covariates = X_test[top5_resids, ], + X = X_test[top5_resids, ], leaf_basis = leaf_basis_test[top5_resids], type = "posterior", terms = "y_hat" From 078ed72e48be778cd4974d846a74b2f169d7d43e Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 20 Nov 2025 00:24:37 -0600 Subject: [PATCH 03/11] Updated predict interface in R and Python and tests / demos that call it --- R/bart.R | 4 +- R/kernel.R | 2 +- man/createBARTModelFromJsonString.Rd | 2 +- man/predict.bartmodel.Rd | 2 +- test/R/testthat/test-bart.R | 2 +- test/R/testthat/test-predict.R | 14 +- test/R/testthat/test-serialization.R | 4 +- tools/debug/bart_predict_debug.R | 10 +- tools/debug/parallel_warmstart.R | 180 +++++++++++++++++------- tools/debug/parallel_warmstart_bcf.R | 200 +++++++++++++++++++-------- 10 files changed, 293 insertions(+), 127 deletions(-) diff --git a/R/bart.R b/R/bart.R index 6d6e2e2e..47ee41e0 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1961,7 +1961,7 @@ bart <- function( #' y_train <- y[train_inds] #' bart_model <- bart(X_train = X_train, y_train = y_train, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) -#' y_hat_test <- predict(bart_model, X_test)$y_hat +#' y_hat_test <- predict(bart_model, X=X_test)$y_hat predict.bartmodel <- function( object, X, @@ -2843,7 +2843,7 @@ createBARTModelFromJsonFile <- function(json_filename) { #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' bart_json <- saveBARTModelToJsonString(bart_model) #' bart_model_roundtrip <- createBARTModelFromJsonString(bart_json) -#' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat) +#' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X=X_train)$y_hat) createBARTModelFromJsonString <- function(json_string) { # Load a `CppJson` object from string bart_json <- createCppJsonString(json_string) diff --git a/R/kernel.R b/R/kernel.R index 2b643b98..7ab21370 100644 --- a/R/kernel.R +++ b/R/kernel.R @@ -129,7 +129,7 @@ computeForestLeafIndices <- function( propensity <- rowMeans( predict( model_object$bart_propensity_model, - covariates + X = covariates )$y_hat ) } diff --git a/man/createBARTModelFromJsonString.Rd b/man/createBARTModelFromJsonString.Rd index 0748d97a..7a09d9c9 100644 --- a/man/createBARTModelFromJsonString.Rd +++ b/man/createBARTModelFromJsonString.Rd @@ -42,5 +42,5 @@ bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJsonString(bart_model) bart_model_roundtrip <- createBARTModelFromJsonString(bart_json) -y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat) +y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X=X_train)$y_hat) } diff --git a/man/predict.bartmodel.Rd b/man/predict.bartmodel.Rd index 1c520be8..c1bdfd09 100644 --- a/man/predict.bartmodel.Rd +++ b/man/predict.bartmodel.Rd @@ -66,5 +66,5 @@ y_test <- y[test_inds] y_train <- y[train_inds] bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) -y_hat_test <- predict(bart_model, X_test)$y_hat +y_hat_test <- predict(bart_model, X=X_test)$y_hat } diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R index 7009ad3a..36b7bfe8 100644 --- a/test/R/testthat/test-bart.R +++ b/test/R/testthat/test-bart.R @@ -584,7 +584,7 @@ test_that("Random Effects BART", { ) preds <- predict( bart_model, - covariates = X_test, + X = X_test, leaf_basis = W_test, rfx_group_ids = rfx_group_ids_test, type = "posterior", diff --git a/test/R/testthat/test-predict.R b/test/R/testthat/test-predict.R index bdd9d66b..63ff0f94 100644 --- a/test/R/testthat/test-predict.R +++ b/test/R/testthat/test-predict.R @@ -216,12 +216,12 @@ test_that("BART predictions with pre-summarization", { ) # Check that the default predict method returns a list - pred <- predict(bart_model, X_test) + pred <- predict(bart_model, X = X_test) y_hat_posterior_test <- pred$y_hat expect_equal(dim(y_hat_posterior_test), c(20, 10)) # Check that the pre-aggregated predictions match with those computed by rowMeans - pred_mean <- predict(bart_model, X_test, type = "mean") + pred_mean <- predict(bart_model, X = X_test, type = "mean") y_hat_mean_test <- pred_mean$y_hat expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) @@ -229,7 +229,7 @@ test_that("BART predictions with pre-summarization", { expect_warning({ pred_mean <- predict( bart_model, - X_test, + X = X_test, type = "mean", terms = c("rfx", "variance_forest") ) @@ -248,7 +248,7 @@ test_that("BART predictions with pre-summarization", { ) # Check that the default predict method returns a list - pred <- predict(het_bart_model, X_test) + pred <- predict(het_bart_model, X = X_test) y_hat_posterior_test <- pred$y_hat sigma2_hat_posterior_test <- pred$variance_forest_predictions @@ -257,7 +257,7 @@ test_that("BART predictions with pre-summarization", { expect_equal(dim(sigma2_hat_posterior_test), c(20, 10)) # Check that the pre-aggregated predictions match with those computed by rowMeans - pred_mean <- predict(het_bart_model, X_test, type = "mean") + pred_mean <- predict(het_bart_model, X = X_test, type = "mean") y_hat_mean_test <- pred_mean$y_hat sigma2_hat_mean_test <- pred_mean$variance_forest_predictions @@ -269,13 +269,13 @@ test_that("BART predictions with pre-summarization", { # match those computed by pre-aggregated predictions returned in a list y_hat_mean_test_single_term <- predict( het_bart_model, - X_test, + X = X_test, type = "mean", terms = "y_hat" ) sigma2_hat_mean_test_single_term <- predict( het_bart_model, - X_test, + X = X_test, type = "mean", terms = "variance_forest" ) diff --git a/test/R/testthat/test-serialization.R b/test/R/testthat/test-serialization.R index fe50af5f..2f0d4aaa 100644 --- a/test/R/testthat/test-serialization.R +++ b/test/R/testthat/test-serialization.R @@ -34,7 +34,7 @@ test_that("BART Serialization", { num_mcmc = 10, general_params = general_param_list ) - y_hat_orig <- rowMeans(predict(bart_model, X_test)$y_hat) + y_hat_orig <- rowMeans(predict(bart_model, X = X_test)$y_hat) # Save to JSON bart_json_string <- saveBARTModelToJsonString(bart_model) @@ -43,7 +43,7 @@ test_that("BART Serialization", { bart_model_roundtrip <- createBARTModelFromJsonString(bart_json_string) # Predict from the roundtrip BART model - y_hat_reloaded <- rowMeans(predict(bart_model_roundtrip, X_test)$y_hat) + y_hat_reloaded <- rowMeans(predict(bart_model_roundtrip, X = X_test)$y_hat) # Assertion expect_equal(y_hat_orig, y_hat_reloaded) diff --git a/tools/debug/bart_predict_debug.R b/tools/debug/bart_predict_debug.R index bb4c5c1e..49bea0b6 100644 --- a/tools/debug/bart_predict_debug.R +++ b/tools/debug/bart_predict_debug.R @@ -38,16 +38,16 @@ bart_model <- bart( ) # Check several predict approaches -y_hat_posterior_test <- predict(bart_model, X_test)$y_hat +y_hat_posterior_test <- predict(bart_model, X = X_test)$y_hat y_hat_mean_test <- predict( bart_model, - X_test, + X = X_test, type = "mean", terms = c("y_hat") ) y_hat_test <- predict( bart_model, - X_test, + X = X_test, type = "mean", terms = c("rfx", "variance") ) @@ -117,18 +117,18 @@ bart_model <- bart( # Predict on latent scale y_hat_post <- predict( object = bart_model, + X = X_test, type = "posterior", terms = c("y_hat"), - X = X_test, scale = "linear" ) # Predict on probability scale y_hat_post_prob <- predict( object = bart_model, + X = X_test, type = "posterior", terms = c("y_hat"), - X = X_test, scale = "probability" ) diff --git a/tools/debug/parallel_warmstart.R b/tools/debug/parallel_warmstart.R index 18b73574..243180db 100644 --- a/tools/debug/parallel_warmstart.R +++ b/tools/debug/parallel_warmstart.R @@ -14,35 +14,55 @@ num_trees <- 100 n <- 500 p_x <- 20 snr <- 2 -X <- matrix(runif(n*p_x), ncol = p_x) -f_XW <- sin(4*pi*X[,1]) + cos(4*pi*X[,2]) + sin(4*pi*X[,3]) +cos(4*pi*X[,4]) +X <- matrix(runif(n * p_x), ncol = p_x) +f_XW <- sin(4 * pi * X[, 1]) + + cos(4 * pi * X[, 2]) + + sin(4 * pi * X[, 3]) + + cos(4 * pi * X[, 4]) noise_sd <- sd(f_XW) / snr -y <- f_XW + rnorm(n, 0, 1)*noise_sd +y <- f_XW + rnorm(n, 0, 1) * noise_sd # Split data into test and train sets test_set_pct <- 0.2 -n_test <- round(test_set_pct*n) +n_test <- round(test_set_pct * n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] -X_test <- as.data.frame(X[test_inds,]) -X_train <- as.data.frame(X[train_inds,]) +X_test <- as.data.frame(X[test_inds, ]) +X_train <- as.data.frame(X[train_inds, ]) y_test <- y[test_inds] y_train <- y[train_inds] # Run the GFR algorithm -xbart_params <- list(sample_sigma_global = T, - num_trees_mean = num_trees, alpha_mean = 0.99, - beta_mean = 1, max_depth_mean = -1, - min_samples_leaf_mean = 1, sample_sigma_leaf = F, - sigma_leaf_init = 1/num_trees) +xbart_params <- list( + sample_sigma_global = T, + num_trees_mean = num_trees, + alpha_mean = 0.99, + beta_mean = 1, + max_depth_mean = -1, + min_samples_leaf_mean = 1, + sample_sigma_leaf = F, + sigma_leaf_init = 1 / num_trees +) xbart_model <- stochtree::bart( - X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = num_gfr, num_burnin = 0, num_mcmc = 0, params = xbart_params + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = num_gfr, + num_burnin = 0, + num_mcmc = 0, + params = xbart_params ) -plot(rowMeans(xbart_model$y_hat_test), y_test); abline(0,1) +plot(rowMeans(xbart_model$y_hat_test), y_test) +abline(0, 1) cat(sqrt(mean((rowMeans(xbart_model$y_hat_test) - y_test)^2)), "\n") -cat(mean((apply(xbart_model$y_hat_test, 1, quantile, probs=0.05) <= y_test) & (apply(xbart_model$y_hat_test, 1, quantile, probs=0.95) >= y_test)), "\n") +cat( + mean( + (apply(xbart_model$y_hat_test, 1, quantile, probs = 0.05) <= y_test) & + (apply(xbart_model$y_hat_test, 1, quantile, probs = 0.95) >= y_test) + ), + "\n" +) xbart_model_string <- stochtree::saveBARTModelToJsonString(xbart_model) # Parallel setup @@ -51,20 +71,32 @@ cl <- makeCluster(ncores) registerDoParallel(cl) # Run the parallel BART MCMC samplers -bart_model_outputs <- foreach (i = 1:num_chains) %dopar% { +bart_model_outputs <- foreach(i = 1:num_chains) %dopar% + { random_seed <- i - bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, - num_trees_mean = num_trees, random_seed = random_seed, - alpha_mean = 0.999, beta_mean = 1) + bart_params <- list( + sample_sigma_global = T, + sample_sigma_leaf = T, + num_trees_mean = num_trees, + random_seed = random_seed, + alpha_mean = 0.999, + beta_mean = 1 + ) bart_model <- stochtree::bart( - X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = num_burnin, num_mcmc = num_mcmc, params = bart_params, - previous_model_json = xbart_model_string, warmstart_sample_num = num_gfr - i + 1, + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + params = bart_params, + previous_model_json = xbart_model_string, + warmstart_sample_num = num_gfr - i + 1, ) bart_model_string <- stochtree::saveBARTModelToJsonString(bart_model) y_hat_test <- bart_model$y_hat_test - list(model=bart_model_string, yhat=y_hat_test) -} + list(model = bart_model_string, yhat = y_hat_test) + } # Close the cluster connection stopCluster(cl) @@ -73,43 +105,89 @@ stopCluster(cl) bart_model_strings <- list() bart_model_yhats <- matrix(NA, nrow = length(y_test), ncol = num_chains) for (i in 1:length(bart_model_outputs)) { - bart_model_strings[[i]] <- bart_model_outputs[[i]]$model - bart_model_yhats[,i] <- rowMeans(bart_model_outputs[[i]]$yhat) + bart_model_strings[[i]] <- bart_model_outputs[[i]]$model + bart_model_yhats[, i] <- rowMeans(bart_model_outputs[[i]]$yhat) } combined_bart <- createBARTModelFromCombinedJsonString(bart_model_strings) # Inspect the results -yhat_combined <- predict(combined_bart, X_test)$y_hat -par(mfrow = c(1,2)) +yhat_combined <- predict(combined_bart, X = X_test)$y_hat +par(mfrow = c(1, 2)) for (i in 1:num_chains) { - offset <- (i-1)*num_mcmc - inds_start <- offset + 1 - inds_end <- offset + num_mcmc - plot(rowMeans(yhat_combined[,inds_start:inds_end]), bart_model_yhats[,i], - xlab = "deserialized", ylab = "original", - main = paste0("Chain ", i, "\nPredictions")) - abline(0,1,col="red",lty=3,lwd=3) + offset <- (i - 1) * num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot( + rowMeans(yhat_combined[, inds_start:inds_end]), + bart_model_yhats[, i], + xlab = "deserialized", + ylab = "original", + main = paste0("Chain ", i, "\nPredictions") + ) + abline(0, 1, col = "red", lty = 3, lwd = 3) } for (i in 1:num_chains) { - offset <- (i-1)*num_mcmc - inds_start <- offset + 1 - inds_end <- offset + num_mcmc - plot(rowMeans(yhat_combined[,inds_start:inds_end]), y_test, - xlab = "predicted", ylab = "actual", - main = paste0("Chain ", i, "\nPredictions")) - abline(0,1,col="red",lty=3,lwd=3) - cat(sqrt(mean((rowMeans(yhat_combined[,inds_start:inds_end]) - y_test)^2)), "\n") - cat(mean((apply(yhat_combined[,inds_start:inds_end], 1, quantile, probs=0.05) <= y_test) & (apply(yhat_combined[,inds_start:inds_end], 1, quantile, probs=0.95) >= y_test)), "\n") + offset <- (i - 1) * num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot( + rowMeans(yhat_combined[, inds_start:inds_end]), + y_test, + xlab = "predicted", + ylab = "actual", + main = paste0("Chain ", i, "\nPredictions") + ) + abline(0, 1, col = "red", lty = 3, lwd = 3) + cat( + sqrt(mean((rowMeans(yhat_combined[, inds_start:inds_end]) - y_test)^2)), + "\n" + ) + cat( + mean( + (apply(yhat_combined[, inds_start:inds_end], 1, quantile, probs = 0.05) <= + y_test) & + (apply( + yhat_combined[, inds_start:inds_end], + 1, + quantile, + probs = 0.95 + ) >= + y_test) + ), + "\n" + ) } -par(mfrow = c(1,1)) +par(mfrow = c(1, 1)) # Compare to a single chain of MCMC samples initialized at root -bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, - num_trees_mean = num_trees, alpha_mean = 0.95, beta_mean = 2) +bart_params <- list( + sample_sigma_global = T, + sample_sigma_leaf = T, + num_trees_mean = num_trees, + alpha_mean = 0.95, + beta_mean = 2 +) bart_model <- stochtree::bart( - X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, params = bart_params + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + params = bart_params +) +plot( + rowMeans(bart_model$y_hat_test), + y_test, + xlab = "predicted", + ylab = "actual" ) -plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual"); abline(0,1) +abline(0, 1) cat(sqrt(mean((rowMeans(bart_model$y_hat_test) - y_test)^2)), "\n") -cat(mean((apply(bart_model$y_hat_test, 1, quantile, probs=0.05) <= y_test) & (apply(bart_model$y_hat_test, 1, quantile, probs=0.95) >= y_test)), "\n") +cat( + mean( + (apply(bart_model$y_hat_test, 1, quantile, probs = 0.05) <= y_test) & + (apply(bart_model$y_hat_test, 1, quantile, probs = 0.95) >= y_test) + ), + "\n" +) diff --git a/tools/debug/parallel_warmstart_bcf.R b/tools/debug/parallel_warmstart_bcf.R index 9d002b32..1abf9213 100644 --- a/tools/debug/parallel_warmstart_bcf.R +++ b/tools/debug/parallel_warmstart_bcf.R @@ -16,28 +16,32 @@ n <- 500 x1 <- rnorm(n) x2 <- rnorm(n) x3 <- rnorm(n) -x4 <- rnorm(n,x2,1) -X <- cbind(x1,x2,x3,x4) +x4 <- rnorm(n, x2, 1) +X <- cbind(x1, x2, x3, x4) p <- ncol(X) -mu <- function(x) {-1*(x[,1]>(x[,2])) + 1*(x[,1]<(x[,2])) - 0.1} -tau <- function(x) {1/(1 + exp(-x[,3])) + x[,2]/10} +mu <- function(x) { + -1 * (x[, 1] > (x[, 2])) + 1 * (x[, 1] < (x[, 2])) - 0.1 +} +tau <- function(x) { + 1 / (1 + exp(-x[, 3])) + x[, 2] / 10 +} mu_x <- mu(X) tau_x <- tau(X) pi_x <- pnorm(mu_x) -Z <- rbinom(n,1,pi_x) -E_XZ <- mu_x + Z*tau_x -sigma <- diff(range(mu_x + tau_x*pi))/8 -y <- E_XZ + sigma*rnorm(n) +Z <- rbinom(n, 1, pi_x) +E_XZ <- mu_x + Z * tau_x +sigma <- diff(range(mu_x + tau_x * pi)) / 8 +y <- E_XZ + sigma * rnorm(n) X <- as.data.frame(X) # Split data into test and train sets test_set_pct <- 0.2 -n_test <- round(test_set_pct*n) +n_test <- round(test_set_pct * n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] -X_test <- X[test_inds,] -X_train <- X[train_inds,] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] pi_test <- pi_x[test_inds] pi_train <- pi_x[train_inds] Z_test <- Z[test_inds] @@ -50,17 +54,39 @@ tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] # Run the GFR algorithm -xbcf_params <- list(num_trees_mu = num_trees_mu, num_trees_tau = num_trees_tau, - alpha_mu = 0.95, beta_mu = 1, max_depth_mu = -1, - alpha_tau = 0.8, beta_tau = 2, max_depth_tau = 10) +xbcf_params <- list( + num_trees_mu = num_trees_mu, + num_trees_tau = num_trees_tau, + alpha_mu = 0.95, + beta_mu = 1, + max_depth_mu = -1, + alpha_tau = 0.8, + beta_tau = 2, + max_depth_tau = 10 +) xbcf_model <- stochtree::bcf( - X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, - X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, - num_burnin = 0, num_mcmc = 0, params = xbcf_params + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + pi_train = pi_train, + X_test = X_test, + Z_test = Z_test, + pi_test = pi_test, + num_gfr = num_gfr, + num_burnin = 0, + num_mcmc = 0, + params = xbcf_params ) -plot(rowMeans(xbcf_model$y_hat_test), y_test); abline(0,1) +plot(rowMeans(xbcf_model$y_hat_test), y_test) +abline(0, 1) cat(sqrt(mean((rowMeans(xbcf_model$y_hat_test) - y_test)^2)), "\n") -cat(mean((apply(xbcf_model$y_hat_test, 1, quantile, probs=0.05) <= y_test) & (apply(xbcf_model$y_hat_test, 1, quantile, probs=0.95) >= y_test)), "\n") +cat( + mean( + (apply(xbcf_model$y_hat_test, 1, quantile, probs = 0.05) <= y_test) & + (apply(xbcf_model$y_hat_test, 1, quantile, probs = 0.95) >= y_test) + ), + "\n" +) xbcf_model_string <- stochtree::saveBCFModelToJsonString(xbcf_model) # Parallel setup @@ -69,20 +95,33 @@ cl <- makeCluster(ncores) registerDoParallel(cl) # Run the parallel BART MCMC samplers -bcf_model_outputs <- foreach (i = 1:num_chains) %dopar% { +bcf_model_outputs <- foreach(i = 1:num_chains) %dopar% + { random_seed <- i - bcf_params <- list(num_trees_mu = num_trees_mu, num_trees_tau = num_trees_tau, - random_seed = random_seed) + bcf_params <- list( + num_trees_mu = num_trees_mu, + num_trees_tau = num_trees_tau, + random_seed = random_seed + ) bcf_model <- stochtree::bcf( - X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, - X_test = X_test, Z_test = Z_test, pi_test = pi_test, - num_gfr = 0, num_burnin = num_burnin, num_mcmc = num_mcmc, params = bcf_params, - previous_model_json = xbcf_model_string, warmstart_sample_num = num_gfr - i + 1, + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + pi_train = pi_train, + X_test = X_test, + Z_test = Z_test, + pi_test = pi_test, + num_gfr = 0, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + params = bcf_params, + previous_model_json = xbcf_model_string, + warmstart_sample_num = num_gfr - i + 1, ) bcf_model_string <- stochtree::saveBCFModelToJsonString(bcf_model) y_hat_test <- bcf_model$y_hat_test - list(model=bcf_model_string, yhat=y_hat_test) -} + list(model = bcf_model_string, yhat = y_hat_test) + } # Close the cluster connection stopCluster(cl) @@ -91,44 +130,93 @@ stopCluster(cl) bcf_model_strings <- list() bcf_model_yhats <- matrix(NA, nrow = length(y_test), ncol = num_chains) for (i in 1:length(bcf_model_outputs)) { - bcf_model_strings[[i]] <- bcf_model_outputs[[i]]$model - bcf_model_yhats[,i] <- rowMeans(bcf_model_outputs[[i]]$yhat) + bcf_model_strings[[i]] <- bcf_model_outputs[[i]]$model + bcf_model_yhats[, i] <- rowMeans(bcf_model_outputs[[i]]$yhat) } combined_bcf <- createBCFModelFromCombinedJsonString(bcf_model_strings) # Inspect the results -yhat_combined <- predict(combined_bcf, X_test)$y_hat -par(mfrow = c(1,2)) +yhat_combined <- predict(combined_bcf, X = X_test)$y_hat +par(mfrow = c(1, 2)) for (i in 1:num_chains) { - offset <- (i-1)*num_mcmc - inds_start <- offset + 1 - inds_end <- offset + num_mcmc - plot(rowMeans(yhat_combined[,inds_start:inds_end]), bcf_model_yhats[,i], - xlab = "deserialized", ylab = "original", - main = paste0("Chain ", i, "\nPredictions")) - abline(0,1,col="red",lty=3,lwd=3) + offset <- (i - 1) * num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot( + rowMeans(yhat_combined[, inds_start:inds_end]), + bcf_model_yhats[, i], + xlab = "deserialized", + ylab = "original", + main = paste0("Chain ", i, "\nPredictions") + ) + abline(0, 1, col = "red", lty = 3, lwd = 3) } for (i in 1:num_chains) { - offset <- (i-1)*num_mcmc - inds_start <- offset + 1 - inds_end <- offset + num_mcmc - plot(rowMeans(yhat_combined[,inds_start:inds_end]), y_test, - xlab = "predicted", ylab = "actual", - main = paste0("Chain ", i, "\nPredictions")) - abline(0,1,col="red",lty=3,lwd=3) - cat(sqrt(mean((rowMeans(yhat_combined[,inds_start:inds_end]) - y_test)^2)), "\n") - cat(mean((apply(yhat_combined[,inds_start:inds_end], 1, quantile, probs=0.05) <= y_test) & (apply(yhat_combined[,inds_start:inds_end], 1, quantile, probs=0.95) >= y_test)), "\n") + offset <- (i - 1) * num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot( + rowMeans(yhat_combined[, inds_start:inds_end]), + y_test, + xlab = "predicted", + ylab = "actual", + main = paste0("Chain ", i, "\nPredictions") + ) + abline(0, 1, col = "red", lty = 3, lwd = 3) + cat( + sqrt(mean((rowMeans(yhat_combined[, inds_start:inds_end]) - y_test)^2)), + "\n" + ) + cat( + mean( + (apply(yhat_combined[, inds_start:inds_end], 1, quantile, probs = 0.05) <= + y_test) & + (apply( + yhat_combined[, inds_start:inds_end], + 1, + quantile, + probs = 0.95 + ) >= + y_test) + ), + "\n" + ) } -par(mfrow = c(1,1)) +par(mfrow = c(1, 1)) # Compare to a single chain of MCMC samples initialized at root -bcf_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, - num_trees_mean = num_trees, alpha_mean = 0.95, beta_mean = 2) +bcf_params <- list( + sample_sigma_global = T, + sample_sigma_leaf = T, + num_trees_mean = num_trees, + alpha_mean = 0.95, + beta_mean = 2 +) bcf_model <- stochtree::bcf( - X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, - X_test = X_test, Z_test = Z_test, pi_test = pi_test, - num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, params = bcf_params + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + pi_train = pi_train, + X_test = X_test, + Z_test = Z_test, + pi_test = pi_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + params = bcf_params ) -plot(rowMeans(bcf_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual"); abline(0,1) +plot( + rowMeans(bcf_model$y_hat_test), + y_test, + xlab = "predicted", + ylab = "actual" +) +abline(0, 1) cat(sqrt(mean((rowMeans(bcf_model$y_hat_test) - y_test)^2)), "\n") -cat(mean((apply(bcf_model$y_hat_test, 1, quantile, probs=0.05) <= y_test) & (apply(bcf_model$y_hat_test, 1, quantile, probs=0.95) >= y_test)), "\n") +cat( + mean( + (apply(bcf_model$y_hat_test, 1, quantile, probs = 0.05) <= y_test) & + (apply(bcf_model$y_hat_test, 1, quantile, probs = 0.95) >= y_test) + ), + "\n" +) From 9679068fd1fed1a872e2445503c421365c58476c Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 20 Nov 2025 00:50:36 -0600 Subject: [PATCH 04/11] Updated tests and BART and BCF R functions --- R/bart.R | 6 +++++- R/bcf.R | 6 +++++- test/R/testthat/test-bart.R | 30 ++++++++++++++++++++++++++++++ test/R/testthat/test-bcf.R | 19 ++++++++++++++++++- test/python/test_bart.py | 22 +++++++++++----------- test/python/test_predict.py | 12 ++++++------ 6 files changed, 75 insertions(+), 20 deletions(-) diff --git a/R/bart.R b/R/bart.R index 47ee41e0..d5acd6af 100644 --- a/R/bart.R +++ b/R/bart.R @@ -418,7 +418,11 @@ bart <- function( # Raise a warning if the data have ties and only GFR is being run if ((num_gfr > 0) && (num_mcmc == 0) && (num_burnin == 0)) { num_values <- nrow(X_train) - max_grid_size <- floor(num_values / cutpoint_grid_size) + max_grid_size <- ifelse( + num_values > cutpoint_grid_size, + floor(num_values / cutpoint_grid_size), + 1 + ) covs_warning_1 <- NULL covs_warning_2 <- NULL covs_warning_3 <- NULL diff --git a/R/bcf.R b/R/bcf.R index a9e60d5d..5a80d5ec 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -522,7 +522,11 @@ bcf <- function( # Raise a warning if the data have ties and only GFR is being run if ((num_gfr > 0) && (num_mcmc == 0) && (num_burnin == 0)) { num_values <- nrow(X_train) - max_grid_size <- floor(num_values / cutpoint_grid_size) + max_grid_size <- ifelse( + num_values > cutpoint_grid_size, + floor(num_values / cutpoint_grid_size), + 1 + ) covs_warning_1 <- NULL covs_warning_2 <- NULL covs_warning_3 <- NULL diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R index 36b7bfe8..b42099ea 100644 --- a/test/R/testthat/test-bart.R +++ b/test/R/testthat/test-bart.R @@ -312,6 +312,19 @@ test_that("Warmstart BART", { # Run a new BART chain from the existing (X)BART model general_param_list <- list(num_chains = 3, keep_every = 5) expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + previous_model_json = bart_model_json_string, + previous_model_warmstart_sample_num = 10, + general_params = general_param_list + ) + ) + expect_warning( bart_model <- bart( X_train = X_train, y_train = y_train, @@ -376,6 +389,23 @@ test_that("Warmstart BART", { # Run a new BART chain from the existing (X)BART model general_param_list <- list(num_chains = 4, keep_every = 5) expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + previous_model_json = bart_model_json_string, + previous_model_warmstart_sample_num = 10, + general_params = general_param_list + ) + ) + expect_warning( bart_model <- bart( X_train = X_train, y_train = y_train, diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R index 221c333f..ab1cb7b4 100644 --- a/test/R/testthat/test-bcf.R +++ b/test/R/testthat/test-bcf.R @@ -375,6 +375,23 @@ test_that("Warmstart BCF", { # Run a new BCF chain from the existing (X)BCF model general_param_list <- list(num_chains = 3, keep_every = 5) expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + previous_model_json = bcf_model_json_string, + previous_model_warmstart_sample_num = 10, + general_params = general_param_list + ) + ) + expect_warning( bcf_model <- bcf( X_train = X_train, y_train = y_train, @@ -482,7 +499,7 @@ test_that("Warmstart BCF", { num_burnin = 10, num_mcmc = 10, previous_model_json = bcf_model_json_string, - previous_model_warmstart_sample_num = 1, + previous_model_warmstart_sample_num = 10, general_params = general_param_list ) ) diff --git a/test/python/test_bart.py b/test/python/test_bart.py index 3243b86a..8abebfdb 100644 --- a/test/python/test_bart.py +++ b/test/python/test_bart.py @@ -83,7 +83,7 @@ def outcome_mean(X): bart_model_3.from_json_string_list(bart_models_json) # Assertions - bart_preds_combined = bart_model_3.predict(covariates=X_train) + bart_preds_combined = bart_model_3.predict(X=X_train) y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( @@ -190,7 +190,7 @@ def outcome_mean(X, W): # Assertions bart_preds_combined = bart_model_3.predict( - covariates=X_train, basis=basis_train + X=X_train, leaf_basis=basis_train ) y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) @@ -298,7 +298,7 @@ def outcome_mean(X, W): # Assertions bart_preds_combined = bart_model_3.predict( - covariates=X_train, basis=basis_train + X=X_train, leaf_basis=basis_train ) y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) @@ -410,7 +410,7 @@ def conditional_stddev(X): bart_model_3.from_json_string_list(bart_models_json) # Assertions - bart_preds_combined = bart_model_3.predict(covariates=X_train) + bart_preds_combined = bart_model_3.predict(X=X_train) y_hat_train_combined, sigma2_x_train_combined = ( bart_preds_combined["y_hat"], bart_preds_combined["variance_forest_predictions"], @@ -545,7 +545,7 @@ def conditional_stddev(X): # Assertions bart_preds_combined = bart_model_3.predict( - covariates=X_train, basis=basis_train + X=X_train, leaf_basis=basis_train ) y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) @@ -670,7 +670,7 @@ def conditional_stddev(X): # Assertions bart_preds_combined = bart_model_3.predict( - covariates=X_train, basis=basis_train + X=X_train, leaf_basis=basis_train ) y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) @@ -825,7 +825,7 @@ def rfx_term(group_labels, basis): # Assertions bart_preds_combined = bart_model_3.predict( - covariates=X_train, + X=X_train, rfx_group_ids=group_labels_train, rfx_basis=rfx_basis_train, ) @@ -998,8 +998,8 @@ def conditional_stddev(X): # Assertions bart_preds_combined = bart_model_3.predict( - covariates=X_train, - basis=basis_train, + X=X_train, + leaf_basis=basis_train, rfx_group_ids=group_labels_train, rfx_basis=rfx_basis_train, ) @@ -1196,8 +1196,8 @@ def conditional_stddev(X): random_effects_params=rfx_params, ) preds = bart_model_4.predict( - covariates=X_test, - basis=basis_test, + X=X_test, + leaf_basis=basis_test, rfx_group_ids=group_labels_test, type="posterior", terms="rfx", diff --git a/test/python/test_predict.py b/test/python/test_predict.py index 03f36cb2..618ccea6 100644 --- a/test/python/test_predict.py +++ b/test/python/test_predict.py @@ -221,12 +221,12 @@ def test_bart_prediction(self): ) # Check that the default predict method returns a dictionary - pred = bart_model.predict(covariates=X_test) + pred = bart_model.predict(X=X_test) y_hat_posterior_test = pred["y_hat"] assert y_hat_posterior_test.shape == (20, 10) # Check that the pre-aggregated predictions match with those computed by np.mean - pred_mean = bart_model.predict(covariates=X_test, type="mean") + pred_mean = bart_model.predict(X=X_test, type="mean") y_hat_mean_test = pred_mean["y_hat"] np.testing.assert_almost_equal( y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1) @@ -245,14 +245,14 @@ def test_bart_prediction(self): ) # Check that the default predict method returns a dictionary - pred = het_bart_model.predict(covariates=X_test) + pred = het_bart_model.predict(X=X_test) y_hat_posterior_test = pred["y_hat"] sigma2_hat_posterior_test = pred["variance_forest_predictions"] assert y_hat_posterior_test.shape == (20, 10) assert sigma2_hat_posterior_test.shape == (20, 10) # Check that the pre-aggregated predictions match with those computed by np.mean - pred_mean = het_bart_model.predict(covariates=X_test, type="mean") + pred_mean = het_bart_model.predict(X=X_test, type="mean") y_hat_mean_test = pred_mean["y_hat"] sigma2_hat_mean_test = pred_mean["variance_forest_predictions"] np.testing.assert_almost_equal( @@ -265,10 +265,10 @@ def test_bart_prediction(self): # Check that the "single-term" pre-aggregated predictions # match those computed by pre-aggregated predictions returned in a dictionary y_hat_mean_test_single_term = het_bart_model.predict( - covariates=X_test, type="mean", terms="y_hat" + X=X_test, type="mean", terms="y_hat" ) sigma2_hat_mean_test_single_term = het_bart_model.predict( - covariates=X_test, type="mean", terms="variance_forest" + X=X_test, type="mean", terms="variance_forest" ) np.testing.assert_almost_equal(y_hat_mean_test, y_hat_mean_test_single_term) np.testing.assert_almost_equal( From 9597886131e14f85fd2fc411a2a62ba049209f03 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 20 Nov 2025 00:53:42 -0600 Subject: [PATCH 05/11] Updated compute_contrast in R --- R/posterior_transformation.R | 40 +++++++++++++++--------------- man/compute_contrast_bart_model.Rd | 12 ++++----- tools/debug/bart_contrast_debug.R | 8 +++--- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index 87c51d12..6b4b4c79 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -260,8 +260,8 @@ compute_contrast_bcf_model <- function( #' Only valid when there is either a mean forest or a random effects term in the BART model. #' #' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs. -#' @param covariates_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe. -#' @param covariates_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe. +#' @param X_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe. +#' @param X_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe. #' @param leaf_basis_0 (Optional) Bases used for prediction in the "control" case (by e.g. dot product with leaf values). Default: `NULL`. #' @param leaf_basis_1 (Optional) Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). Default: `NULL`. #' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects @@ -306,8 +306,8 @@ compute_contrast_bcf_model <- function( #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' contrast_test <- compute_contrast_bart_model( #' bart_model, -#' covariates_0 = X_test, -#' covariates_1 = X_test, +#' X_0 = X_test, +#' X_1 = X_test, #' leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), #' leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), #' type = "posterior", @@ -315,8 +315,8 @@ compute_contrast_bcf_model <- function( #' ) compute_contrast_bart_model <- function( object, - covariates_0, - covariates_1, + X_0, + X_1, leaf_basis_0 = NULL, leaf_basis_1 = NULL, rfx_group_ids_0 = NULL, @@ -360,11 +360,11 @@ compute_contrast_bart_model <- function( } # Check that covariates are matrix or data frame - if ((!is.data.frame(covariates_0)) && (!is.matrix(covariates_0))) { - stop("covariates_0 must be a matrix or dataframe") + if ((!is.data.frame(X_0)) && (!is.matrix(X_0))) { + stop("X_0 must be a matrix or dataframe") } - if ((!is.data.frame(covariates_1)) && (!is.matrix(covariates_1))) { - stop("covariates_1 must be a matrix or dataframe") + if ((!is.data.frame(X_1)) && (!is.matrix(X_1))) { + stop("X_1 must be a matrix or dataframe") } # Convert all input data to matrices if not already converted @@ -388,20 +388,20 @@ compute_contrast_bart_model <- function( ) { stop("leaf_basis_0 and leaf_basis_1 must be provided for this model") } - if ((!is.null(leaf_basis_0)) && (nrow(covariates_0) != nrow(leaf_basis_0))) { - stop("covariates_0 and leaf_basis_0 must have the same number of rows") + if ((!is.null(leaf_basis_0)) && (nrow(X_0) != nrow(leaf_basis_0))) { + stop("X_0 and leaf_basis_0 must have the same number of rows") } - if ((!is.null(leaf_basis_1)) && (nrow(covariates_1) != nrow(leaf_basis_1))) { - stop("covariates_1 and leaf_basis_1 must have the same number of rows") + if ((!is.null(leaf_basis_1)) && (nrow(X_1) != nrow(leaf_basis_1))) { + stop("X_1 and leaf_basis_1 must have the same number of rows") } - if (object$model_params$num_covariates != ncol(covariates_0)) { + if (object$model_params$num_covariates != ncol(X_0)) { stop( - "covariates_0 must contain the same number of columns as the BART model's training dataset" + "X_0 must contain the same number of columns as the BART model's training dataset" ) } - if (object$model_params$num_covariates != ncol(covariates_1)) { + if (object$model_params$num_covariates != ncol(X_1)) { stop( - "covariates_1 must contain the same number of columns as the BART model's training dataset" + "X_1 must contain the same number of columns as the BART model's training dataset" ) } if ((has_rfx) && (is.null(rfx_group_ids_0) || is.null(rfx_group_ids_1))) { @@ -427,7 +427,7 @@ compute_contrast_bart_model <- function( # Predict for the control arm control_preds <- predict( object = object, - covariates = covariates_0, + X = X_0, leaf_basis = leaf_basis_0, rfx_group_ids = rfx_group_ids_0, rfx_basis = rfx_basis_0, @@ -439,7 +439,7 @@ compute_contrast_bart_model <- function( # Predict for the treatment arm treatment_preds <- predict( object = object, - covariates = covariates_1, + X = X_1, leaf_basis = leaf_basis_1, rfx_group_ids = rfx_group_ids_1, rfx_basis = rfx_basis_1, diff --git a/man/compute_contrast_bart_model.Rd b/man/compute_contrast_bart_model.Rd index 0851c9b4..c09bf23a 100644 --- a/man/compute_contrast_bart_model.Rd +++ b/man/compute_contrast_bart_model.Rd @@ -6,8 +6,8 @@ \usage{ compute_contrast_bart_model( object, - covariates_0, - covariates_1, + X_0, + X_1, leaf_basis_0 = NULL, leaf_basis_1 = NULL, rfx_group_ids_0 = NULL, @@ -21,9 +21,9 @@ compute_contrast_bart_model( \arguments{ \item{object}{Object of type \code{bart} containing draws of a regression forest and associated sampling outputs.} -\item{covariates_0}{Covariates used for prediction in the "control" case. Must be a matrix or dataframe.} +\item{X_0}{Covariates used for prediction in the "control" case. Must be a matrix or dataframe.} -\item{covariates_1}{Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.} +\item{X_1}{Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.} \item{leaf_basis_0}{(Optional) Bases used for prediction in the "control" case (by e.g. dot product with leaf values). Default: \code{NULL}.} @@ -88,8 +88,8 @@ bart_model <- bart(X_train = X_train, leaf_basis_train = W_train, y_train = y_tr num_gfr = 10, num_burnin = 0, num_mcmc = 10) contrast_test <- compute_contrast_bart_model( bart_model, - covariates_0 = X_test, - covariates_1 = X_test, + X_0 = X_test, + X_1 = X_test, leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), type = "posterior", diff --git a/tools/debug/bart_contrast_debug.R b/tools/debug/bart_contrast_debug.R index 79f6fe10..9b46fffe 100644 --- a/tools/debug/bart_contrast_debug.R +++ b/tools/debug/bart_contrast_debug.R @@ -45,8 +45,8 @@ bart_model <- bart( # Compute contrast posterior contrast_posterior_test <- compute_contrast_bart_model( bart_model, - covariates_0 = X_test, - covariates_1 = X_test, + X_0 = X_test, + X_1 = X_test, leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), type = "posterior", @@ -128,8 +128,8 @@ bart_model <- bart( # Compute contrast posterior contrast_posterior_test <- compute_contrast_bart_model( bart_model, - covariates_0 = X_test, - covariates_1 = X_test, + X_0 = X_test, + X_1 = X_test, leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), rfx_group_ids_0 = group_ids_test, From 396931c939da2bb8486ef1a01087b9d45ef73c07 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 20 Nov 2025 01:01:36 -0600 Subject: [PATCH 06/11] Updated compute_contrast in Python and corrected other R package issues --- R/posterior_transformation.R | 4 +-- R/utils.R | 30 ---------------- demo/debug/bart_contrast_debug.py | 16 ++++----- demo/debug/probit_bart_rfx_debug.py | 8 ++--- stochtree/bart.py | 56 ++++++++++++++--------------- 5 files changed, 42 insertions(+), 72 deletions(-) diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index 6b4b4c79..3eb25f12 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -751,7 +751,7 @@ sample_bart_posterior_predictive <- function( # Compute posterior samples bart_preds <- predict( model_object, - covariates = covariates, + X = covariates, leaf_basis = basis, rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis, @@ -1188,7 +1188,7 @@ compute_bart_posterior_interval <- function( # Compute posterior matrices for the requested model terms predictions <- predict( model_object, - covariates = covariates, + X = covariates, leaf_basis = basis, rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis, diff --git a/R/utils.R b/R/utils.R index 25c2fa1a..d33fe2c5 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1092,33 +1092,3 @@ expand_dims_2d_diag <- function(input, output_size) { } return(output) } - - -gfr_tie_checks <- function(covariates) { - num_vars <- ncol(covariates) - for (j in 1:num_vars) { - x_j <- covariates[, j] - if (has_few_unique_values(x_j)) { - warning_message <- paste0( - "Covariate column ", - j, - " has relatively few unique values. ", - "This may lead to tied values when sampling split points in BART/BCF, ", - "which can cause errors during model fitting. ", - "Consider adding small amounts of noise to this variable to break ties." - ) - warning(warning_message) - } - } -} - - -has_few_unique_values <- function( - x, - count_threshold = 15 -) { - x_unique <- unique(x) - num_unique_values <- length(unique_values) - unique_to_total_count_ratio <- num_unique_values / length(x) - return(num_unique_values <= threshold) -} diff --git a/demo/debug/bart_contrast_debug.py b/demo/debug/bart_contrast_debug.py index bbfd82d4..f80100bc 100644 --- a/demo/debug/bart_contrast_debug.py +++ b/demo/debug/bart_contrast_debug.py @@ -55,10 +55,10 @@ # Compute contrast posterior contrast_posterior_test = bart_model.compute_contrast( - covariates_0=X_test, - covariates_1=X_test, - basis_0=np.zeros((n_test, 1)), - basis_1=np.ones((n_test, 1)), + X_0=X_test, + X_1=X_test, + leaf_basis_0=np.zeros((n_test, 1)), + leaf_basis_1=np.ones((n_test, 1)), type="posterior", scale="linear", ) @@ -143,10 +143,10 @@ # Compute contrast posterior contrast_posterior_test = bart_model.compute_contrast( - covariates_0=X_test, - covariates_1=X_test, - basis_0=np.zeros((n_test, 1)), - basis_1=np.ones((n_test, 1)), + X_0=X_test, + X_1=X_test, + leaf_basis_0=np.zeros((n_test, 1)), + leaf_basis_1=np.ones((n_test, 1)), rfx_group_ids_0=group_ids_test, rfx_group_ids_1=group_ids_test, rfx_basis_0=rfx_basis_test, diff --git a/demo/debug/probit_bart_rfx_debug.py b/demo/debug/probit_bart_rfx_debug.py index bfe0be6c..de8d3953 100644 --- a/demo/debug/probit_bart_rfx_debug.py +++ b/demo/debug/probit_bart_rfx_debug.py @@ -72,10 +72,10 @@ # Compute contrast posterior contrast_posterior_test = bart_model.compute_contrast( - covariates_0=X_test, - covariates_1=X_test, - basis_0=np.zeros((n_test, 1)), - basis_1=np.ones((n_test, 1)), + X_0=X_test, + X_1=X_test, + leaf_basis_0=np.zeros((n_test, 1)), + leaf_basis_1=np.ones((n_test, 1)), rfx_group_ids_0=group_ids_test, rfx_group_ids_1=group_ids_test, rfx_basis_0=rfx_basis_test, diff --git a/stochtree/bart.py b/stochtree/bart.py index 2988a5fd..0bb9a6f7 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -2046,10 +2046,10 @@ def predict( def compute_contrast( self, - covariates_0: Union[np.array, pd.DataFrame], - covariates_1: Union[np.array, pd.DataFrame], - basis_0: np.array = None, - basis_1: np.array = None, + X_0: Union[np.array, pd.DataFrame], + X_1: Union[np.array, pd.DataFrame], + leaf_basis_0: np.array = None, + leaf_basis_1: np.array = None, rfx_group_ids_0: np.array = None, rfx_group_ids_1: np.array = None, rfx_basis_0: np.array = None, @@ -2068,13 +2068,13 @@ def compute_contrast( Parameters ---------- - covariates_0 : np.array or pd.DataFrame + X_0 : np.array or pd.DataFrame Covariates used for prediction in the "control" case. Must be a numpy array or dataframe. - covariates_1 : np.array or pd.DataFrame + X_1 : np.array or pd.DataFrame Covariates used for prediction in the "treatment" case. Must be a numpy array or dataframe. - basis_0 : np.array, optional + leaf_basis_0 : np.array, optional Bases used for prediction in the "control" case (by e.g. dot product with leaf values). - basis_1 : np.array, optional + leaf_basis_1 : np.array, optional Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). rfx_group_ids_0 : np.array, optional Test set group labels used for prediction from an additive random effects model in the "control" case. @@ -2135,33 +2135,33 @@ def compute_contrast( raise NotSampledError(msg) # Data checks - if not isinstance(covariates_0, pd.DataFrame) and not isinstance( - covariates_0, np.ndarray + if not isinstance(X_0, pd.DataFrame) and not isinstance( + X_0, np.ndarray ): - raise ValueError("covariates_0 must be a pandas dataframe or numpy array") - if not isinstance(covariates_1, pd.DataFrame) and not isinstance( - covariates_1, np.ndarray + raise ValueError("X_0 must be a pandas dataframe or numpy array") + if not isinstance(X_1, pd.DataFrame) and not isinstance( + X_1, np.ndarray ): - raise ValueError("covariates_1 must be a pandas dataframe or numpy array") - if basis_0 is not None: - if not isinstance(basis_0, np.ndarray): - raise ValueError("basis_0 must be a numpy array") - if basis_0.shape[0] != covariates_0.shape[0]: + raise ValueError("X_1 must be a pandas dataframe or numpy array") + if leaf_basis_0 is not None: + if not isinstance(leaf_basis_0, np.ndarray): + raise ValueError("leaf_basis_0 must be a numpy array") + if leaf_basis_0.shape[0] != X_0.shape[0]: raise ValueError( - "covariates_0 and basis_0 must have the same number of rows" + "X_0 and leaf_basis_0 must have the same number of rows" ) - if basis_1 is not None: - if not isinstance(basis_1, np.ndarray): - raise ValueError("basis_1 must be a numpy array") - if basis_1.shape[0] != covariates_1.shape[0]: + if leaf_basis_1 is not None: + if not isinstance(leaf_basis_1, np.ndarray): + raise ValueError("leaf_basis_1 must be a numpy array") + if leaf_basis_1.shape[0] != X_1.shape[0]: raise ValueError( - "covariates_1 and basis_1 must have the same number of rows" + "X_1 and leaf_basis_1 must have the same number of rows" ) # Predict for the control arm control_preds = self.predict( - covariates=covariates_0, - basis=basis_0, + X=X_0, + leaf_basis=leaf_basis_0, rfx_group_ids=rfx_group_ids_0, rfx_basis=rfx_basis_0, type="posterior", @@ -2171,8 +2171,8 @@ def compute_contrast( # Predict for the treatment arm treatment_preds = self.predict( - covariates=covariates_1, - basis=basis_1, + X=X_1, + leaf_basis=leaf_basis_1, rfx_group_ids=rfx_group_ids_1, rfx_basis=rfx_basis_1, type="posterior", From 8db1d8d7a5a613f251a368d01bcede31d09c8481 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 20 Nov 2025 01:09:42 -0600 Subject: [PATCH 07/11] Updated BART posterior interval function / method in R and Python --- R/posterior_transformation.R | 46 +++++++++++++------------- man/compute_bart_posterior_interval.Rd | 10 +++--- stochtree/bart.py | 42 +++++++++++------------ tools/debug/bart_predict_debug.R | 6 ++-- 4 files changed, 52 insertions(+), 52 deletions(-) diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index 3eb25f12..07e92ce5 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -1068,8 +1068,8 @@ compute_bcf_posterior_interval <- function( #' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. #' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval). #' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". -#' @param covariates A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions). -#' @param basis An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. +#' @param X A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions). +#' @param leaf_basis An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. #' @param rfx_group_ids An optional vector of group IDs for random effects. Required if the requested term includes random effects. #' @param rfx_basis An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. #' @@ -1085,7 +1085,7 @@ compute_bcf_posterior_interval <- function( #' intervals <- compute_bart_posterior_interval( #' model_object = bart_model, #' terms = c("mean_forest", "y_hat"), -#' covariates = X, +#' X = X, #' level = 0.90 #' ) #' @export @@ -1094,8 +1094,8 @@ compute_bart_posterior_interval <- function( terms, level = 0.95, scale = "linear", - covariates = NULL, - basis = NULL, + X = NULL, + leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL ) { @@ -1129,30 +1129,30 @@ compute_bart_posterior_interval <- function( if (needs_covariates) { if (is.null(covariates)) { stop( - "'covariates' must be provided in order to compute the requested intervals" + "'X' must be provided in order to compute the requested intervals" ) } - if (!is.matrix(covariates) && !is.data.frame(covariates)) { - stop("'covariates' must be a matrix or data frame") + if (!is.matrix(X) && !is.data.frame(X)) { + stop("'X' must be a matrix or data frame") } } needs_basis <- needs_covariates && model_object$model_params$has_basis if (needs_basis) { - if (is.null(basis)) { + if (is.null(leaf_basis)) { stop( - "'basis' must be provided in order to compute the requested intervals" + "'leaf_basis' must be provided in order to compute the requested intervals" ) } - if (!is.matrix(basis)) { - stop("'basis' must be a matrix") + if (!is.matrix(leaf_basis)) { + stop("'leaf_basis' must be a matrix") } - if (is.matrix(basis)) { - if (nrow(basis) != nrow(covariates)) { - stop("'basis' must have the same number of rows as 'covariates'") + if (is.matrix(leaf_basis)) { + if (nrow(leaf_basis) != nrow(X)) { + stop("'leaf_basis' must have the same number of rows as 'X'") } } else { - if (length(basis) != nrow(covariates)) { - stop("'basis' must have the same number of elements as 'covariates'") + if (length(leaf_basis) != nrow(X)) { + stop("'leaf_basis' must have the same number of elements as 'X'") } } } @@ -1167,9 +1167,9 @@ compute_bart_posterior_interval <- function( "'rfx_group_ids' must be provided in order to compute the requested intervals" ) } - if (length(rfx_group_ids) != nrow(covariates)) { + if (length(rfx_group_ids) != nrow(X)) { stop( - "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) } if (is.null(rfx_basis)) { @@ -1180,16 +1180,16 @@ compute_bart_posterior_interval <- function( if (!is.matrix(rfx_basis)) { stop("'rfx_basis' must be a matrix") } - if (nrow(rfx_basis) != nrow(covariates)) { - stop("'rfx_basis' must have the same number of rows as 'covariates'") + if (nrow(rfx_basis) != nrow(X)) { + stop("'rfx_basis' must have the same number of rows as 'X'") } } # Compute posterior matrices for the requested model terms predictions <- predict( model_object, - X = covariates, - leaf_basis = basis, + X = X, + leaf_basis = leaf_basis, rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis, type = "posterior", diff --git a/man/compute_bart_posterior_interval.Rd b/man/compute_bart_posterior_interval.Rd index 8a802e45..be383a8d 100644 --- a/man/compute_bart_posterior_interval.Rd +++ b/man/compute_bart_posterior_interval.Rd @@ -9,8 +9,8 @@ compute_bart_posterior_interval( terms, level = 0.95, scale = "linear", - covariates = NULL, - basis = NULL, + X = NULL, + leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL ) @@ -24,9 +24,9 @@ compute_bart_posterior_interval( \item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} -\item{covariates}{A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions).} +\item{X}{A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions).} -\item{basis}{An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models.} +\item{leaf_basis}{An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models.} \item{rfx_group_ids}{An optional vector of group IDs for random effects. Required if the requested term includes random effects.} @@ -47,7 +47,7 @@ bart_model <- bart(y_train = y, X_train = X) intervals <- compute_bart_posterior_interval( model_object = bart_model, terms = c("mean_forest", "y_hat"), - covariates = X, + X = X, level = 0.90 ) } diff --git a/stochtree/bart.py b/stochtree/bart.py index 0bb9a6f7..39e13332 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -2196,8 +2196,8 @@ def compute_posterior_interval( terms: Union[list[str], str] = "all", level: float = 0.95, scale: str = "linear", - covariates: np.array = None, - basis: np.array = None, + X: np.array = None, + leaf_basis: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, ) -> dict: @@ -2212,9 +2212,9 @@ def compute_posterior_interval( Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`. level : float, optional A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval. - covariates : np.array, optional + X : np.array, optional Optional array or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions). - basis : np.array, optional + leaf_basis : np.array, optional Optional array of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. rfx_group_ids : np.array, optional Optional vector of group IDs for random effects. Required if the requested term includes random effects. @@ -2266,25 +2266,25 @@ def compute_posterior_interval( or needs_covariates_intermediate ) if needs_covariates: - if covariates is None: + if X is None: raise ValueError( - "'covariates' must be provided in order to compute the requested intervals" + "'X' must be provided in order to compute the requested intervals" ) - if not isinstance(covariates, np.ndarray) and not isinstance( - covariates, pd.DataFrame + if not isinstance(X, np.ndarray) and not isinstance( + X, pd.DataFrame ): - raise ValueError("'covariates' must be a matrix or data frame") + raise ValueError("'X' must be a matrix or data frame") needs_basis = needs_covariates and self.has_basis if needs_basis: - if basis is None: + if leaf_basis is None: raise ValueError( - "'basis' must be provided in order to compute the requested intervals" + "'leaf_basis' must be provided in order to compute the requested intervals" ) - if not isinstance(basis, np.ndarray): - raise ValueError("'basis' must be a numpy array") - if basis.shape[0] != covariates.shape[0]: + if not isinstance(leaf_basis, np.ndarray): + raise ValueError("'leaf_basis' must be a numpy array") + if leaf_basis.shape[0] != X.shape[0]: raise ValueError( - "'basis' must have the same number of rows as 'covariates'" + "'leaf_basis' must have the same number of rows as 'X'" ) needs_rfx_data_intermediate = ( ("y_hat" in terms) or ("all" in terms) @@ -2297,9 +2297,9 @@ def compute_posterior_interval( ) if not isinstance(rfx_group_ids, np.ndarray): raise ValueError("'rfx_group_ids' must be a numpy array") - if rfx_group_ids.shape[0] != covariates.shape[0]: + if rfx_group_ids.shape[0] != X.shape[0]: raise ValueError( - "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) if rfx_basis is None: raise ValueError( @@ -2307,15 +2307,15 @@ def compute_posterior_interval( ) if not isinstance(rfx_basis, np.ndarray): raise ValueError("'rfx_basis' must be a numpy array") - if rfx_basis.shape[0] != covariates.shape[0]: + if rfx_basis.shape[0] != X.shape[0]: raise ValueError( - "'rfx_basis' must have the same number of rows as 'covariates'" + "'rfx_basis' must have the same number of rows as 'X'" ) # Compute posterior matrices for the requested model terms predictions = self.predict( - covariates=covariates, - basis=basis, + X=X, + leaf_basis=leaf_basis, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", diff --git a/tools/debug/bart_predict_debug.R b/tools/debug/bart_predict_debug.R index 49bea0b6..f131d930 100644 --- a/tools/debug/bart_predict_debug.R +++ b/tools/debug/bart_predict_debug.R @@ -56,7 +56,7 @@ y_hat_intervals <- compute_bart_posterior_interval( model_object = bart_model, transform = function(x) x, terms = c("y_hat", "mean_forest"), - covariates = X_test, + X = X_test, level = 0.95 ) @@ -137,7 +137,7 @@ y_hat_intervals <- compute_bart_posterior_interval( model_object = bart_model, scale = "linear", terms = c("y_hat"), - covariates = X_test, + X = X_test, level = 0.95 ) @@ -146,7 +146,7 @@ y_hat_prob_intervals <- compute_bart_posterior_interval( model_object = bart_model, scale = "probability", terms = c("y_hat"), - covariates = X_test, + X = X_test, level = 0.95 ) From d303fcb7ab4ee87670dc47ef0b68d99ce00f921a Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 20 Nov 2025 01:16:32 -0600 Subject: [PATCH 08/11] Updated posterior predictive method / function in R and Python --- R/posterior_transformation.R | 50 ++++++++++++------------- man/sample_bart_posterior_predictive.Rd | 10 ++--- stochtree/bart.py | 44 +++++++++++----------- tools/debug/bart_predict_debug.R | 4 +- 4 files changed, 54 insertions(+), 54 deletions(-) diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index 07e92ce5..8884ac5c 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -659,8 +659,8 @@ sample_bcf_posterior_predictive <- function( #' Sample from the posterior predictive distribution for outcomes modeled by BART #' #' @param model_object A fitted BART model object of class `bartmodel`. -#' @param covariates A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest). -#' @param basis A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models. +#' @param X A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest). +#' @param leaf_basis A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models. #' @param rfx_group_ids A vector of group IDs for random effects model. Required if the BART model includes random effects. #' @param rfx_basis A matrix of bases for random effects model. Required if the BART model includes random effects. #' @param num_draws_per_sample The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws). @@ -675,12 +675,12 @@ sample_bcf_posterior_predictive <- function( #' y <- 2 * X[,1] + rnorm(n) #' bart_model <- bart(y_train = y, X_train = X) #' ppd_samples <- sample_bart_posterior_predictive( -#' model_object = bart_model, covariates = X +#' model_object = bart_model, X = X #' ) sample_bart_posterior_predictive <- function( model_object, - covariates = NULL, - basis = NULL, + X = NULL, + leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, num_draws_per_sample = NULL @@ -694,32 +694,32 @@ sample_bart_posterior_predictive <- function( # Check that all the necessary inputs were provided for interval computation needs_covariates <- model_object$model_params$include_mean_forest if (needs_covariates) { - if (is.null(covariates)) { + if (is.null(X)) { stop( - "'covariates' must be provided in order to compute the requested intervals" + "'X' must be provided in order to compute the requested intervals" ) } - if (!is.matrix(covariates) && !is.data.frame(covariates)) { - stop("'covariates' must be a matrix or data frame") + if (!is.matrix(X) && !is.data.frame(X)) { + stop("'X' must be a matrix or data frame") } } needs_basis <- needs_covariates && model_object$model_params$has_basis if (needs_basis) { - if (is.null(basis)) { + if (is.null(leaf_basis)) { stop( - "'basis' must be provided in order to compute the requested intervals" + "'leaf_basis' must be provided in order to compute the requested intervals" ) } - if (!is.matrix(basis)) { - stop("'basis' must be a matrix") + if (!is.matrix(leaf_basis)) { + stop("'leaf_basis' must be a matrix") } - if (is.matrix(basis)) { - if (nrow(basis) != nrow(covariates)) { - stop("'basis' must have the same number of rows as 'covariates'") + if (is.matrix(leaf_basis)) { + if (nrow(leaf_basis) != nrow(X)) { + stop("'leaf_basis' must have the same number of rows as 'X'") } } else { - if (length(basis) != nrow(covariates)) { - stop("'basis' must have the same number of elements as 'covariates'") + if (length(leaf_basis) != nrow(X)) { + stop("'leaf_basis' must have the same number of elements as 'X'") } } } @@ -730,9 +730,9 @@ sample_bart_posterior_predictive <- function( "'rfx_group_ids' must be provided in order to compute the requested intervals" ) } - if (length(rfx_group_ids) != nrow(covariates)) { + if (length(rfx_group_ids) != nrow(X)) { stop( - "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) } if (is.null(rfx_basis)) { @@ -743,16 +743,16 @@ sample_bart_posterior_predictive <- function( if (!is.matrix(rfx_basis)) { stop("'rfx_basis' must be a matrix") } - if (nrow(rfx_basis) != nrow(covariates)) { - stop("'rfx_basis' must have the same number of rows as 'covariates'") + if (nrow(rfx_basis) != nrow(X)) { + stop("'rfx_basis' must have the same number of rows as 'X'") } } # Compute posterior samples bart_preds <- predict( model_object, - X = covariates, - leaf_basis = basis, + X = X, + leaf_basis = leaf_basis, rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis, type = "posterior", @@ -766,7 +766,7 @@ sample_bart_posterior_predictive <- function( has_variance_forest <- model_object$model_params$include_variance_forest samples_global_variance <- model_object$model_params$sample_sigma2_global num_posterior_draws <- model_object$model_params$num_samples - num_observations <- nrow(covariates) + num_observations <- nrow(X) if (has_mean_term) { ppd_mean <- bart_preds$y_hat } else { diff --git a/man/sample_bart_posterior_predictive.Rd b/man/sample_bart_posterior_predictive.Rd index 5bce8442..5dffb782 100644 --- a/man/sample_bart_posterior_predictive.Rd +++ b/man/sample_bart_posterior_predictive.Rd @@ -6,8 +6,8 @@ \usage{ sample_bart_posterior_predictive( model_object, - covariates = NULL, - basis = NULL, + X = NULL, + leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, num_draws_per_sample = NULL @@ -16,9 +16,9 @@ sample_bart_posterior_predictive( \arguments{ \item{model_object}{A fitted BART model object of class \code{bartmodel}.} -\item{covariates}{A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).} +\item{X}{A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).} -\item{basis}{A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models.} +\item{leaf_basis}{A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models.} \item{rfx_group_ids}{A vector of group IDs for random effects model. Required if the BART model includes random effects.} @@ -39,6 +39,6 @@ X <- matrix(rnorm(n * p), nrow = n, ncol = p) y <- 2 * X[,1] + rnorm(n) bart_model <- bart(y_train = y, X_train = X) ppd_samples <- sample_bart_posterior_predictive( - model_object = bart_model, covariates = X + model_object = bart_model, X = X ) } diff --git a/stochtree/bart.py b/stochtree/bart.py index 39e13332..906fe899 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -2338,8 +2338,8 @@ def compute_posterior_interval( def sample_posterior_predictive( self, - covariates: np.array = None, - basis: np.array = None, + X: np.array = None, + leaf_basis: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, num_draws_per_sample: int = None, @@ -2349,9 +2349,9 @@ def sample_posterior_predictive( Parameters ---------- - covariates : np.array, optional + X : np.array, optional An array or data frame of covariates at which to compute the intervals. Required if the BART model depends on covariates (e.g., contains a mean or variance forest). - basis : np.array, optional + leaf_basis : np.array, optional An array of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. rfx_group_ids : np.array, optional An array of group IDs for random effects. Required if the BART model includes random effects. @@ -2375,25 +2375,25 @@ def sample_posterior_predictive( # Check that all the necessary inputs were provided for interval computation needs_covariates = self.include_mean_forest if needs_covariates: - if covariates is None: + if X is None: raise ValueError( - "'covariates' must be provided in order to compute the requested intervals" + "'X' must be provided in order to compute the requested intervals" ) - if not isinstance(covariates, np.ndarray) and not isinstance( - covariates, pd.DataFrame + if not isinstance(X, np.ndarray) and not isinstance( + X, pd.DataFrame ): - raise ValueError("'covariates' must be a matrix or data frame") + raise ValueError("'X' must be a matrix or data frame") needs_basis = needs_covariates and self.has_basis if needs_basis: - if basis is None: + if leaf_basis is None: raise ValueError( - "'basis' must be provided in order to compute the requested intervals" + "'leaf_basis' must be provided in order to compute the requested intervals" ) - if not isinstance(basis, np.ndarray): - raise ValueError("'basis' must be a numpy array") - if basis.shape[0] != covariates.shape[0]: + if not isinstance(leaf_basis, np.ndarray): + raise ValueError("'leaf_basis' must be a numpy array") + if leaf_basis.shape[0] != X.shape[0]: raise ValueError( - "'basis' must have the same number of rows as 'covariates'" + "'leaf_basis' must have the same number of rows as 'X'" ) needs_rfx_data = self.has_rfx if needs_rfx_data: @@ -2403,9 +2403,9 @@ def sample_posterior_predictive( ) if not isinstance(rfx_group_ids, np.ndarray): raise ValueError("'rfx_group_ids' must be a numpy array") - if rfx_group_ids.shape[0] != covariates.shape[0]: + if rfx_group_ids.shape[0] != X.shape[0]: raise ValueError( - "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) if rfx_basis is None: raise ValueError( @@ -2413,15 +2413,15 @@ def sample_posterior_predictive( ) if not isinstance(rfx_basis, np.ndarray): raise ValueError("'rfx_basis' must be a numpy array") - if rfx_basis.shape[0] != covariates.shape[0]: + if rfx_basis.shape[0] != X.shape[0]: raise ValueError( - "'rfx_basis' must have the same number of rows as 'covariates'" + "'rfx_basis' must have the same number of rows as 'X'" ) # Compute posterior predictive samples bart_preds = self.predict( - covariates=covariates, - basis=basis, + X=X, + leaf_basis=leaf_basis, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", @@ -2433,7 +2433,7 @@ def sample_posterior_predictive( has_variance_forest = self.include_variance_forest samples_global_variance = self.sample_sigma2_global num_posterior_draws = self.num_samples - num_observations = covariates.shape[0] + num_observations = X.shape[0] if has_mean_term: ppd_mean = bart_preds["y_hat"] else: diff --git a/tools/debug/bart_predict_debug.R b/tools/debug/bart_predict_debug.R index f131d930..4e99f51d 100644 --- a/tools/debug/bart_predict_debug.R +++ b/tools/debug/bart_predict_debug.R @@ -67,7 +67,7 @@ y_hat_intervals <- compute_bart_posterior_interval( pred_intervals <- sample_bart_posterior_predictive( model_object = bart_model, - covariates = X_test, + X = X_test, level = 0.95 ) @@ -169,7 +169,7 @@ lines(y_hat_prob_intervals$upper[sort_inds]) # Draw from posterior predictive for covariates in the test set ppd_samples <- sample_bart_posterior_predictive( model_object = bart_model, - covariates = X_test, + X = X_test, num_draws = 10 ) From 19783696bcf9fb765b33588931c97914a367e78b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 20 Nov 2025 01:23:59 -0600 Subject: [PATCH 09/11] R bugfix --- R/posterior_transformation.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index 8884ac5c..16737922 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -1127,7 +1127,7 @@ compute_bart_posterior_interval <- function( ("variance_forest" %in% terms) || (needs_covariates_intermediate)) if (needs_covariates) { - if (is.null(covariates)) { + if (is.null(X)) { stop( "'X' must be provided in order to compute the requested intervals" ) From f55e349a325bb1a578cd9500d385ff581ca2d077 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 20 Nov 2025 01:30:14 -0600 Subject: [PATCH 10/11] Updated Python BCF propensity arguments --- demo/debug/bcf_pred_rmse.py | 4 +- demo/debug/bcf_predict_debug.py | 4 +- demo/debug/causal_inference_binary_outcome.py | 4 +- .../debug/causal_inference_feature_subsets.py | 8 +- demo/debug/gfr_ties_debug.py | 8 +- demo/notebooks/causal_inference.ipynb | 4 +- .../causal_inference_feature_subsets.ipynb | 8 +- ...tivariate_treatment_causal_inference.ipynb | 4 +- stochtree/bcf.py | 74 +++++++++---------- test/python/test_bcf.py | 40 +++++----- test/python/test_json.py | 4 +- test/python/test_predict.py | 8 +- .../bcf/individual_regression_test_bcf.py | 2 +- 13 files changed, 86 insertions(+), 86 deletions(-) diff --git a/demo/debug/bcf_pred_rmse.py b/demo/debug/bcf_pred_rmse.py index 0706842f..721074ea 100644 --- a/demo/debug/bcf_pred_rmse.py +++ b/demo/debug/bcf_pred_rmse.py @@ -51,11 +51,11 @@ bcf_model.sample( X_train=X_train, Z_train=Z_train, - pi_train=pi_train, + propensity_train=pi_train, y_train=y_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, ) # Predict out of sample diff --git a/demo/debug/bcf_predict_debug.py b/demo/debug/bcf_predict_debug.py index 141f4ee8..9a628bb4 100644 --- a/demo/debug/bcf_predict_debug.py +++ b/demo/debug/bcf_predict_debug.py @@ -45,7 +45,7 @@ bcf_model.sample( X_train=X_train, Z_train=Z_train, - pi_train=pi_train, + propensity_train=pi_train, y_train=y_train, num_gfr=10, num_burnin=0, @@ -182,7 +182,7 @@ bcf_model.sample( X_train=X_train, Z_train=Z_train, - pi_train=pi_train, + propensity_train=pi_train, y_train=y_train, rfx_group_ids_train=rfx_group_ids_train, num_gfr=10, diff --git a/demo/debug/causal_inference_binary_outcome.py b/demo/debug/causal_inference_binary_outcome.py index c603927d..6f3c75d4 100644 --- a/demo/debug/causal_inference_binary_outcome.py +++ b/demo/debug/causal_inference_binary_outcome.py @@ -101,8 +101,8 @@ def g(x5): # Run the sampler bcf_model = BCFModel() -bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, pi_train=pi_train, - X_test=X_test, Z_test=Z_test, pi_test=pi_test, num_gfr=num_gfr, +bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, propensity_train=pi_train, + X_test=X_test, Z_test=Z_test, propensity_test=pi_test, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, general_params=general_params, prognostic_forest_params=prognostic_forest_params, treatment_effect_forest_params=treatment_effect_forest_params) diff --git a/demo/debug/causal_inference_feature_subsets.py b/demo/debug/causal_inference_feature_subsets.py index 00cab8b8..8fa0fb2f 100644 --- a/demo/debug/causal_inference_feature_subsets.py +++ b/demo/debug/causal_inference_feature_subsets.py @@ -44,7 +44,7 @@ bcf_model_a = BCFModel() prog_forest_config_a = {"num_trees": 100} trt_forest_config_a = {"num_trees": 50} -bcf_model_a.sample(X_train=X_train, Z_train=Z_train, pi_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, pi_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_a, treatment_effect_forest_params=trt_forest_config_a) +bcf_model_a.sample(X_train=X_train, Z_train=Z_train, propensity_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, propensity_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_a, treatment_effect_forest_params=trt_forest_config_a) """ timing_no_subsampling = timeit.timeit(stmt=s, number=5, globals=globals()) print(f"Average runtime, without feature subsampling (p = {p:d}): {timing_no_subsampling:.2f}") @@ -54,7 +54,7 @@ bcf_model_b = BCFModel() prog_forest_config_b = {"num_trees": 100, "num_features_subsample": 5} trt_forest_config_b = {"num_trees": 50, "num_features_subsample": 5} -bcf_model_b.sample(X_train=X_train, Z_train=Z_train, pi_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, pi_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_b, treatment_effect_forest_params=trt_forest_config_b) +bcf_model_b.sample(X_train=X_train, Z_train=Z_train, propensity_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, propensity_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_b, treatment_effect_forest_params=trt_forest_config_b) """ timing_subsampling = timeit.timeit(stmt=s, number=5, globals=globals()) print(f"Average runtime, subsampling 5 out of {p:d} features: {timing_subsampling:.2f}") @@ -63,11 +63,11 @@ bcf_model_a = BCFModel() prog_forest_config_a = {"num_trees": 100} trt_forest_config_a = {"num_trees": 50} -bcf_model_a.sample(X_train=X_train, Z_train=Z_train, pi_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, pi_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_a, treatment_effect_forest_params=trt_forest_config_a) +bcf_model_a.sample(X_train=X_train, Z_train=Z_train, propensity_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, propensity_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_a, treatment_effect_forest_params=trt_forest_config_a) bcf_model_b = BCFModel() prog_forest_config_b = {"num_trees": 100, "num_features_subsample": 5} trt_forest_config_b = {"num_trees": 50, "num_features_subsample": 5} -bcf_model_b.sample(X_train=X_train, Z_train=Z_train, pi_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, pi_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_b, treatment_effect_forest_params=trt_forest_config_b) +bcf_model_b.sample(X_train=X_train, Z_train=Z_train, propensity_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, propensity_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_b, treatment_effect_forest_params=trt_forest_config_b) y_hat_test_a = np.squeeze(bcf_model_a.y_hat_test).mean(axis = 1) rmse_no_subsampling = np.sqrt(np.mean(np.power(y_test - y_hat_test_a,2))) print(f"Test set RMSE, no subsampling (p = {p:d}): {rmse_no_subsampling:.2f}") diff --git a/demo/debug/gfr_ties_debug.py b/demo/debug/gfr_ties_debug.py index c69e3d67..0a194e77 100644 --- a/demo/debug/gfr_ties_debug.py +++ b/demo/debug/gfr_ties_debug.py @@ -157,7 +157,7 @@ xbcf_model.sample( X_train=X_train, Z_train=Z_train, - pi_train=propensity_train, + propensity_train=propensity_train, y_train=y_train, num_gfr=10, num_burnin=0, @@ -182,7 +182,7 @@ bcf_model.sample( X_train=X_train, Z_train=Z_train, - pi_train=propensity_train, + propensity_train=propensity_train, y_train=y_train, num_gfr=10, num_burnin=0, @@ -237,7 +237,7 @@ xbcf_model.sample( X_train=X_train, Z_train=Z_train, - pi_train=propensity_train, + propensity_train=propensity_train, y_train=y_train, num_gfr=10, num_burnin=0, @@ -262,7 +262,7 @@ bcf_model.sample( X_train=X_train, Z_train=Z_train, - pi_train=propensity_train, + propensity_train=propensity_train, y_train=y_train, num_gfr=10, num_burnin=0, diff --git a/demo/notebooks/causal_inference.ipynb b/demo/notebooks/causal_inference.ipynb index 511ce0c4..151356ae 100644 --- a/demo/notebooks/causal_inference.ipynb +++ b/demo/notebooks/causal_inference.ipynb @@ -109,10 +109,10 @@ " X_train=X_train,\n", " Z_train=Z_train,\n", " y_train=y_train,\n", - " pi_train=pi_train,\n", + " propensity_train=pi_train,\n", " X_test=X_test,\n", " Z_test=Z_test,\n", - " pi_test=pi_test,\n", + " propensity_test=pi_test,\n", " num_gfr=10,\n", " num_mcmc=100,\n", " general_params=general_params,\n", diff --git a/demo/notebooks/causal_inference_feature_subsets.ipynb b/demo/notebooks/causal_inference_feature_subsets.ipynb index 2bbfbad5..f4465568 100644 --- a/demo/notebooks/causal_inference_feature_subsets.ipynb +++ b/demo/notebooks/causal_inference_feature_subsets.ipynb @@ -113,10 +113,10 @@ " X_train=X_train,\n", " Z_train=Z_train,\n", " y_train=y_train,\n", - " pi_train=pi_train,\n", + " propensity_train=pi_train,\n", " X_test=X_test,\n", " Z_test=Z_test,\n", - " pi_test=pi_test,\n", + " propensity_test=pi_test,\n", " num_gfr=10,\n", " num_mcmc=100,\n", " general_params={\"keep_every\": 5},\n", @@ -242,10 +242,10 @@ " X_train=X_train,\n", " Z_train=Z_train,\n", " y_train=y_train,\n", - " pi_train=pi_train,\n", + " propensity_train=pi_train,\n", " X_test=X_test,\n", " Z_test=Z_test,\n", - " pi_test=pi_test,\n", + " propensity_test=pi_test,\n", " num_gfr=10,\n", " num_mcmc=100,\n", " treatment_effect_forest_params=tau_params,\n", diff --git a/demo/notebooks/multivariate_treatment_causal_inference.ipynb b/demo/notebooks/multivariate_treatment_causal_inference.ipynb index 3e345aa4..88b528cd 100644 --- a/demo/notebooks/multivariate_treatment_causal_inference.ipynb +++ b/demo/notebooks/multivariate_treatment_causal_inference.ipynb @@ -110,10 +110,10 @@ " X_train=X_train,\n", " Z_train=Z_train,\n", " y_train=y_train,\n", - " pi_train=pi_train,\n", + " propensity_train=pi_train,\n", " X_test=X_test,\n", " Z_test=Z_test,\n", - " pi_test=pi_test,\n", + " propensity_test=pi_test,\n", " num_gfr=10,\n", " num_mcmc=100,\n", ")" diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 98442965..4ce66288 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -84,12 +84,12 @@ def sample( X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_train: np.array, - pi_train: np.array = None, + propensity_train: np.array = None, rfx_group_ids_train: np.array = None, rfx_basis_train: np.array = None, X_test: Union[pd.DataFrame, np.array] = None, Z_test: np.array = None, - pi_test: np.array = None, + propensity_test: np.array = None, rfx_group_ids_test: np.array = None, rfx_basis_test: np.array = None, num_gfr: int = 5, @@ -114,7 +114,7 @@ def sample( Array of (continuous or binary; univariate or multivariate) treatment assignments. y_train : np.array Outcome to be modeled by the ensemble. - pi_train : np.array + propensity_train : np.array Optional vector of propensity scores. If not provided, this will be estimated from the data. rfx_group_ids_train : np.array, optional Optional group labels used for an additive random effects model. @@ -125,7 +125,7 @@ def sample( Z_test : np.array, optional Optional test set of (continuous or binary) treatment assignments. Must be provided if `X_test` is provided. - pi_test : np.array, optional + propensity_test : np.array, optional Optional test set vector of propensity scores. If not provided (but `X_test` and `Z_test` are), this will be estimated from the data. rfx_group_ids_test : np.array, optional Optional test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), @@ -541,9 +541,9 @@ def sample( raise ValueError("X_train must be a pandas dataframe or numpy array") if not isinstance(Z_train, np.ndarray): raise ValueError("Z_train must be a numpy array") - if pi_train is not None: - if not isinstance(pi_train, np.ndarray): - raise ValueError("pi_train must be a numpy array") + if propensity_train is not None: + if not isinstance(propensity_train, np.ndarray): + raise ValueError("propensity_train must be a numpy array") if not isinstance(y_train, np.ndarray): raise ValueError("y_train must be a numpy array") if X_test is not None: @@ -554,9 +554,9 @@ def sample( if Z_test is not None: if not isinstance(Z_test, np.ndarray): raise ValueError("Z_test must be a numpy array") - if pi_test is not None: - if not isinstance(pi_test, np.ndarray): - raise ValueError("pi_test must be a numpy array") + if propensity_test is not None: + if not isinstance(propensity_test, np.ndarray): + raise ValueError("propensity_test must be a numpy array") if rfx_group_ids_train is not None: if not isinstance(rfx_group_ids_train, np.ndarray): raise ValueError("rfx_group_ids_train must be a numpy array") @@ -585,9 +585,9 @@ def sample( if Z_train is not None: if Z_train.ndim == 1: Z_train = np.expand_dims(Z_train, 1) - if pi_train is not None: - if pi_train.ndim == 1: - pi_train = np.expand_dims(pi_train, 1) + if propensity_train is not None: + if propensity_train.ndim == 1: + propensity_train = np.expand_dims(propensity_train, 1) if y_train.ndim == 1: y_train = np.expand_dims(y_train, 1) if X_test is not None: @@ -597,9 +597,9 @@ def sample( if Z_test is not None: if Z_test.ndim == 1: Z_test = np.expand_dims(Z_test, 1) - if pi_test is not None: - if pi_test.ndim == 1: - pi_test = np.expand_dims(pi_test, 1) + if propensity_test is not None: + if propensity_test.ndim == 1: + propensity_test = np.expand_dims(propensity_test, 1) if rfx_group_ids_train is not None: if rfx_group_ids_train.ndim != 1: rfx_group_ids_train = np.squeeze(rfx_group_ids_train) @@ -631,17 +631,17 @@ def sample( raise ValueError("X_train and Z_train must have the same number of rows") if y_train.shape[0] != X_train.shape[0]: raise ValueError("X_train and y_train must have the same number of rows") - if pi_train is not None: - if pi_train.shape[0] != X_train.shape[0]: + if propensity_train is not None: + if propensity_train.shape[0] != X_train.shape[0]: raise ValueError( - "X_train and pi_train must have the same number of rows" + "X_train and propensity_train must have the same number of rows" ) if X_test is not None and Z_test is not None: if X_test.shape[0] != Z_test.shape[0]: raise ValueError("X_test and Z_test must have the same number of rows") - if X_test is not None and pi_test is not None: - if X_test.shape[0] != pi_test.shape[0]: - raise ValueError("X_test and pi_test must have the same number of rows") + if X_test is not None and propensity_test is not None: + if X_test.shape[0] != propensity_test.shape[0]: + raise ValueError("X_test and propensity_test must have the same number of rows") # Raise a warning if the data have ties and only GFR is being run if (num_gfr > 0) and (num_burnin == 0) and (num_mcmc == 0): @@ -1311,10 +1311,10 @@ def sample( sample_sigma2_leaf_tau = False # Check if user has provided propensities that are needed in the model - if pi_train is None and propensity_covariate != "none": + if propensity_train is None and propensity_covariate != "none": if self.multivariate_treatment: raise ValueError( - "Propensities must be provided (via pi_train and / or pi_test parameters) or omitted by setting propensity_covariate = 'none' for multivariate treatments" + "Propensities must be provided (via propensity_train and / or propensity_test parameters) or omitted by setting propensity_covariate = 'none' for multivariate treatments" ) else: self.bart_propensity_model = BARTModel() @@ -1330,10 +1330,10 @@ def sample( num_burnin=num_burnin_propensity, num_mcmc=num_mcmc_propensity, ) - pi_train = np.mean( + propensity_train = np.mean( self.bart_propensity_model.y_hat_train, axis=1, keepdims=True ) - pi_test = np.mean( + propensity_test = np.mean( self.bart_propensity_model.y_hat_test, axis=1, keepdims=True ) else: @@ -1344,7 +1344,7 @@ def sample( num_burnin=num_burnin_propensity, num_mcmc=num_mcmc_propensity, ) - pi_train = np.mean( + propensity_train = np.mean( self.bart_propensity_model.y_hat_train, axis=1, keepdims=True ) self.internal_propensity_model = True @@ -1674,34 +1674,34 @@ def sample( ) if propensity_covariate != "none": feature_types = np.append( - feature_types, np.repeat(0, pi_train.shape[1]) + feature_types, np.repeat(0, propensity_train.shape[1]) ).astype("int") - X_train_processed = np.c_[X_train_processed, pi_train] + X_train_processed = np.c_[X_train_processed, propensity_train] if self.has_test: - X_test_processed = np.c_[X_test_processed, pi_test] + X_test_processed = np.c_[X_test_processed, propensity_test] if propensity_covariate == "prognostic": variable_weights_mu = np.append( - variable_weights_mu, np.repeat(1 / num_cov_orig, pi_train.shape[1]) + variable_weights_mu, np.repeat(1 / num_cov_orig, propensity_train.shape[1]) ) variable_weights_tau = np.append( - variable_weights_tau, np.repeat(0.0, pi_train.shape[1]) + variable_weights_tau, np.repeat(0.0, propensity_train.shape[1]) ) elif propensity_covariate == "treatment_effect": variable_weights_mu = np.append( - variable_weights_mu, np.repeat(0.0, pi_train.shape[1]) + variable_weights_mu, np.repeat(0.0, propensity_train.shape[1]) ) variable_weights_tau = np.append( - variable_weights_tau, np.repeat(1 / num_cov_orig, pi_train.shape[1]) + variable_weights_tau, np.repeat(1 / num_cov_orig, propensity_train.shape[1]) ) elif propensity_covariate == "both": variable_weights_mu = np.append( - variable_weights_mu, np.repeat(1 / num_cov_orig, pi_train.shape[1]) + variable_weights_mu, np.repeat(1 / num_cov_orig, propensity_train.shape[1]) ) variable_weights_tau = np.append( - variable_weights_tau, np.repeat(1 / num_cov_orig, pi_train.shape[1]) + variable_weights_tau, np.repeat(1 / num_cov_orig, propensity_train.shape[1]) ) variable_weights_variance = np.append( - variable_weights_variance, np.repeat(0.0, pi_train.shape[1]) + variable_weights_variance, np.repeat(0.0, propensity_train.shape[1]) ) # Renormalize variable weights diff --git a/test/python/test_bcf.py b/test/python/test_bcf.py index dac1ea25..eca2a5ff 100644 --- a/test/python/test_bcf.py +++ b/test/python/test_bcf.py @@ -51,10 +51,10 @@ def test_binary_bcf(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -93,7 +93,7 @@ def test_binary_bcf(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -239,10 +239,10 @@ def test_continuous_univariate_bcf(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -281,10 +281,10 @@ def test_continuous_univariate_bcf(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -352,7 +352,7 @@ def test_continuous_univariate_bcf(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -560,10 +560,10 @@ def test_multivariate_bcf(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -602,7 +602,7 @@ def test_multivariate_bcf(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -706,10 +706,10 @@ def test_binary_bcf_heteroskedastic(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -752,7 +752,7 @@ def test_binary_bcf_heteroskedastic(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -918,10 +918,10 @@ def rfx_term(group_labels, basis): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, rfx_group_ids_train=group_labels_train, rfx_basis_train=rfx_basis_train, rfx_group_ids_test=group_labels_test, @@ -946,10 +946,10 @@ def rfx_term(group_labels, basis): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, rfx_group_ids_train=group_labels_train, rfx_basis_train=rfx_basis_train, rfx_group_ids_test=group_labels_test, @@ -974,10 +974,10 @@ def rfx_term(group_labels, basis): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, rfx_group_ids_train=group_labels_train, rfx_basis_train=rfx_basis_train, rfx_group_ids_test=group_labels_test, diff --git a/test/python/test_json.py b/test/python/test_json.py index 48d4845b..b6f9b36f 100644 --- a/test/python/test_json.py +++ b/test/python/test_json.py @@ -454,7 +454,7 @@ def test_bcf_string(self): # Run BCF bcf_orig = BCFModel() bcf_orig.sample( - X_train=X, Z_train=Z, y_train=y, pi_train=pi_X, num_gfr=10, num_mcmc=10 + X_train=X, Z_train=Z, y_train=y, propensity_train=pi_X, num_gfr=10, num_mcmc=10 ) # Extract predictions from the sampler @@ -529,7 +529,7 @@ def rfx_mean(group_labels, basis): X_train=X, Z_train=Z, y_train=y, - pi_train=pi_X, + propensity_train=pi_X, rfx_group_ids_train=group_labels, rfx_basis_train=basis, num_gfr=10, diff --git a/test/python/test_predict.py b/test/python/test_predict.py index 618ccea6..ebc6fbc5 100644 --- a/test/python/test_predict.py +++ b/test/python/test_predict.py @@ -332,10 +332,10 @@ def g(x5): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_x_train, + propensity_train=pi_x_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_x_test, + propensity_test=pi_x_test, num_gfr=10, num_burnin=0, num_mcmc=10, @@ -372,10 +372,10 @@ def g(x5): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_x_train, + propensity_train=pi_x_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_x_test, + propensity_test=pi_x_test, num_gfr=10, num_burnin=0, num_mcmc=10, diff --git a/tools/regression/bcf/individual_regression_test_bcf.py b/tools/regression/bcf/individual_regression_test_bcf.py index 591b24d2..f4279193 100644 --- a/tools/regression/bcf/individual_regression_test_bcf.py +++ b/tools/regression/bcf/individual_regression_test_bcf.py @@ -337,7 +337,7 @@ def main(): X_train=covariates_train, Z_train=treatment_train, y_train=outcome_train, - pi_train=propensity_train, + propensity_train=propensity_train, rfx_group_ids_train=rfx_group_ids_train, rfx_basis_train=rfx_basis_train, num_gfr=num_gfr, From d296f3ef6f37b7f5323078d679616fad240bdae2 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 20 Nov 2025 01:51:39 -0600 Subject: [PATCH 11/11] Updated remaining R and Python functions --- R/posterior_transformation.R | 118 +++++++++--------- demo/debug/bart_predict_debug.py | 4 +- demo/debug/bcf_predict_debug.py | 26 ++-- demo/debug/causal_inference_binary_outcome.py | 1 - demo/debug/multi_chain.py | 1 - demo/debug/multiple_initializations.py | 6 +- man/compute_bcf_posterior_interval.Rd | 12 +- man/sample_bcf_posterior_predictive.Rd | 12 +- stochtree/bcf.py | 94 +++++++------- test/python/test_bart.py | 1 - test/python/test_predict.py | 1 - tools/debug/bcf_predict_debug.R | 40 +++--- 12 files changed, 156 insertions(+), 160 deletions(-) diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index 16737922..dca34be3 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -465,8 +465,8 @@ compute_contrast_bart_model <- function( #' Sample from the posterior predictive distribution for outcomes modeled by BCF #' #' @param model_object A fitted BCF model object of class `bcfmodel`. -#' @param covariates A matrix or data frame of covariates. -#' @param treatment A vector or matrix of treatment assignments. +#' @param X A matrix or data frame of covariates. +#' @param Z A vector or matrix of treatment assignments. #' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities. #' @param rfx_group_ids (Optional) A vector of group IDs for random effects model. Required if the BCF model includes random effects. #' @param rfx_basis (Optional) A matrix of bases for random effects model. Required if the BCF model includes random effects. @@ -484,13 +484,13 @@ compute_contrast_bart_model <- function( #' y <- 2 * X[,2] + 0.5 * X[,2] * Z + rnorm(n) #' bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, propensity_train = pi_X) #' ppd_samples <- sample_bcf_posterior_predictive( -#' model_object = bcf_model, covariates = X, -#' treatment = Z, propensity = pi_X +#' model_object = bcf_model, X = X, +#' Z = Z, propensity = pi_X #' ) sample_bcf_posterior_predictive <- function( model_object, - covariates = NULL, - treatment = NULL, + X = NULL, + Z = NULL, propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL, @@ -505,33 +505,33 @@ sample_bcf_posterior_predictive <- function( # Check that all the necessary inputs were provided for interval computation needs_covariates <- TRUE if (needs_covariates) { - if (is.null(covariates)) { + if (is.null(X)) { stop( - "'covariates' must be provided in order to compute the requested intervals" + "'X' must be provided in order to compute the requested intervals" ) } - if (!is.matrix(covariates) && !is.data.frame(covariates)) { - stop("'covariates' must be a matrix or data frame") + if (!is.matrix(X) && !is.data.frame(X)) { + stop("'X' must be a matrix or data frame") } } needs_treatment <- needs_covariates if (needs_treatment) { - if (is.null(treatment)) { + if (is.null(Z)) { stop( - "'treatment' must be provided in order to compute the requested intervals" + "'Z' must be provided in order to compute the requested intervals" ) } - if (!is.matrix(treatment) && !is.numeric(treatment)) { - stop("'treatment' must be a numeric vector or matrix") + if (!is.matrix(Z) && !is.numeric(Z)) { + stop("'Z' must be a numeric vector or matrix") } - if (is.matrix(treatment)) { - if (nrow(treatment) != nrow(covariates)) { - stop("'treatment' must have the same number of rows as 'covariates'") + if (is.matrix(Z)) { + if (nrow(Z) != nrow(X)) { + stop("'Z' must have the same number of rows as 'X'") } } else { - if (length(treatment) != nrow(covariates)) { + if (length(Z) != nrow(X)) { stop( - "'treatment' must have the same number of elements as 'covariates'" + "'Z' must have the same number of elements as 'X'" ) } } @@ -551,13 +551,13 @@ sample_bcf_posterior_predictive <- function( stop("'propensity' must be a numeric vector or matrix") } if (is.matrix(propensity)) { - if (nrow(propensity) != nrow(covariates)) { - stop("'propensity' must have the same number of rows as 'covariates'") + if (nrow(propensity) != nrow(X)) { + stop("'propensity' must have the same number of rows as 'X'") } } else { - if (length(propensity) != nrow(covariates)) { + if (length(propensity) != nrow(X)) { stop( - "'propensity' must have the same number of elements as 'covariates'" + "'propensity' must have the same number of elements as 'X'" ) } } @@ -569,9 +569,9 @@ sample_bcf_posterior_predictive <- function( "'rfx_group_ids' must be provided in order to compute the requested intervals" ) } - if (length(rfx_group_ids) != nrow(covariates)) { + if (length(rfx_group_ids) != nrow(X)) { stop( - "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) } if (is.null(rfx_basis)) { @@ -582,16 +582,16 @@ sample_bcf_posterior_predictive <- function( if (!is.matrix(rfx_basis)) { stop("'rfx_basis' must be a matrix") } - if (nrow(rfx_basis) != nrow(covariates)) { - stop("'rfx_basis' must have the same number of rows as 'covariates'") + if (nrow(rfx_basis) != nrow(X)) { + stop("'rfx_basis' must have the same number of rows as 'X'") } } # Compute posterior samples bcf_preds <- predict( model_object, - X = covariates, - Z = treatment, + X = X, + Z = Z, propensity = propensity, rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis, @@ -605,7 +605,7 @@ sample_bcf_posterior_predictive <- function( has_variance_forest <- model_object$model_params$include_variance_forest samples_global_variance <- model_object$model_params$sample_sigma2_global num_posterior_draws <- model_object$model_params$num_samples - num_observations <- nrow(covariates) + num_observations <- nrow(X) ppd_mean <- bcf_preds$y_hat if (has_variance_forest) { ppd_variance <- bcf_preds$variance_forest_predictions @@ -840,8 +840,8 @@ posterior_predictive_heuristic_multiplier <- function( #' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Note that `"mu"` is only different from `"prognostic_function"` if random effects are included with a model spec of `"intercept_only"` or `"intercept_plus_treatment"` and `"tau"` is only different from `"cate"` if random effects are included with a model spec of `"intercept_plus_treatment"`. #' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval). #' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". -#' @param covariates (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions). -#' @param treatment (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions). +#' @param X (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions). +#' @param Z (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions). #' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities. #' @param rfx_group_ids An optional vector of group IDs for random effects. Required if the requested term includes random effects. #' @param rfx_basis An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. @@ -863,8 +863,8 @@ posterior_predictive_heuristic_multiplier <- function( #' intervals <- compute_bcf_posterior_interval( #' model_object = bcf_model, #' terms = c("prognostic_function", "cate"), -#' covariates = X, -#' treatment = Z, +#' X = X, +#' Z = Z, #' propensity = pi_X, #' level = 0.90 #' ) @@ -873,8 +873,8 @@ compute_bcf_posterior_interval <- function( terms, level = 0.95, scale = "linear", - covariates = NULL, - treatment = NULL, + X = NULL, + Z = NULL, propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL @@ -930,33 +930,33 @@ compute_bcf_posterior_interval <- function( ("variance_forest" %in% terms) || (needs_covariates_intermediate)) if (needs_covariates) { - if (is.null(covariates)) { + if (is.null(X)) { stop( - "'covariates' must be provided in order to compute the requested intervals" + "'X' must be provided in order to compute the requested intervals" ) } - if (!is.matrix(covariates) && !is.data.frame(covariates)) { - stop("'covariates' must be a matrix or data frame") + if (!is.matrix(X) && !is.data.frame(X)) { + stop("'X' must be a matrix or data frame") } } needs_treatment <- needs_covariates if (needs_treatment) { - if (is.null(treatment)) { + if (is.null(Z)) { stop( - "'treatment' must be provided in order to compute the requested intervals" + "'Z' must be provided in order to compute the requested intervals" ) } - if (!is.matrix(treatment) && !is.numeric(treatment)) { - stop("'treatment' must be a numeric vector or matrix") + if (!is.matrix(Z) && !is.numeric(Z)) { + stop("'Z' must be a numeric vector or matrix") } - if (is.matrix(treatment)) { - if (nrow(treatment) != nrow(covariates)) { - stop("'treatment' must have the same number of rows as 'covariates'") + if (is.matrix(Z)) { + if (nrow(Z) != nrow(X)) { + stop("'Z' must have the same number of rows as 'X'") } } else { - if (length(treatment) != nrow(covariates)) { + if (length(Z) != nrow(X)) { stop( - "'treatment' must have the same number of elements as 'covariates'" + "'Z' must have the same number of elements as 'X'" ) } } @@ -976,13 +976,13 @@ compute_bcf_posterior_interval <- function( stop("'propensity' must be a numeric vector or matrix") } if (is.matrix(propensity)) { - if (nrow(propensity) != nrow(covariates)) { - stop("'propensity' must have the same number of rows as 'covariates'") + if (nrow(propensity) != nrow(X)) { + stop("'propensity' must have the same number of rows as 'X'") } } else { - if (length(propensity) != nrow(covariates)) { + if (length(propensity) != nrow(X)) { stop( - "'propensity' must have the same number of elements as 'covariates'" + "'propensity' must have the same number of elements as 'X'" ) } } @@ -998,9 +998,9 @@ compute_bcf_posterior_interval <- function( "'rfx_group_ids' must be provided in order to compute the requested intervals" ) } - if (length(rfx_group_ids) != nrow(covariates)) { + if (length(rfx_group_ids) != nrow(X)) { stop( - "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) } @@ -1016,8 +1016,8 @@ compute_bcf_posterior_interval <- function( if (!is.matrix(rfx_basis)) { stop("'rfx_basis' must be a matrix") } - if (nrow(rfx_basis) != nrow(covariates)) { - stop("'rfx_basis' must have the same number of rows as 'covariates'") + if (nrow(rfx_basis) != nrow(X)) { + stop("'rfx_basis' must have the same number of rows as 'X'") } } } @@ -1025,8 +1025,8 @@ compute_bcf_posterior_interval <- function( # Compute posterior matrices for the requested model terms predictions <- predict( model_object, - X = covariates, - Z = treatment, + X = X, + Z = Z, propensity = propensity, rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis, diff --git a/demo/debug/bart_predict_debug.py b/demo/debug/bart_predict_debug.py index d58c5ef1..ca617c8a 100644 --- a/demo/debug/bart_predict_debug.py +++ b/demo/debug/bart_predict_debug.py @@ -63,7 +63,7 @@ # Compute posterior interval intervals = bart_model.compute_posterior_interval( - terms="all", scale="linear", level=0.95, covariates=X_test + terms="all", scale="linear", level=0.95, X=X_test ) # Check coverage @@ -75,7 +75,7 @@ # Sample from the posterior predictive distribution bart_ppd_samples = bart_model.sample_posterior_predictive( - covariates=X_test, num_draws_per_sample=10 + X=X_test, num_draws_per_sample=10 ) # Plot PPD mean vs actual diff --git a/demo/debug/bcf_predict_debug.py b/demo/debug/bcf_predict_debug.py index 9a628bb4..24b68031 100644 --- a/demo/debug/bcf_predict_debug.py +++ b/demo/debug/bcf_predict_debug.py @@ -90,8 +90,8 @@ terms="all", scale="linear", level=0.95, - covariates=X_test, - treatment=Z_test, + X=X_test, + Z=Z_test, propensity=pi_test, ) @@ -118,7 +118,7 @@ # Sample from the posterior predictive distribution bcf_ppd_samples = bcf_model.sample_posterior_predictive( - covariates=X_test, treatment=Z_test, propensity=pi_test, num_draws_per_sample=10 + X=X_test, Z=Z_test, propensity=pi_test, num_draws_per_sample=10 ) # Plot PPD mean vs actual @@ -229,8 +229,8 @@ terms="all", scale="linear", level=0.95, - covariates=X_test, - treatment=Z_test, + X=X_test, + Z=Z_test, propensity=pi_test, rfx_group_ids=rfx_group_ids_test, ) @@ -240,8 +240,8 @@ terms="prognostic_function", scale="linear", level=0.95, - covariates=X_test, - treatment=Z_test, + X=X_test, + Z=Z_test, propensity=pi_test, rfx_group_ids=rfx_group_ids_test ) @@ -251,8 +251,8 @@ terms="cate", scale="linear", level=0.95, - covariates=X_test, - treatment=Z_test, + X=X_test, + Z=Z_test, propensity=pi_test, rfx_group_ids=rfx_group_ids_test ) @@ -284,8 +284,8 @@ terms="mu", scale="linear", level=0.95, - covariates=X_test, - treatment=Z_test, + X=X_test, + Z=Z_test, propensity=pi_test, rfx_group_ids=rfx_group_ids_test ) @@ -293,8 +293,8 @@ terms="tau", scale="linear", level=0.95, - covariates=X_test, - treatment=Z_test, + X=X_test, + Z=Z_test, propensity=pi_test, rfx_group_ids=rfx_group_ids_test ) diff --git a/demo/debug/causal_inference_binary_outcome.py b/demo/debug/causal_inference_binary_outcome.py index 6f3c75d4..4d249cbd 100644 --- a/demo/debug/causal_inference_binary_outcome.py +++ b/demo/debug/causal_inference_binary_outcome.py @@ -1,7 +1,6 @@ # Load necessary libraries import numpy as np import pandas as pd -import seaborn as sns import matplotlib.pyplot as plt from stochtree import BCFModel from sklearn.model_selection import train_test_split diff --git a/demo/debug/multi_chain.py b/demo/debug/multi_chain.py index 59d6e11d..e5f621ba 100644 --- a/demo/debug/multi_chain.py +++ b/demo/debug/multi_chain.py @@ -3,7 +3,6 @@ # Load necessary libraries import matplotlib.pyplot as plt import numpy as np -import pandas as pd import arviz as az from sklearn.model_selection import train_test_split diff --git a/demo/debug/multiple_initializations.py b/demo/debug/multiple_initializations.py index b489ee80..ad3b60a3 100644 --- a/demo/debug/multiple_initializations.py +++ b/demo/debug/multiple_initializations.py @@ -118,14 +118,14 @@ def outcome_mean(X, W): ) # Inspect the model outputs -bart_preds_2 = bart_model_2.predict(X=X_test, basis_test) +bart_preds_2 = bart_model_2.predict(X=X_test, leaf_basis=basis_test) y_hat_mcmc_2 = bart_preds_2['y_hat'] y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True) y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True) -bart_preds_3 = bart_model_3.predict(X=X_test, basis_test) +bart_preds_3 = bart_model_3.predict(X=X_test, leaf_basis=basis_test) y_hat_mcmc_3 = bart_preds_3['y_hat'] y_avg_mcmc_3 = np.squeeze(y_hat_mcmc_3).mean(axis=1, keepdims=True) -bart_preds_4 = bart_model_4.predict(X=X_test, basis_test) +bart_preds_4 = bart_model_4.predict(X=X_test, leaf_basis=basis_test) y_hat_mcmc_4 = bart_preds_4['y_hat'] y_avg_mcmc_4 = np.squeeze(y_hat_mcmc_4).mean(axis=1, keepdims=True) y_df = pd.DataFrame( diff --git a/man/compute_bcf_posterior_interval.Rd b/man/compute_bcf_posterior_interval.Rd index 118c0256..00e12157 100644 --- a/man/compute_bcf_posterior_interval.Rd +++ b/man/compute_bcf_posterior_interval.Rd @@ -9,8 +9,8 @@ compute_bcf_posterior_interval( terms, level = 0.95, scale = "linear", - covariates = NULL, - treatment = NULL, + X = NULL, + Z = NULL, propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL @@ -25,9 +25,9 @@ compute_bcf_posterior_interval( \item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} -\item{covariates}{(Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).} +\item{X}{(Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).} -\item{treatment}{(Optional) A vector or matrix of treatment assignments. Required if the requested term is \code{"y_hat"} (overall predictions).} +\item{Z}{(Optional) A vector or matrix of treatment assignments. Required if the requested term is \code{"y_hat"} (overall predictions).} \item{propensity}{(Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.} @@ -55,8 +55,8 @@ bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, intervals <- compute_bcf_posterior_interval( model_object = bcf_model, terms = c("prognostic_function", "cate"), - covariates = X, - treatment = Z, + X = X, + Z = Z, propensity = pi_X, level = 0.90 ) diff --git a/man/sample_bcf_posterior_predictive.Rd b/man/sample_bcf_posterior_predictive.Rd index 0c77d7c1..b6cb191d 100644 --- a/man/sample_bcf_posterior_predictive.Rd +++ b/man/sample_bcf_posterior_predictive.Rd @@ -6,8 +6,8 @@ \usage{ sample_bcf_posterior_predictive( model_object, - covariates = NULL, - treatment = NULL, + X = NULL, + Z = NULL, propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL, @@ -17,9 +17,9 @@ sample_bcf_posterior_predictive( \arguments{ \item{model_object}{A fitted BCF model object of class \code{bcfmodel}.} -\item{covariates}{A matrix or data frame of covariates.} +\item{X}{A matrix or data frame of covariates.} -\item{treatment}{A vector or matrix of treatment assignments.} +\item{Z}{A vector or matrix of treatment assignments.} \item{propensity}{(Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.} @@ -44,7 +44,7 @@ Z <- rbinom(n, 1, pi_X) y <- 2 * X[,2] + 0.5 * X[,2] * Z + rnorm(n) bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, propensity_train = pi_X) ppd_samples <- sample_bcf_posterior_predictive( - model_object = bcf_model, covariates = X, - treatment = Z, propensity = pi_X + model_object = bcf_model, X = X, + Z = Z, propensity = pi_X ) } diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 4ce66288..983791ab 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -3263,8 +3263,8 @@ def compute_posterior_interval( terms: Union[list[str], str] = "all", level: float = 0.95, scale: str = "linear", - covariates: np.array = None, - treatment: np.array = None, + X: np.array = None, + Z: np.array = None, propensity: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, @@ -3280,9 +3280,9 @@ def compute_posterior_interval( Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`. level : float, optional A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval. - covariates : np.array, optional + X : np.array, optional Optional array or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, treatment effect forest, variance forest, or overall predictions). - treatment : np.array, optional + Z : np.array, optional Optional array of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions). propensity : np.array, optional Optional array of propensity scores. Required if the underlying model depends on user-provided propensities. @@ -3346,25 +3346,25 @@ def compute_posterior_interval( or needs_covariates_intermediate ) if needs_covariates: - if covariates is None: + if X is None: raise ValueError( - "'covariates' must be provided in order to compute the requested intervals" + "'X' must be provided in order to compute the requested intervals" ) - if not isinstance(covariates, np.ndarray) and not isinstance( - covariates, pd.DataFrame + if not isinstance(X, np.ndarray) and not isinstance( + X, pd.DataFrame ): - raise ValueError("'covariates' must be a matrix or data frame") + raise ValueError("'X' must be a matrix or data frame") needs_treatment = needs_covariates if needs_treatment: - if treatment is None: + if Z is None: raise ValueError( - "'treatment' must be provided in order to compute the requested intervals" + "'Z' must be provided in order to compute the requested intervals" ) - if not isinstance(treatment, np.ndarray): - raise ValueError("'treatment' must be a numpy array") - if treatment.shape[0] != covariates.shape[0]: + if not isinstance(Z, np.ndarray): + raise ValueError("'Z' must be a numpy array") + if Z.shape[0] != X.shape[0]: raise ValueError( - "'treatment' must have the same number of rows as 'covariates'" + "'Z' must have the same number of rows as 'X'" ) uses_propensity = self.propensity_covariate != "none" internal_propensity_model = self.internal_propensity_model @@ -3378,9 +3378,9 @@ def compute_posterior_interval( ) if not isinstance(propensity, np.ndarray): raise ValueError("'propensity' must be a numpy array") - if propensity.shape[0] != covariates.shape[0]: + if propensity.shape[0] != X.shape[0]: raise ValueError( - "'propensity' must have the same number of rows as 'covariates'" + "'propensity' must have the same number of rows as 'X'" ) needs_rfx_data_intermediate = ( ("y_hat" in terms) or ("all" in terms) @@ -3393,9 +3393,9 @@ def compute_posterior_interval( ) if not isinstance(rfx_group_ids, np.ndarray): raise ValueError("'rfx_group_ids' must be a numpy array") - if rfx_group_ids.shape[0] != covariates.shape[0]: + if rfx_group_ids.shape[0] != X.shape[0]: raise ValueError( - "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) if self.rfx_model_spec == "custom": if rfx_basis is None: @@ -3405,15 +3405,15 @@ def compute_posterior_interval( if rfx_basis is not None: if not isinstance(rfx_basis, np.ndarray): raise ValueError("'rfx_basis' must be a numpy array") - if rfx_basis.shape[0] != covariates.shape[0]: + if rfx_basis.shape[0] != X.shape[0]: raise ValueError( - "'rfx_basis' must have the same number of rows as 'covariates'" + "'rfx_basis' must have the same number of rows as 'X'" ) # Compute posterior matrices for the requested model terms predictions = self.predict( - X=covariates, - Z=treatment, + X=X, + Z=Z, propensity=propensity, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, @@ -3437,8 +3437,8 @@ def compute_posterior_interval( def sample_posterior_predictive( self, - covariates: np.array, - treatment: np.array, + X: np.array, + Z: np.array, propensity: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, @@ -3449,9 +3449,9 @@ def sample_posterior_predictive( Parameters ---------- - covariates : np.array + X : np.array An array or data frame of covariates. - treatment : np.array + Z : np.array An array of treatment assignments. propensity : np.array, optional Optional array of propensity scores. Required if the underlying model depends on user-provided propensities. @@ -3477,25 +3477,25 @@ def sample_posterior_predictive( # Check that all the necessary inputs were provided for interval computation needs_covariates = True if needs_covariates: - if covariates is None: + if X is None: raise ValueError( - "'covariates' must be provided in order to compute the requested intervals" + "'X' must be provided in order to compute the requested intervals" ) - if not isinstance(covariates, np.ndarray) and not isinstance( - covariates, pd.DataFrame + if not isinstance(X, np.ndarray) and not isinstance( + X, pd.DataFrame ): - raise ValueError("'covariates' must be a matrix or data frame") + raise ValueError("'X' must be a matrix or data frame") needs_treatment = needs_covariates if needs_treatment: - if treatment is None: + if Z is None: raise ValueError( - "'treatment' must be provided in order to compute the requested intervals" + "'Z' must be provided in order to compute the requested intervals" ) - if not isinstance(treatment, np.ndarray): - raise ValueError("'treatment' must be a numpy array") - if treatment.shape[0] != covariates.shape[0]: + if not isinstance(Z, np.ndarray): + raise ValueError("'Z' must be a numpy array") + if Z.shape[0] != X.shape[0]: raise ValueError( - "'treatment' must have the same number of rows as 'covariates'" + "'Z' must have the same number of rows as 'X'" ) uses_propensity = self.propensity_covariate != "none" internal_propensity_model = self.internal_propensity_model @@ -3509,9 +3509,9 @@ def sample_posterior_predictive( ) if not isinstance(propensity, np.ndarray): raise ValueError("'propensity' must be a numpy array") - if propensity.shape[0] != covariates.shape[0]: + if propensity.shape[0] != X.shape[0]: raise ValueError( - "'propensity' must have the same number of rows as 'covariates'" + "'propensity' must have the same number of rows as 'X'" ) needs_rfx_data = self.has_rfx if needs_rfx_data: @@ -3521,9 +3521,9 @@ def sample_posterior_predictive( ) if not isinstance(rfx_group_ids, np.ndarray): raise ValueError("'rfx_group_ids' must be a numpy array") - if rfx_group_ids.shape[0] != covariates.shape[0]: + if rfx_group_ids.shape[0] != X.shape[0]: raise ValueError( - "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) if rfx_basis is None: raise ValueError( @@ -3531,15 +3531,15 @@ def sample_posterior_predictive( ) if not isinstance(rfx_basis, np.ndarray): raise ValueError("'rfx_basis' must be a numpy array") - if rfx_basis.shape[0] != covariates.shape[0]: + if rfx_basis.shape[0] != X.shape[0]: raise ValueError( - "'rfx_basis' must have the same number of rows as 'covariates'" + "'rfx_basis' must have the same number of rows as 'X'" ) # Compute posterior predictive samples bcf_preds = self.predict( - X=covariates, - Z=treatment, + X=X, + Z=Z, propensity=propensity, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, @@ -3552,7 +3552,7 @@ def sample_posterior_predictive( has_variance_forest = self.include_variance_forest samples_global_variance = self.sample_sigma2_global num_posterior_draws = self.num_samples - num_observations = covariates.shape[0] + num_observations = X.shape[0] ppd_mean = bcf_preds["y_hat"] if has_variance_forest: ppd_variance = bcf_preds["variance_forest_predictions"] diff --git a/test/python/test_bart.py b/test/python/test_bart.py index 8abebfdb..b182524b 100644 --- a/test/python/test_bart.py +++ b/test/python/test_bart.py @@ -1,5 +1,4 @@ import numpy as np -import pytest from sklearn.model_selection import train_test_split from stochtree import BARTModel diff --git a/test/python/test_predict.py b/test/python/test_predict.py index ebc6fbc5..117cac04 100644 --- a/test/python/test_predict.py +++ b/test/python/test_predict.py @@ -279,7 +279,6 @@ def test_bcf_prediction(self): # Generate data and test/train split rng = np.random.default_rng(1234) n = 100 - g = lambda x: np.where(x[:, 4] == 1, 2, np.where(x[:, 4] == 2, -1, -4)) x1 = rng.normal(size=n) x2 = rng.normal(size=n) x3 = rng.normal(size=n) diff --git a/tools/debug/bcf_predict_debug.R b/tools/debug/bcf_predict_debug.R index 70bc71ed..3ed45a2c 100644 --- a/tools/debug/bcf_predict_debug.R +++ b/tools/debug/bcf_predict_debug.R @@ -78,8 +78,8 @@ y_hat_intervals <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "linear", terms = c("all"), - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, level = 0.95 ) @@ -94,8 +94,8 @@ y_hat_intervals <- compute_bcf_posterior_interval( quantiles <- c(0.05, 0.95) ppd_samples <- sample_bcf_posterior_predictive( model_object = bcf_model, - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, num_draws = 1 ) @@ -179,8 +179,8 @@ y_hat_intervals <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "linear", terms = c("y_hat"), - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, level = 0.95 ) @@ -190,8 +190,8 @@ y_hat_prob_intervals <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "probability", terms = c("y_hat"), - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, level = 0.95 ) @@ -215,8 +215,8 @@ lines(y_hat_prob_intervals$upper[sort_inds]) # Draw from posterior predictive for covariates / treatment values in the test set ppd_samples <- sample_bcf_posterior_predictive( model_object = bcf_model, - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, num_draws = 10 ) @@ -360,8 +360,8 @@ posterior_intervals_test <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "linear", terms = "all", - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, rfx_group_ids = rfx_group_ids_test, level = 0.95 @@ -372,8 +372,8 @@ prog_intervals_test <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "linear", terms = "prognostic_function", - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, rfx_group_ids = rfx_group_ids_test, level = 0.95 @@ -384,8 +384,8 @@ cate_intervals_test <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "linear", terms = "cate", - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, rfx_group_ids = rfx_group_ids_test, level = 0.95 @@ -426,8 +426,8 @@ mu_intervals_test <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "linear", terms = "mu", - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, rfx_group_ids = rfx_group_ids_test, level = 0.95 @@ -436,8 +436,8 @@ tau_intervals_test <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "linear", terms = "tau", - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, rfx_group_ids = rfx_group_ids_test, level = 0.95