Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 106 additions & 20 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -2646,7 +2646,7 @@ bcf <- function(
#' that were not in the training set.
#' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. If the model was sampled with a random effects `model_spec` of "intercept_only" or "intercept_plus_treatment", this is optional, but if it is provided, it will be used.
#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".
#' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all".
#' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "mu", "cate", "tau", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If a model has random effects fit with either "intercept_only" or "intercept_plus_treatment" model_spec, then "prognostic_function" refers to the predictions of the prognostic forest plus the random intercept and "cate" refers to the predictions of the treatment effect forest plus the random slope on the treatment variable. For these models, the forest predictions alone can be requested via "mu" (prognostic forest) and "tau" (treatment effect forest). In all other cases, "mu" will return exactly the same result as "prognostic_function" and "tau" will return exactly the same result as "cate". If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all".
#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
#' @param ... (Optional) Other prediction parameters.
#'
Expand Down Expand Up @@ -2735,14 +2735,53 @@ predict.bcfmodel <- function(
}
predict_mean <- type == "mean"

# Warn users about CATE / prognostic function when rfx_model_spec is "custom"
if (object$model_params$has_rfx) {
if (object$model_params$rfx_model_spec == "custom") {
if (("prognostic_function" %in% terms) || ("cate" %in% terms)) {
warning(paste0(
"This BCF model was fit with a custom random effects model specification (i.e. a user-provided basis). ",
"As a result, 'prognostic_function' and 'cate' refer only to the prognostic ('mu') ",
"and treatment effect 'tau' forests, respectively, and do not include any random ",
"effects contributions. If your user-provided random effects basis includes a random intercept or a ",
"random slope on the treatment variable, you will need to compute the prognostic or CATE functions manually by predicting ",
"'yhat' for different covariate and rfx_basis values."
))
}
}
}

# Handle prediction terms
rfx_model_spec = object$model_params$rfx_model_spec
rfx_intercept_only <- rfx_model_spec == "intercept_only"
rfx_intercept_plus_treatment <- (rfx_model_spec == "intercept_plus_treatment")
rfx_intercept_plus_treatment <- rfx_model_spec == "intercept_plus_treatment"
rfx_intercept <- rfx_intercept_only || rfx_intercept_plus_treatment
mu_prog_separate <- ifelse(rfx_intercept, TRUE, FALSE)
tau_cate_separate <- ifelse(rfx_intercept_plus_treatment, TRUE, FALSE)
if (!is.character(terms)) {
stop("type must be a string or character vector")
}
for (term in terms) {
if (
!(term %in%
c(
"y_hat",
"prognostic_function",
"mu",
"cate",
"tau",
"rfx",
"variance_forest",
"all"
))
) {
stop(paste0(
"Term '",
term,
"' was requested. Valid terms are 'y_hat', 'prognostic_function', 'mu', 'cate', 'tau', 'rfx', 'variance_forest', and 'all'."
))
}
}
num_terms <- length(terms)
has_mu_forest <- T
has_tau_forest <- T
Expand All @@ -2751,10 +2790,14 @@ predict.bcfmodel <- function(
has_y_hat <- T
predict_y_hat <- (((has_y_hat) && ("y_hat" %in% terms)) ||
((has_y_hat) && ("all" %in% terms)))
predict_mu_forest <- (((has_mu_forest) &&
predict_mu_forest <- (((has_mu_forest) && ("all" %in% terms)) ||
((has_mu_forest) && ("mu" %in% terms)))
predict_tau_forest <- (((has_tau_forest) && ("tau" %in% terms)) ||
((has_tau_forest) && ("all" %in% terms)))
predict_prog_function <- (((has_mu_forest) &&
("prognostic_function" %in% terms)) ||
((has_mu_forest) && ("all" %in% terms)))
predict_tau_forest <- (((has_tau_forest) && ("cate" %in% terms)) ||
predict_cate_function <- (((has_tau_forest) && ("cate" %in% terms)) ||
((has_tau_forest) && ("all" %in% terms)))
predict_rfx <- (((has_rfx) && ("rfx" %in% terms)) ||
((has_rfx) && ("all" %in% terms)))
Expand All @@ -2764,7 +2807,9 @@ predict.bcfmodel <- function(
predict_count <- sum(c(
predict_y_hat,
predict_mu_forest,
predict_prog_function,
predict_tau_forest,
predict_cate_function,
predict_rfx,
predict_variance_forest
))
Expand All @@ -2777,10 +2822,13 @@ predict.bcfmodel <- function(
return(NULL)
}
predict_rfx_intermediate <- (predict_y_hat && has_rfx)
predict_rfx_raw <- ((predict_mu_forest && has_rfx && rfx_intercept) ||
(predict_tau_forest && has_rfx && rfx_intercept_plus_treatment))
predict_mu_forest_intermediate <- (predict_y_hat && has_mu_forest)
predict_tau_forest_intermediate <- (predict_y_hat && has_tau_forest)
predict_rfx_raw <- ((predict_prog_function && has_rfx && rfx_intercept) ||
(predict_cate_function && has_rfx && rfx_intercept_plus_treatment))
predict_mu_forest_intermediate <- ((predict_y_hat || predict_prog_function) &&
has_mu_forest)
predict_tau_forest_intermediate <- ((predict_y_hat ||
predict_cate_function) &&
has_tau_forest)

# Make sure covariates are matrix or data frame
if ((!is.data.frame(X)) && (!is.matrix(X))) {
Expand Down Expand Up @@ -2983,26 +3031,28 @@ predict.bcfmodel <- function(
}

# Add raw RFX predictions to mu and tau if warranted by the RFX model spec
if (predict_mu_forest || predict_mu_forest_intermediate) {
if (rfx_intercept && predict_rfx_raw) {
mu_hat_final <- mu_hat_forest + rfx_predictions_raw[, 1, ]
if (predict_prog_function) {
if (mu_prog_separate) {
prognostic_function <- mu_hat_forest + rfx_predictions_raw[, 1, ]
} else {
mu_hat_final <- mu_hat_forest
prognostic_function <- mu_hat_forest
}
}
if (predict_tau_forest || predict_tau_forest_intermediate) {
if (rfx_intercept_plus_treatment && predict_rfx_raw) {
tau_hat_final <- (tau_hat_forest +
if (predict_cate_function) {
if (tau_cate_separate) {
cate <- (tau_hat_forest +
rfx_predictions_raw[, 2:ncol(rfx_basis), ])
} else {
tau_hat_final <- tau_hat_forest
cate <- tau_hat_forest
}
}

# Combine into y hat predictions
needs_mean_term_preds <- predict_y_hat ||
predict_mu_forest ||
predict_tau_forest ||
predict_prog_function ||
predict_cate_function ||
predict_rfx
if (needs_mean_term_preds) {
if (probability_scale) {
Expand All @@ -3019,10 +3069,16 @@ predict.bcfmodel <- function(
}
}
if (predict_mu_forest) {
mu_hat <- pnorm(mu_hat_final)
mu_hat <- pnorm(mu_hat_forest)
}
if (predict_tau_forest) {
tau_hat <- pnorm(tau_hat_final)
tau_hat <- pnorm(tau_hat_forest)
}
if (predict_prog_function) {
prognostic_function <- pnorm(prognostic_function)
}
if (predict_cate_function) {
cate <- pnorm(cate)
}
} else {
if (has_rfx) {
Expand All @@ -3035,10 +3091,16 @@ predict.bcfmodel <- function(
}
}
if (predict_mu_forest) {
mu_hat <- mu_hat_final
mu_hat <- mu_hat_forest
}
if (predict_tau_forest) {
tau_hat <- tau_hat_final
tau_hat <- tau_hat_forest
}
if (predict_prog_function) {
prognostic_function <- prognostic_function
}
if (predict_cate_function) {
cate <- cate
}
}
}
Expand All @@ -3055,6 +3117,16 @@ predict.bcfmodel <- function(
tau_hat <- rowMeans(tau_hat)
}
}
if (predict_prog_function) {
prognostic_function <- rowMeans(prognostic_function)
}
if (predict_cate_function) {
if (object$model_params$multivariate_treatment) {
cate <- apply(cate, c(1, 2), mean)
} else {
cate <- rowMeans(cate)
}
}
if (predict_rfx) {
rfx_predictions <- rowMeans(rfx_predictions)
}
Expand All @@ -3071,6 +3143,10 @@ predict.bcfmodel <- function(
return(mu_hat)
} else if (predict_tau_forest) {
return(tau_hat)
} else if (predict_prog_function) {
return(prognostic_function)
} else if (predict_cate_function) {
return(cate)
} else if (predict_rfx) {
return(rfx_predictions)
} else if (predict_variance_forest) {
Expand All @@ -3093,6 +3169,16 @@ predict.bcfmodel <- function(
} else {
result[["tau_hat"]] <- NULL
}
if (predict_prog_function) {
result[["prognostic_function"]] = prognostic_function
} else {
result[["prognostic_function"]] <- NULL
}
if (predict_cate_function) {
result[["cate"]] = cate
} else {
result[["cate"]] <- NULL
}
if (predict_rfx) {
result[["rfx_predictions"]] = rfx_predictions
} else {
Expand Down
85 changes: 64 additions & 21 deletions R/posterior_transformation.R
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,7 @@ posterior_predictive_heuristic_multiplier <- function(
#'
#' This function computes posterior credible intervals for specified terms from a fitted BCF model. It supports intervals for prognostic forests, CATE forests, variance forests, random effects, and overall mean outcome predictions.
#' @param model_object A fitted BCF model object of class `bcfmodel`.
#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"cate"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`.
#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Note that `"mu"` is only different from `"prognostic_function"` if random effects are included with a model spec of `"intercept_only"` or `"intercept_plus_treatment"` and `"tau"` is only different from `"cate"` if random effects are included with a model spec of `"intercept_plus_treatment"`.
#' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval).
#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
#' @param covariates (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).
Expand Down Expand Up @@ -895,6 +895,29 @@ compute_bcf_posterior_interval <- function(
}

# Check that all the necessary inputs were provided for interval computation
for (term in terms) {
if (
!(term %in%
c(
"prognostic_function",
"mu",
"cate",
"tau",
"variance_forest",
"rfx",
"y_hat",
"all"
))
) {
stop(
paste0(
"Term '",
term,
"' was requested. Valid terms are 'prognostic_function', 'mu', 'cate', 'tau', 'variance_forest', 'rfx', 'y_hat', and 'all'."
)
)
}
}
needs_covariates_intermediate <- ((("y_hat" %in% terms) ||
("all" %in% terms)))
needs_covariates <- (("prognostic_function" %in% terms) ||
Expand Down Expand Up @@ -975,16 +998,22 @@ compute_bcf_posterior_interval <- function(
"'rfx_group_ids' must have the same length as the number of rows in 'covariates'"
)
}
if (is.null(rfx_basis)) {
stop(
"'rfx_basis' must be provided in order to compute the requested intervals"
)
}
if (!is.matrix(rfx_basis)) {
stop("'rfx_basis' must be a matrix")

if (model_object$model_params$rfx_model_spec == "custom") {
if (is.null(rfx_basis)) {
stop(
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
)
}
}
if (nrow(rfx_basis) != nrow(covariates)) {
stop("'rfx_basis' must have the same number of rows as 'covariates'")

if (!is.null(rfx_basis)) {
if (!is.matrix(rfx_basis)) {
stop("'rfx_basis' must be a matrix")
}
if (nrow(rfx_basis) != nrow(covariates)) {
stop("'rfx_basis' must have the same number of rows as 'covariates'")
}
}
}

Expand All @@ -1006,11 +1035,15 @@ compute_bcf_posterior_interval <- function(
if (has_multiple_terms) {
result <- list()
for (term_name in names(predictions)) {
result[[term_name]] <- summarize_interval(
predictions[[term_name]],
sample_dim = 2,
level = level
)
if (!is.null(predictions[[term_name]])) {
result[[term_name]] <- summarize_interval(
predictions[[term_name]],
sample_dim = 2,
level = level
)
} else {
result[[term_name]] <- NULL
}
}
return(result)
} else {
Expand Down Expand Up @@ -1161,11 +1194,15 @@ compute_bart_posterior_interval <- function(
if (has_multiple_terms) {
result <- list()
for (term_name in names(predictions)) {
result[[term_name]] <- summarize_interval(
predictions[[term_name]],
sample_dim = 2,
level = level
)
if (!is.null(predictions[[term_name]])) {
result[[term_name]] <- summarize_interval(
predictions[[term_name]],
sample_dim = 2,
level = level
)
} else {
result[[term_name]] <- NULL
}
}
return(result)
} else {
Expand Down Expand Up @@ -1253,8 +1290,12 @@ bart_model_has_term <- function(model_object, term) {
bcf_model_has_term <- function(model_object, term) {
if (term == "prognostic_function") {
return(TRUE)
} else if (term == "mu") {
return(TRUE)
} else if (term == "cate") {
return(TRUE)
} else if (term == "tau") {
return(TRUE)
} else if (term == "variance_forest") {
return(model_object$model_params$include_variance_forest)
} else if (term == "rfx") {
Expand All @@ -1280,15 +1321,17 @@ validate_bart_term <- function(term) {
validate_bcf_term <- function(term) {
model_terms <- c(
"prognostic_function",
"mu",
"cate",
"tau",
"variance_forest",
"rfx",
"y_hat",
"all"
)
if (!(term %in% model_terms)) {
stop(
"'term' must be one of 'prognostic_function', 'cate', 'variance_forest', 'rfx', 'y_hat', or 'all' for bcfmodel objects"
"'term' must be one of 'prognostic_function', 'mu', 'cate', 'tau', 'variance_forest', 'rfx', 'y_hat', or 'all' for bcfmodel objects"
)
}
}
Loading
Loading