diff --git a/R/bcf.R b/R/bcf.R index 82ee9cd1..605290ac 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -912,26 +912,23 @@ bcf <- function( if ((is.null(propensity_train)) && (propensity_covariate != "none")) { internal_propensity_model <- TRUE # Estimate using the last of several iterations of GFR BART - num_burnin <- 10 - num_total <- 50 + num_gfr_propensity <- 10 + num_burnin_propensity <- 0 + num_mcmc_propensity <- 10 bart_model_propensity <- bart( X_train = X_train, y_train = as.numeric(Z_train), X_test = X_test_raw, - num_gfr = num_total, - num_burnin = 0, - num_mcmc = 0 + num_gfr = num_gfr_propensity, + num_burnin = num_burnin_propensity, + num_mcmc = num_mcmc_propensity ) - propensity_train <- rowMeans(bart_model_propensity$y_hat_train[, - (num_burnin + 1):num_total - ]) + propensity_train <- rowMeans(bart_model_propensity$y_hat_train) if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) { propensity_train <- as.matrix(propensity_train) } if (has_test) { - propensity_test <- rowMeans(bart_model_propensity$y_hat_test[, - (num_burnin + 1):num_total - ]) + propensity_test <- rowMeans(bart_model_propensity$y_hat_test) if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) { propensity_test <- as.matrix(propensity_test) } diff --git a/stochtree/bcf.py b/stochtree/bcf.py index d51eaf9e..b1e9c5a5 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -1169,13 +1169,17 @@ def sample( ) else: self.bart_propensity_model = BARTModel() + num_gfr_propensity = 10 + num_burnin_propensity = 0 + num_mcmc_propensity = 10 if self.has_test: self.bart_propensity_model.sample( X_train=X_train_processed, y_train=Z_train, X_test=X_test_processed, - num_gfr=10, - num_mcmc=10, + num_gfr=num_gfr_propensity, + num_burnin=num_burnin_propensity, + num_mcmc=num_mcmc_propensity ) pi_train = np.mean( self.bart_propensity_model.y_hat_train, axis=1, keepdims=True @@ -1187,8 +1191,9 @@ def sample( self.bart_propensity_model.sample( X_train=X_train_processed, y_train=Z_train, - num_gfr=10, - num_mcmc=10, + num_gfr=num_gfr_propensity, + num_burnin=num_burnin_propensity, + num_mcmc=num_mcmc_propensity ) pi_train = np.mean( self.bart_propensity_model.y_hat_train, axis=1, keepdims=True