Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,11 @@ bart <- function(
# Raise a warning if the data have ties and only GFR is being run
if ((num_gfr > 0) && (num_mcmc == 0) && (num_burnin == 0)) {
num_values <- nrow(X_train)
max_grid_size <- floor(num_values / cutpoint_grid_size)
max_grid_size <- ifelse(
num_values > cutpoint_grid_size,
floor(num_values / cutpoint_grid_size),
1
)
covs_warning_1 <- NULL
covs_warning_2 <- NULL
covs_warning_3 <- NULL
Expand Down Expand Up @@ -1924,7 +1928,7 @@ bart <- function(
#' Predict from a sampled BART model on new data
#'
#' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs.
#' @param covariates Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.
#' @param X Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.
#' @param leaf_basis (Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: `NULL`.
#' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model.
#' We do not currently support (but plan to in the near future), test set evaluation for group labels
Expand Down Expand Up @@ -1961,10 +1965,10 @@ bart <- function(
#' y_train <- y[train_inds]
#' bart_model <- bart(X_train = X_train, y_train = y_train,
#' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' y_hat_test <- predict(bart_model, X_test)$y_hat
#' y_hat_test <- predict(bart_model, X=X_test)$y_hat
predict.bartmodel <- function(
object,
covariates,
X,
leaf_basis = NULL,
rfx_group_ids = NULL,
rfx_basis = NULL,
Expand Down Expand Up @@ -2047,8 +2051,8 @@ predict.bartmodel <- function(
}

# Check that covariates are matrix or data frame
if ((!is.data.frame(covariates)) && (!is.matrix(covariates))) {
stop("covariates must be a matrix or dataframe")
if ((!is.data.frame(X)) && (!is.matrix(X))) {
stop("X must be a matrix or dataframe")
}

# Convert all input data to matrices if not already converted
Expand All @@ -2063,12 +2067,12 @@ predict.bartmodel <- function(
if ((object$model_params$requires_basis) && (is.null(leaf_basis))) {
stop("Basis (leaf_basis) must be provided for this model")
}
if ((!is.null(leaf_basis)) && (nrow(covariates) != nrow(leaf_basis))) {
stop("covariates and leaf_basis must have the same number of rows")
if ((!is.null(leaf_basis)) && (nrow(X) != nrow(leaf_basis))) {
stop("X and leaf_basis must have the same number of rows")
}
if (object$model_params$num_covariates != ncol(covariates)) {
if (object$model_params$num_covariates != ncol(X)) {
stop(
"covariates must contain the same number of columns as the BART model's training dataset"
"X must contain the same number of columns as the BART model's training dataset"
)
}
if ((predict_rfx) && (is.null(rfx_group_ids))) {
Expand All @@ -2089,7 +2093,7 @@ predict.bartmodel <- function(

# Preprocess covariates
train_set_metadata <- object$train_set_metadata
covariates <- preprocessPredictionData(covariates, train_set_metadata)
X <- preprocessPredictionData(X, train_set_metadata)

# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
has_rfx <- FALSE
Expand Down Expand Up @@ -2119,8 +2123,8 @@ predict.bartmodel <- function(
# Only construct a basis if user-provided basis missing
if (is.null(rfx_basis)) {
rfx_basis <- matrix(
rep(1, nrow(covariates)),
nrow = nrow(covariates),
rep(1, nrow(X)),
nrow = nrow(X),
ncol = 1
)
}
Expand All @@ -2129,9 +2133,9 @@ predict.bartmodel <- function(

# Create prediction dataset
if (!is.null(leaf_basis)) {
prediction_dataset <- createForestDataset(covariates, leaf_basis)
prediction_dataset <- createForestDataset(X, leaf_basis)
} else {
prediction_dataset <- createForestDataset(covariates)
prediction_dataset <- createForestDataset(X)
}

# Compute variance forest predictions
Expand Down Expand Up @@ -2843,7 +2847,7 @@ createBARTModelFromJsonFile <- function(json_filename) {
#' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' bart_json <- saveBARTModelToJsonString(bart_model)
#' bart_model_roundtrip <- createBARTModelFromJsonString(bart_json)
#' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat)
#' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X=X_train)$y_hat)
createBARTModelFromJsonString <- function(json_string) {
# Load a `CppJson` object from string
bart_json <- createCppJsonString(json_string)
Expand Down
6 changes: 5 additions & 1 deletion R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,11 @@ bcf <- function(
# Raise a warning if the data have ties and only GFR is being run
if ((num_gfr > 0) && (num_mcmc == 0) && (num_burnin == 0)) {
num_values <- nrow(X_train)
max_grid_size <- floor(num_values / cutpoint_grid_size)
max_grid_size <- ifelse(
num_values > cutpoint_grid_size,
floor(num_values / cutpoint_grid_size),
1
)
covs_warning_1 <- NULL
covs_warning_2 <- NULL
covs_warning_3 <- NULL
Expand Down
2 changes: 1 addition & 1 deletion R/kernel.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ computeForestLeafIndices <- function(
propensity <- rowMeans(
predict(
model_object$bart_propensity_model,
covariates
X = covariates
)$y_hat
)
}
Expand Down
Loading
Loading