From bb77a17d5711b210ea1b73758952eec33d9619ec Mon Sep 17 00:00:00 2001 From: Jared Murray <30992825+jaredsmurray@users.noreply.github.com> Date: Tue, 25 Nov 2025 17:02:16 -0600 Subject: [PATCH 1/3] Promote rfx_beta_draws to consistent array dimensions in predict.bcf Ensure rfx_beta_draws has consistent dimensions when there's one rfx term. --- R/bcf.R | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/R/bcf.R b/R/bcf.R index d6e11ad0..765347cb 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -3162,6 +3162,11 @@ predict.bcfmodel <- function( rfx_beta_draws <- rfx_param_list$beta_samples * object$model_params$outcome_scale + # Promote to an array with consistent dimensions when there's one rfx term + if (length(dim(rfx_beta_draws)) == 2) { + dim(rfx_beta_draws) <- c(1, dim(rfx_beta_draws)) + } + # Construct a matrix with the appropriate group random effects arranged for each observation rfx_predictions_raw <- array( NA, From 84a52e28c4e5bb0b3c32792b466d52a6d1f10a76 Mon Sep 17 00:00:00 2001 From: Jared Murray <30992825+jaredsmurray@users.noreply.github.com> Date: Tue, 25 Nov 2025 17:04:58 -0600 Subject: [PATCH 2/3] Promote rfx_beta_draws to array with consistent dimensions Ensure rfx_beta_draws has consistent dimensions when there's one rfx term. --- R/bart.R | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/R/bart.R b/R/bart.R index 177ed961..f5068f96 100644 --- a/R/bart.R +++ b/R/bart.R @@ -2219,6 +2219,11 @@ predict.bartmodel <- function( rfx_param_list <- object$rfx_samples$extract_parameter_samples() rfx_beta_draws <- rfx_param_list$beta_samples * y_std + # Promote to an array with consistent dimensions when there's one rfx term + if (length(dim(rfx_beta_draws)) == 2) { + dim(rfx_beta_draws) <- c(1, dim(rfx_beta_draws)) + } + # Construct a matrix with the appropriate group random effects arranged for each observation rfx_predictions_raw <- array( NA, From 3cb48042221d82c10cd2b590fd64dd670dd99462 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 25 Nov 2025 19:33:50 -0500 Subject: [PATCH 3/3] Added unit tests for the intercept only model for R and Python BART and BCF --- test/R/testthat/test-predict.R | 238 +++++++++++++++++++++++++++++++++ test/python/test_predict.py | 108 +++++++++++++++ 2 files changed, 346 insertions(+) diff --git a/test/R/testthat/test-predict.R b/test/R/testthat/test-predict.R index 63ff0f94..ac541be0 100644 --- a/test/R/testthat/test-predict.R +++ b/test/R/testthat/test-predict.R @@ -285,6 +285,76 @@ test_that("BART predictions with pre-summarization", { expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term) }) +test_that("BART predictions with random effects", { + # Generate data and test-train split + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + noise_sd <- 1 + rfx_group_ids <- sample(1:3, n, replace = TRUE) + rfx_coefs <- c(-2, 0, 2) + rfx_term <- rfx_coefs[rfx_group_ids] + rfx_basis <- matrix(1, nrow = n, ncol = 1) + y <- f_XW + rfx_term + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds, ] + rfx_basis_train <- rfx_basis[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Fit a "classic" BART model + rfx_params <- list(model_spec = "intercept_only") + bart_model <- bart( + X_train = X_train, + y_train = y_train, + rfx_group_ids_train = rfx_group_ids_train, + random_effects_params = rfx_params, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10 + ) + + # Check that the default predict method returns a list + pred <- predict(bart_model, X = X_test, rfx_group_ids = rfx_group_ids_test) + y_hat_posterior_test <- pred$y_hat + expect_equal(dim(y_hat_posterior_test), c(20, 10)) + + # Check that the pre-aggregated predictions match with those computed by rowMeans + pred_mean <- predict( + bart_model, + X = X_test, + rfx_group_ids = rfx_group_ids_test, + type = "mean" + ) + y_hat_mean_test <- pred_mean$y_hat + expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) + + # Check that we warn and return a NULL when requesting terms that weren't fit + expect_warning({ + pred_mean <- predict( + bart_model, + X = X_test, + rfx_group_ids = rfx_group_ids_test, + type = "mean", + terms = c("variance_forest") + ) + }) + expect_equal(NULL, pred_mean) +}) + test_that("BCF predictions with pre-summarization", { # Generate data and test-train split n <- 100 @@ -443,3 +513,171 @@ test_that("BCF predictions with pre-summarization", { expect_equal(y_hat_mean_test, y_hat_mean_test_single_term) expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term) }) + +test_that("BCF predictions with random effects", { + # Generate data and test-train split + n <- 100 + g <- function(x) { + ifelse(x[, 5] == 1, 2, ifelse(x[, 5] == 2, -1, -4)) + } + x1 <- rnorm(n) + x2 <- rnorm(n) + x3 <- rnorm(n) + x4 <- as.numeric(rbinom(n, 1, 0.5)) + x5 <- as.numeric(sample(1:3, n, replace = TRUE)) + X <- cbind(x1, x2, x3, x4, x5) + p <- ncol(X) + mu_x <- 1 + g(X) + X[, 1] * X[, 3] + tau_x <- 1 + 2 * X[, 2] * X[, 4] + pi_x <- 0.8 * + pnorm((3 * mu_x / sd(mu_x)) - 0.5 * X[, 1]) + + 0.05 + + runif(n) / 10 + Z <- rbinom(n, 1, pi_x) + E_XZ <- mu_x + Z * tau_x + rfx_group_ids <- sample(1:3, n, replace = TRUE) + rfx_basis <- cbind(1, Z) + rfx_coefs <- matrix( + c( + -2, + -0.5, + 0, + 0.0, + 2, + 0.5 + ), + byrow = T, + ncol = 2 + ) + rfx_term <- rowSums(rfx_basis * rfx_coefs[rfx_group_ids, ]) + snr <- 2 + y <- E_XZ + rfx_term + rnorm(n, 0, 1) * (sd(E_XZ + rfx_term) / snr) + X <- as.data.frame(X) + X$x4 <- factor(X$x4, ordered = TRUE) + X$x5 <- factor(X$x5, ordered = TRUE) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + pi_test <- pi_x[test_inds] + pi_train <- pi_x[train_inds] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds, ] + rfx_basis_train <- rfx_basis[train_inds, ] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Fit a BCF model with random intercept and random slope on Z + rfx_params = list(model_spec = "intercept_plus_treatment") + bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, + random_effects_params = rfx_params, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10 + ) + + # Check that the default predict method returns a list + pred <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test + ) + y_hat_posterior_test <- pred$y_hat + expect_equal(dim(y_hat_posterior_test), c(20, 10)) + + # Check that the pre-aggregated predictions match with those computed by rowMeans + pred_mean <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + type = "mean" + ) + y_hat_mean_test <- pred_mean$y_hat + expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) + + # Check that we warn and return a NULL when requesting terms that weren't fit + expect_warning({ + pred_mean <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean", + terms = c("variance_forest") + ) + }) + expect_equal(NULL, pred_mean) + + # Fit a BCF model with random intercept only + # Fit a BCF model with random intercept and random slope on Z + rfx_params = list(model_spec = "intercept_only") + bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, + random_effects_params = rfx_params, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10 + ) + + # Check that the default predict method returns a list + pred <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test + ) + y_hat_posterior_test <- pred$y_hat + expect_equal(dim(y_hat_posterior_test), c(20, 10)) + + # Check that the pre-aggregated predictions match with those computed by rowMeans + pred_mean <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + type = "mean" + ) + y_hat_mean_test <- pred_mean$y_hat + expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) + + # Check that we warn and return a NULL when requesting terms that weren't fit + expect_warning({ + pred_mean <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean", + terms = c("variance_forest") + ) + }) +}) diff --git a/test/python/test_predict.py b/test/python/test_predict.py index 117cac04..cd567f83 100644 --- a/test/python/test_predict.py +++ b/test/python/test_predict.py @@ -275,6 +275,45 @@ def test_bart_prediction(self): sigma2_hat_mean_test, sigma2_hat_mean_test_single_term ) + # Generate data with random effects + rfx_group_ids = rng.choice(3, size=n) + rfx_basis = np.ones((n, 1)) + rfx_coefs = np.array([-2.0, 0.0, 2.0]) + rfx_term = rfx_coefs[rfx_group_ids] + noise_sd = 1 + y = f_XW + rfx_term + rng.normal(0, noise_sd, size=n) + test_set_pct = 0.2 + train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 + ) + X_train = X[train_inds, :] + X_test = X[test_inds, :] + rfx_group_ids_train = rfx_group_ids[train_inds] + rfx_group_ids_test = rfx_group_ids[test_inds] + rfx_basis_train = rfx_basis[train_inds,:] + rfx_basis_test = rfx_basis[test_inds,:] + y_train = y[train_inds] + y_test = y[test_inds] + + # Fit a BART model with random intercepts + rfx_params = {"model_spec": "intercept_only"} + bart_model = BARTModel() + bart_model.sample( + X_train=X_train, y_train=y_train, rfx_group_ids_train=rfx_group_ids_train, random_effects_params=rfx_params, num_gfr=10, num_burnin=0, num_mcmc=10 + ) + + # Check that the default predict method returns a dictionary + pred = bart_model.predict(X=X_test, rfx_group_ids=rfx_group_ids_test) + y_hat_posterior_test = pred["y_hat"] + assert y_hat_posterior_test.shape == (20, 10) + + # Check that the pre-aggregated predictions match with those computed by np.mean + pred_mean = bart_model.predict(X=X_test, rfx_group_ids=rfx_group_ids_test, type="mean") + y_hat_mean_test = pred_mean["y_hat"] + np.testing.assert_almost_equal( + y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1) + ) + def test_bcf_prediction(self): # Generate data and test/train split rng = np.random.default_rng(1234) @@ -417,3 +456,72 @@ def g(x5): np.testing.assert_almost_equal( sigma2_hat_mean_test, sigma2_hat_mean_test_single_term ) + + # Generate data with random effects + rfx_group_ids = rng.choice(3, size=n) + rfx_basis = np.concatenate((np.ones((n, 1)), np.expand_dims(Z, 1)), axis=1) + rfx_coefs = np.array([[-2.0, -0.5], [0.0, 0.0], [2.0, 0.5]]) + rfx_term = np.multiply(rfx_coefs[rfx_group_ids,:], rfx_basis).sum(axis=1) + E_XZ = mu_x + tau_x * Z + rfx_term + snr = 2 + y = E_XZ + rng.normal(loc=0.0, scale=np.std(E_XZ) / snr, size=(n,)) + test_set_pct = 0.2 + train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 + ) + X_train = X.iloc[train_inds, :] + X_test = X.iloc[test_inds, :] + Z_train = Z[train_inds] + Z_test = Z[test_inds] + pi_x_train = pi_x[train_inds] + pi_x_test = pi_x[test_inds] + rfx_group_ids_train = rfx_group_ids[train_inds] + rfx_group_ids_test = rfx_group_ids[test_inds] + rfx_basis_train = rfx_basis[train_inds,:] + rfx_basis_test = rfx_basis[test_inds,:] + y_train = y[train_inds] + y_test = y[test_inds] + + # Fit a "classic" BCF model + rfx_params = {"model_spec": "intercept_only"} + bcf_model = BCFModel() + bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + propensity_train=pi_x_train, + rfx_group_ids_train=rfx_group_ids_train, + X_test=X_test, + Z_test=Z_test, + propensity_test=pi_x_test, + rfx_group_ids_test=rfx_group_ids_test, + random_effects_params=rfx_params, + num_gfr=10, + num_burnin=0, + num_mcmc=10, + ) + + # Check that the default predict method returns a dictionary + pred = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_x_test, rfx_group_ids=rfx_group_ids_test) + y_hat_posterior_test = pred["y_hat"] + assert y_hat_posterior_test.shape == (20, 10) + + # Check that the pre-aggregated predictions match with those computed by np.mean + pred_mean = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_x_test, rfx_group_ids=rfx_group_ids_test, type="mean" + ) + y_hat_mean_test = pred_mean["y_hat"] + np.testing.assert_almost_equal( + y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1) + ) + + # Check that we warn and return None when requesting terms that weren't fit + with pytest.warns(UserWarning): + pred_mean = bcf_model.predict( + X=X_test, + Z=Z_test, + propensity=pi_x_test, + rfx_group_ids=rfx_group_ids_test, + type="mean", + terms=["variance_forest"], + )