Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -2219,6 +2219,11 @@ predict.bartmodel <- function(
rfx_param_list <- object$rfx_samples$extract_parameter_samples()
rfx_beta_draws <- rfx_param_list$beta_samples * y_std

# Promote to an array with consistent dimensions when there's one rfx term
if (length(dim(rfx_beta_draws)) == 2) {
dim(rfx_beta_draws) <- c(1, dim(rfx_beta_draws))
}

# Construct a matrix with the appropriate group random effects arranged for each observation
rfx_predictions_raw <- array(
NA,
Expand Down
5 changes: 5 additions & 0 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -3162,6 +3162,11 @@ predict.bcfmodel <- function(
rfx_beta_draws <- rfx_param_list$beta_samples *
object$model_params$outcome_scale

# Promote to an array with consistent dimensions when there's one rfx term
if (length(dim(rfx_beta_draws)) == 2) {
dim(rfx_beta_draws) <- c(1, dim(rfx_beta_draws))
}

# Construct a matrix with the appropriate group random effects arranged for each observation
rfx_predictions_raw <- array(
NA,
Expand Down
238 changes: 238 additions & 0 deletions test/R/testthat/test-predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,76 @@ test_that("BART predictions with pre-summarization", {
expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term)
})

test_that("BART predictions with random effects", {
# Generate data and test-train split
n <- 100
p <- 5
X <- matrix(runif(n * p), ncol = p)
# fmt: skip
f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) +
((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) +
((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) +
((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5))
noise_sd <- 1
rfx_group_ids <- sample(1:3, n, replace = TRUE)
rfx_coefs <- c(-2, 0, 2)
rfx_term <- rfx_coefs[rfx_group_ids]
rfx_basis <- matrix(1, nrow = n, ncol = 1)
y <- f_XW + rfx_term + rnorm(n, 0, noise_sd)
test_set_pct <- 0.2
n_test <- round(test_set_pct * n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds, ]
X_train <- X[train_inds, ]
rfx_group_ids_test <- rfx_group_ids[test_inds]
rfx_group_ids_train <- rfx_group_ids[train_inds]
rfx_basis_test <- rfx_basis[test_inds, ]
rfx_basis_train <- rfx_basis[train_inds, ]
y_test <- y[test_inds]
y_train <- y[train_inds]

# Fit a "classic" BART model
rfx_params <- list(model_spec = "intercept_only")
bart_model <- bart(
X_train = X_train,
y_train = y_train,
rfx_group_ids_train = rfx_group_ids_train,
random_effects_params = rfx_params,
num_gfr = 10,
num_burnin = 0,
num_mcmc = 10
)

# Check that the default predict method returns a list
pred <- predict(bart_model, X = X_test, rfx_group_ids = rfx_group_ids_test)
y_hat_posterior_test <- pred$y_hat
expect_equal(dim(y_hat_posterior_test), c(20, 10))

# Check that the pre-aggregated predictions match with those computed by rowMeans
pred_mean <- predict(
bart_model,
X = X_test,
rfx_group_ids = rfx_group_ids_test,
type = "mean"
)
y_hat_mean_test <- pred_mean$y_hat
expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test))

# Check that we warn and return a NULL when requesting terms that weren't fit
expect_warning({
pred_mean <- predict(
bart_model,
X = X_test,
rfx_group_ids = rfx_group_ids_test,
type = "mean",
terms = c("variance_forest")
)
})
expect_equal(NULL, pred_mean)
})

test_that("BCF predictions with pre-summarization", {
# Generate data and test-train split
n <- 100
Expand Down Expand Up @@ -443,3 +513,171 @@ test_that("BCF predictions with pre-summarization", {
expect_equal(y_hat_mean_test, y_hat_mean_test_single_term)
expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term)
})

test_that("BCF predictions with random effects", {
# Generate data and test-train split
n <- 100
g <- function(x) {
ifelse(x[, 5] == 1, 2, ifelse(x[, 5] == 2, -1, -4))
}
x1 <- rnorm(n)
x2 <- rnorm(n)
x3 <- rnorm(n)
x4 <- as.numeric(rbinom(n, 1, 0.5))
x5 <- as.numeric(sample(1:3, n, replace = TRUE))
X <- cbind(x1, x2, x3, x4, x5)
p <- ncol(X)
mu_x <- 1 + g(X) + X[, 1] * X[, 3]
tau_x <- 1 + 2 * X[, 2] * X[, 4]
pi_x <- 0.8 *
pnorm((3 * mu_x / sd(mu_x)) - 0.5 * X[, 1]) +
0.05 +
runif(n) / 10
Z <- rbinom(n, 1, pi_x)
E_XZ <- mu_x + Z * tau_x
rfx_group_ids <- sample(1:3, n, replace = TRUE)
rfx_basis <- cbind(1, Z)
rfx_coefs <- matrix(
c(
-2,
-0.5,
0,
0.0,
2,
0.5
),
byrow = T,
ncol = 2
)
rfx_term <- rowSums(rfx_basis * rfx_coefs[rfx_group_ids, ])
snr <- 2
y <- E_XZ + rfx_term + rnorm(n, 0, 1) * (sd(E_XZ + rfx_term) / snr)
X <- as.data.frame(X)
X$x4 <- factor(X$x4, ordered = TRUE)
X$x5 <- factor(X$x5, ordered = TRUE)
test_set_pct <- 0.2
n_test <- round(test_set_pct * n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds, ]
X_train <- X[train_inds, ]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
rfx_group_ids_test <- rfx_group_ids[test_inds]
rfx_group_ids_train <- rfx_group_ids[train_inds]
rfx_basis_test <- rfx_basis[test_inds, ]
rfx_basis_train <- rfx_basis[train_inds, ]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]

# Fit a BCF model with random intercept and random slope on Z
rfx_params = list(model_spec = "intercept_plus_treatment")
bcf_model <- bcf(
X_train = X_train,
Z_train = Z_train,
y_train = y_train,
propensity_train = pi_train,
rfx_group_ids_train = rfx_group_ids_train,
X_test = X_test,
Z_test = Z_test,
propensity_test = pi_test,
rfx_group_ids_test = rfx_group_ids_test,
random_effects_params = rfx_params,
num_gfr = 10,
num_burnin = 0,
num_mcmc = 10
)

# Check that the default predict method returns a list
pred <- predict(
bcf_model,
X = X_test,
Z = Z_test,
propensity = pi_test,
rfx_group_ids = rfx_group_ids_test
)
y_hat_posterior_test <- pred$y_hat
expect_equal(dim(y_hat_posterior_test), c(20, 10))

# Check that the pre-aggregated predictions match with those computed by rowMeans
pred_mean <- predict(
bcf_model,
X = X_test,
Z = Z_test,
propensity = pi_test,
rfx_group_ids = rfx_group_ids_test,
type = "mean"
)
y_hat_mean_test <- pred_mean$y_hat
expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test))

# Check that we warn and return a NULL when requesting terms that weren't fit
expect_warning({
pred_mean <- predict(
bcf_model,
X = X_test,
Z = Z_test,
propensity = pi_test,
type = "mean",
terms = c("variance_forest")
)
})
expect_equal(NULL, pred_mean)

# Fit a BCF model with random intercept only
# Fit a BCF model with random intercept and random slope on Z
rfx_params = list(model_spec = "intercept_only")
bcf_model <- bcf(
X_train = X_train,
Z_train = Z_train,
y_train = y_train,
propensity_train = pi_train,
rfx_group_ids_train = rfx_group_ids_train,
X_test = X_test,
Z_test = Z_test,
propensity_test = pi_test,
rfx_group_ids_test = rfx_group_ids_test,
random_effects_params = rfx_params,
num_gfr = 10,
num_burnin = 0,
num_mcmc = 10
)

# Check that the default predict method returns a list
pred <- predict(
bcf_model,
X = X_test,
Z = Z_test,
propensity = pi_test,
rfx_group_ids = rfx_group_ids_test
)
y_hat_posterior_test <- pred$y_hat
expect_equal(dim(y_hat_posterior_test), c(20, 10))

# Check that the pre-aggregated predictions match with those computed by rowMeans
pred_mean <- predict(
bcf_model,
X = X_test,
Z = Z_test,
propensity = pi_test,
rfx_group_ids = rfx_group_ids_test,
type = "mean"
)
y_hat_mean_test <- pred_mean$y_hat
expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test))

# Check that we warn and return a NULL when requesting terms that weren't fit
expect_warning({
pred_mean <- predict(
bcf_model,
X = X_test,
Z = Z_test,
propensity = pi_test,
type = "mean",
terms = c("variance_forest")
)
})
})
108 changes: 108 additions & 0 deletions test/python/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,45 @@ def test_bart_prediction(self):
sigma2_hat_mean_test, sigma2_hat_mean_test_single_term
)

# Generate data with random effects
rfx_group_ids = rng.choice(3, size=n)
rfx_basis = np.ones((n, 1))
rfx_coefs = np.array([-2.0, 0.0, 2.0])
rfx_term = rfx_coefs[rfx_group_ids]
noise_sd = 1
y = f_XW + rfx_term + rng.normal(0, noise_sd, size=n)
test_set_pct = 0.2
train_inds, test_inds = train_test_split(
np.arange(n), test_size=test_set_pct, random_state=1234
)
X_train = X[train_inds, :]
X_test = X[test_inds, :]
rfx_group_ids_train = rfx_group_ids[train_inds]
rfx_group_ids_test = rfx_group_ids[test_inds]
rfx_basis_train = rfx_basis[train_inds,:]
rfx_basis_test = rfx_basis[test_inds,:]
y_train = y[train_inds]
y_test = y[test_inds]

# Fit a BART model with random intercepts
rfx_params = {"model_spec": "intercept_only"}
bart_model = BARTModel()
bart_model.sample(
X_train=X_train, y_train=y_train, rfx_group_ids_train=rfx_group_ids_train, random_effects_params=rfx_params, num_gfr=10, num_burnin=0, num_mcmc=10
)

# Check that the default predict method returns a dictionary
pred = bart_model.predict(X=X_test, rfx_group_ids=rfx_group_ids_test)
y_hat_posterior_test = pred["y_hat"]
assert y_hat_posterior_test.shape == (20, 10)

# Check that the pre-aggregated predictions match with those computed by np.mean
pred_mean = bart_model.predict(X=X_test, rfx_group_ids=rfx_group_ids_test, type="mean")
y_hat_mean_test = pred_mean["y_hat"]
np.testing.assert_almost_equal(
y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1)
)

def test_bcf_prediction(self):
# Generate data and test/train split
rng = np.random.default_rng(1234)
Expand Down Expand Up @@ -417,3 +456,72 @@ def g(x5):
np.testing.assert_almost_equal(
sigma2_hat_mean_test, sigma2_hat_mean_test_single_term
)

# Generate data with random effects
rfx_group_ids = rng.choice(3, size=n)
rfx_basis = np.concatenate((np.ones((n, 1)), np.expand_dims(Z, 1)), axis=1)
rfx_coefs = np.array([[-2.0, -0.5], [0.0, 0.0], [2.0, 0.5]])
rfx_term = np.multiply(rfx_coefs[rfx_group_ids,:], rfx_basis).sum(axis=1)
E_XZ = mu_x + tau_x * Z + rfx_term
snr = 2
y = E_XZ + rng.normal(loc=0.0, scale=np.std(E_XZ) / snr, size=(n,))
test_set_pct = 0.2
train_inds, test_inds = train_test_split(
np.arange(n), test_size=test_set_pct, random_state=1234
)
X_train = X.iloc[train_inds, :]
X_test = X.iloc[test_inds, :]
Z_train = Z[train_inds]
Z_test = Z[test_inds]
pi_x_train = pi_x[train_inds]
pi_x_test = pi_x[test_inds]
rfx_group_ids_train = rfx_group_ids[train_inds]
rfx_group_ids_test = rfx_group_ids[test_inds]
rfx_basis_train = rfx_basis[train_inds,:]
rfx_basis_test = rfx_basis[test_inds,:]
y_train = y[train_inds]
y_test = y[test_inds]

# Fit a "classic" BCF model
rfx_params = {"model_spec": "intercept_only"}
bcf_model = BCFModel()
bcf_model.sample(
X_train=X_train,
Z_train=Z_train,
y_train=y_train,
propensity_train=pi_x_train,
rfx_group_ids_train=rfx_group_ids_train,
X_test=X_test,
Z_test=Z_test,
propensity_test=pi_x_test,
rfx_group_ids_test=rfx_group_ids_test,
random_effects_params=rfx_params,
num_gfr=10,
num_burnin=0,
num_mcmc=10,
)

# Check that the default predict method returns a dictionary
pred = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_x_test, rfx_group_ids=rfx_group_ids_test)
y_hat_posterior_test = pred["y_hat"]
assert y_hat_posterior_test.shape == (20, 10)

# Check that the pre-aggregated predictions match with those computed by np.mean
pred_mean = bcf_model.predict(
X=X_test, Z=Z_test, propensity=pi_x_test, rfx_group_ids=rfx_group_ids_test, type="mean"
)
y_hat_mean_test = pred_mean["y_hat"]
np.testing.assert_almost_equal(
y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1)
)

# Check that we warn and return None when requesting terms that weren't fit
with pytest.warns(UserWarning):
pred_mean = bcf_model.predict(
X=X_test,
Z=Z_test,
propensity=pi_x_test,
rfx_group_ids=rfx_group_ids_test,
type="mean",
terms=["variance_forest"],
)
Loading