diff --git a/R/bart.R b/R/bart.R index e12eb041..d5acd6af 100644 --- a/R/bart.R +++ b/R/bart.R @@ -418,7 +418,11 @@ bart <- function( # Raise a warning if the data have ties and only GFR is being run if ((num_gfr > 0) && (num_mcmc == 0) && (num_burnin == 0)) { num_values <- nrow(X_train) - max_grid_size <- floor(num_values / cutpoint_grid_size) + max_grid_size <- ifelse( + num_values > cutpoint_grid_size, + floor(num_values / cutpoint_grid_size), + 1 + ) covs_warning_1 <- NULL covs_warning_2 <- NULL covs_warning_3 <- NULL @@ -1924,7 +1928,7 @@ bart <- function( #' Predict from a sampled BART model on new data #' #' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs. -#' @param covariates Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe. +#' @param X Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe. #' @param leaf_basis (Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: `NULL`. #' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model. #' We do not currently support (but plan to in the near future), test set evaluation for group labels @@ -1961,10 +1965,10 @@ bart <- function( #' y_train <- y[train_inds] #' bart_model <- bart(X_train = X_train, y_train = y_train, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) -#' y_hat_test <- predict(bart_model, X_test)$y_hat +#' y_hat_test <- predict(bart_model, X=X_test)$y_hat predict.bartmodel <- function( object, - covariates, + X, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, @@ -2047,8 +2051,8 @@ predict.bartmodel <- function( } # Check that covariates are matrix or data frame - if ((!is.data.frame(covariates)) && (!is.matrix(covariates))) { - stop("covariates must be a matrix or dataframe") + if ((!is.data.frame(X)) && (!is.matrix(X))) { + stop("X must be a matrix or dataframe") } # Convert all input data to matrices if not already converted @@ -2063,12 +2067,12 @@ predict.bartmodel <- function( if ((object$model_params$requires_basis) && (is.null(leaf_basis))) { stop("Basis (leaf_basis) must be provided for this model") } - if ((!is.null(leaf_basis)) && (nrow(covariates) != nrow(leaf_basis))) { - stop("covariates and leaf_basis must have the same number of rows") + if ((!is.null(leaf_basis)) && (nrow(X) != nrow(leaf_basis))) { + stop("X and leaf_basis must have the same number of rows") } - if (object$model_params$num_covariates != ncol(covariates)) { + if (object$model_params$num_covariates != ncol(X)) { stop( - "covariates must contain the same number of columns as the BART model's training dataset" + "X must contain the same number of columns as the BART model's training dataset" ) } if ((predict_rfx) && (is.null(rfx_group_ids))) { @@ -2089,7 +2093,7 @@ predict.bartmodel <- function( # Preprocess covariates train_set_metadata <- object$train_set_metadata - covariates <- preprocessPredictionData(covariates, train_set_metadata) + X <- preprocessPredictionData(X, train_set_metadata) # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) has_rfx <- FALSE @@ -2119,8 +2123,8 @@ predict.bartmodel <- function( # Only construct a basis if user-provided basis missing if (is.null(rfx_basis)) { rfx_basis <- matrix( - rep(1, nrow(covariates)), - nrow = nrow(covariates), + rep(1, nrow(X)), + nrow = nrow(X), ncol = 1 ) } @@ -2129,9 +2133,9 @@ predict.bartmodel <- function( # Create prediction dataset if (!is.null(leaf_basis)) { - prediction_dataset <- createForestDataset(covariates, leaf_basis) + prediction_dataset <- createForestDataset(X, leaf_basis) } else { - prediction_dataset <- createForestDataset(covariates) + prediction_dataset <- createForestDataset(X) } # Compute variance forest predictions @@ -2843,7 +2847,7 @@ createBARTModelFromJsonFile <- function(json_filename) { #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' bart_json <- saveBARTModelToJsonString(bart_model) #' bart_model_roundtrip <- createBARTModelFromJsonString(bart_json) -#' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat) +#' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X=X_train)$y_hat) createBARTModelFromJsonString <- function(json_string) { # Load a `CppJson` object from string bart_json <- createCppJsonString(json_string) diff --git a/R/bcf.R b/R/bcf.R index a9e60d5d..5a80d5ec 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -522,7 +522,11 @@ bcf <- function( # Raise a warning if the data have ties and only GFR is being run if ((num_gfr > 0) && (num_mcmc == 0) && (num_burnin == 0)) { num_values <- nrow(X_train) - max_grid_size <- floor(num_values / cutpoint_grid_size) + max_grid_size <- ifelse( + num_values > cutpoint_grid_size, + floor(num_values / cutpoint_grid_size), + 1 + ) covs_warning_1 <- NULL covs_warning_2 <- NULL covs_warning_3 <- NULL diff --git a/R/kernel.R b/R/kernel.R index 2b643b98..7ab21370 100644 --- a/R/kernel.R +++ b/R/kernel.R @@ -129,7 +129,7 @@ computeForestLeafIndices <- function( propensity <- rowMeans( predict( model_object$bart_propensity_model, - covariates + X = covariates )$y_hat ) } diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index 87c51d12..dca34be3 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -260,8 +260,8 @@ compute_contrast_bcf_model <- function( #' Only valid when there is either a mean forest or a random effects term in the BART model. #' #' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs. -#' @param covariates_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe. -#' @param covariates_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe. +#' @param X_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe. +#' @param X_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe. #' @param leaf_basis_0 (Optional) Bases used for prediction in the "control" case (by e.g. dot product with leaf values). Default: `NULL`. #' @param leaf_basis_1 (Optional) Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). Default: `NULL`. #' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects @@ -306,8 +306,8 @@ compute_contrast_bcf_model <- function( #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' contrast_test <- compute_contrast_bart_model( #' bart_model, -#' covariates_0 = X_test, -#' covariates_1 = X_test, +#' X_0 = X_test, +#' X_1 = X_test, #' leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), #' leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), #' type = "posterior", @@ -315,8 +315,8 @@ compute_contrast_bcf_model <- function( #' ) compute_contrast_bart_model <- function( object, - covariates_0, - covariates_1, + X_0, + X_1, leaf_basis_0 = NULL, leaf_basis_1 = NULL, rfx_group_ids_0 = NULL, @@ -360,11 +360,11 @@ compute_contrast_bart_model <- function( } # Check that covariates are matrix or data frame - if ((!is.data.frame(covariates_0)) && (!is.matrix(covariates_0))) { - stop("covariates_0 must be a matrix or dataframe") + if ((!is.data.frame(X_0)) && (!is.matrix(X_0))) { + stop("X_0 must be a matrix or dataframe") } - if ((!is.data.frame(covariates_1)) && (!is.matrix(covariates_1))) { - stop("covariates_1 must be a matrix or dataframe") + if ((!is.data.frame(X_1)) && (!is.matrix(X_1))) { + stop("X_1 must be a matrix or dataframe") } # Convert all input data to matrices if not already converted @@ -388,20 +388,20 @@ compute_contrast_bart_model <- function( ) { stop("leaf_basis_0 and leaf_basis_1 must be provided for this model") } - if ((!is.null(leaf_basis_0)) && (nrow(covariates_0) != nrow(leaf_basis_0))) { - stop("covariates_0 and leaf_basis_0 must have the same number of rows") + if ((!is.null(leaf_basis_0)) && (nrow(X_0) != nrow(leaf_basis_0))) { + stop("X_0 and leaf_basis_0 must have the same number of rows") } - if ((!is.null(leaf_basis_1)) && (nrow(covariates_1) != nrow(leaf_basis_1))) { - stop("covariates_1 and leaf_basis_1 must have the same number of rows") + if ((!is.null(leaf_basis_1)) && (nrow(X_1) != nrow(leaf_basis_1))) { + stop("X_1 and leaf_basis_1 must have the same number of rows") } - if (object$model_params$num_covariates != ncol(covariates_0)) { + if (object$model_params$num_covariates != ncol(X_0)) { stop( - "covariates_0 must contain the same number of columns as the BART model's training dataset" + "X_0 must contain the same number of columns as the BART model's training dataset" ) } - if (object$model_params$num_covariates != ncol(covariates_1)) { + if (object$model_params$num_covariates != ncol(X_1)) { stop( - "covariates_1 must contain the same number of columns as the BART model's training dataset" + "X_1 must contain the same number of columns as the BART model's training dataset" ) } if ((has_rfx) && (is.null(rfx_group_ids_0) || is.null(rfx_group_ids_1))) { @@ -427,7 +427,7 @@ compute_contrast_bart_model <- function( # Predict for the control arm control_preds <- predict( object = object, - covariates = covariates_0, + X = X_0, leaf_basis = leaf_basis_0, rfx_group_ids = rfx_group_ids_0, rfx_basis = rfx_basis_0, @@ -439,7 +439,7 @@ compute_contrast_bart_model <- function( # Predict for the treatment arm treatment_preds <- predict( object = object, - covariates = covariates_1, + X = X_1, leaf_basis = leaf_basis_1, rfx_group_ids = rfx_group_ids_1, rfx_basis = rfx_basis_1, @@ -465,8 +465,8 @@ compute_contrast_bart_model <- function( #' Sample from the posterior predictive distribution for outcomes modeled by BCF #' #' @param model_object A fitted BCF model object of class `bcfmodel`. -#' @param covariates A matrix or data frame of covariates. -#' @param treatment A vector or matrix of treatment assignments. +#' @param X A matrix or data frame of covariates. +#' @param Z A vector or matrix of treatment assignments. #' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities. #' @param rfx_group_ids (Optional) A vector of group IDs for random effects model. Required if the BCF model includes random effects. #' @param rfx_basis (Optional) A matrix of bases for random effects model. Required if the BCF model includes random effects. @@ -484,13 +484,13 @@ compute_contrast_bart_model <- function( #' y <- 2 * X[,2] + 0.5 * X[,2] * Z + rnorm(n) #' bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, propensity_train = pi_X) #' ppd_samples <- sample_bcf_posterior_predictive( -#' model_object = bcf_model, covariates = X, -#' treatment = Z, propensity = pi_X +#' model_object = bcf_model, X = X, +#' Z = Z, propensity = pi_X #' ) sample_bcf_posterior_predictive <- function( model_object, - covariates = NULL, - treatment = NULL, + X = NULL, + Z = NULL, propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL, @@ -505,33 +505,33 @@ sample_bcf_posterior_predictive <- function( # Check that all the necessary inputs were provided for interval computation needs_covariates <- TRUE if (needs_covariates) { - if (is.null(covariates)) { + if (is.null(X)) { stop( - "'covariates' must be provided in order to compute the requested intervals" + "'X' must be provided in order to compute the requested intervals" ) } - if (!is.matrix(covariates) && !is.data.frame(covariates)) { - stop("'covariates' must be a matrix or data frame") + if (!is.matrix(X) && !is.data.frame(X)) { + stop("'X' must be a matrix or data frame") } } needs_treatment <- needs_covariates if (needs_treatment) { - if (is.null(treatment)) { + if (is.null(Z)) { stop( - "'treatment' must be provided in order to compute the requested intervals" + "'Z' must be provided in order to compute the requested intervals" ) } - if (!is.matrix(treatment) && !is.numeric(treatment)) { - stop("'treatment' must be a numeric vector or matrix") + if (!is.matrix(Z) && !is.numeric(Z)) { + stop("'Z' must be a numeric vector or matrix") } - if (is.matrix(treatment)) { - if (nrow(treatment) != nrow(covariates)) { - stop("'treatment' must have the same number of rows as 'covariates'") + if (is.matrix(Z)) { + if (nrow(Z) != nrow(X)) { + stop("'Z' must have the same number of rows as 'X'") } } else { - if (length(treatment) != nrow(covariates)) { + if (length(Z) != nrow(X)) { stop( - "'treatment' must have the same number of elements as 'covariates'" + "'Z' must have the same number of elements as 'X'" ) } } @@ -551,13 +551,13 @@ sample_bcf_posterior_predictive <- function( stop("'propensity' must be a numeric vector or matrix") } if (is.matrix(propensity)) { - if (nrow(propensity) != nrow(covariates)) { - stop("'propensity' must have the same number of rows as 'covariates'") + if (nrow(propensity) != nrow(X)) { + stop("'propensity' must have the same number of rows as 'X'") } } else { - if (length(propensity) != nrow(covariates)) { + if (length(propensity) != nrow(X)) { stop( - "'propensity' must have the same number of elements as 'covariates'" + "'propensity' must have the same number of elements as 'X'" ) } } @@ -569,9 +569,9 @@ sample_bcf_posterior_predictive <- function( "'rfx_group_ids' must be provided in order to compute the requested intervals" ) } - if (length(rfx_group_ids) != nrow(covariates)) { + if (length(rfx_group_ids) != nrow(X)) { stop( - "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) } if (is.null(rfx_basis)) { @@ -582,16 +582,16 @@ sample_bcf_posterior_predictive <- function( if (!is.matrix(rfx_basis)) { stop("'rfx_basis' must be a matrix") } - if (nrow(rfx_basis) != nrow(covariates)) { - stop("'rfx_basis' must have the same number of rows as 'covariates'") + if (nrow(rfx_basis) != nrow(X)) { + stop("'rfx_basis' must have the same number of rows as 'X'") } } # Compute posterior samples bcf_preds <- predict( model_object, - X = covariates, - Z = treatment, + X = X, + Z = Z, propensity = propensity, rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis, @@ -605,7 +605,7 @@ sample_bcf_posterior_predictive <- function( has_variance_forest <- model_object$model_params$include_variance_forest samples_global_variance <- model_object$model_params$sample_sigma2_global num_posterior_draws <- model_object$model_params$num_samples - num_observations <- nrow(covariates) + num_observations <- nrow(X) ppd_mean <- bcf_preds$y_hat if (has_variance_forest) { ppd_variance <- bcf_preds$variance_forest_predictions @@ -659,8 +659,8 @@ sample_bcf_posterior_predictive <- function( #' Sample from the posterior predictive distribution for outcomes modeled by BART #' #' @param model_object A fitted BART model object of class `bartmodel`. -#' @param covariates A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest). -#' @param basis A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models. +#' @param X A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest). +#' @param leaf_basis A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models. #' @param rfx_group_ids A vector of group IDs for random effects model. Required if the BART model includes random effects. #' @param rfx_basis A matrix of bases for random effects model. Required if the BART model includes random effects. #' @param num_draws_per_sample The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws). @@ -675,12 +675,12 @@ sample_bcf_posterior_predictive <- function( #' y <- 2 * X[,1] + rnorm(n) #' bart_model <- bart(y_train = y, X_train = X) #' ppd_samples <- sample_bart_posterior_predictive( -#' model_object = bart_model, covariates = X +#' model_object = bart_model, X = X #' ) sample_bart_posterior_predictive <- function( model_object, - covariates = NULL, - basis = NULL, + X = NULL, + leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, num_draws_per_sample = NULL @@ -694,32 +694,32 @@ sample_bart_posterior_predictive <- function( # Check that all the necessary inputs were provided for interval computation needs_covariates <- model_object$model_params$include_mean_forest if (needs_covariates) { - if (is.null(covariates)) { + if (is.null(X)) { stop( - "'covariates' must be provided in order to compute the requested intervals" + "'X' must be provided in order to compute the requested intervals" ) } - if (!is.matrix(covariates) && !is.data.frame(covariates)) { - stop("'covariates' must be a matrix or data frame") + if (!is.matrix(X) && !is.data.frame(X)) { + stop("'X' must be a matrix or data frame") } } needs_basis <- needs_covariates && model_object$model_params$has_basis if (needs_basis) { - if (is.null(basis)) { + if (is.null(leaf_basis)) { stop( - "'basis' must be provided in order to compute the requested intervals" + "'leaf_basis' must be provided in order to compute the requested intervals" ) } - if (!is.matrix(basis)) { - stop("'basis' must be a matrix") + if (!is.matrix(leaf_basis)) { + stop("'leaf_basis' must be a matrix") } - if (is.matrix(basis)) { - if (nrow(basis) != nrow(covariates)) { - stop("'basis' must have the same number of rows as 'covariates'") + if (is.matrix(leaf_basis)) { + if (nrow(leaf_basis) != nrow(X)) { + stop("'leaf_basis' must have the same number of rows as 'X'") } } else { - if (length(basis) != nrow(covariates)) { - stop("'basis' must have the same number of elements as 'covariates'") + if (length(leaf_basis) != nrow(X)) { + stop("'leaf_basis' must have the same number of elements as 'X'") } } } @@ -730,9 +730,9 @@ sample_bart_posterior_predictive <- function( "'rfx_group_ids' must be provided in order to compute the requested intervals" ) } - if (length(rfx_group_ids) != nrow(covariates)) { + if (length(rfx_group_ids) != nrow(X)) { stop( - "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) } if (is.null(rfx_basis)) { @@ -743,16 +743,16 @@ sample_bart_posterior_predictive <- function( if (!is.matrix(rfx_basis)) { stop("'rfx_basis' must be a matrix") } - if (nrow(rfx_basis) != nrow(covariates)) { - stop("'rfx_basis' must have the same number of rows as 'covariates'") + if (nrow(rfx_basis) != nrow(X)) { + stop("'rfx_basis' must have the same number of rows as 'X'") } } # Compute posterior samples bart_preds <- predict( model_object, - covariates = covariates, - leaf_basis = basis, + X = X, + leaf_basis = leaf_basis, rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis, type = "posterior", @@ -766,7 +766,7 @@ sample_bart_posterior_predictive <- function( has_variance_forest <- model_object$model_params$include_variance_forest samples_global_variance <- model_object$model_params$sample_sigma2_global num_posterior_draws <- model_object$model_params$num_samples - num_observations <- nrow(covariates) + num_observations <- nrow(X) if (has_mean_term) { ppd_mean <- bart_preds$y_hat } else { @@ -840,8 +840,8 @@ posterior_predictive_heuristic_multiplier <- function( #' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Note that `"mu"` is only different from `"prognostic_function"` if random effects are included with a model spec of `"intercept_only"` or `"intercept_plus_treatment"` and `"tau"` is only different from `"cate"` if random effects are included with a model spec of `"intercept_plus_treatment"`. #' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval). #' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". -#' @param covariates (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions). -#' @param treatment (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions). +#' @param X (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions). +#' @param Z (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions). #' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities. #' @param rfx_group_ids An optional vector of group IDs for random effects. Required if the requested term includes random effects. #' @param rfx_basis An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. @@ -863,8 +863,8 @@ posterior_predictive_heuristic_multiplier <- function( #' intervals <- compute_bcf_posterior_interval( #' model_object = bcf_model, #' terms = c("prognostic_function", "cate"), -#' covariates = X, -#' treatment = Z, +#' X = X, +#' Z = Z, #' propensity = pi_X, #' level = 0.90 #' ) @@ -873,8 +873,8 @@ compute_bcf_posterior_interval <- function( terms, level = 0.95, scale = "linear", - covariates = NULL, - treatment = NULL, + X = NULL, + Z = NULL, propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL @@ -930,33 +930,33 @@ compute_bcf_posterior_interval <- function( ("variance_forest" %in% terms) || (needs_covariates_intermediate)) if (needs_covariates) { - if (is.null(covariates)) { + if (is.null(X)) { stop( - "'covariates' must be provided in order to compute the requested intervals" + "'X' must be provided in order to compute the requested intervals" ) } - if (!is.matrix(covariates) && !is.data.frame(covariates)) { - stop("'covariates' must be a matrix or data frame") + if (!is.matrix(X) && !is.data.frame(X)) { + stop("'X' must be a matrix or data frame") } } needs_treatment <- needs_covariates if (needs_treatment) { - if (is.null(treatment)) { + if (is.null(Z)) { stop( - "'treatment' must be provided in order to compute the requested intervals" + "'Z' must be provided in order to compute the requested intervals" ) } - if (!is.matrix(treatment) && !is.numeric(treatment)) { - stop("'treatment' must be a numeric vector or matrix") + if (!is.matrix(Z) && !is.numeric(Z)) { + stop("'Z' must be a numeric vector or matrix") } - if (is.matrix(treatment)) { - if (nrow(treatment) != nrow(covariates)) { - stop("'treatment' must have the same number of rows as 'covariates'") + if (is.matrix(Z)) { + if (nrow(Z) != nrow(X)) { + stop("'Z' must have the same number of rows as 'X'") } } else { - if (length(treatment) != nrow(covariates)) { + if (length(Z) != nrow(X)) { stop( - "'treatment' must have the same number of elements as 'covariates'" + "'Z' must have the same number of elements as 'X'" ) } } @@ -976,13 +976,13 @@ compute_bcf_posterior_interval <- function( stop("'propensity' must be a numeric vector or matrix") } if (is.matrix(propensity)) { - if (nrow(propensity) != nrow(covariates)) { - stop("'propensity' must have the same number of rows as 'covariates'") + if (nrow(propensity) != nrow(X)) { + stop("'propensity' must have the same number of rows as 'X'") } } else { - if (length(propensity) != nrow(covariates)) { + if (length(propensity) != nrow(X)) { stop( - "'propensity' must have the same number of elements as 'covariates'" + "'propensity' must have the same number of elements as 'X'" ) } } @@ -998,9 +998,9 @@ compute_bcf_posterior_interval <- function( "'rfx_group_ids' must be provided in order to compute the requested intervals" ) } - if (length(rfx_group_ids) != nrow(covariates)) { + if (length(rfx_group_ids) != nrow(X)) { stop( - "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) } @@ -1016,8 +1016,8 @@ compute_bcf_posterior_interval <- function( if (!is.matrix(rfx_basis)) { stop("'rfx_basis' must be a matrix") } - if (nrow(rfx_basis) != nrow(covariates)) { - stop("'rfx_basis' must have the same number of rows as 'covariates'") + if (nrow(rfx_basis) != nrow(X)) { + stop("'rfx_basis' must have the same number of rows as 'X'") } } } @@ -1025,8 +1025,8 @@ compute_bcf_posterior_interval <- function( # Compute posterior matrices for the requested model terms predictions <- predict( model_object, - X = covariates, - Z = treatment, + X = X, + Z = Z, propensity = propensity, rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis, @@ -1068,8 +1068,8 @@ compute_bcf_posterior_interval <- function( #' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. #' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval). #' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". -#' @param covariates A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions). -#' @param basis An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. +#' @param X A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions). +#' @param leaf_basis An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. #' @param rfx_group_ids An optional vector of group IDs for random effects. Required if the requested term includes random effects. #' @param rfx_basis An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. #' @@ -1085,7 +1085,7 @@ compute_bcf_posterior_interval <- function( #' intervals <- compute_bart_posterior_interval( #' model_object = bart_model, #' terms = c("mean_forest", "y_hat"), -#' covariates = X, +#' X = X, #' level = 0.90 #' ) #' @export @@ -1094,8 +1094,8 @@ compute_bart_posterior_interval <- function( terms, level = 0.95, scale = "linear", - covariates = NULL, - basis = NULL, + X = NULL, + leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL ) { @@ -1127,32 +1127,32 @@ compute_bart_posterior_interval <- function( ("variance_forest" %in% terms) || (needs_covariates_intermediate)) if (needs_covariates) { - if (is.null(covariates)) { + if (is.null(X)) { stop( - "'covariates' must be provided in order to compute the requested intervals" + "'X' must be provided in order to compute the requested intervals" ) } - if (!is.matrix(covariates) && !is.data.frame(covariates)) { - stop("'covariates' must be a matrix or data frame") + if (!is.matrix(X) && !is.data.frame(X)) { + stop("'X' must be a matrix or data frame") } } needs_basis <- needs_covariates && model_object$model_params$has_basis if (needs_basis) { - if (is.null(basis)) { + if (is.null(leaf_basis)) { stop( - "'basis' must be provided in order to compute the requested intervals" + "'leaf_basis' must be provided in order to compute the requested intervals" ) } - if (!is.matrix(basis)) { - stop("'basis' must be a matrix") + if (!is.matrix(leaf_basis)) { + stop("'leaf_basis' must be a matrix") } - if (is.matrix(basis)) { - if (nrow(basis) != nrow(covariates)) { - stop("'basis' must have the same number of rows as 'covariates'") + if (is.matrix(leaf_basis)) { + if (nrow(leaf_basis) != nrow(X)) { + stop("'leaf_basis' must have the same number of rows as 'X'") } } else { - if (length(basis) != nrow(covariates)) { - stop("'basis' must have the same number of elements as 'covariates'") + if (length(leaf_basis) != nrow(X)) { + stop("'leaf_basis' must have the same number of elements as 'X'") } } } @@ -1167,9 +1167,9 @@ compute_bart_posterior_interval <- function( "'rfx_group_ids' must be provided in order to compute the requested intervals" ) } - if (length(rfx_group_ids) != nrow(covariates)) { + if (length(rfx_group_ids) != nrow(X)) { stop( - "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) } if (is.null(rfx_basis)) { @@ -1180,16 +1180,16 @@ compute_bart_posterior_interval <- function( if (!is.matrix(rfx_basis)) { stop("'rfx_basis' must be a matrix") } - if (nrow(rfx_basis) != nrow(covariates)) { - stop("'rfx_basis' must have the same number of rows as 'covariates'") + if (nrow(rfx_basis) != nrow(X)) { + stop("'rfx_basis' must have the same number of rows as 'X'") } } # Compute posterior matrices for the requested model terms predictions <- predict( model_object, - covariates = covariates, - leaf_basis = basis, + X = X, + leaf_basis = leaf_basis, rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis, type = "posterior", diff --git a/R/utils.R b/R/utils.R index 25c2fa1a..d33fe2c5 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1092,33 +1092,3 @@ expand_dims_2d_diag <- function(input, output_size) { } return(output) } - - -gfr_tie_checks <- function(covariates) { - num_vars <- ncol(covariates) - for (j in 1:num_vars) { - x_j <- covariates[, j] - if (has_few_unique_values(x_j)) { - warning_message <- paste0( - "Covariate column ", - j, - " has relatively few unique values. ", - "This may lead to tied values when sampling split points in BART/BCF, ", - "which can cause errors during model fitting. ", - "Consider adding small amounts of noise to this variable to break ties." - ) - warning(warning_message) - } - } -} - - -has_few_unique_values <- function( - x, - count_threshold = 15 -) { - x_unique <- unique(x) - num_unique_values <- length(unique_values) - unique_to_total_count_ratio <- num_unique_values / length(x) - return(num_unique_values <= threshold) -} diff --git a/demo/debug/bart_contrast_debug.py b/demo/debug/bart_contrast_debug.py index 15ce5705..f80100bc 100644 --- a/demo/debug/bart_contrast_debug.py +++ b/demo/debug/bart_contrast_debug.py @@ -55,25 +55,25 @@ # Compute contrast posterior contrast_posterior_test = bart_model.compute_contrast( - covariates_0=X_test, - covariates_1=X_test, - basis_0=np.zeros((n_test, 1)), - basis_1=np.ones((n_test, 1)), + X_0=X_test, + X_1=X_test, + leaf_basis_0=np.zeros((n_test, 1)), + leaf_basis_1=np.ones((n_test, 1)), type="posterior", scale="linear", ) # Compute the same quantity via two predict calls y_hat_posterior_test_0 = bart_model.predict( - covariates=X_test, - basis=np.zeros((n_test, 1)), + X=X_test, + leaf_basis=np.zeros((n_test, 1)), type="posterior", terms="y_hat", scale="linear", ) y_hat_posterior_test_1 = bart_model.predict( - covariates=X_test, - basis=np.ones((n_test, 1)), + X=X_test, + leaf_basis=np.ones((n_test, 1)), type="posterior", terms="y_hat", scale="linear", @@ -143,10 +143,10 @@ # Compute contrast posterior contrast_posterior_test = bart_model.compute_contrast( - covariates_0=X_test, - covariates_1=X_test, - basis_0=np.zeros((n_test, 1)), - basis_1=np.ones((n_test, 1)), + X_0=X_test, + X_1=X_test, + leaf_basis_0=np.zeros((n_test, 1)), + leaf_basis_1=np.ones((n_test, 1)), rfx_group_ids_0=group_ids_test, rfx_group_ids_1=group_ids_test, rfx_basis_0=rfx_basis_test, @@ -157,8 +157,8 @@ # Compute the same quantity via two predict calls y_hat_posterior_test_0 = bart_model.predict( - covariates=X_test, - basis=np.zeros((n_test, 1)), + X=X_test, + leaf_basis=np.zeros((n_test, 1)), rfx_group_ids=group_ids_test, rfx_basis=rfx_basis_test, type="posterior", @@ -166,8 +166,8 @@ scale="linear", ) y_hat_posterior_test_1 = bart_model.predict( - covariates=X_test, - basis=np.ones((n_test, 1)), + X=X_test, + leaf_basis=np.ones((n_test, 1)), rfx_group_ids=group_ids_test, rfx_basis=rfx_basis_test, type="posterior", diff --git a/demo/debug/bart_predict_debug.py b/demo/debug/bart_predict_debug.py index d66b1110..ca617c8a 100644 --- a/demo/debug/bart_predict_debug.py +++ b/demo/debug/bart_predict_debug.py @@ -46,11 +46,11 @@ ) # # Check several predict approaches -bart_preds = bart_model.predict(covariates=X_test) -y_hat_posterior_test = bart_model.predict(covariates=X_test)["y_hat"] -y_hat_mean_test = bart_model.predict(covariates=X_test, type="mean", terms=["y_hat"]) +bart_preds = bart_model.predict(X=X_test) +y_hat_posterior_test = bart_model.predict(X=X_test)["y_hat"] +y_hat_mean_test = bart_model.predict(X=X_test, type="mean", terms=["y_hat"]) y_hat_test = bart_model.predict( - covariates=X_test, type="mean", terms=["rfx", "variance"] + X=X_test, type="mean", terms=["rfx", "variance"] ) # Plot predicted versus actual @@ -63,7 +63,7 @@ # Compute posterior interval intervals = bart_model.compute_posterior_interval( - terms="all", scale="linear", level=0.95, covariates=X_test + terms="all", scale="linear", level=0.95, X=X_test ) # Check coverage @@ -75,7 +75,7 @@ # Sample from the posterior predictive distribution bart_ppd_samples = bart_model.sample_posterior_predictive( - covariates=X_test, num_draws_per_sample=10 + X=X_test, num_draws_per_sample=10 ) # Plot PPD mean vs actual diff --git a/demo/debug/bcf_pred_rmse.py b/demo/debug/bcf_pred_rmse.py index 0706842f..721074ea 100644 --- a/demo/debug/bcf_pred_rmse.py +++ b/demo/debug/bcf_pred_rmse.py @@ -51,11 +51,11 @@ bcf_model.sample( X_train=X_train, Z_train=Z_train, - pi_train=pi_train, + propensity_train=pi_train, y_train=y_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, ) # Predict out of sample diff --git a/demo/debug/bcf_predict_debug.py b/demo/debug/bcf_predict_debug.py index 141f4ee8..24b68031 100644 --- a/demo/debug/bcf_predict_debug.py +++ b/demo/debug/bcf_predict_debug.py @@ -45,7 +45,7 @@ bcf_model.sample( X_train=X_train, Z_train=Z_train, - pi_train=pi_train, + propensity_train=pi_train, y_train=y_train, num_gfr=10, num_burnin=0, @@ -90,8 +90,8 @@ terms="all", scale="linear", level=0.95, - covariates=X_test, - treatment=Z_test, + X=X_test, + Z=Z_test, propensity=pi_test, ) @@ -118,7 +118,7 @@ # Sample from the posterior predictive distribution bcf_ppd_samples = bcf_model.sample_posterior_predictive( - covariates=X_test, treatment=Z_test, propensity=pi_test, num_draws_per_sample=10 + X=X_test, Z=Z_test, propensity=pi_test, num_draws_per_sample=10 ) # Plot PPD mean vs actual @@ -182,7 +182,7 @@ bcf_model.sample( X_train=X_train, Z_train=Z_train, - pi_train=pi_train, + propensity_train=pi_train, y_train=y_train, rfx_group_ids_train=rfx_group_ids_train, num_gfr=10, @@ -229,8 +229,8 @@ terms="all", scale="linear", level=0.95, - covariates=X_test, - treatment=Z_test, + X=X_test, + Z=Z_test, propensity=pi_test, rfx_group_ids=rfx_group_ids_test, ) @@ -240,8 +240,8 @@ terms="prognostic_function", scale="linear", level=0.95, - covariates=X_test, - treatment=Z_test, + X=X_test, + Z=Z_test, propensity=pi_test, rfx_group_ids=rfx_group_ids_test ) @@ -251,8 +251,8 @@ terms="cate", scale="linear", level=0.95, - covariates=X_test, - treatment=Z_test, + X=X_test, + Z=Z_test, propensity=pi_test, rfx_group_ids=rfx_group_ids_test ) @@ -284,8 +284,8 @@ terms="mu", scale="linear", level=0.95, - covariates=X_test, - treatment=Z_test, + X=X_test, + Z=Z_test, propensity=pi_test, rfx_group_ids=rfx_group_ids_test ) @@ -293,8 +293,8 @@ terms="tau", scale="linear", level=0.95, - covariates=X_test, - treatment=Z_test, + X=X_test, + Z=Z_test, propensity=pi_test, rfx_group_ids=rfx_group_ids_test ) diff --git a/demo/debug/causal_inference_binary_outcome.py b/demo/debug/causal_inference_binary_outcome.py index c603927d..4d249cbd 100644 --- a/demo/debug/causal_inference_binary_outcome.py +++ b/demo/debug/causal_inference_binary_outcome.py @@ -1,7 +1,6 @@ # Load necessary libraries import numpy as np import pandas as pd -import seaborn as sns import matplotlib.pyplot as plt from stochtree import BCFModel from sklearn.model_selection import train_test_split @@ -101,8 +100,8 @@ def g(x5): # Run the sampler bcf_model = BCFModel() -bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, pi_train=pi_train, - X_test=X_test, Z_test=Z_test, pi_test=pi_test, num_gfr=num_gfr, +bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, propensity_train=pi_train, + X_test=X_test, Z_test=Z_test, propensity_test=pi_test, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, general_params=general_params, prognostic_forest_params=prognostic_forest_params, treatment_effect_forest_params=treatment_effect_forest_params) diff --git a/demo/debug/causal_inference_feature_subsets.py b/demo/debug/causal_inference_feature_subsets.py index 00cab8b8..8fa0fb2f 100644 --- a/demo/debug/causal_inference_feature_subsets.py +++ b/demo/debug/causal_inference_feature_subsets.py @@ -44,7 +44,7 @@ bcf_model_a = BCFModel() prog_forest_config_a = {"num_trees": 100} trt_forest_config_a = {"num_trees": 50} -bcf_model_a.sample(X_train=X_train, Z_train=Z_train, pi_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, pi_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_a, treatment_effect_forest_params=trt_forest_config_a) +bcf_model_a.sample(X_train=X_train, Z_train=Z_train, propensity_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, propensity_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_a, treatment_effect_forest_params=trt_forest_config_a) """ timing_no_subsampling = timeit.timeit(stmt=s, number=5, globals=globals()) print(f"Average runtime, without feature subsampling (p = {p:d}): {timing_no_subsampling:.2f}") @@ -54,7 +54,7 @@ bcf_model_b = BCFModel() prog_forest_config_b = {"num_trees": 100, "num_features_subsample": 5} trt_forest_config_b = {"num_trees": 50, "num_features_subsample": 5} -bcf_model_b.sample(X_train=X_train, Z_train=Z_train, pi_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, pi_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_b, treatment_effect_forest_params=trt_forest_config_b) +bcf_model_b.sample(X_train=X_train, Z_train=Z_train, propensity_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, propensity_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_b, treatment_effect_forest_params=trt_forest_config_b) """ timing_subsampling = timeit.timeit(stmt=s, number=5, globals=globals()) print(f"Average runtime, subsampling 5 out of {p:d} features: {timing_subsampling:.2f}") @@ -63,11 +63,11 @@ bcf_model_a = BCFModel() prog_forest_config_a = {"num_trees": 100} trt_forest_config_a = {"num_trees": 50} -bcf_model_a.sample(X_train=X_train, Z_train=Z_train, pi_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, pi_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_a, treatment_effect_forest_params=trt_forest_config_a) +bcf_model_a.sample(X_train=X_train, Z_train=Z_train, propensity_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, propensity_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_a, treatment_effect_forest_params=trt_forest_config_a) bcf_model_b = BCFModel() prog_forest_config_b = {"num_trees": 100, "num_features_subsample": 5} trt_forest_config_b = {"num_trees": 50, "num_features_subsample": 5} -bcf_model_b.sample(X_train=X_train, Z_train=Z_train, pi_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, pi_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_b, treatment_effect_forest_params=trt_forest_config_b) +bcf_model_b.sample(X_train=X_train, Z_train=Z_train, propensity_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, propensity_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_b, treatment_effect_forest_params=trt_forest_config_b) y_hat_test_a = np.squeeze(bcf_model_a.y_hat_test).mean(axis = 1) rmse_no_subsampling = np.sqrt(np.mean(np.power(y_test - y_hat_test_a,2))) print(f"Test set RMSE, no subsampling (p = {p:d}): {rmse_no_subsampling:.2f}") diff --git a/demo/debug/gfr_ties_debug.py b/demo/debug/gfr_ties_debug.py index fabc70b5..0a194e77 100644 --- a/demo/debug/gfr_ties_debug.py +++ b/demo/debug/gfr_ties_debug.py @@ -38,7 +38,7 @@ ) # Inspect the model fit -y_hat_test = xbart_model.predict(X_test, type="mean", terms="y_hat") +y_hat_test = xbart_model.predict(X=X_test, type="mean", terms="y_hat") plt.scatter(y_hat_test, y_test) plt.axline((0, 0), slope=1, color="red", linestyle="dashed", linewidth=1.5) plt.xlabel("Predicted Outcome Mean") @@ -54,7 +54,7 @@ ) # Inspect the model fit -y_hat_test = bart_model.predict(X_test, type="mean", terms="y_hat") +y_hat_test = bart_model.predict(X=X_test, type="mean", terms="y_hat") plt.clf() plt.scatter(y_hat_test, y_test) plt.axline((0, 0), slope=1, color="red", linestyle="dashed", linewidth=1.5) @@ -95,7 +95,7 @@ ) # Inspect the model fit -y_hat_test = xbart_model.predict(X_test, type="mean", terms="y_hat") +y_hat_test = xbart_model.predict(X=X_test, type="mean", terms="y_hat") plt.clf() plt.scatter(y_hat_test, y_test) plt.axline((0, 0), slope=1, color="red", linestyle="dashed", linewidth=1.5) @@ -112,7 +112,7 @@ ) # Inspect the model fit -y_hat_test = bart_model.predict(X_test, type="mean", terms="y_hat") +y_hat_test = bart_model.predict(X=X_test, type="mean", terms="y_hat") plt.clf() plt.scatter(y_hat_test, y_test) plt.axline((0, 0), slope=1, color="red", linestyle="dashed", linewidth=1.5) @@ -157,7 +157,7 @@ xbcf_model.sample( X_train=X_train, Z_train=Z_train, - pi_train=propensity_train, + propensity_train=propensity_train, y_train=y_train, num_gfr=10, num_burnin=0, @@ -182,7 +182,7 @@ bcf_model.sample( X_train=X_train, Z_train=Z_train, - pi_train=propensity_train, + propensity_train=propensity_train, y_train=y_train, num_gfr=10, num_burnin=0, @@ -237,7 +237,7 @@ xbcf_model.sample( X_train=X_train, Z_train=Z_train, - pi_train=propensity_train, + propensity_train=propensity_train, y_train=y_train, num_gfr=10, num_burnin=0, @@ -262,7 +262,7 @@ bcf_model.sample( X_train=X_train, Z_train=Z_train, - pi_train=propensity_train, + propensity_train=propensity_train, y_train=y_train, num_gfr=10, num_burnin=0, diff --git a/demo/debug/multi_chain.py b/demo/debug/multi_chain.py index bb35ee9a..e5f621ba 100644 --- a/demo/debug/multi_chain.py +++ b/demo/debug/multi_chain.py @@ -3,7 +3,6 @@ # Load necessary libraries import matplotlib.pyplot as plt import numpy as np -import pandas as pd import arviz as az from sklearn.model_selection import train_test_split @@ -89,8 +88,8 @@ def outcome_mean(X, W): # Analyze model predictions collectively across all chains y_hat_test = bart_model.predict( - covariates = X_test, - basis = basis_test, + X = X_test, + leaf_basis = basis_test, type = "mean", terms = "y_hat" ) diff --git a/demo/debug/multiple_initializations.py b/demo/debug/multiple_initializations.py index c499f45b..ad3b60a3 100644 --- a/demo/debug/multiple_initializations.py +++ b/demo/debug/multiple_initializations.py @@ -118,14 +118,14 @@ def outcome_mean(X, W): ) # Inspect the model outputs -bart_preds_2 = bart_model_2.predict(X_test, basis_test) +bart_preds_2 = bart_model_2.predict(X=X_test, leaf_basis=basis_test) y_hat_mcmc_2 = bart_preds_2['y_hat'] y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True) y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True) -bart_preds_3 = bart_model_3.predict(X_test, basis_test) +bart_preds_3 = bart_model_3.predict(X=X_test, leaf_basis=basis_test) y_hat_mcmc_3 = bart_preds_3['y_hat'] y_avg_mcmc_3 = np.squeeze(y_hat_mcmc_3).mean(axis=1, keepdims=True) -bart_preds_4 = bart_model_4.predict(X_test, basis_test) +bart_preds_4 = bart_model_4.predict(X=X_test, leaf_basis=basis_test) y_hat_mcmc_4 = bart_preds_4['y_hat'] y_avg_mcmc_4 = np.squeeze(y_hat_mcmc_4).mean(axis=1, keepdims=True) y_df = pd.DataFrame( diff --git a/demo/debug/parallel_multi_chain.py b/demo/debug/parallel_multi_chain.py index ee618df5..e3148e5b 100644 --- a/demo/debug/parallel_multi_chain.py +++ b/demo/debug/parallel_multi_chain.py @@ -145,7 +145,7 @@ def outcome_mean(X, W): ) # Inspect the model outputs - bart_preds = combined_bart.predict(X_test, basis_test) + bart_preds = combined_bart.predict(X=X_test, leaf_basis=basis_test) y_hat_mcmc = bart_preds['y_hat'] y_avg_mcmc = np.squeeze(y_hat_mcmc).mean(axis=1, keepdims=True) y_df = pd.DataFrame( diff --git a/demo/debug/probit_bart_rfx_debug.py b/demo/debug/probit_bart_rfx_debug.py index ae2e8c10..de8d3953 100644 --- a/demo/debug/probit_bart_rfx_debug.py +++ b/demo/debug/probit_bart_rfx_debug.py @@ -72,10 +72,10 @@ # Compute contrast posterior contrast_posterior_test = bart_model.compute_contrast( - covariates_0=X_test, - covariates_1=X_test, - basis_0=np.zeros((n_test, 1)), - basis_1=np.ones((n_test, 1)), + X_0=X_test, + X_1=X_test, + leaf_basis_0=np.zeros((n_test, 1)), + leaf_basis_1=np.ones((n_test, 1)), rfx_group_ids_0=group_ids_test, rfx_group_ids_1=group_ids_test, rfx_basis_0=rfx_basis_test, @@ -86,8 +86,8 @@ # Compute the same quantity via two predict calls y_hat_posterior_test_0 = bart_model.predict( - covariates=X_test, - basis=np.zeros((n_test, 1)), + X=X_test, + leaf_basis=np.zeros((n_test, 1)), rfx_group_ids=group_ids_test, rfx_basis=rfx_basis_test, type="posterior", @@ -95,8 +95,8 @@ scale="linear", ) y_hat_posterior_test_1 = bart_model.predict( - covariates=X_test, - basis=np.ones((n_test, 1)), + X=X_test, + leaf_basis=np.ones((n_test, 1)), rfx_group_ids=group_ids_test, rfx_basis=rfx_basis_test, type="posterior", @@ -111,8 +111,8 @@ # Plot predicted versus actual outcome Z_hat_test = bart_model.predict( - covariates=X_test, - basis=W_test, + X=X_test, + leaf_basis=W_test, rfx_group_ids=group_ids_test, rfx_basis=rfx_basis_test, type="mean", diff --git a/demo/debug/rfx_serialization.py b/demo/debug/rfx_serialization.py index fec857b6..b6fc3d97 100644 --- a/demo/debug/rfx_serialization.py +++ b/demo/debug/rfx_serialization.py @@ -60,13 +60,13 @@ def rfx_mean(group_labels, basis): rfx_basis_train=basis, num_gfr=10, num_mcmc=10) # Extract predictions from the sampler -bart_preds_orig = bart_orig.predict(X, W, group_labels, basis) +bart_preds_orig = bart_orig.predict(X=X, leaf_basis=W, rfx_group_ids=group_labels, rfx_basis=basis) y_hat_orig = bart_preds_orig['y_hat'] # "Round-trip" the model to JSON string and back and check that the predictions agree bart_json_string = bart_orig.to_json() bart_reloaded = BARTModel() bart_reloaded.from_json(bart_json_string) -bart_preds_reloaded = bart_reloaded.predict(X, W, group_labels, basis) +bart_preds_reloaded = bart_reloaded.predict(X=X, leaf_basis=W, rfx_group_ids=group_labels, rfx_basis=basis) y_hat_reloaded = bart_preds_reloaded['y_hat'] np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded) \ No newline at end of file diff --git a/demo/notebooks/causal_inference.ipynb b/demo/notebooks/causal_inference.ipynb index 511ce0c4..151356ae 100644 --- a/demo/notebooks/causal_inference.ipynb +++ b/demo/notebooks/causal_inference.ipynb @@ -109,10 +109,10 @@ " X_train=X_train,\n", " Z_train=Z_train,\n", " y_train=y_train,\n", - " pi_train=pi_train,\n", + " propensity_train=pi_train,\n", " X_test=X_test,\n", " Z_test=Z_test,\n", - " pi_test=pi_test,\n", + " propensity_test=pi_test,\n", " num_gfr=10,\n", " num_mcmc=100,\n", " general_params=general_params,\n", diff --git a/demo/notebooks/causal_inference_feature_subsets.ipynb b/demo/notebooks/causal_inference_feature_subsets.ipynb index 2bbfbad5..f4465568 100644 --- a/demo/notebooks/causal_inference_feature_subsets.ipynb +++ b/demo/notebooks/causal_inference_feature_subsets.ipynb @@ -113,10 +113,10 @@ " X_train=X_train,\n", " Z_train=Z_train,\n", " y_train=y_train,\n", - " pi_train=pi_train,\n", + " propensity_train=pi_train,\n", " X_test=X_test,\n", " Z_test=Z_test,\n", - " pi_test=pi_test,\n", + " propensity_test=pi_test,\n", " num_gfr=10,\n", " num_mcmc=100,\n", " general_params={\"keep_every\": 5},\n", @@ -242,10 +242,10 @@ " X_train=X_train,\n", " Z_train=Z_train,\n", " y_train=y_train,\n", - " pi_train=pi_train,\n", + " propensity_train=pi_train,\n", " X_test=X_test,\n", " Z_test=Z_test,\n", - " pi_test=pi_test,\n", + " propensity_test=pi_test,\n", " num_gfr=10,\n", " num_mcmc=100,\n", " treatment_effect_forest_params=tau_params,\n", diff --git a/demo/notebooks/multi_chain.ipynb b/demo/notebooks/multi_chain.ipynb index 85aebd8e..afe4c741 100644 --- a/demo/notebooks/multi_chain.ipynb +++ b/demo/notebooks/multi_chain.ipynb @@ -161,8 +161,8 @@ "outputs": [], "source": [ "y_hat_test = bart_model.predict(\n", - " covariates = X_test,\n", - " basis = leaf_basis_test, \n", + " X = X_test,\n", + " leaf_basis = leaf_basis_test, \n", " type = \"mean\", \n", " terms = \"y_hat\"\n", ")\n", @@ -321,8 +321,8 @@ "outputs": [], "source": [ "y_hat_test = bart_model.predict(\n", - " covariates = X_test,\n", - " basis = leaf_basis_test, \n", + " X = X_test,\n", + " leaf_basis = leaf_basis_test, \n", " type = \"mean\", \n", " terms = \"y_hat\"\n", ")\n", diff --git a/demo/notebooks/multivariate_treatment_causal_inference.ipynb b/demo/notebooks/multivariate_treatment_causal_inference.ipynb index 3e345aa4..88b528cd 100644 --- a/demo/notebooks/multivariate_treatment_causal_inference.ipynb +++ b/demo/notebooks/multivariate_treatment_causal_inference.ipynb @@ -110,10 +110,10 @@ " X_train=X_train,\n", " Z_train=Z_train,\n", " y_train=y_train,\n", - " pi_train=pi_train,\n", + " propensity_train=pi_train,\n", " X_test=X_test,\n", " Z_test=Z_test,\n", - " pi_test=pi_test,\n", + " propensity_test=pi_test,\n", " num_gfr=10,\n", " num_mcmc=100,\n", ")" diff --git a/man/compute_bart_posterior_interval.Rd b/man/compute_bart_posterior_interval.Rd index 8a802e45..be383a8d 100644 --- a/man/compute_bart_posterior_interval.Rd +++ b/man/compute_bart_posterior_interval.Rd @@ -9,8 +9,8 @@ compute_bart_posterior_interval( terms, level = 0.95, scale = "linear", - covariates = NULL, - basis = NULL, + X = NULL, + leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL ) @@ -24,9 +24,9 @@ compute_bart_posterior_interval( \item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} -\item{covariates}{A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions).} +\item{X}{A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions).} -\item{basis}{An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models.} +\item{leaf_basis}{An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models.} \item{rfx_group_ids}{An optional vector of group IDs for random effects. Required if the requested term includes random effects.} @@ -47,7 +47,7 @@ bart_model <- bart(y_train = y, X_train = X) intervals <- compute_bart_posterior_interval( model_object = bart_model, terms = c("mean_forest", "y_hat"), - covariates = X, + X = X, level = 0.90 ) } diff --git a/man/compute_bcf_posterior_interval.Rd b/man/compute_bcf_posterior_interval.Rd index 118c0256..00e12157 100644 --- a/man/compute_bcf_posterior_interval.Rd +++ b/man/compute_bcf_posterior_interval.Rd @@ -9,8 +9,8 @@ compute_bcf_posterior_interval( terms, level = 0.95, scale = "linear", - covariates = NULL, - treatment = NULL, + X = NULL, + Z = NULL, propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL @@ -25,9 +25,9 @@ compute_bcf_posterior_interval( \item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} -\item{covariates}{(Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).} +\item{X}{(Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).} -\item{treatment}{(Optional) A vector or matrix of treatment assignments. Required if the requested term is \code{"y_hat"} (overall predictions).} +\item{Z}{(Optional) A vector or matrix of treatment assignments. Required if the requested term is \code{"y_hat"} (overall predictions).} \item{propensity}{(Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.} @@ -55,8 +55,8 @@ bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, intervals <- compute_bcf_posterior_interval( model_object = bcf_model, terms = c("prognostic_function", "cate"), - covariates = X, - treatment = Z, + X = X, + Z = Z, propensity = pi_X, level = 0.90 ) diff --git a/man/compute_contrast_bart_model.Rd b/man/compute_contrast_bart_model.Rd index 0851c9b4..c09bf23a 100644 --- a/man/compute_contrast_bart_model.Rd +++ b/man/compute_contrast_bart_model.Rd @@ -6,8 +6,8 @@ \usage{ compute_contrast_bart_model( object, - covariates_0, - covariates_1, + X_0, + X_1, leaf_basis_0 = NULL, leaf_basis_1 = NULL, rfx_group_ids_0 = NULL, @@ -21,9 +21,9 @@ compute_contrast_bart_model( \arguments{ \item{object}{Object of type \code{bart} containing draws of a regression forest and associated sampling outputs.} -\item{covariates_0}{Covariates used for prediction in the "control" case. Must be a matrix or dataframe.} +\item{X_0}{Covariates used for prediction in the "control" case. Must be a matrix or dataframe.} -\item{covariates_1}{Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.} +\item{X_1}{Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.} \item{leaf_basis_0}{(Optional) Bases used for prediction in the "control" case (by e.g. dot product with leaf values). Default: \code{NULL}.} @@ -88,8 +88,8 @@ bart_model <- bart(X_train = X_train, leaf_basis_train = W_train, y_train = y_tr num_gfr = 10, num_burnin = 0, num_mcmc = 10) contrast_test <- compute_contrast_bart_model( bart_model, - covariates_0 = X_test, - covariates_1 = X_test, + X_0 = X_test, + X_1 = X_test, leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), type = "posterior", diff --git a/man/createBARTModelFromJsonString.Rd b/man/createBARTModelFromJsonString.Rd index 0748d97a..7a09d9c9 100644 --- a/man/createBARTModelFromJsonString.Rd +++ b/man/createBARTModelFromJsonString.Rd @@ -42,5 +42,5 @@ bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJsonString(bart_model) bart_model_roundtrip <- createBARTModelFromJsonString(bart_json) -y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat) +y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X=X_train)$y_hat) } diff --git a/man/predict.bartmodel.Rd b/man/predict.bartmodel.Rd index 0cb82678..c1bdfd09 100644 --- a/man/predict.bartmodel.Rd +++ b/man/predict.bartmodel.Rd @@ -6,7 +6,7 @@ \usage{ \method{predict}{bartmodel}( object, - covariates, + X, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, @@ -19,7 +19,7 @@ \arguments{ \item{object}{Object of type \code{bart} containing draws of a regression forest and associated sampling outputs.} -\item{covariates}{Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.} +\item{X}{Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.} \item{leaf_basis}{(Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: \code{NULL}.} @@ -66,5 +66,5 @@ y_test <- y[test_inds] y_train <- y[train_inds] bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) -y_hat_test <- predict(bart_model, X_test)$y_hat +y_hat_test <- predict(bart_model, X=X_test)$y_hat } diff --git a/man/sample_bart_posterior_predictive.Rd b/man/sample_bart_posterior_predictive.Rd index 5bce8442..5dffb782 100644 --- a/man/sample_bart_posterior_predictive.Rd +++ b/man/sample_bart_posterior_predictive.Rd @@ -6,8 +6,8 @@ \usage{ sample_bart_posterior_predictive( model_object, - covariates = NULL, - basis = NULL, + X = NULL, + leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, num_draws_per_sample = NULL @@ -16,9 +16,9 @@ sample_bart_posterior_predictive( \arguments{ \item{model_object}{A fitted BART model object of class \code{bartmodel}.} -\item{covariates}{A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).} +\item{X}{A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).} -\item{basis}{A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models.} +\item{leaf_basis}{A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models.} \item{rfx_group_ids}{A vector of group IDs for random effects model. Required if the BART model includes random effects.} @@ -39,6 +39,6 @@ X <- matrix(rnorm(n * p), nrow = n, ncol = p) y <- 2 * X[,1] + rnorm(n) bart_model <- bart(y_train = y, X_train = X) ppd_samples <- sample_bart_posterior_predictive( - model_object = bart_model, covariates = X + model_object = bart_model, X = X ) } diff --git a/man/sample_bcf_posterior_predictive.Rd b/man/sample_bcf_posterior_predictive.Rd index 0c77d7c1..b6cb191d 100644 --- a/man/sample_bcf_posterior_predictive.Rd +++ b/man/sample_bcf_posterior_predictive.Rd @@ -6,8 +6,8 @@ \usage{ sample_bcf_posterior_predictive( model_object, - covariates = NULL, - treatment = NULL, + X = NULL, + Z = NULL, propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL, @@ -17,9 +17,9 @@ sample_bcf_posterior_predictive( \arguments{ \item{model_object}{A fitted BCF model object of class \code{bcfmodel}.} -\item{covariates}{A matrix or data frame of covariates.} +\item{X}{A matrix or data frame of covariates.} -\item{treatment}{A vector or matrix of treatment assignments.} +\item{Z}{A vector or matrix of treatment assignments.} \item{propensity}{(Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.} @@ -44,7 +44,7 @@ Z <- rbinom(n, 1, pi_X) y <- 2 * X[,2] + 0.5 * X[,2] * Z + rnorm(n) bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, propensity_train = pi_X) ppd_samples <- sample_bcf_posterior_predictive( - model_object = bcf_model, covariates = X, - treatment = Z, propensity = pi_X + model_object = bcf_model, X = X, + Z = Z, propensity = pi_X ) } diff --git a/stochtree/bart.py b/stochtree/bart.py index 656bcba7..906fe899 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -82,12 +82,12 @@ def sample( num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, + previous_model_json: Optional[str] = None, + previous_model_warmstart_sample_num: Optional[int] = None, general_params: Optional[Dict[str, Any]] = None, mean_forest_params: Optional[Dict[str, Any]] = None, variance_forest_params: Optional[Dict[str, Any]] = None, random_effects_params: Optional[Dict[str, Any]] = None, - previous_model_json: Optional[str] = None, - previous_model_warmstart_sample_num: Optional[int] = None, ) -> None: """Runs a BART sampler on provided training set. Predictions will be cached for the training set and (if provided) the test set. Does not require a leaf regression basis. @@ -1743,8 +1743,8 @@ def sample( def predict( self, - covariates: Union[np.array, pd.DataFrame], - basis: np.array = None, + X: Union[np.array, pd.DataFrame], + leaf_basis: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, type: str = "posterior", @@ -1757,9 +1757,9 @@ def predict( Parameters ---------- - covariates : np.array + X : np.array Test set covariates. - basis : np.array, optional + leaf_basis : np.array, optional Optional test set basis vector, must be provided if the model was trained with a leaf regression basis. rfx_group_ids : np.array, optional Optional group labels used for an additive random effects model. @@ -1861,29 +1861,29 @@ def predict( raise NotSampledError(msg) # Data checks - if not isinstance(covariates, pd.DataFrame) and not isinstance( - covariates, np.ndarray + if not isinstance(X, pd.DataFrame) and not isinstance( + X, np.ndarray ): - raise ValueError("covariates must be a pandas dataframe or numpy array") - if basis is not None: - if not isinstance(basis, np.ndarray): - raise ValueError("basis must be a numpy array") - if basis.shape[0] != covariates.shape[0]: + raise ValueError("X must be a pandas dataframe or numpy array") + if leaf_basis is not None: + if not isinstance(leaf_basis, np.ndarray): + raise ValueError("leaf_basis must be a numpy array") + if leaf_basis.shape[0] != X.shape[0]: raise ValueError( - "covariates and basis must have the same number of rows" + "X and leaf_basis must have the same number of rows" ) # Convert everything to standard shape (2-dimensional) - if isinstance(covariates, np.ndarray): - if covariates.ndim == 1: - covariates = np.expand_dims(covariates, 1) - if basis is not None: - if basis.ndim == 1: - basis = np.expand_dims(basis, 1) + if isinstance(X, np.ndarray): + if X.ndim == 1: + X = np.expand_dims(X, 1) + if leaf_basis is not None: + if leaf_basis.ndim == 1: + leaf_basis = np.expand_dims(leaf_basis, 1) # Covariate preprocessing if not self._covariate_preprocessor._check_is_fitted(): - if not isinstance(covariates, np.ndarray): + if not isinstance(X, np.ndarray): raise ValueError( "Prediction cannot proceed on a pandas dataframe, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe." ) @@ -1893,20 +1893,20 @@ def predict( RuntimeWarning, ) if not np.issubdtype( - covariates.dtype, np.floating - ) and not np.issubdtype(covariates.dtype, np.integer): + X.dtype, np.floating + ) and not np.issubdtype(X.dtype, np.integer): raise ValueError( "Prediction cannot proceed on a non-numeric numpy array, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe." ) - covariates_processed = covariates + X_processed = X else: - covariates_processed = self._covariate_preprocessor.transform(covariates) + X_processed = self._covariate_preprocessor.transform(X) # Dataset construction pred_dataset = Dataset() - pred_dataset.add_covariates(covariates_processed) - if basis is not None: - pred_dataset.add_basis(basis) + pred_dataset.add_covariates(X_processed) + if leaf_basis is not None: + pred_dataset.add_basis(leaf_basis) # Variance forest predictions if predict_variance_forest: @@ -1946,7 +1946,7 @@ def predict( if rfx_basis is not None: if rfx_basis.ndim == 1: rfx_basis = np.expand_dims(rfx_basis, 1) - if rfx_basis.shape[0] != covariates.shape[0]: + if rfx_basis.shape[0] != X.shape[0]: raise ValueError("X and rfx_basis must have the same number of rows") if rfx_basis.shape[1] != self.num_rfx_basis: raise ValueError( @@ -1971,7 +1971,7 @@ def predict( rfx_beta_draws = rfx_samples_raw["beta_samples"] * self.y_std # Construct an array with the appropriate group random effects arranged for each observation - n_train = covariates.shape[0] + n_train = X.shape[0] if rfx_beta_draws.ndim != 2: raise ValueError( "BART models fit with random intercept models should only yield 2 dimensional random effect sample matrices" @@ -2046,10 +2046,10 @@ def predict( def compute_contrast( self, - covariates_0: Union[np.array, pd.DataFrame], - covariates_1: Union[np.array, pd.DataFrame], - basis_0: np.array = None, - basis_1: np.array = None, + X_0: Union[np.array, pd.DataFrame], + X_1: Union[np.array, pd.DataFrame], + leaf_basis_0: np.array = None, + leaf_basis_1: np.array = None, rfx_group_ids_0: np.array = None, rfx_group_ids_1: np.array = None, rfx_basis_0: np.array = None, @@ -2068,13 +2068,13 @@ def compute_contrast( Parameters ---------- - covariates_0 : np.array or pd.DataFrame + X_0 : np.array or pd.DataFrame Covariates used for prediction in the "control" case. Must be a numpy array or dataframe. - covariates_1 : np.array or pd.DataFrame + X_1 : np.array or pd.DataFrame Covariates used for prediction in the "treatment" case. Must be a numpy array or dataframe. - basis_0 : np.array, optional + leaf_basis_0 : np.array, optional Bases used for prediction in the "control" case (by e.g. dot product with leaf values). - basis_1 : np.array, optional + leaf_basis_1 : np.array, optional Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). rfx_group_ids_0 : np.array, optional Test set group labels used for prediction from an additive random effects model in the "control" case. @@ -2135,33 +2135,33 @@ def compute_contrast( raise NotSampledError(msg) # Data checks - if not isinstance(covariates_0, pd.DataFrame) and not isinstance( - covariates_0, np.ndarray + if not isinstance(X_0, pd.DataFrame) and not isinstance( + X_0, np.ndarray ): - raise ValueError("covariates_0 must be a pandas dataframe or numpy array") - if not isinstance(covariates_1, pd.DataFrame) and not isinstance( - covariates_1, np.ndarray + raise ValueError("X_0 must be a pandas dataframe or numpy array") + if not isinstance(X_1, pd.DataFrame) and not isinstance( + X_1, np.ndarray ): - raise ValueError("covariates_1 must be a pandas dataframe or numpy array") - if basis_0 is not None: - if not isinstance(basis_0, np.ndarray): - raise ValueError("basis_0 must be a numpy array") - if basis_0.shape[0] != covariates_0.shape[0]: + raise ValueError("X_1 must be a pandas dataframe or numpy array") + if leaf_basis_0 is not None: + if not isinstance(leaf_basis_0, np.ndarray): + raise ValueError("leaf_basis_0 must be a numpy array") + if leaf_basis_0.shape[0] != X_0.shape[0]: raise ValueError( - "covariates_0 and basis_0 must have the same number of rows" + "X_0 and leaf_basis_0 must have the same number of rows" ) - if basis_1 is not None: - if not isinstance(basis_1, np.ndarray): - raise ValueError("basis_1 must be a numpy array") - if basis_1.shape[0] != covariates_1.shape[0]: + if leaf_basis_1 is not None: + if not isinstance(leaf_basis_1, np.ndarray): + raise ValueError("leaf_basis_1 must be a numpy array") + if leaf_basis_1.shape[0] != X_1.shape[0]: raise ValueError( - "covariates_1 and basis_1 must have the same number of rows" + "X_1 and leaf_basis_1 must have the same number of rows" ) # Predict for the control arm control_preds = self.predict( - covariates=covariates_0, - basis=basis_0, + X=X_0, + leaf_basis=leaf_basis_0, rfx_group_ids=rfx_group_ids_0, rfx_basis=rfx_basis_0, type="posterior", @@ -2171,8 +2171,8 @@ def compute_contrast( # Predict for the treatment arm treatment_preds = self.predict( - covariates=covariates_1, - basis=basis_1, + X=X_1, + leaf_basis=leaf_basis_1, rfx_group_ids=rfx_group_ids_1, rfx_basis=rfx_basis_1, type="posterior", @@ -2194,10 +2194,10 @@ def compute_contrast( def compute_posterior_interval( self, terms: Union[list[str], str] = "all", - scale: str = "linear", level: float = 0.95, - covariates: np.array = None, - basis: np.array = None, + scale: str = "linear", + X: np.array = None, + leaf_basis: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, ) -> dict: @@ -2212,9 +2212,9 @@ def compute_posterior_interval( Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`. level : float, optional A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval. - covariates : np.array, optional + X : np.array, optional Optional array or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions). - basis : np.array, optional + leaf_basis : np.array, optional Optional array of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. rfx_group_ids : np.array, optional Optional vector of group IDs for random effects. Required if the requested term includes random effects. @@ -2266,25 +2266,25 @@ def compute_posterior_interval( or needs_covariates_intermediate ) if needs_covariates: - if covariates is None: + if X is None: raise ValueError( - "'covariates' must be provided in order to compute the requested intervals" + "'X' must be provided in order to compute the requested intervals" ) - if not isinstance(covariates, np.ndarray) and not isinstance( - covariates, pd.DataFrame + if not isinstance(X, np.ndarray) and not isinstance( + X, pd.DataFrame ): - raise ValueError("'covariates' must be a matrix or data frame") + raise ValueError("'X' must be a matrix or data frame") needs_basis = needs_covariates and self.has_basis if needs_basis: - if basis is None: + if leaf_basis is None: raise ValueError( - "'basis' must be provided in order to compute the requested intervals" + "'leaf_basis' must be provided in order to compute the requested intervals" ) - if not isinstance(basis, np.ndarray): - raise ValueError("'basis' must be a numpy array") - if basis.shape[0] != covariates.shape[0]: + if not isinstance(leaf_basis, np.ndarray): + raise ValueError("'leaf_basis' must be a numpy array") + if leaf_basis.shape[0] != X.shape[0]: raise ValueError( - "'basis' must have the same number of rows as 'covariates'" + "'leaf_basis' must have the same number of rows as 'X'" ) needs_rfx_data_intermediate = ( ("y_hat" in terms) or ("all" in terms) @@ -2297,9 +2297,9 @@ def compute_posterior_interval( ) if not isinstance(rfx_group_ids, np.ndarray): raise ValueError("'rfx_group_ids' must be a numpy array") - if rfx_group_ids.shape[0] != covariates.shape[0]: + if rfx_group_ids.shape[0] != X.shape[0]: raise ValueError( - "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) if rfx_basis is None: raise ValueError( @@ -2307,15 +2307,15 @@ def compute_posterior_interval( ) if not isinstance(rfx_basis, np.ndarray): raise ValueError("'rfx_basis' must be a numpy array") - if rfx_basis.shape[0] != covariates.shape[0]: + if rfx_basis.shape[0] != X.shape[0]: raise ValueError( - "'rfx_basis' must have the same number of rows as 'covariates'" + "'rfx_basis' must have the same number of rows as 'X'" ) # Compute posterior matrices for the requested model terms predictions = self.predict( - covariates=covariates, - basis=basis, + X=X, + leaf_basis=leaf_basis, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", @@ -2338,8 +2338,8 @@ def compute_posterior_interval( def sample_posterior_predictive( self, - covariates: np.array = None, - basis: np.array = None, + X: np.array = None, + leaf_basis: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, num_draws_per_sample: int = None, @@ -2349,9 +2349,9 @@ def sample_posterior_predictive( Parameters ---------- - covariates : np.array, optional + X : np.array, optional An array or data frame of covariates at which to compute the intervals. Required if the BART model depends on covariates (e.g., contains a mean or variance forest). - basis : np.array, optional + leaf_basis : np.array, optional An array of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. rfx_group_ids : np.array, optional An array of group IDs for random effects. Required if the BART model includes random effects. @@ -2375,25 +2375,25 @@ def sample_posterior_predictive( # Check that all the necessary inputs were provided for interval computation needs_covariates = self.include_mean_forest if needs_covariates: - if covariates is None: + if X is None: raise ValueError( - "'covariates' must be provided in order to compute the requested intervals" + "'X' must be provided in order to compute the requested intervals" ) - if not isinstance(covariates, np.ndarray) and not isinstance( - covariates, pd.DataFrame + if not isinstance(X, np.ndarray) and not isinstance( + X, pd.DataFrame ): - raise ValueError("'covariates' must be a matrix or data frame") + raise ValueError("'X' must be a matrix or data frame") needs_basis = needs_covariates and self.has_basis if needs_basis: - if basis is None: + if leaf_basis is None: raise ValueError( - "'basis' must be provided in order to compute the requested intervals" + "'leaf_basis' must be provided in order to compute the requested intervals" ) - if not isinstance(basis, np.ndarray): - raise ValueError("'basis' must be a numpy array") - if basis.shape[0] != covariates.shape[0]: + if not isinstance(leaf_basis, np.ndarray): + raise ValueError("'leaf_basis' must be a numpy array") + if leaf_basis.shape[0] != X.shape[0]: raise ValueError( - "'basis' must have the same number of rows as 'covariates'" + "'leaf_basis' must have the same number of rows as 'X'" ) needs_rfx_data = self.has_rfx if needs_rfx_data: @@ -2403,9 +2403,9 @@ def sample_posterior_predictive( ) if not isinstance(rfx_group_ids, np.ndarray): raise ValueError("'rfx_group_ids' must be a numpy array") - if rfx_group_ids.shape[0] != covariates.shape[0]: + if rfx_group_ids.shape[0] != X.shape[0]: raise ValueError( - "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) if rfx_basis is None: raise ValueError( @@ -2413,15 +2413,15 @@ def sample_posterior_predictive( ) if not isinstance(rfx_basis, np.ndarray): raise ValueError("'rfx_basis' must be a numpy array") - if rfx_basis.shape[0] != covariates.shape[0]: + if rfx_basis.shape[0] != X.shape[0]: raise ValueError( - "'rfx_basis' must have the same number of rows as 'covariates'" + "'rfx_basis' must have the same number of rows as 'X'" ) # Compute posterior predictive samples bart_preds = self.predict( - covariates=covariates, - basis=basis, + X=X, + leaf_basis=leaf_basis, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", @@ -2433,7 +2433,7 @@ def sample_posterior_predictive( has_variance_forest = self.include_variance_forest samples_global_variance = self.sample_sigma2_global num_posterior_draws = self.num_samples - num_observations = covariates.shape[0] + num_observations = X.shape[0] if has_mean_term: ppd_mean = bart_preds["y_hat"] else: diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 7f017c29..983791ab 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -84,24 +84,24 @@ def sample( X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_train: np.array, - pi_train: np.array = None, + propensity_train: np.array = None, rfx_group_ids_train: np.array = None, rfx_basis_train: np.array = None, X_test: Union[pd.DataFrame, np.array] = None, Z_test: np.array = None, - pi_test: np.array = None, + propensity_test: np.array = None, rfx_group_ids_test: np.array = None, rfx_basis_test: np.array = None, num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, + previous_model_json: Optional[str] = None, + previous_model_warmstart_sample_num: Optional[int] = None, general_params: Optional[Dict[str, Any]] = None, prognostic_forest_params: Optional[Dict[str, Any]] = None, treatment_effect_forest_params: Optional[Dict[str, Any]] = None, variance_forest_params: Optional[Dict[str, Any]] = None, random_effects_params: Optional[Dict[str, Any]] = None, - previous_model_json: Optional[str] = None, - previous_model_warmstart_sample_num: Optional[int] = None, ) -> None: """Runs a BCF sampler on provided training set. Outcome predictions and estimates of the prognostic and treatment effect functions will be cached for the training set and (if provided) the test set. @@ -114,7 +114,7 @@ def sample( Array of (continuous or binary; univariate or multivariate) treatment assignments. y_train : np.array Outcome to be modeled by the ensemble. - pi_train : np.array + propensity_train : np.array Optional vector of propensity scores. If not provided, this will be estimated from the data. rfx_group_ids_train : np.array, optional Optional group labels used for an additive random effects model. @@ -125,7 +125,7 @@ def sample( Z_test : np.array, optional Optional test set of (continuous or binary) treatment assignments. Must be provided if `X_test` is provided. - pi_test : np.array, optional + propensity_test : np.array, optional Optional test set vector of propensity scores. If not provided (but `X_test` and `Z_test` are), this will be estimated from the data. rfx_group_ids_test : np.array, optional Optional test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), @@ -541,9 +541,9 @@ def sample( raise ValueError("X_train must be a pandas dataframe or numpy array") if not isinstance(Z_train, np.ndarray): raise ValueError("Z_train must be a numpy array") - if pi_train is not None: - if not isinstance(pi_train, np.ndarray): - raise ValueError("pi_train must be a numpy array") + if propensity_train is not None: + if not isinstance(propensity_train, np.ndarray): + raise ValueError("propensity_train must be a numpy array") if not isinstance(y_train, np.ndarray): raise ValueError("y_train must be a numpy array") if X_test is not None: @@ -554,9 +554,9 @@ def sample( if Z_test is not None: if not isinstance(Z_test, np.ndarray): raise ValueError("Z_test must be a numpy array") - if pi_test is not None: - if not isinstance(pi_test, np.ndarray): - raise ValueError("pi_test must be a numpy array") + if propensity_test is not None: + if not isinstance(propensity_test, np.ndarray): + raise ValueError("propensity_test must be a numpy array") if rfx_group_ids_train is not None: if not isinstance(rfx_group_ids_train, np.ndarray): raise ValueError("rfx_group_ids_train must be a numpy array") @@ -585,9 +585,9 @@ def sample( if Z_train is not None: if Z_train.ndim == 1: Z_train = np.expand_dims(Z_train, 1) - if pi_train is not None: - if pi_train.ndim == 1: - pi_train = np.expand_dims(pi_train, 1) + if propensity_train is not None: + if propensity_train.ndim == 1: + propensity_train = np.expand_dims(propensity_train, 1) if y_train.ndim == 1: y_train = np.expand_dims(y_train, 1) if X_test is not None: @@ -597,9 +597,9 @@ def sample( if Z_test is not None: if Z_test.ndim == 1: Z_test = np.expand_dims(Z_test, 1) - if pi_test is not None: - if pi_test.ndim == 1: - pi_test = np.expand_dims(pi_test, 1) + if propensity_test is not None: + if propensity_test.ndim == 1: + propensity_test = np.expand_dims(propensity_test, 1) if rfx_group_ids_train is not None: if rfx_group_ids_train.ndim != 1: rfx_group_ids_train = np.squeeze(rfx_group_ids_train) @@ -631,17 +631,17 @@ def sample( raise ValueError("X_train and Z_train must have the same number of rows") if y_train.shape[0] != X_train.shape[0]: raise ValueError("X_train and y_train must have the same number of rows") - if pi_train is not None: - if pi_train.shape[0] != X_train.shape[0]: + if propensity_train is not None: + if propensity_train.shape[0] != X_train.shape[0]: raise ValueError( - "X_train and pi_train must have the same number of rows" + "X_train and propensity_train must have the same number of rows" ) if X_test is not None and Z_test is not None: if X_test.shape[0] != Z_test.shape[0]: raise ValueError("X_test and Z_test must have the same number of rows") - if X_test is not None and pi_test is not None: - if X_test.shape[0] != pi_test.shape[0]: - raise ValueError("X_test and pi_test must have the same number of rows") + if X_test is not None and propensity_test is not None: + if X_test.shape[0] != propensity_test.shape[0]: + raise ValueError("X_test and propensity_test must have the same number of rows") # Raise a warning if the data have ties and only GFR is being run if (num_gfr > 0) and (num_burnin == 0) and (num_mcmc == 0): @@ -1311,10 +1311,10 @@ def sample( sample_sigma2_leaf_tau = False # Check if user has provided propensities that are needed in the model - if pi_train is None and propensity_covariate != "none": + if propensity_train is None and propensity_covariate != "none": if self.multivariate_treatment: raise ValueError( - "Propensities must be provided (via pi_train and / or pi_test parameters) or omitted by setting propensity_covariate = 'none' for multivariate treatments" + "Propensities must be provided (via propensity_train and / or propensity_test parameters) or omitted by setting propensity_covariate = 'none' for multivariate treatments" ) else: self.bart_propensity_model = BARTModel() @@ -1330,10 +1330,10 @@ def sample( num_burnin=num_burnin_propensity, num_mcmc=num_mcmc_propensity, ) - pi_train = np.mean( + propensity_train = np.mean( self.bart_propensity_model.y_hat_train, axis=1, keepdims=True ) - pi_test = np.mean( + propensity_test = np.mean( self.bart_propensity_model.y_hat_test, axis=1, keepdims=True ) else: @@ -1344,7 +1344,7 @@ def sample( num_burnin=num_burnin_propensity, num_mcmc=num_mcmc_propensity, ) - pi_train = np.mean( + propensity_train = np.mean( self.bart_propensity_model.y_hat_train, axis=1, keepdims=True ) self.internal_propensity_model = True @@ -1674,34 +1674,34 @@ def sample( ) if propensity_covariate != "none": feature_types = np.append( - feature_types, np.repeat(0, pi_train.shape[1]) + feature_types, np.repeat(0, propensity_train.shape[1]) ).astype("int") - X_train_processed = np.c_[X_train_processed, pi_train] + X_train_processed = np.c_[X_train_processed, propensity_train] if self.has_test: - X_test_processed = np.c_[X_test_processed, pi_test] + X_test_processed = np.c_[X_test_processed, propensity_test] if propensity_covariate == "prognostic": variable_weights_mu = np.append( - variable_weights_mu, np.repeat(1 / num_cov_orig, pi_train.shape[1]) + variable_weights_mu, np.repeat(1 / num_cov_orig, propensity_train.shape[1]) ) variable_weights_tau = np.append( - variable_weights_tau, np.repeat(0.0, pi_train.shape[1]) + variable_weights_tau, np.repeat(0.0, propensity_train.shape[1]) ) elif propensity_covariate == "treatment_effect": variable_weights_mu = np.append( - variable_weights_mu, np.repeat(0.0, pi_train.shape[1]) + variable_weights_mu, np.repeat(0.0, propensity_train.shape[1]) ) variable_weights_tau = np.append( - variable_weights_tau, np.repeat(1 / num_cov_orig, pi_train.shape[1]) + variable_weights_tau, np.repeat(1 / num_cov_orig, propensity_train.shape[1]) ) elif propensity_covariate == "both": variable_weights_mu = np.append( - variable_weights_mu, np.repeat(1 / num_cov_orig, pi_train.shape[1]) + variable_weights_mu, np.repeat(1 / num_cov_orig, propensity_train.shape[1]) ) variable_weights_tau = np.append( - variable_weights_tau, np.repeat(1 / num_cov_orig, pi_train.shape[1]) + variable_weights_tau, np.repeat(1 / num_cov_orig, propensity_train.shape[1]) ) variable_weights_variance = np.append( - variable_weights_variance, np.repeat(0.0, pi_train.shape[1]) + variable_weights_variance, np.repeat(0.0, propensity_train.shape[1]) ) # Renormalize variable weights @@ -3261,10 +3261,10 @@ def compute_contrast( def compute_posterior_interval( self, terms: Union[list[str], str] = "all", - scale: str = "linear", level: float = 0.95, - covariates: np.array = None, - treatment: np.array = None, + scale: str = "linear", + X: np.array = None, + Z: np.array = None, propensity: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, @@ -3280,9 +3280,9 @@ def compute_posterior_interval( Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`. level : float, optional A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval. - covariates : np.array, optional + X : np.array, optional Optional array or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, treatment effect forest, variance forest, or overall predictions). - treatment : np.array, optional + Z : np.array, optional Optional array of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions). propensity : np.array, optional Optional array of propensity scores. Required if the underlying model depends on user-provided propensities. @@ -3346,25 +3346,25 @@ def compute_posterior_interval( or needs_covariates_intermediate ) if needs_covariates: - if covariates is None: + if X is None: raise ValueError( - "'covariates' must be provided in order to compute the requested intervals" + "'X' must be provided in order to compute the requested intervals" ) - if not isinstance(covariates, np.ndarray) and not isinstance( - covariates, pd.DataFrame + if not isinstance(X, np.ndarray) and not isinstance( + X, pd.DataFrame ): - raise ValueError("'covariates' must be a matrix or data frame") + raise ValueError("'X' must be a matrix or data frame") needs_treatment = needs_covariates if needs_treatment: - if treatment is None: + if Z is None: raise ValueError( - "'treatment' must be provided in order to compute the requested intervals" + "'Z' must be provided in order to compute the requested intervals" ) - if not isinstance(treatment, np.ndarray): - raise ValueError("'treatment' must be a numpy array") - if treatment.shape[0] != covariates.shape[0]: + if not isinstance(Z, np.ndarray): + raise ValueError("'Z' must be a numpy array") + if Z.shape[0] != X.shape[0]: raise ValueError( - "'treatment' must have the same number of rows as 'covariates'" + "'Z' must have the same number of rows as 'X'" ) uses_propensity = self.propensity_covariate != "none" internal_propensity_model = self.internal_propensity_model @@ -3378,9 +3378,9 @@ def compute_posterior_interval( ) if not isinstance(propensity, np.ndarray): raise ValueError("'propensity' must be a numpy array") - if propensity.shape[0] != covariates.shape[0]: + if propensity.shape[0] != X.shape[0]: raise ValueError( - "'propensity' must have the same number of rows as 'covariates'" + "'propensity' must have the same number of rows as 'X'" ) needs_rfx_data_intermediate = ( ("y_hat" in terms) or ("all" in terms) @@ -3393,9 +3393,9 @@ def compute_posterior_interval( ) if not isinstance(rfx_group_ids, np.ndarray): raise ValueError("'rfx_group_ids' must be a numpy array") - if rfx_group_ids.shape[0] != covariates.shape[0]: + if rfx_group_ids.shape[0] != X.shape[0]: raise ValueError( - "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) if self.rfx_model_spec == "custom": if rfx_basis is None: @@ -3405,15 +3405,15 @@ def compute_posterior_interval( if rfx_basis is not None: if not isinstance(rfx_basis, np.ndarray): raise ValueError("'rfx_basis' must be a numpy array") - if rfx_basis.shape[0] != covariates.shape[0]: + if rfx_basis.shape[0] != X.shape[0]: raise ValueError( - "'rfx_basis' must have the same number of rows as 'covariates'" + "'rfx_basis' must have the same number of rows as 'X'" ) # Compute posterior matrices for the requested model terms predictions = self.predict( - X=covariates, - Z=treatment, + X=X, + Z=Z, propensity=propensity, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, @@ -3437,8 +3437,8 @@ def compute_posterior_interval( def sample_posterior_predictive( self, - covariates: np.array, - treatment: np.array, + X: np.array, + Z: np.array, propensity: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, @@ -3449,9 +3449,9 @@ def sample_posterior_predictive( Parameters ---------- - covariates : np.array + X : np.array An array or data frame of covariates. - treatment : np.array + Z : np.array An array of treatment assignments. propensity : np.array, optional Optional array of propensity scores. Required if the underlying model depends on user-provided propensities. @@ -3477,25 +3477,25 @@ def sample_posterior_predictive( # Check that all the necessary inputs were provided for interval computation needs_covariates = True if needs_covariates: - if covariates is None: + if X is None: raise ValueError( - "'covariates' must be provided in order to compute the requested intervals" + "'X' must be provided in order to compute the requested intervals" ) - if not isinstance(covariates, np.ndarray) and not isinstance( - covariates, pd.DataFrame + if not isinstance(X, np.ndarray) and not isinstance( + X, pd.DataFrame ): - raise ValueError("'covariates' must be a matrix or data frame") + raise ValueError("'X' must be a matrix or data frame") needs_treatment = needs_covariates if needs_treatment: - if treatment is None: + if Z is None: raise ValueError( - "'treatment' must be provided in order to compute the requested intervals" + "'Z' must be provided in order to compute the requested intervals" ) - if not isinstance(treatment, np.ndarray): - raise ValueError("'treatment' must be a numpy array") - if treatment.shape[0] != covariates.shape[0]: + if not isinstance(Z, np.ndarray): + raise ValueError("'Z' must be a numpy array") + if Z.shape[0] != X.shape[0]: raise ValueError( - "'treatment' must have the same number of rows as 'covariates'" + "'Z' must have the same number of rows as 'X'" ) uses_propensity = self.propensity_covariate != "none" internal_propensity_model = self.internal_propensity_model @@ -3509,9 +3509,9 @@ def sample_posterior_predictive( ) if not isinstance(propensity, np.ndarray): raise ValueError("'propensity' must be a numpy array") - if propensity.shape[0] != covariates.shape[0]: + if propensity.shape[0] != X.shape[0]: raise ValueError( - "'propensity' must have the same number of rows as 'covariates'" + "'propensity' must have the same number of rows as 'X'" ) needs_rfx_data = self.has_rfx if needs_rfx_data: @@ -3521,9 +3521,9 @@ def sample_posterior_predictive( ) if not isinstance(rfx_group_ids, np.ndarray): raise ValueError("'rfx_group_ids' must be a numpy array") - if rfx_group_ids.shape[0] != covariates.shape[0]: + if rfx_group_ids.shape[0] != X.shape[0]: raise ValueError( - "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) if rfx_basis is None: raise ValueError( @@ -3531,15 +3531,15 @@ def sample_posterior_predictive( ) if not isinstance(rfx_basis, np.ndarray): raise ValueError("'rfx_basis' must be a numpy array") - if rfx_basis.shape[0] != covariates.shape[0]: + if rfx_basis.shape[0] != X.shape[0]: raise ValueError( - "'rfx_basis' must have the same number of rows as 'covariates'" + "'rfx_basis' must have the same number of rows as 'X'" ) # Compute posterior predictive samples bcf_preds = self.predict( - X=covariates, - Z=treatment, + X=X, + Z=Z, propensity=propensity, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, @@ -3552,7 +3552,7 @@ def sample_posterior_predictive( has_variance_forest = self.include_variance_forest samples_global_variance = self.sample_sigma2_global num_posterior_draws = self.num_samples - num_observations = covariates.shape[0] + num_observations = X.shape[0] ppd_mean = bcf_preds["y_hat"] if has_variance_forest: ppd_variance = bcf_preds["variance_forest_predictions"] diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R index 23013ec2..b42099ea 100644 --- a/test/R/testthat/test-bart.R +++ b/test/R/testthat/test-bart.R @@ -312,6 +312,19 @@ test_that("Warmstart BART", { # Run a new BART chain from the existing (X)BART model general_param_list <- list(num_chains = 3, keep_every = 5) expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + previous_model_json = bart_model_json_string, + previous_model_warmstart_sample_num = 10, + general_params = general_param_list + ) + ) + expect_warning( bart_model <- bart( X_train = X_train, y_train = y_train, @@ -376,6 +389,23 @@ test_that("Warmstart BART", { # Run a new BART chain from the existing (X)BART model general_param_list <- list(num_chains = 4, keep_every = 5) expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + previous_model_json = bart_model_json_string, + previous_model_warmstart_sample_num = 10, + general_params = general_param_list + ) + ) + expect_warning( bart_model <- bart( X_train = X_train, y_train = y_train, @@ -433,7 +463,7 @@ test_that("BART Predictions", { ) # Check that cached predictions agree with results of predict() function - train_preds <- predict(bart_model, covariates = X_train) + train_preds <- predict(bart_model, X = X_train) train_preds_mean_cached <- bart_model$y_hat_train train_preds_mean_recomputed <- train_preds$mean_forest_predictions train_preds_variance_cached <- bart_model$sigma2_x_hat_train @@ -584,7 +614,7 @@ test_that("Random Effects BART", { ) preds <- predict( bart_model, - covariates = X_test, + X = X_test, leaf_basis = W_test, rfx_group_ids = rfx_group_ids_test, type = "posterior", diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R index 221c333f..ab1cb7b4 100644 --- a/test/R/testthat/test-bcf.R +++ b/test/R/testthat/test-bcf.R @@ -375,6 +375,23 @@ test_that("Warmstart BCF", { # Run a new BCF chain from the existing (X)BCF model general_param_list <- list(num_chains = 3, keep_every = 5) expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + previous_model_json = bcf_model_json_string, + previous_model_warmstart_sample_num = 10, + general_params = general_param_list + ) + ) + expect_warning( bcf_model <- bcf( X_train = X_train, y_train = y_train, @@ -482,7 +499,7 @@ test_that("Warmstart BCF", { num_burnin = 10, num_mcmc = 10, previous_model_json = bcf_model_json_string, - previous_model_warmstart_sample_num = 1, + previous_model_warmstart_sample_num = 10, general_params = general_param_list ) ) diff --git a/test/R/testthat/test-predict.R b/test/R/testthat/test-predict.R index bdd9d66b..63ff0f94 100644 --- a/test/R/testthat/test-predict.R +++ b/test/R/testthat/test-predict.R @@ -216,12 +216,12 @@ test_that("BART predictions with pre-summarization", { ) # Check that the default predict method returns a list - pred <- predict(bart_model, X_test) + pred <- predict(bart_model, X = X_test) y_hat_posterior_test <- pred$y_hat expect_equal(dim(y_hat_posterior_test), c(20, 10)) # Check that the pre-aggregated predictions match with those computed by rowMeans - pred_mean <- predict(bart_model, X_test, type = "mean") + pred_mean <- predict(bart_model, X = X_test, type = "mean") y_hat_mean_test <- pred_mean$y_hat expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) @@ -229,7 +229,7 @@ test_that("BART predictions with pre-summarization", { expect_warning({ pred_mean <- predict( bart_model, - X_test, + X = X_test, type = "mean", terms = c("rfx", "variance_forest") ) @@ -248,7 +248,7 @@ test_that("BART predictions with pre-summarization", { ) # Check that the default predict method returns a list - pred <- predict(het_bart_model, X_test) + pred <- predict(het_bart_model, X = X_test) y_hat_posterior_test <- pred$y_hat sigma2_hat_posterior_test <- pred$variance_forest_predictions @@ -257,7 +257,7 @@ test_that("BART predictions with pre-summarization", { expect_equal(dim(sigma2_hat_posterior_test), c(20, 10)) # Check that the pre-aggregated predictions match with those computed by rowMeans - pred_mean <- predict(het_bart_model, X_test, type = "mean") + pred_mean <- predict(het_bart_model, X = X_test, type = "mean") y_hat_mean_test <- pred_mean$y_hat sigma2_hat_mean_test <- pred_mean$variance_forest_predictions @@ -269,13 +269,13 @@ test_that("BART predictions with pre-summarization", { # match those computed by pre-aggregated predictions returned in a list y_hat_mean_test_single_term <- predict( het_bart_model, - X_test, + X = X_test, type = "mean", terms = "y_hat" ) sigma2_hat_mean_test_single_term <- predict( het_bart_model, - X_test, + X = X_test, type = "mean", terms = "variance_forest" ) diff --git a/test/R/testthat/test-serialization.R b/test/R/testthat/test-serialization.R index fe50af5f..2f0d4aaa 100644 --- a/test/R/testthat/test-serialization.R +++ b/test/R/testthat/test-serialization.R @@ -34,7 +34,7 @@ test_that("BART Serialization", { num_mcmc = 10, general_params = general_param_list ) - y_hat_orig <- rowMeans(predict(bart_model, X_test)$y_hat) + y_hat_orig <- rowMeans(predict(bart_model, X = X_test)$y_hat) # Save to JSON bart_json_string <- saveBARTModelToJsonString(bart_model) @@ -43,7 +43,7 @@ test_that("BART Serialization", { bart_model_roundtrip <- createBARTModelFromJsonString(bart_json_string) # Predict from the roundtrip BART model - y_hat_reloaded <- rowMeans(predict(bart_model_roundtrip, X_test)$y_hat) + y_hat_reloaded <- rowMeans(predict(bart_model_roundtrip, X = X_test)$y_hat) # Assertion expect_equal(y_hat_orig, y_hat_reloaded) diff --git a/test/python/test_bart.py b/test/python/test_bart.py index 3243b86a..b182524b 100644 --- a/test/python/test_bart.py +++ b/test/python/test_bart.py @@ -1,5 +1,4 @@ import numpy as np -import pytest from sklearn.model_selection import train_test_split from stochtree import BARTModel @@ -83,7 +82,7 @@ def outcome_mean(X): bart_model_3.from_json_string_list(bart_models_json) # Assertions - bart_preds_combined = bart_model_3.predict(covariates=X_train) + bart_preds_combined = bart_model_3.predict(X=X_train) y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( @@ -190,7 +189,7 @@ def outcome_mean(X, W): # Assertions bart_preds_combined = bart_model_3.predict( - covariates=X_train, basis=basis_train + X=X_train, leaf_basis=basis_train ) y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) @@ -298,7 +297,7 @@ def outcome_mean(X, W): # Assertions bart_preds_combined = bart_model_3.predict( - covariates=X_train, basis=basis_train + X=X_train, leaf_basis=basis_train ) y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) @@ -410,7 +409,7 @@ def conditional_stddev(X): bart_model_3.from_json_string_list(bart_models_json) # Assertions - bart_preds_combined = bart_model_3.predict(covariates=X_train) + bart_preds_combined = bart_model_3.predict(X=X_train) y_hat_train_combined, sigma2_x_train_combined = ( bart_preds_combined["y_hat"], bart_preds_combined["variance_forest_predictions"], @@ -545,7 +544,7 @@ def conditional_stddev(X): # Assertions bart_preds_combined = bart_model_3.predict( - covariates=X_train, basis=basis_train + X=X_train, leaf_basis=basis_train ) y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) @@ -670,7 +669,7 @@ def conditional_stddev(X): # Assertions bart_preds_combined = bart_model_3.predict( - covariates=X_train, basis=basis_train + X=X_train, leaf_basis=basis_train ) y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) @@ -825,7 +824,7 @@ def rfx_term(group_labels, basis): # Assertions bart_preds_combined = bart_model_3.predict( - covariates=X_train, + X=X_train, rfx_group_ids=group_labels_train, rfx_basis=rfx_basis_train, ) @@ -998,8 +997,8 @@ def conditional_stddev(X): # Assertions bart_preds_combined = bart_model_3.predict( - covariates=X_train, - basis=basis_train, + X=X_train, + leaf_basis=basis_train, rfx_group_ids=group_labels_train, rfx_basis=rfx_basis_train, ) @@ -1196,8 +1195,8 @@ def conditional_stddev(X): random_effects_params=rfx_params, ) preds = bart_model_4.predict( - covariates=X_test, - basis=basis_test, + X=X_test, + leaf_basis=basis_test, rfx_group_ids=group_labels_test, type="posterior", terms="rfx", diff --git a/test/python/test_bcf.py b/test/python/test_bcf.py index dac1ea25..eca2a5ff 100644 --- a/test/python/test_bcf.py +++ b/test/python/test_bcf.py @@ -51,10 +51,10 @@ def test_binary_bcf(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -93,7 +93,7 @@ def test_binary_bcf(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -239,10 +239,10 @@ def test_continuous_univariate_bcf(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -281,10 +281,10 @@ def test_continuous_univariate_bcf(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -352,7 +352,7 @@ def test_continuous_univariate_bcf(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -560,10 +560,10 @@ def test_multivariate_bcf(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -602,7 +602,7 @@ def test_multivariate_bcf(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -706,10 +706,10 @@ def test_binary_bcf_heteroskedastic(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -752,7 +752,7 @@ def test_binary_bcf_heteroskedastic(self): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, @@ -918,10 +918,10 @@ def rfx_term(group_labels, basis): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, rfx_group_ids_train=group_labels_train, rfx_basis_train=rfx_basis_train, rfx_group_ids_test=group_labels_test, @@ -946,10 +946,10 @@ def rfx_term(group_labels, basis): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, rfx_group_ids_train=group_labels_train, rfx_basis_train=rfx_basis_train, rfx_group_ids_test=group_labels_test, @@ -974,10 +974,10 @@ def rfx_term(group_labels, basis): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_test, + propensity_test=pi_test, rfx_group_ids_train=group_labels_train, rfx_basis_train=rfx_basis_train, rfx_group_ids_test=group_labels_test, diff --git a/test/python/test_json.py b/test/python/test_json.py index 48d4845b..b6f9b36f 100644 --- a/test/python/test_json.py +++ b/test/python/test_json.py @@ -454,7 +454,7 @@ def test_bcf_string(self): # Run BCF bcf_orig = BCFModel() bcf_orig.sample( - X_train=X, Z_train=Z, y_train=y, pi_train=pi_X, num_gfr=10, num_mcmc=10 + X_train=X, Z_train=Z, y_train=y, propensity_train=pi_X, num_gfr=10, num_mcmc=10 ) # Extract predictions from the sampler @@ -529,7 +529,7 @@ def rfx_mean(group_labels, basis): X_train=X, Z_train=Z, y_train=y, - pi_train=pi_X, + propensity_train=pi_X, rfx_group_ids_train=group_labels, rfx_basis_train=basis, num_gfr=10, diff --git a/test/python/test_predict.py b/test/python/test_predict.py index 03f36cb2..117cac04 100644 --- a/test/python/test_predict.py +++ b/test/python/test_predict.py @@ -221,12 +221,12 @@ def test_bart_prediction(self): ) # Check that the default predict method returns a dictionary - pred = bart_model.predict(covariates=X_test) + pred = bart_model.predict(X=X_test) y_hat_posterior_test = pred["y_hat"] assert y_hat_posterior_test.shape == (20, 10) # Check that the pre-aggregated predictions match with those computed by np.mean - pred_mean = bart_model.predict(covariates=X_test, type="mean") + pred_mean = bart_model.predict(X=X_test, type="mean") y_hat_mean_test = pred_mean["y_hat"] np.testing.assert_almost_equal( y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1) @@ -245,14 +245,14 @@ def test_bart_prediction(self): ) # Check that the default predict method returns a dictionary - pred = het_bart_model.predict(covariates=X_test) + pred = het_bart_model.predict(X=X_test) y_hat_posterior_test = pred["y_hat"] sigma2_hat_posterior_test = pred["variance_forest_predictions"] assert y_hat_posterior_test.shape == (20, 10) assert sigma2_hat_posterior_test.shape == (20, 10) # Check that the pre-aggregated predictions match with those computed by np.mean - pred_mean = het_bart_model.predict(covariates=X_test, type="mean") + pred_mean = het_bart_model.predict(X=X_test, type="mean") y_hat_mean_test = pred_mean["y_hat"] sigma2_hat_mean_test = pred_mean["variance_forest_predictions"] np.testing.assert_almost_equal( @@ -265,10 +265,10 @@ def test_bart_prediction(self): # Check that the "single-term" pre-aggregated predictions # match those computed by pre-aggregated predictions returned in a dictionary y_hat_mean_test_single_term = het_bart_model.predict( - covariates=X_test, type="mean", terms="y_hat" + X=X_test, type="mean", terms="y_hat" ) sigma2_hat_mean_test_single_term = het_bart_model.predict( - covariates=X_test, type="mean", terms="variance_forest" + X=X_test, type="mean", terms="variance_forest" ) np.testing.assert_almost_equal(y_hat_mean_test, y_hat_mean_test_single_term) np.testing.assert_almost_equal( @@ -279,7 +279,6 @@ def test_bcf_prediction(self): # Generate data and test/train split rng = np.random.default_rng(1234) n = 100 - g = lambda x: np.where(x[:, 4] == 1, 2, np.where(x[:, 4] == 2, -1, -4)) x1 = rng.normal(size=n) x2 = rng.normal(size=n) x3 = rng.normal(size=n) @@ -332,10 +331,10 @@ def g(x5): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_x_train, + propensity_train=pi_x_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_x_test, + propensity_test=pi_x_test, num_gfr=10, num_burnin=0, num_mcmc=10, @@ -372,10 +371,10 @@ def g(x5): X_train=X_train, Z_train=Z_train, y_train=y_train, - pi_train=pi_x_train, + propensity_train=pi_x_train, X_test=X_test, Z_test=Z_test, - pi_test=pi_x_test, + propensity_test=pi_x_test, num_gfr=10, num_burnin=0, num_mcmc=10, diff --git a/tools/debug/acic_bcf_surrogate_debug.R b/tools/debug/acic_bcf_surrogate_debug.R index 1678a8cb..32fece42 100644 --- a/tools/debug/acic_bcf_surrogate_debug.R +++ b/tools/debug/acic_bcf_surrogate_debug.R @@ -66,7 +66,7 @@ propensity_model <- stochtree::bart( ) propensity <- predict( propensity_model, - covariates = covariate_df, + X = covariate_df, type = "mean", terms = "y_hat" ) diff --git a/tools/debug/bart_contrast_debug.R b/tools/debug/bart_contrast_debug.R index 647d12b0..9b46fffe 100644 --- a/tools/debug/bart_contrast_debug.R +++ b/tools/debug/bart_contrast_debug.R @@ -45,8 +45,8 @@ bart_model <- bart( # Compute contrast posterior contrast_posterior_test <- compute_contrast_bart_model( bart_model, - covariates_0 = X_test, - covariates_1 = X_test, + X_0 = X_test, + X_1 = X_test, leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), type = "posterior", @@ -56,7 +56,7 @@ contrast_posterior_test <- compute_contrast_bart_model( # Compute the same quantity via two predict calls y_hat_posterior_test_0 <- predict( bart_model, - covariates = X_test, + X = X_test, leaf_basis = matrix(0, nrow = n_test, ncol = 1), type = "posterior", term = "y_hat", @@ -64,7 +64,7 @@ y_hat_posterior_test_0 <- predict( ) y_hat_posterior_test_1 <- predict( bart_model, - covariates = X_test, + X = X_test, leaf_basis = matrix(1, nrow = n_test, ncol = 1), type = "posterior", term = "y_hat", @@ -128,8 +128,8 @@ bart_model <- bart( # Compute contrast posterior contrast_posterior_test <- compute_contrast_bart_model( bart_model, - covariates_0 = X_test, - covariates_1 = X_test, + X_0 = X_test, + X_1 = X_test, leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), rfx_group_ids_0 = group_ids_test, @@ -143,7 +143,7 @@ contrast_posterior_test <- compute_contrast_bart_model( # Compute the same quantity via two predict calls y_hat_posterior_test_0 <- predict( bart_model, - covariates = X_test, + X = X_test, leaf_basis = matrix(0, nrow = n_test, ncol = 1), rfx_group_ids = group_ids_test, rfx_basis = rfx_basis_test, @@ -153,7 +153,7 @@ y_hat_posterior_test_0 <- predict( ) y_hat_posterior_test_1 <- predict( bart_model, - covariates = X_test, + X = X_test, leaf_basis = matrix(1, nrow = n_test, ncol = 1), rfx_group_ids = group_ids_test, rfx_basis = rfx_basis_test, diff --git a/tools/debug/bart_predict_debug.R b/tools/debug/bart_predict_debug.R index 89766a74..4e99f51d 100644 --- a/tools/debug/bart_predict_debug.R +++ b/tools/debug/bart_predict_debug.R @@ -38,16 +38,16 @@ bart_model <- bart( ) # Check several predict approaches -y_hat_posterior_test <- predict(bart_model, X_test)$y_hat +y_hat_posterior_test <- predict(bart_model, X = X_test)$y_hat y_hat_mean_test <- predict( bart_model, - X_test, + X = X_test, type = "mean", terms = c("y_hat") ) y_hat_test <- predict( bart_model, - X_test, + X = X_test, type = "mean", terms = c("rfx", "variance") ) @@ -56,7 +56,7 @@ y_hat_intervals <- compute_bart_posterior_interval( model_object = bart_model, transform = function(x) x, terms = c("y_hat", "mean_forest"), - covariates = X_test, + X = X_test, level = 0.95 ) @@ -67,7 +67,7 @@ y_hat_intervals <- compute_bart_posterior_interval( pred_intervals <- sample_bart_posterior_predictive( model_object = bart_model, - covariates = X_test, + X = X_test, level = 0.95 ) @@ -117,18 +117,18 @@ bart_model <- bart( # Predict on latent scale y_hat_post <- predict( object = bart_model, + X = X_test, type = "posterior", terms = c("y_hat"), - covariates = X_test, scale = "linear" ) # Predict on probability scale y_hat_post_prob <- predict( object = bart_model, + X = X_test, type = "posterior", terms = c("y_hat"), - covariates = X_test, scale = "probability" ) @@ -137,7 +137,7 @@ y_hat_intervals <- compute_bart_posterior_interval( model_object = bart_model, scale = "linear", terms = c("y_hat"), - covariates = X_test, + X = X_test, level = 0.95 ) @@ -146,7 +146,7 @@ y_hat_prob_intervals <- compute_bart_posterior_interval( model_object = bart_model, scale = "probability", terms = c("y_hat"), - covariates = X_test, + X = X_test, level = 0.95 ) @@ -169,7 +169,7 @@ lines(y_hat_prob_intervals$upper[sort_inds]) # Draw from posterior predictive for covariates in the test set ppd_samples <- sample_bart_posterior_predictive( model_object = bart_model, - covariates = X_test, + X = X_test, num_draws = 10 ) diff --git a/tools/debug/bcf_predict_debug.R b/tools/debug/bcf_predict_debug.R index 70bc71ed..3ed45a2c 100644 --- a/tools/debug/bcf_predict_debug.R +++ b/tools/debug/bcf_predict_debug.R @@ -78,8 +78,8 @@ y_hat_intervals <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "linear", terms = c("all"), - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, level = 0.95 ) @@ -94,8 +94,8 @@ y_hat_intervals <- compute_bcf_posterior_interval( quantiles <- c(0.05, 0.95) ppd_samples <- sample_bcf_posterior_predictive( model_object = bcf_model, - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, num_draws = 1 ) @@ -179,8 +179,8 @@ y_hat_intervals <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "linear", terms = c("y_hat"), - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, level = 0.95 ) @@ -190,8 +190,8 @@ y_hat_prob_intervals <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "probability", terms = c("y_hat"), - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, level = 0.95 ) @@ -215,8 +215,8 @@ lines(y_hat_prob_intervals$upper[sort_inds]) # Draw from posterior predictive for covariates / treatment values in the test set ppd_samples <- sample_bcf_posterior_predictive( model_object = bcf_model, - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, num_draws = 10 ) @@ -360,8 +360,8 @@ posterior_intervals_test <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "linear", terms = "all", - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, rfx_group_ids = rfx_group_ids_test, level = 0.95 @@ -372,8 +372,8 @@ prog_intervals_test <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "linear", terms = "prognostic_function", - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, rfx_group_ids = rfx_group_ids_test, level = 0.95 @@ -384,8 +384,8 @@ cate_intervals_test <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "linear", terms = "cate", - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, rfx_group_ids = rfx_group_ids_test, level = 0.95 @@ -426,8 +426,8 @@ mu_intervals_test <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "linear", terms = "mu", - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, rfx_group_ids = rfx_group_ids_test, level = 0.95 @@ -436,8 +436,8 @@ tau_intervals_test <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "linear", terms = "tau", - covariates = X_test, - treatment = Z_test, + X = X_test, + Z = Z_test, propensity = pi_test, rfx_group_ids = rfx_group_ids_test, level = 0.95 diff --git a/tools/debug/gfr_ties_debug.R b/tools/debug/gfr_ties_debug.R index 833bf533..757faaa6 100644 --- a/tools/debug/gfr_ties_debug.R +++ b/tools/debug/gfr_ties_debug.R @@ -38,7 +38,7 @@ xbart_model <- bart( # Inspect the model fit y_hat_test <- predict( xbart_model, - covariates = X_test, + X = X_test, type = "mean", terms = "y_hat" ) @@ -57,7 +57,7 @@ bart_model <- bart( # Inspect the model fit y_hat_test <- predict( bart_model, - covariates = X_test, + X = X_test, type = "mean", terms = "y_hat" ) @@ -100,7 +100,7 @@ xbart_model <- bart( # Inspect the model fit y_hat_test <- predict( xbart_model, - covariates = X_test, + X = X_test, type = "mean", terms = "y_hat" ) @@ -119,7 +119,7 @@ bart_model <- bart( # Inspect the model fit y_hat_test <- predict( bart_model, - covariates = X_test, + X = X_test, type = "mean", terms = "y_hat" ) diff --git a/tools/debug/parallel_warmstart.R b/tools/debug/parallel_warmstart.R index 18b73574..243180db 100644 --- a/tools/debug/parallel_warmstart.R +++ b/tools/debug/parallel_warmstart.R @@ -14,35 +14,55 @@ num_trees <- 100 n <- 500 p_x <- 20 snr <- 2 -X <- matrix(runif(n*p_x), ncol = p_x) -f_XW <- sin(4*pi*X[,1]) + cos(4*pi*X[,2]) + sin(4*pi*X[,3]) +cos(4*pi*X[,4]) +X <- matrix(runif(n * p_x), ncol = p_x) +f_XW <- sin(4 * pi * X[, 1]) + + cos(4 * pi * X[, 2]) + + sin(4 * pi * X[, 3]) + + cos(4 * pi * X[, 4]) noise_sd <- sd(f_XW) / snr -y <- f_XW + rnorm(n, 0, 1)*noise_sd +y <- f_XW + rnorm(n, 0, 1) * noise_sd # Split data into test and train sets test_set_pct <- 0.2 -n_test <- round(test_set_pct*n) +n_test <- round(test_set_pct * n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] -X_test <- as.data.frame(X[test_inds,]) -X_train <- as.data.frame(X[train_inds,]) +X_test <- as.data.frame(X[test_inds, ]) +X_train <- as.data.frame(X[train_inds, ]) y_test <- y[test_inds] y_train <- y[train_inds] # Run the GFR algorithm -xbart_params <- list(sample_sigma_global = T, - num_trees_mean = num_trees, alpha_mean = 0.99, - beta_mean = 1, max_depth_mean = -1, - min_samples_leaf_mean = 1, sample_sigma_leaf = F, - sigma_leaf_init = 1/num_trees) +xbart_params <- list( + sample_sigma_global = T, + num_trees_mean = num_trees, + alpha_mean = 0.99, + beta_mean = 1, + max_depth_mean = -1, + min_samples_leaf_mean = 1, + sample_sigma_leaf = F, + sigma_leaf_init = 1 / num_trees +) xbart_model <- stochtree::bart( - X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = num_gfr, num_burnin = 0, num_mcmc = 0, params = xbart_params + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = num_gfr, + num_burnin = 0, + num_mcmc = 0, + params = xbart_params ) -plot(rowMeans(xbart_model$y_hat_test), y_test); abline(0,1) +plot(rowMeans(xbart_model$y_hat_test), y_test) +abline(0, 1) cat(sqrt(mean((rowMeans(xbart_model$y_hat_test) - y_test)^2)), "\n") -cat(mean((apply(xbart_model$y_hat_test, 1, quantile, probs=0.05) <= y_test) & (apply(xbart_model$y_hat_test, 1, quantile, probs=0.95) >= y_test)), "\n") +cat( + mean( + (apply(xbart_model$y_hat_test, 1, quantile, probs = 0.05) <= y_test) & + (apply(xbart_model$y_hat_test, 1, quantile, probs = 0.95) >= y_test) + ), + "\n" +) xbart_model_string <- stochtree::saveBARTModelToJsonString(xbart_model) # Parallel setup @@ -51,20 +71,32 @@ cl <- makeCluster(ncores) registerDoParallel(cl) # Run the parallel BART MCMC samplers -bart_model_outputs <- foreach (i = 1:num_chains) %dopar% { +bart_model_outputs <- foreach(i = 1:num_chains) %dopar% + { random_seed <- i - bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, - num_trees_mean = num_trees, random_seed = random_seed, - alpha_mean = 0.999, beta_mean = 1) + bart_params <- list( + sample_sigma_global = T, + sample_sigma_leaf = T, + num_trees_mean = num_trees, + random_seed = random_seed, + alpha_mean = 0.999, + beta_mean = 1 + ) bart_model <- stochtree::bart( - X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = num_burnin, num_mcmc = num_mcmc, params = bart_params, - previous_model_json = xbart_model_string, warmstart_sample_num = num_gfr - i + 1, + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + params = bart_params, + previous_model_json = xbart_model_string, + warmstart_sample_num = num_gfr - i + 1, ) bart_model_string <- stochtree::saveBARTModelToJsonString(bart_model) y_hat_test <- bart_model$y_hat_test - list(model=bart_model_string, yhat=y_hat_test) -} + list(model = bart_model_string, yhat = y_hat_test) + } # Close the cluster connection stopCluster(cl) @@ -73,43 +105,89 @@ stopCluster(cl) bart_model_strings <- list() bart_model_yhats <- matrix(NA, nrow = length(y_test), ncol = num_chains) for (i in 1:length(bart_model_outputs)) { - bart_model_strings[[i]] <- bart_model_outputs[[i]]$model - bart_model_yhats[,i] <- rowMeans(bart_model_outputs[[i]]$yhat) + bart_model_strings[[i]] <- bart_model_outputs[[i]]$model + bart_model_yhats[, i] <- rowMeans(bart_model_outputs[[i]]$yhat) } combined_bart <- createBARTModelFromCombinedJsonString(bart_model_strings) # Inspect the results -yhat_combined <- predict(combined_bart, X_test)$y_hat -par(mfrow = c(1,2)) +yhat_combined <- predict(combined_bart, X = X_test)$y_hat +par(mfrow = c(1, 2)) for (i in 1:num_chains) { - offset <- (i-1)*num_mcmc - inds_start <- offset + 1 - inds_end <- offset + num_mcmc - plot(rowMeans(yhat_combined[,inds_start:inds_end]), bart_model_yhats[,i], - xlab = "deserialized", ylab = "original", - main = paste0("Chain ", i, "\nPredictions")) - abline(0,1,col="red",lty=3,lwd=3) + offset <- (i - 1) * num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot( + rowMeans(yhat_combined[, inds_start:inds_end]), + bart_model_yhats[, i], + xlab = "deserialized", + ylab = "original", + main = paste0("Chain ", i, "\nPredictions") + ) + abline(0, 1, col = "red", lty = 3, lwd = 3) } for (i in 1:num_chains) { - offset <- (i-1)*num_mcmc - inds_start <- offset + 1 - inds_end <- offset + num_mcmc - plot(rowMeans(yhat_combined[,inds_start:inds_end]), y_test, - xlab = "predicted", ylab = "actual", - main = paste0("Chain ", i, "\nPredictions")) - abline(0,1,col="red",lty=3,lwd=3) - cat(sqrt(mean((rowMeans(yhat_combined[,inds_start:inds_end]) - y_test)^2)), "\n") - cat(mean((apply(yhat_combined[,inds_start:inds_end], 1, quantile, probs=0.05) <= y_test) & (apply(yhat_combined[,inds_start:inds_end], 1, quantile, probs=0.95) >= y_test)), "\n") + offset <- (i - 1) * num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot( + rowMeans(yhat_combined[, inds_start:inds_end]), + y_test, + xlab = "predicted", + ylab = "actual", + main = paste0("Chain ", i, "\nPredictions") + ) + abline(0, 1, col = "red", lty = 3, lwd = 3) + cat( + sqrt(mean((rowMeans(yhat_combined[, inds_start:inds_end]) - y_test)^2)), + "\n" + ) + cat( + mean( + (apply(yhat_combined[, inds_start:inds_end], 1, quantile, probs = 0.05) <= + y_test) & + (apply( + yhat_combined[, inds_start:inds_end], + 1, + quantile, + probs = 0.95 + ) >= + y_test) + ), + "\n" + ) } -par(mfrow = c(1,1)) +par(mfrow = c(1, 1)) # Compare to a single chain of MCMC samples initialized at root -bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, - num_trees_mean = num_trees, alpha_mean = 0.95, beta_mean = 2) +bart_params <- list( + sample_sigma_global = T, + sample_sigma_leaf = T, + num_trees_mean = num_trees, + alpha_mean = 0.95, + beta_mean = 2 +) bart_model <- stochtree::bart( - X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, params = bart_params + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + params = bart_params +) +plot( + rowMeans(bart_model$y_hat_test), + y_test, + xlab = "predicted", + ylab = "actual" ) -plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual"); abline(0,1) +abline(0, 1) cat(sqrt(mean((rowMeans(bart_model$y_hat_test) - y_test)^2)), "\n") -cat(mean((apply(bart_model$y_hat_test, 1, quantile, probs=0.05) <= y_test) & (apply(bart_model$y_hat_test, 1, quantile, probs=0.95) >= y_test)), "\n") +cat( + mean( + (apply(bart_model$y_hat_test, 1, quantile, probs = 0.05) <= y_test) & + (apply(bart_model$y_hat_test, 1, quantile, probs = 0.95) >= y_test) + ), + "\n" +) diff --git a/tools/debug/parallel_warmstart_bcf.R b/tools/debug/parallel_warmstart_bcf.R index 9d002b32..1abf9213 100644 --- a/tools/debug/parallel_warmstart_bcf.R +++ b/tools/debug/parallel_warmstart_bcf.R @@ -16,28 +16,32 @@ n <- 500 x1 <- rnorm(n) x2 <- rnorm(n) x3 <- rnorm(n) -x4 <- rnorm(n,x2,1) -X <- cbind(x1,x2,x3,x4) +x4 <- rnorm(n, x2, 1) +X <- cbind(x1, x2, x3, x4) p <- ncol(X) -mu <- function(x) {-1*(x[,1]>(x[,2])) + 1*(x[,1]<(x[,2])) - 0.1} -tau <- function(x) {1/(1 + exp(-x[,3])) + x[,2]/10} +mu <- function(x) { + -1 * (x[, 1] > (x[, 2])) + 1 * (x[, 1] < (x[, 2])) - 0.1 +} +tau <- function(x) { + 1 / (1 + exp(-x[, 3])) + x[, 2] / 10 +} mu_x <- mu(X) tau_x <- tau(X) pi_x <- pnorm(mu_x) -Z <- rbinom(n,1,pi_x) -E_XZ <- mu_x + Z*tau_x -sigma <- diff(range(mu_x + tau_x*pi))/8 -y <- E_XZ + sigma*rnorm(n) +Z <- rbinom(n, 1, pi_x) +E_XZ <- mu_x + Z * tau_x +sigma <- diff(range(mu_x + tau_x * pi)) / 8 +y <- E_XZ + sigma * rnorm(n) X <- as.data.frame(X) # Split data into test and train sets test_set_pct <- 0.2 -n_test <- round(test_set_pct*n) +n_test <- round(test_set_pct * n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] -X_test <- X[test_inds,] -X_train <- X[train_inds,] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] pi_test <- pi_x[test_inds] pi_train <- pi_x[train_inds] Z_test <- Z[test_inds] @@ -50,17 +54,39 @@ tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] # Run the GFR algorithm -xbcf_params <- list(num_trees_mu = num_trees_mu, num_trees_tau = num_trees_tau, - alpha_mu = 0.95, beta_mu = 1, max_depth_mu = -1, - alpha_tau = 0.8, beta_tau = 2, max_depth_tau = 10) +xbcf_params <- list( + num_trees_mu = num_trees_mu, + num_trees_tau = num_trees_tau, + alpha_mu = 0.95, + beta_mu = 1, + max_depth_mu = -1, + alpha_tau = 0.8, + beta_tau = 2, + max_depth_tau = 10 +) xbcf_model <- stochtree::bcf( - X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, - X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, - num_burnin = 0, num_mcmc = 0, params = xbcf_params + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + pi_train = pi_train, + X_test = X_test, + Z_test = Z_test, + pi_test = pi_test, + num_gfr = num_gfr, + num_burnin = 0, + num_mcmc = 0, + params = xbcf_params ) -plot(rowMeans(xbcf_model$y_hat_test), y_test); abline(0,1) +plot(rowMeans(xbcf_model$y_hat_test), y_test) +abline(0, 1) cat(sqrt(mean((rowMeans(xbcf_model$y_hat_test) - y_test)^2)), "\n") -cat(mean((apply(xbcf_model$y_hat_test, 1, quantile, probs=0.05) <= y_test) & (apply(xbcf_model$y_hat_test, 1, quantile, probs=0.95) >= y_test)), "\n") +cat( + mean( + (apply(xbcf_model$y_hat_test, 1, quantile, probs = 0.05) <= y_test) & + (apply(xbcf_model$y_hat_test, 1, quantile, probs = 0.95) >= y_test) + ), + "\n" +) xbcf_model_string <- stochtree::saveBCFModelToJsonString(xbcf_model) # Parallel setup @@ -69,20 +95,33 @@ cl <- makeCluster(ncores) registerDoParallel(cl) # Run the parallel BART MCMC samplers -bcf_model_outputs <- foreach (i = 1:num_chains) %dopar% { +bcf_model_outputs <- foreach(i = 1:num_chains) %dopar% + { random_seed <- i - bcf_params <- list(num_trees_mu = num_trees_mu, num_trees_tau = num_trees_tau, - random_seed = random_seed) + bcf_params <- list( + num_trees_mu = num_trees_mu, + num_trees_tau = num_trees_tau, + random_seed = random_seed + ) bcf_model <- stochtree::bcf( - X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, - X_test = X_test, Z_test = Z_test, pi_test = pi_test, - num_gfr = 0, num_burnin = num_burnin, num_mcmc = num_mcmc, params = bcf_params, - previous_model_json = xbcf_model_string, warmstart_sample_num = num_gfr - i + 1, + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + pi_train = pi_train, + X_test = X_test, + Z_test = Z_test, + pi_test = pi_test, + num_gfr = 0, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + params = bcf_params, + previous_model_json = xbcf_model_string, + warmstart_sample_num = num_gfr - i + 1, ) bcf_model_string <- stochtree::saveBCFModelToJsonString(bcf_model) y_hat_test <- bcf_model$y_hat_test - list(model=bcf_model_string, yhat=y_hat_test) -} + list(model = bcf_model_string, yhat = y_hat_test) + } # Close the cluster connection stopCluster(cl) @@ -91,44 +130,93 @@ stopCluster(cl) bcf_model_strings <- list() bcf_model_yhats <- matrix(NA, nrow = length(y_test), ncol = num_chains) for (i in 1:length(bcf_model_outputs)) { - bcf_model_strings[[i]] <- bcf_model_outputs[[i]]$model - bcf_model_yhats[,i] <- rowMeans(bcf_model_outputs[[i]]$yhat) + bcf_model_strings[[i]] <- bcf_model_outputs[[i]]$model + bcf_model_yhats[, i] <- rowMeans(bcf_model_outputs[[i]]$yhat) } combined_bcf <- createBCFModelFromCombinedJsonString(bcf_model_strings) # Inspect the results -yhat_combined <- predict(combined_bcf, X_test)$y_hat -par(mfrow = c(1,2)) +yhat_combined <- predict(combined_bcf, X = X_test)$y_hat +par(mfrow = c(1, 2)) for (i in 1:num_chains) { - offset <- (i-1)*num_mcmc - inds_start <- offset + 1 - inds_end <- offset + num_mcmc - plot(rowMeans(yhat_combined[,inds_start:inds_end]), bcf_model_yhats[,i], - xlab = "deserialized", ylab = "original", - main = paste0("Chain ", i, "\nPredictions")) - abline(0,1,col="red",lty=3,lwd=3) + offset <- (i - 1) * num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot( + rowMeans(yhat_combined[, inds_start:inds_end]), + bcf_model_yhats[, i], + xlab = "deserialized", + ylab = "original", + main = paste0("Chain ", i, "\nPredictions") + ) + abline(0, 1, col = "red", lty = 3, lwd = 3) } for (i in 1:num_chains) { - offset <- (i-1)*num_mcmc - inds_start <- offset + 1 - inds_end <- offset + num_mcmc - plot(rowMeans(yhat_combined[,inds_start:inds_end]), y_test, - xlab = "predicted", ylab = "actual", - main = paste0("Chain ", i, "\nPredictions")) - abline(0,1,col="red",lty=3,lwd=3) - cat(sqrt(mean((rowMeans(yhat_combined[,inds_start:inds_end]) - y_test)^2)), "\n") - cat(mean((apply(yhat_combined[,inds_start:inds_end], 1, quantile, probs=0.05) <= y_test) & (apply(yhat_combined[,inds_start:inds_end], 1, quantile, probs=0.95) >= y_test)), "\n") + offset <- (i - 1) * num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot( + rowMeans(yhat_combined[, inds_start:inds_end]), + y_test, + xlab = "predicted", + ylab = "actual", + main = paste0("Chain ", i, "\nPredictions") + ) + abline(0, 1, col = "red", lty = 3, lwd = 3) + cat( + sqrt(mean((rowMeans(yhat_combined[, inds_start:inds_end]) - y_test)^2)), + "\n" + ) + cat( + mean( + (apply(yhat_combined[, inds_start:inds_end], 1, quantile, probs = 0.05) <= + y_test) & + (apply( + yhat_combined[, inds_start:inds_end], + 1, + quantile, + probs = 0.95 + ) >= + y_test) + ), + "\n" + ) } -par(mfrow = c(1,1)) +par(mfrow = c(1, 1)) # Compare to a single chain of MCMC samples initialized at root -bcf_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, - num_trees_mean = num_trees, alpha_mean = 0.95, beta_mean = 2) +bcf_params <- list( + sample_sigma_global = T, + sample_sigma_leaf = T, + num_trees_mean = num_trees, + alpha_mean = 0.95, + beta_mean = 2 +) bcf_model <- stochtree::bcf( - X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, - X_test = X_test, Z_test = Z_test, pi_test = pi_test, - num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, params = bcf_params + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + pi_train = pi_train, + X_test = X_test, + Z_test = Z_test, + pi_test = pi_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + params = bcf_params ) -plot(rowMeans(bcf_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual"); abline(0,1) +plot( + rowMeans(bcf_model$y_hat_test), + y_test, + xlab = "predicted", + ylab = "actual" +) +abline(0, 1) cat(sqrt(mean((rowMeans(bcf_model$y_hat_test) - y_test)^2)), "\n") -cat(mean((apply(bcf_model$y_hat_test, 1, quantile, probs=0.05) <= y_test) & (apply(bcf_model$y_hat_test, 1, quantile, probs=0.95) >= y_test)), "\n") +cat( + mean( + (apply(bcf_model$y_hat_test, 1, quantile, probs = 0.05) <= y_test) & + (apply(bcf_model$y_hat_test, 1, quantile, probs = 0.95) >= y_test) + ), + "\n" +) diff --git a/tools/regression/bcf/individual_regression_test_bcf.py b/tools/regression/bcf/individual_regression_test_bcf.py index 591b24d2..f4279193 100644 --- a/tools/regression/bcf/individual_regression_test_bcf.py +++ b/tools/regression/bcf/individual_regression_test_bcf.py @@ -337,7 +337,7 @@ def main(): X_train=covariates_train, Z_train=treatment_train, y_train=outcome_train, - pi_train=propensity_train, + propensity_train=propensity_train, rfx_group_ids_train=rfx_group_ids_train, rfx_basis_train=rfx_basis_train, num_gfr=num_gfr, diff --git a/vignettes/MultiChain.Rmd b/vignettes/MultiChain.Rmd index bedf5220..83c60e6c 100644 --- a/vignettes/MultiChain.Rmd +++ b/vignettes/MultiChain.Rmd @@ -129,7 +129,7 @@ predictions. ```{r} y_hat_test <- predict( bart_model, - covariates = X_test, + X = X_test, leaf_basis = leaf_basis_test, type = "mean", terms = "y_hat" @@ -217,7 +217,7 @@ abs_test_set_resid <- abs(y_test - y_hat_test) top5_resids <- order(abs_test_set_resid, decreasing = T)[1:5] y_hat_test_posterior <- predict( bart_model, - covariates = X_test[top5_resids, ], + X = X_test[top5_resids, ], leaf_basis = leaf_basis_test[top5_resids], type = "posterior", terms = "y_hat" @@ -345,7 +345,7 @@ predictions. ```{r} y_hat_test <- predict( bart_model, - covariates = X_test, + X = X_test, leaf_basis = leaf_basis_test, type = "mean", terms = "y_hat" @@ -433,7 +433,7 @@ abs_test_set_resid <- abs(y_test - y_hat_test) top5_resids <- order(abs_test_set_resid, decreasing = T)[1:5] y_hat_test_posterior <- predict( bart_model, - covariates = X_test[top5_resids, ], + X = X_test[top5_resids, ], leaf_basis = leaf_basis_test[top5_resids], type = "posterior", terms = "y_hat"