From ee29c0308557810b51f8017ad6da3fe869ac7bc7 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 10 Nov 2025 13:40:19 -0600 Subject: [PATCH] Fixed bug with number of burning / MCMC samples in BCF with an internal propensity model (and made sampling behavior consistent across R and Python) --- R/bcf.R | 19 ++++++++----------- stochtree/bcf.py | 13 +++++++++---- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index 82ee9cd1..605290ac 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -912,26 +912,23 @@ bcf <- function( if ((is.null(propensity_train)) && (propensity_covariate != "none")) { internal_propensity_model <- TRUE # Estimate using the last of several iterations of GFR BART - num_burnin <- 10 - num_total <- 50 + num_gfr_propensity <- 10 + num_burnin_propensity <- 0 + num_mcmc_propensity <- 10 bart_model_propensity <- bart( X_train = X_train, y_train = as.numeric(Z_train), X_test = X_test_raw, - num_gfr = num_total, - num_burnin = 0, - num_mcmc = 0 + num_gfr = num_gfr_propensity, + num_burnin = num_burnin_propensity, + num_mcmc = num_mcmc_propensity ) - propensity_train <- rowMeans(bart_model_propensity$y_hat_train[, - (num_burnin + 1):num_total - ]) + propensity_train <- rowMeans(bart_model_propensity$y_hat_train) if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) { propensity_train <- as.matrix(propensity_train) } if (has_test) { - propensity_test <- rowMeans(bart_model_propensity$y_hat_test[, - (num_burnin + 1):num_total - ]) + propensity_test <- rowMeans(bart_model_propensity$y_hat_test) if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) { propensity_test <- as.matrix(propensity_test) } diff --git a/stochtree/bcf.py b/stochtree/bcf.py index d51eaf9e..b1e9c5a5 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -1169,13 +1169,17 @@ def sample( ) else: self.bart_propensity_model = BARTModel() + num_gfr_propensity = 10 + num_burnin_propensity = 0 + num_mcmc_propensity = 10 if self.has_test: self.bart_propensity_model.sample( X_train=X_train_processed, y_train=Z_train, X_test=X_test_processed, - num_gfr=10, - num_mcmc=10, + num_gfr=num_gfr_propensity, + num_burnin=num_burnin_propensity, + num_mcmc=num_mcmc_propensity ) pi_train = np.mean( self.bart_propensity_model.y_hat_train, axis=1, keepdims=True @@ -1187,8 +1191,9 @@ def sample( self.bart_propensity_model.sample( X_train=X_train_processed, y_train=Z_train, - num_gfr=10, - num_mcmc=10, + num_gfr=num_gfr_propensity, + num_burnin=num_burnin_propensity, + num_mcmc=num_mcmc_propensity ) pi_train = np.mean( self.bart_propensity_model.y_hat_train, axis=1, keepdims=True