From a18f053adeebaa71aeed05aa9f7b6bc815966d0b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 7 Nov 2025 02:05:47 -0500 Subject: [PATCH 1/5] Updated BCF predict functions to split out tau and mu from cate and prognostic function when there are certain random effects terms --- R/bcf.R | 110 ++++++++++++++++++++++----- demo/debug/bcf_predict_debug.py | 86 +++++++++++++++++++++ man/predict.bcfmodel.Rd | 2 +- stochtree/bcf.py | 123 +++++++++++++++++++++++-------- tools/debug/bcf_predict_debug.R | 127 ++++++++++++++++++++++++++++++++ 5 files changed, 398 insertions(+), 50 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index 12e63791..c91ed7e8 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -2646,7 +2646,7 @@ bcf <- function( #' that were not in the training set. #' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. If the model was sampled with a random effects `model_spec` of "intercept_only" or "intercept_plus_treatment", this is optional, but if it is provided, it will be used. #' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". -#' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". +#' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "mu", "cate", "tau", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If a model has random effects fit with either "intercept_only" or "intercept_plus_treatment" model_spec, then "prognostic_function" refers to the predictions of the prognostic forest plus the random intercept and "cate" refers to the predictions of the treatment effect forest plus the random slope on the treatment variable. For these models, the forest predictions alone can be requested via "mu" (prognostic forest) and "tau" (treatment effect forest). In all other cases, "mu" will return exactly the same result as "prognostic_function" and "tau" will return exactly the same result as "cate". If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". #' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". #' @param ... (Optional) Other prediction parameters. #' @@ -2738,11 +2738,34 @@ predict.bcfmodel <- function( # Handle prediction terms rfx_model_spec = object$model_params$rfx_model_spec rfx_intercept_only <- rfx_model_spec == "intercept_only" - rfx_intercept_plus_treatment <- (rfx_model_spec == "intercept_plus_treatment") + rfx_intercept_plus_treatment <- rfx_model_spec == "intercept_plus_treatment" rfx_intercept <- rfx_intercept_only || rfx_intercept_plus_treatment + mu_prog_separate <- ifelse(rfx_intercept, TRUE, FALSE) + tau_cate_separate <- ifelse(rfx_intercept_plus_treatment, TRUE, FALSE) if (!is.character(terms)) { stop("type must be a string or character vector") } + for (term in terms) { + if ( + !(term %in% + c( + "y_hat", + "prognostic_function", + "mu", + "cate", + "tau", + "rfx", + "variance_forest", + "all" + )) + ) { + stop(paste0( + "Term '", + term, + "' was requested. Valid terms are 'y_hat', 'prognostic_function', 'mu', 'cate', 'tau', 'rfx', 'variance_forest', and 'all'." + )) + } + } num_terms <- length(terms) has_mu_forest <- T has_tau_forest <- T @@ -2751,10 +2774,14 @@ predict.bcfmodel <- function( has_y_hat <- T predict_y_hat <- (((has_y_hat) && ("y_hat" %in% terms)) || ((has_y_hat) && ("all" %in% terms))) - predict_mu_forest <- (((has_mu_forest) && + predict_mu_forest <- (((has_mu_forest) && ("all" %in% terms)) || + ((has_mu_forest) && ("mu" %in% terms))) + predict_tau_forest <- (((has_tau_forest) && ("tau" %in% terms)) || + ((has_tau_forest) && ("all" %in% terms))) + predict_prog_function <- (((has_mu_forest) && ("prognostic_function" %in% terms)) || ((has_mu_forest) && ("all" %in% terms))) - predict_tau_forest <- (((has_tau_forest) && ("cate" %in% terms)) || + predict_cate_function <- (((has_tau_forest) && ("cate" %in% terms)) || ((has_tau_forest) && ("all" %in% terms))) predict_rfx <- (((has_rfx) && ("rfx" %in% terms)) || ((has_rfx) && ("all" %in% terms))) @@ -2764,7 +2791,9 @@ predict.bcfmodel <- function( predict_count <- sum(c( predict_y_hat, predict_mu_forest, + predict_prog_function, predict_tau_forest, + predict_cate_function, predict_rfx, predict_variance_forest )) @@ -2777,10 +2806,13 @@ predict.bcfmodel <- function( return(NULL) } predict_rfx_intermediate <- (predict_y_hat && has_rfx) - predict_rfx_raw <- ((predict_mu_forest && has_rfx && rfx_intercept) || - (predict_tau_forest && has_rfx && rfx_intercept_plus_treatment)) - predict_mu_forest_intermediate <- (predict_y_hat && has_mu_forest) - predict_tau_forest_intermediate <- (predict_y_hat && has_tau_forest) + predict_rfx_raw <- ((predict_prog_function && has_rfx && rfx_intercept) || + (predict_cate_function && has_rfx && rfx_intercept_plus_treatment)) + predict_mu_forest_intermediate <- ((predict_y_hat || predict_prog_function) && + has_mu_forest) + predict_tau_forest_intermediate <- ((predict_y_hat || + predict_cate_function) && + has_tau_forest) # Make sure covariates are matrix or data frame if ((!is.data.frame(X)) && (!is.matrix(X))) { @@ -2983,19 +3015,19 @@ predict.bcfmodel <- function( } # Add raw RFX predictions to mu and tau if warranted by the RFX model spec - if (predict_mu_forest || predict_mu_forest_intermediate) { - if (rfx_intercept && predict_rfx_raw) { - mu_hat_final <- mu_hat_forest + rfx_predictions_raw[, 1, ] + if (predict_prog_function) { + if (mu_prog_separate) { + prognostic_function <- mu_hat_forest + rfx_predictions_raw[, 1, ] } else { - mu_hat_final <- mu_hat_forest + prognostic_function <- mu_hat_forest } } - if (predict_tau_forest || predict_tau_forest_intermediate) { - if (rfx_intercept_plus_treatment && predict_rfx_raw) { - tau_hat_final <- (tau_hat_forest + + if (predict_cate_function) { + if (tau_cate_separate) { + cate <- (tau_hat_forest + rfx_predictions_raw[, 2:ncol(rfx_basis), ]) } else { - tau_hat_final <- tau_hat_forest + cate <- tau_hat_forest } } @@ -3003,6 +3035,8 @@ predict.bcfmodel <- function( needs_mean_term_preds <- predict_y_hat || predict_mu_forest || predict_tau_forest || + predict_prog_function || + predict_cate_function || predict_rfx if (needs_mean_term_preds) { if (probability_scale) { @@ -3019,10 +3053,16 @@ predict.bcfmodel <- function( } } if (predict_mu_forest) { - mu_hat <- pnorm(mu_hat_final) + mu_hat <- pnorm(mu_hat_forest) } if (predict_tau_forest) { - tau_hat <- pnorm(tau_hat_final) + tau_hat <- pnorm(tau_hat_forest) + } + if (predict_prog_function) { + prognostic_function <- pnorm(prognostic_function) + } + if (predict_cate_function) { + cate <- pnorm(cate) } } else { if (has_rfx) { @@ -3035,10 +3075,16 @@ predict.bcfmodel <- function( } } if (predict_mu_forest) { - mu_hat <- mu_hat_final + mu_hat <- mu_hat_forest } if (predict_tau_forest) { - tau_hat <- tau_hat_final + tau_hat <- tau_hat_forest + } + if (predict_prog_function) { + prognostic_function <- prognostic_function + } + if (predict_cate_function) { + cate <- cate } } } @@ -3055,6 +3101,16 @@ predict.bcfmodel <- function( tau_hat <- rowMeans(tau_hat) } } + if (predict_prog_function) { + prognostic_function <- rowMeans(prognostic_function) + } + if (predict_cate_function) { + if (object$model_params$multivariate_treatment) { + cate <- apply(cate, c(1, 2), mean) + } else { + cate <- rowMeans(cate) + } + } if (predict_rfx) { rfx_predictions <- rowMeans(rfx_predictions) } @@ -3071,6 +3127,10 @@ predict.bcfmodel <- function( return(mu_hat) } else if (predict_tau_forest) { return(tau_hat) + } else if (predict_prog_function) { + return(prognostic_function) + } else if (predict_cate_function) { + return(cate) } else if (predict_rfx) { return(rfx_predictions) } else if (predict_variance_forest) { @@ -3093,6 +3153,16 @@ predict.bcfmodel <- function( } else { result[["tau_hat"]] <- NULL } + if (predict_prog_function) { + result[["prognostic_function"]] = prognostic_function + } else { + result[["prognostic_function"]] <- NULL + } + if (predict_cate_function) { + result[["cate"]] = cate + } else { + result[["cate"]] <- NULL + } if (predict_rfx) { result[["rfx_predictions"]] = rfx_predictions } else { diff --git a/demo/debug/bcf_predict_debug.py b/demo/debug/bcf_predict_debug.py index 2257684a..ec787ad6 100644 --- a/demo/debug/bcf_predict_debug.py +++ b/demo/debug/bcf_predict_debug.py @@ -137,3 +137,89 @@ (ppd_intervals[0, :] <= y_test) & (y_test <= ppd_intervals[1, :]) ) print(f"Coverage of 95% posterior predictive interval for Y: {ppd_coverage:.3f}") + +# Generate data with random effects +X = rng.normal(loc=0.0, scale=1.0, size=(n, p)) +mu_X = X[:, 0] +tau_X = 0.25 * X[:, 1] +pi_X = norm.cdf(0.5 * X[:, 1]) +Z = rng.binomial(n=1, p=pi_X, size=(n,)) +rfx_group_ids = rng.choice(a=3, size=(n,)) +rfx_basis = np.concatenate((np.ones((n, 1)), np.expand_dims(Z, 1)), axis=1) +rfx_coefs = np.array([[-2.0, -0.5], [0.0, 0.0], [2.0, 0.5]]) +rfx_term = np.sum(rfx_coefs[rfx_group_ids, :] * rfx_basis, axis=1) +E_XZ = mu_X + tau_X * Z + rfx_term +snr = 2.0 +noise_sd = np.std(E_XZ) / snr +y = E_XZ + rng.normal(loc=0.0, scale=noise_sd, size=(n,)) + +# Train-test split +sample_inds = np.arange(n) +test_set_pct = 0.2 +train_inds, test_inds = train_test_split(sample_inds, test_size=test_set_pct) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +Z_train = Z[train_inds] +Z_test = Z[test_inds] +pi_train = pi_X[train_inds] +pi_test = pi_X[test_inds] +tau_train = tau_X[train_inds] +tau_test = tau_X[test_inds] +mu_train = mu_X[train_inds] +mu_test = mu_X[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] +E_XZ_train = E_XZ[train_inds] +E_XZ_test = E_XZ[test_inds] +rfx_group_ids_train = rfx_group_ids[train_inds] +rfx_group_ids_test = rfx_group_ids[test_inds] +rfx_basis_train = rfx_basis[train_inds, :] +rfx_basis_test = rfx_basis[test_inds, :] + +# Fit simple BCF model +rfx_params = {"model_spec": "intercept_plus_treatment"} +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + pi_train=pi_train, + y_train=y_train, + rfx_group_ids_train=rfx_group_ids_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, + random_effects_params=rfx_params +) + +# Check several predict approaches +bcf_preds = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test, rfx_group_ids=rfx_group_ids_test) + +# Check that mu + tau + rfx = prognostic + cate +np.allclose( + (bcf_preds["mu_hat"] + + np.multiply(bcf_preds["tau_hat"], np.expand_dims(Z_test, 1)) + + bcf_preds["rfx_predictions"]), + (bcf_preds["prognostic_function"] + + np.multiply(bcf_preds["cate"], np.expand_dims(Z_test, 1))), + atol=1e-4 +) + +# Retrieve just prognostic predictions +prog_fn_test = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, + rfx_group_ids=rfx_group_ids_test, + terms = "prognostic_function" +) + +# Compare to prognostic function returned from the larger prediction +np.allclose(prog_fn_test, bcf_preds["prognostic_function"], atol=1e-4) + +# Retrieve just prognostic predictions +mu_hat_test = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, + rfx_group_ids=rfx_group_ids_test, + terms = "mu" +) + +# Compare to prognostic function returned from the larger prediction +np.allclose(mu_hat_test, bcf_preds["mu_hat"], atol=1e-4) diff --git a/man/predict.bcfmodel.Rd b/man/predict.bcfmodel.Rd index bda63aa5..e2bd4c27 100644 --- a/man/predict.bcfmodel.Rd +++ b/man/predict.bcfmodel.Rd @@ -34,7 +34,7 @@ that were not in the training set.} \item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".} -\item{terms}{(Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return \code{NULL} along with a warning. Default: "all".} +\item{terms}{(Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "mu", "cate", "tau", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If a model has random effects fit with either "intercept_only" or "intercept_plus_treatment" model_spec, then "prognostic_function" refers to the predictions of the prognostic forest plus the random intercept and "cate" refers to the predictions of the treatment effect forest plus the random slope on the treatment variable. For these models, the forest predictions alone can be requested via "mu" (prognostic forest) and "tau" (treatment effect forest). In all other cases, "mu" will return exactly the same result as "prognostic_function" and "tau" will return exactly the same result as "cate". If none of the requested terms are present in a model, this function will return \code{NULL} along with a warning. Default: "all".} \item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 2372873b..d0eafdea 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -2321,7 +2321,7 @@ def predict( type : str, optional Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". terms : str, optional - Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". + Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "mu", "cate", "tau", "rfx", "variance_forest", or "all". If a model has random effects fit with either "intercept_only" or "intercept_plus_treatment" model_spec, then "prognostic_function" refers to the predictions of the prognostic forest plus the random intercept and "cate" refers to the predictions of the treatment effect forest plus the random slope on the treatment variable. For these models, the forest predictions alone can be requested via "mu" (prognostic forest) and "tau" (treatment effect forest). In all other cases, "mu" will return exactly the same result as "prognostic_function" and "tau" will return exactly the same result as "cate". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". scale : str, optional Scale on which to return predictions. Options are "linear" (the default), which returns predictions on the original outcome scale, and "probit", which returns predictions on the probit (latent) scale. Only applicable for models fit with `probit_outcome_model=True`. @@ -2349,12 +2349,30 @@ def predict( predict_mean = type == "mean" # Handle prediction terms + if isinstance(terms, str): + terms = [terms] rfx_model_spec = self.rfx_model_spec rfx_intercept_only = rfx_model_spec == "intercept_only" rfx_intercept_plus_treatment = rfx_model_spec == "intercept_plus_treatment" rfx_intercept = rfx_intercept_only or rfx_intercept_plus_treatment + mu_prog_separate = rfx_intercept + tau_cate_separate = rfx_intercept_plus_treatment if not isinstance(terms, str) and not isinstance(terms, list): raise ValueError("type must be a string or list of strings") + for term in terms: + if term not in [ + "y_hat", + "prognostic_function", + "mu", + "cate", + "tau", + "rfx", + "variance_forest", + "all", + ]: + raise ValueError( + f"term '{term}' was requested. Valid terms are 'y_hat', 'prognostic_function', 'mu', 'cate', 'tau', 'rfx', 'variance_forest', and 'all'" + ) num_terms = 1 if isinstance(terms, str) else len(terms) has_mu_forest = True has_tau_forest = True @@ -2364,10 +2382,16 @@ def predict( predict_y_hat = (has_y_hat and ("y_hat" in terms)) or ( has_y_hat and ("all" in terms) ) - predict_mu_forest = (has_mu_forest and ("prognostic_function" in terms)) or ( + predict_mu_forest = (has_mu_forest and ("mu" in terms)) or ( + has_mu_forest and ("all" in terms) + ) + predict_tau_forest = (has_tau_forest and ("tau" in terms)) or ( + has_tau_forest and ("all" in terms) + ) + predict_prog_function = (has_mu_forest and ("prognostic_function" in terms)) or ( has_mu_forest and ("all" in terms) ) - predict_tau_forest = (has_tau_forest and ("cate" in terms)) or ( + predict_cate_function = (has_tau_forest and ("cate" in terms)) or ( has_tau_forest and ("all" in terms) ) predict_rfx = (has_rfx and ("rfx" in terms)) or (has_rfx and ("all" in terms)) @@ -2377,7 +2401,9 @@ def predict( predict_count = ( predict_y_hat + predict_mu_forest + + predict_prog_function + predict_tau_forest + + predict_cate_function + predict_rfx + predict_variance_forest ) @@ -2388,11 +2414,11 @@ def predict( ) return None predict_rfx_intermediate = predict_y_hat and has_rfx - predict_rfx_raw = (predict_mu_forest and has_rfx and rfx_intercept) or ( - predict_tau_forest and has_rfx and rfx_intercept_plus_treatment + predict_rfx_raw = (predict_prog_function and has_rfx and rfx_intercept) or ( + predict_cate_function and has_rfx and rfx_intercept_plus_treatment ) - predict_mu_forest_intermediate = predict_y_hat and has_mu_forest - predict_tau_forest_intermediate = predict_y_hat and has_tau_forest + predict_mu_forest_intermediate = (predict_y_hat or predict_prog_function) and has_mu_forest + predict_tau_forest_intermediate = (predict_y_hat or predict_cate_function) and has_tau_forest if not self.is_sampled(): msg = ( @@ -2512,17 +2538,29 @@ def predict( raise ValueError( "rfx_group_ids must be provided if rfx_basis is provided" ) - if rfx_basis is not None: - if rfx_basis.ndim == 1: - rfx_basis = np.expand_dims(rfx_basis, 1) - if rfx_basis.shape[0] != X.shape[0]: - raise ValueError( - "X and rfx_basis must have the same number of rows" - ) - if rfx_basis.shape[1] != self.num_rfx_basis: + + if self.rfx_model_spec == "custom": + if rfx_basis is None: raise ValueError( - "rfx_basis must have the same number of columns as the random effects basis used to sample this model" + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" ) + elif self.rfx_model_spec == "intercept_only": + if rfx_basis is None: + rfx_basis = np.ones(shape=(X.shape[0], 1)) + elif self.rfx_model_spec == "intercept_plus_treatment": + if rfx_basis is None: + rfx_basis = np.concatenate((np.ones(shape=(X.shape[0], 1)), Z), axis=1) + + if rfx_basis.ndim == 1: + rfx_basis = np.expand_dims(rfx_basis, 1) + if rfx_basis.shape[0] != X.shape[0]: + raise ValueError( + "X and rfx_basis must have the same number of rows" + ) + if rfx_basis.shape[1] != self.num_rfx_basis: + raise ValueError( + "rfx_basis must have the same number of columns as the random effects basis used to sample this model" + ) # Random effects predictions if predict_rfx or predict_rfx_intermediate: @@ -2557,20 +2595,20 @@ def predict( ) # Add raw RFX predictions to mu and tau if warranted by the RFX model spec - if predict_mu_forest or predict_mu_forest_intermediate: - if rfx_intercept and predict_rfx_raw: - mu_x = mu_x_forest + np.squeeze(rfx_predictions_raw[:, 0, :]) + if predict_prog_function: + if mu_prog_separate: + prognostic_function = mu_x_forest + np.squeeze(rfx_predictions_raw[:, 0, :]) else: - mu_x = mu_x_forest - if predict_tau_forest or predict_tau_forest_intermediate: - if rfx_intercept_plus_treatment and predict_rfx_raw: - tau_x = tau_x_forest + np.squeeze(rfx_predictions_raw[:, 1:, :]) + prognostic_function = mu_x_forest + if predict_cate_function: + if tau_cate_separate: + cate = tau_x_forest + np.squeeze(rfx_predictions_raw[:, 1:, :]) else: - tau_x = tau_x_forest + cate = tau_x_forest # Combine into y hat predictions needs_mean_term_preds = ( - predict_y_hat or predict_mu_forest or predict_tau_forest or predict_rfx + predict_y_hat or predict_mu_forest or predict_prog_function or predict_tau_forest or predict_cate_function or predict_rfx ) if needs_mean_term_preds: if probability_scale: @@ -2583,9 +2621,13 @@ def predict( if predict_y_hat: y_hat = norm.cdf(mu_x_forest + treatment_term) if predict_mu_forest: - mu_x = norm.cdf(mu_x) + mu_x = norm.cdf(mu_x_forest) if predict_tau_forest: - tau_x = norm.cdf(tau_x) + tau_x = norm.cdf(tau_x_forest) + if predict_prog_function: + prognostic_function = norm.cdf(prognostic_function) + if predict_cate_function: + cate = norm.cdf(cate) else: if has_rfx: if predict_y_hat: @@ -2594,9 +2636,13 @@ def predict( if predict_y_hat: y_hat = mu_x_forest + treatment_term if predict_mu_forest: - mu_x = mu_x + mu_x = mu_x_forest if predict_tau_forest: - tau_x = tau_x + tau_x = tau_x_forest + if predict_prog_function: + prognostic_function = prognostic_function + if predict_cate_function: + cate = cate # Collapse to posterior mean predictions if requested if predict_mean: @@ -2607,6 +2653,13 @@ def predict( tau_x = np.mean(tau_x, axis=2) else: tau_x = np.mean(tau_x, axis=1) + if predict_prog_function: + prognostic_function = np.mean(prognostic_function, axis=1) + if predict_cate_function: + if Z.shape[1] > 1: + cate = np.mean(cate, axis=2) + else: + cate = np.mean(cate, axis=1) if predict_rfx: rfx_preds = np.mean(rfx_preds, axis=1) if predict_y_hat: @@ -2617,8 +2670,12 @@ def predict( return y_hat elif predict_mu_forest: return mu_x + elif predict_prog_function: + return prognostic_function elif predict_tau_forest: return tau_x + elif predict_cate_function: + return cate elif predict_rfx: return rfx_preds elif predict_variance_forest: @@ -2637,6 +2694,14 @@ def predict( result["tau_hat"] = tau_x else: result["tau_hat"] = None + if predict_prog_function: + result["prognostic_function"] = prognostic_function + else: + result["prognostic_function"] = None + if predict_cate_function: + result["cate"] = cate + else: + result["cate"] = None if predict_rfx: result["rfx_predictions"] = rfx_preds else: diff --git a/tools/debug/bcf_predict_debug.R b/tools/debug/bcf_predict_debug.R index 7a854707..fd8ee7bc 100644 --- a/tools/debug/bcf_predict_debug.R +++ b/tools/debug/bcf_predict_debug.R @@ -227,3 +227,130 @@ ppd_outcome_0 <- ppd_samples_prob[y_test == 0] ppd_outcome_1 <- ppd_samples_prob[y_test == 1] hist(ppd_outcome_0, breaks = 50, xlim = c(0, 1)) hist(ppd_outcome_1, breaks = 50, xlim = c(0, 1)) + +# Generate data with random effects +X <- matrix(rnorm(n * p), ncol = p) +mu_x <- X[, 1] +tau_x <- 0.25 * X[, 2] +pi_x <- pnorm(0.5 * X[, 1]) +Z <- rbinom(n, 1, pi_x) +rfx_group_ids <- sample(1:3, n, replace = TRUE) +rfx_basis <- cbind(1, Z) +rfx_coefs <- matrix(c(-2, -0.5, 0, 0, 2, 0.5), byrow = T, nrow = 3) +rfx_term <- rowSums(rfx_basis * rfx_coefs[rfx_group_ids, ]) +E_XZ <- mu_x + Z * tau_x + rfx_term +snr <- 2 +y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr) + +# Train-test split +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +rfx_group_ids_test <- rfx_group_ids[test_inds] +rfx_group_ids_train <- rfx_group_ids[train_inds] + +# Fit a simple BCF model +rfx_params = list( + model_spec = "intercept_plus_treatment" +) +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000, + random_effects_params = rfx_params +) + +# Retrieve all predictions +posterior_preds_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test +) + +# Check that mu + tau + rfx = prognostic + cate +comp_mat <- (abs( + (posterior_preds_test$mu_hat + + posterior_preds_test$tau_hat * Z_test + + posterior_preds_test$rfx_predictions) - + (posterior_preds_test$prognostic_function + + posterior_preds_test$cate * Z_test) +) < + 0.0001) +all(comp_mat) + +# Retrieve just prognostic predictions +prog_fn_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + terms = c("prognostic_function") +) + +# Compare to prognostic function returned from the larger prediction +all(abs(prog_fn_test - posterior_preds_test$prognostic_function) < 0.0001) + +# Retrieve just mu predictions +mu_hat_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + terms = c("mu") +) + +# Compare to prognostic function returned from the larger prediction +all(abs(mu_hat_test - posterior_preds_test$mu_hat) < 0.0001) + +# Retrieve just CATE predictions +cate_fn_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + terms = c("cate") +) + +# Compare to prognostic function returned from the larger prediction +all(abs(cate_fn_test - posterior_preds_test$cate) < 0.0001) + +# Retrieve just mu predictions +tau_hat_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + terms = c("tau") +) + +# Compare to prognostic function returned from the larger prediction +all(abs(tau_hat_test - posterior_preds_test$tau_hat) < 0.0001) From 729a24d9a26136eb18b2d9ab0ad37f9912facbf2 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 7 Nov 2025 13:57:14 -0500 Subject: [PATCH 2/5] Updated posterior interval R functions to handle mu and tau separately from prognostic function and cate --- R/posterior_transformation.R | 85 +++++++++++++++++------ tools/debug/bcf_predict_debug.R | 116 ++++++++++++++++++++++++++++++++ 2 files changed, 180 insertions(+), 21 deletions(-) diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index b401cffd..6d7ff4f6 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -832,7 +832,7 @@ posterior_predictive_heuristic_multiplier <- function( #' #' This function computes posterior credible intervals for specified terms from a fitted BCF model. It supports intervals for prognostic forests, CATE forests, variance forests, random effects, and overall mean outcome predictions. #' @param model_object A fitted BCF model object of class `bcfmodel`. -#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"cate"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. +#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Note that `"mu"` is only different from `"prognostic_function"` if random effects are included with a model spec of `"intercept_only"` or `"intercept_plus_treatment"` and `"tau"` is only different from `"cate"` if random effects are included with a model spec of `"intercept_plus_treatment"`. #' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval). #' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". #' @param covariates (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions). @@ -895,6 +895,29 @@ compute_bcf_posterior_interval <- function( } # Check that all the necessary inputs were provided for interval computation + for (term in terms) { + if ( + !(term %in% + c( + "prognostic_function", + "mu", + "cate", + "tau", + "variance_forest", + "rfx", + "y_hat", + "all" + )) + ) { + stop( + paste0( + "Term '", + term, + "' was requested. Valid terms are 'prognostic_function', 'mu', 'cate', 'tau', 'variance_forest', 'rfx', 'y_hat', and 'all'." + ) + ) + } + } needs_covariates_intermediate <- ((("y_hat" %in% terms) || ("all" %in% terms))) needs_covariates <- (("prognostic_function" %in% terms) || @@ -975,16 +998,22 @@ compute_bcf_posterior_interval <- function( "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" ) } - if (is.null(rfx_basis)) { - stop( - "'rfx_basis' must be provided in order to compute the requested intervals" - ) - } - if (!is.matrix(rfx_basis)) { - stop("'rfx_basis' must be a matrix") + + if (model_object$model_params$rfx_model_spec == "custom") { + if (is.null(rfx_basis)) { + stop( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + } } - if (nrow(rfx_basis) != nrow(covariates)) { - stop("'rfx_basis' must have the same number of rows as 'covariates'") + + if (!is.null(rfx_basis)) { + if (!is.matrix(rfx_basis)) { + stop("'rfx_basis' must be a matrix") + } + if (nrow(rfx_basis) != nrow(covariates)) { + stop("'rfx_basis' must have the same number of rows as 'covariates'") + } } } @@ -1006,11 +1035,15 @@ compute_bcf_posterior_interval <- function( if (has_multiple_terms) { result <- list() for (term_name in names(predictions)) { - result[[term_name]] <- summarize_interval( - predictions[[term_name]], - sample_dim = 2, - level = level - ) + if (!is.null(predictions[[term_name]])) { + result[[term_name]] <- summarize_interval( + predictions[[term_name]], + sample_dim = 2, + level = level + ) + } else { + result[[term_name]] <- NULL + } } return(result) } else { @@ -1161,11 +1194,15 @@ compute_bart_posterior_interval <- function( if (has_multiple_terms) { result <- list() for (term_name in names(predictions)) { - result[[term_name]] <- summarize_interval( - predictions[[term_name]], - sample_dim = 2, - level = level - ) + if (!is.null(predictions[[term_name]])) { + result[[term_name]] <- summarize_interval( + predictions[[term_name]], + sample_dim = 2, + level = level + ) + } else { + result[[term_name]] <- NULL + } } return(result) } else { @@ -1253,8 +1290,12 @@ bart_model_has_term <- function(model_object, term) { bcf_model_has_term <- function(model_object, term) { if (term == "prognostic_function") { return(TRUE) + } else if (term == "mu") { + return(TRUE) } else if (term == "cate") { return(TRUE) + } else if (term == "tau") { + return(TRUE) } else if (term == "variance_forest") { return(model_object$model_params$include_variance_forest) } else if (term == "rfx") { @@ -1280,7 +1321,9 @@ validate_bart_term <- function(term) { validate_bcf_term <- function(term) { model_terms <- c( "prognostic_function", + "mu", "cate", + "tau", "variance_forest", "rfx", "y_hat", @@ -1288,7 +1331,7 @@ validate_bcf_term <- function(term) { ) if (!(term %in% model_terms)) { stop( - "'term' must be one of 'prognostic_function', 'cate', 'variance_forest', 'rfx', 'y_hat', or 'all' for bcfmodel objects" + "'term' must be one of 'prognostic_function', 'mu', 'cate', 'tau', 'variance_forest', 'rfx', 'y_hat', or 'all' for bcfmodel objects" ) } } diff --git a/tools/debug/bcf_predict_debug.R b/tools/debug/bcf_predict_debug.R index fd8ee7bc..70bc71ed 100644 --- a/tools/debug/bcf_predict_debug.R +++ b/tools/debug/bcf_predict_debug.R @@ -354,3 +354,119 @@ tau_hat_test <- predict( # Compare to prognostic function returned from the larger prediction all(abs(tau_hat_test - posterior_preds_test$tau_hat) < 0.0001) + +# Compute intervals for all of the model terms +posterior_intervals_test <- compute_bcf_posterior_interval( + model_object = bcf_model, + scale = "linear", + terms = "all", + covariates = X_test, + treatment = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + level = 0.95 +) + +# Compute intervals for just the prognostic term +prog_intervals_test <- compute_bcf_posterior_interval( + model_object = bcf_model, + scale = "linear", + terms = "prognostic_function", + covariates = X_test, + treatment = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + level = 0.95 +) + +# Compute intervals for just the CATE term +cate_intervals_test <- compute_bcf_posterior_interval( + model_object = bcf_model, + scale = "linear", + terms = "cate", + covariates = X_test, + treatment = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + level = 0.95 +) + +# Check that they match the corresponding terms from the full interval list +all( + abs( + posterior_intervals_test$prognostic_function$lower - + prog_intervals_test$lower + ) < + 0.0001 +) +all( + abs( + posterior_intervals_test$prognostic_function$upper - + prog_intervals_test$upper + ) < + 0.0001 +) +all( + abs( + posterior_intervals_test$cate$lower - + cate_intervals_test$lower + ) < + 0.0001 +) +all( + abs( + posterior_intervals_test$cate$upper - + cate_intervals_test$upper + ) < + 0.0001 +) + +# Check that the prog and CATE intervals are different from the mu and tau intervals +mu_intervals_test <- compute_bcf_posterior_interval( + model_object = bcf_model, + scale = "linear", + terms = "mu", + covariates = X_test, + treatment = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + level = 0.95 +) +tau_intervals_test <- compute_bcf_posterior_interval( + model_object = bcf_model, + scale = "linear", + terms = "tau", + covariates = X_test, + treatment = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + level = 0.95 +) +all( + abs( + mu_intervals_test$lower - + prog_intervals_test$lower + ) > + 0.0001 +) +all( + abs( + mu_intervals_test$upper - + prog_intervals_test$upper + ) > + 0.0001 +) +all( + abs( + tau_intervals_test$lower - + cate_intervals_test$lower + ) > + 0.0001 +) +all( + abs( + tau_intervals_test$upper - + cate_intervals_test$upper + ) > + 0.0001 +) From 3a32383241a08dc01b18b4f190b370c2d08048ef Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 7 Nov 2025 14:23:57 -0500 Subject: [PATCH 3/5] Reflected mu and tau through the compute_posterior_interval method for Python BCF --- demo/debug/bcf_predict_debug.py | 97 +++++++++++++++++++++++++++++++++ stochtree/bcf.py | 52 +++++++++++++----- 2 files changed, 135 insertions(+), 14 deletions(-) diff --git a/demo/debug/bcf_predict_debug.py b/demo/debug/bcf_predict_debug.py index ec787ad6..141f4ee8 100644 --- a/demo/debug/bcf_predict_debug.py +++ b/demo/debug/bcf_predict_debug.py @@ -223,3 +223,100 @@ # Compare to prognostic function returned from the larger prediction np.allclose(mu_hat_test, bcf_preds["mu_hat"], atol=1e-4) + +# Compute intervals for all of the model terms +posterior_intervals_test = bcf_model.compute_posterior_interval( + terms="all", + scale="linear", + level=0.95, + covariates=X_test, + treatment=Z_test, + propensity=pi_test, + rfx_group_ids=rfx_group_ids_test, +) + +# Compute intervals for just the prognostic term +prog_intervals_test = bcf_model.compute_posterior_interval( + terms="prognostic_function", + scale="linear", + level=0.95, + covariates=X_test, + treatment=Z_test, + propensity=pi_test, + rfx_group_ids=rfx_group_ids_test +) + +# Compute intervals for just the CATE term +cate_intervals_test = bcf_model.compute_posterior_interval( + terms="cate", + scale="linear", + level=0.95, + covariates=X_test, + treatment=Z_test, + propensity=pi_test, + rfx_group_ids=rfx_group_ids_test +) + +# Check that they match the corresponding terms from the full interval list +(np.allclose( + posterior_intervals_test['prognostic_function']['lower'], + prog_intervals_test['lower'], + atol=1e-4 +) and +np.allclose( + posterior_intervals_test['prognostic_function']['upper'], + prog_intervals_test['upper'], + atol=1e-4 +) and +np.allclose( + posterior_intervals_test['cate']['lower'], + cate_intervals_test['lower'], + atol=1e-4 +) and +np.allclose( + posterior_intervals_test['cate']['upper'], + cate_intervals_test['upper'], + atol=1e-4 +)) + +# Check that the prog and CATE intervals are different from the mu and tau intervals +mu_intervals_test = bcf_model.compute_posterior_interval( + terms="mu", + scale="linear", + level=0.95, + covariates=X_test, + treatment=Z_test, + propensity=pi_test, + rfx_group_ids=rfx_group_ids_test +) +tau_intervals_test = bcf_model.compute_posterior_interval( + terms="tau", + scale="linear", + level=0.95, + covariates=X_test, + treatment=Z_test, + propensity=pi_test, + rfx_group_ids=rfx_group_ids_test +) + +(not (np.allclose( + mu_intervals_test['lower'], + prog_intervals_test['lower'], + atol=1e-4 +)) and +not (np.allclose( + mu_intervals_test['upper'], + prog_intervals_test['upper'], + atol=1e-4 +)) and +not (np.allclose( + tau_intervals_test['lower'], + cate_intervals_test['lower'], + atol=1e-4 +)) and +not (np.allclose( + tau_intervals_test['upper'], + cate_intervals_test['upper'], + atol=1e-4 +)) +) \ No newline at end of file diff --git a/stochtree/bcf.py b/stochtree/bcf.py index d0eafdea..4120698d 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -2358,7 +2358,7 @@ def predict( mu_prog_separate = rfx_intercept tau_cate_separate = rfx_intercept_plus_treatment if not isinstance(terms, str) and not isinstance(terms, list): - raise ValueError("type must be a string or list of strings") + raise ValueError("'terms' must be a string or list of strings") for term in terms: if term not in [ "y_hat", @@ -2856,7 +2856,7 @@ def compute_posterior_interval( Parameters ---------- terms : str, optional - Character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"cate"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Defaults to `"all"`. + Character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Defaults to `"all"`. Note that `"mu"` is only different from `"prognostic_function"` if random effects are included with a model spec of `"intercept_only"` or `"intercept_plus_treatment"` and `"tau"` is only different from `"cate"` if random effects are included with a model spec of `"intercept_plus_treatment"`. scale : str, optional Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`. level : float, optional @@ -2880,6 +2880,10 @@ def compute_posterior_interval( # Check the provided model object and requested term if not self.is_sampled(): raise ValueError("Model has not yet been sampled") + if not isinstance(terms, str) and not isinstance(terms, list): + raise ValueError("terms must be a string or list of strings") + if isinstance(terms, str): + terms = [terms] for term in terms: if not self.has_term(term): warnings.warn( @@ -2897,7 +2901,21 @@ def compute_posterior_interval( "scale cannot be 'probability' for models not fit with a probit outcome model" ) - # Check that all the necessary inputs were provided for interval computation + # Handle prediction terms + for term in terms: + if term not in [ + "y_hat", + "prognostic_function", + "mu", + "cate", + "tau", + "rfx", + "variance_forest", + "all", + ]: + raise ValueError( + f"term '{term}' was requested. Valid terms are 'y_hat', 'prognostic_function', 'mu', 'cate', 'tau', 'rfx', 'variance_forest', and 'all'" + ) needs_covariates_intermediate = ("y_hat" in terms) or ("all" in terms) needs_covariates = ( ("prognostic_function" in terms) @@ -2957,16 +2975,18 @@ def compute_posterior_interval( raise ValueError( "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" ) - if rfx_basis is None: - raise ValueError( - "'rfx_basis' must be provided in order to compute the requested intervals" - ) - if not isinstance(rfx_basis, np.ndarray): - raise ValueError("'rfx_basis' must be a numpy array") - if rfx_basis.shape[0] != covariates.shape[0]: - raise ValueError( - "'rfx_basis' must have the same number of rows as 'covariates'" - ) + if self.rfx_model_spec == "custom": + if rfx_basis is None: + raise ValueError( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + if rfx_basis is not None: + if not isinstance(rfx_basis, np.ndarray): + raise ValueError("'rfx_basis' must be a numpy array") + if rfx_basis.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_basis' must have the same number of rows as 'covariates'" + ) # Compute posterior matrices for the requested model terms predictions = self.predict( @@ -3511,7 +3531,7 @@ def has_term(self, term: str) -> bool: Parameters ---------- term : str - Character string specifying the model term to check for. Options for BCF models are `"prognostic_function"`, `"cate"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`. + Character string specifying the model term to check for. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`. Returns ------- @@ -3520,8 +3540,12 @@ def has_term(self, term: str) -> bool: """ if term == "prognostic_function": return True + if term == "mu": + return True if term == "cate": return True + if term == "tau": + return True elif term == "variance_forest": return self.include_variance_forest elif term == "rfx": From 74b4141eab0a9cef96ac2d4f64819d1968bf44c1 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 7 Nov 2025 14:32:25 -0500 Subject: [PATCH 4/5] Added warning about RFX and prognostic_function / CATE --- R/bcf.R | 16 ++++++++++++++++ stochtree/bcf.py | 8 ++++++++ 2 files changed, 24 insertions(+) diff --git a/R/bcf.R b/R/bcf.R index c91ed7e8..82ee9cd1 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -2735,6 +2735,22 @@ predict.bcfmodel <- function( } predict_mean <- type == "mean" + # Warn users about CATE / prognostic function when rfx_model_spec is "custom" + if (object$model_params$has_rfx) { + if (object$model_params$rfx_model_spec == "custom") { + if (("prognostic_function" %in% terms) || ("cate" %in% terms)) { + warning(paste0( + "This BCF model was fit with a custom random effects model specification (i.e. a user-provided basis). ", + "As a result, 'prognostic_function' and 'cate' refer only to the prognostic ('mu') ", + "and treatment effect 'tau' forests, respectively, and do not include any random ", + "effects contributions. If your user-provided random effects basis includes a random intercept or a ", + "random slope on the treatment variable, you will need to compute the prognostic or CATE functions manually by predicting ", + "'yhat' for different covariate and rfx_basis values." + )) + } + } + } + # Handle prediction terms rfx_model_spec = object$model_params$rfx_model_spec rfx_intercept_only <- rfx_model_spec == "intercept_only" diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 4120698d..d66d24ff 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -2901,6 +2901,14 @@ def compute_posterior_interval( "scale cannot be 'probability' for models not fit with a probit outcome model" ) + # Warn users about CATE / prognostic function when rfx_model_spec is "custom" + if self.has_rfx: + if self.rfx_model_spec == "custom": + if "prognostic_function" in terms or "cate" in terms: + warnings.warn( + "This BCF model was fit with a custom random effects model specification (i.e. a user-provided basis). As a result, 'prognostic_function' and 'cate' refer only to the prognostic ('mu') and treatment effect 'tau' forests, respectively, and do not include any random effects contributions. If your user-provided random effects basis includes a random intercept or a random slope on the treatment variable, you will need to compute the prognostic or CATE functions manually by predicting 'y_hat' for different covariate and rfx_basis values." + ) + # Handle prediction terms for term in terms: if term not in [ From d3950266a003062a9430e3f5e7e2dc2eab51e54d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 7 Nov 2025 14:37:53 -0500 Subject: [PATCH 5/5] Updated python warnings and checks --- stochtree/bart.py | 39 ++++++++++++++++++++++++++++++++------- stochtree/bcf.py | 15 +++++---------- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index 3f81c531..30c15c7e 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1700,6 +1700,21 @@ def predict( predict_mean = type == "mean" # Handle prediction terms + if not isinstance(terms, str) and not isinstance(terms, list): + raise ValueError("terms must be a string or list of strings") + if isinstance(terms, str): + terms = [terms] + for term in terms: + if term not in [ + "y_hat", + "mean_forest", + "rfx", + "variance_forest", + "all", + ]: + raise ValueError( + f"term '{term}' was requested. Valid terms are 'y_hat', 'mean_forest', 'rfx', 'variance_forest', and 'all'" + ) rfx_model_spec = self.rfx_model_spec rfx_intercept = rfx_model_spec == "intercept_only" if not isinstance(terms, str) and not isinstance(terms, list): @@ -2116,14 +2131,9 @@ def compute_posterior_interval( dict A dict containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a dict with intervals for each term is returned. """ - # Check the provided model object and requested terms + # Check the provided model object if not self.is_sampled(): raise ValueError("Model has not yet been sampled") - for term in terms: - if not self.has_term(term): - warnings.warn( - f"Term {term} was not sampled in this model and its intervals will not be returned." - ) # Handle mean function scale if not isinstance(scale, str): @@ -2136,7 +2146,22 @@ def compute_posterior_interval( "scale cannot be 'probability' for models not fit with a probit outcome model" ) - # Check that all the necessary inputs were provided for interval computation + # Handle prediction terms + if not isinstance(terms, str) and not isinstance(terms, list): + raise ValueError("terms must be a string or list of strings") + if isinstance(terms, str): + terms = [terms] + for term in terms: + if term not in [ + "y_hat", + "mean_forest", + "rfx", + "variance_forest", + "all", + ]: + raise ValueError( + f"term '{term}' was requested. Valid terms are 'y_hat', 'mean_forest', 'rfx', 'variance_forest', and 'all'" + ) needs_covariates_intermediate = ( ("y_hat" in terms) or ("all" in terms) ) and self.include_mean_forest diff --git a/stochtree/bcf.py b/stochtree/bcf.py index d66d24ff..d51eaf9e 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -2877,18 +2877,9 @@ def compute_posterior_interval( dict A dict containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a dict with intervals for each term is returned. """ - # Check the provided model object and requested term + # Check the provided model object if not self.is_sampled(): raise ValueError("Model has not yet been sampled") - if not isinstance(terms, str) and not isinstance(terms, list): - raise ValueError("terms must be a string or list of strings") - if isinstance(terms, str): - terms = [terms] - for term in terms: - if not self.has_term(term): - warnings.warn( - f"Term {term} was not sampled in this model and its intervals will not be returned." - ) # Handle mean function scale if not isinstance(scale, str): @@ -2910,6 +2901,10 @@ def compute_posterior_interval( ) # Handle prediction terms + if not isinstance(terms, str) and not isinstance(terms, list): + raise ValueError("terms must be a string or list of strings") + if isinstance(terms, str): + terms = [terms] for term in terms: if term not in [ "y_hat",