Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,83 @@ bart <- function(
}
num_cov_orig <- ncol(X_train)

# Raise a warning if the data have ties and only GFR is being run
if ((num_gfr > 0) && (num_mcmc == 0) && (num_burnin == 0)) {
num_values <- nrow(X_train)
max_grid_size <- floor(num_values / cutpoint_grid_size)
covs_warning_1 <- NULL
covs_warning_2 <- NULL
covs_warning_3 <- NULL
for (i in 1:num_cov_orig) {
# Determine the number of unique values
num_unique_values <- length(unique(X_train[, i]))

# Determine a "name" for the covariate
cov_name <- ifelse(
is.null(colnames(X_train)),
paste0("X", i),
colnames(X_train)[i]
)

# Check for a small relative number of unique values
unique_full_ratio <- num_unique_values / num_values
if (unique_full_ratio < 0.2) {
covs_warning_1 <- c(covs_warning_1, cov_name)
}

# Check for a small absolute number of unique values
if (num_values > 100) {
if (num_unique_values < 20) {
covs_warning_2 <- c(covs_warning_2, cov_name)
}
}

# Check for a large number of duplicates of any individual value
x_j_hist <- table(X_train[, i])
if (any(x_j_hist > 2 * max_grid_size)) {
covs_warning_3 <- c(covs_warning_3, cov_name)
}
}

if (!is.null(covs_warning_1)) {
warning(
paste0(
"Covariate(s) ",
paste(covs_warning_1, collapse = ", "),
" have a ratio of unique to overall observations of less than 0.2. ",
"This might present some issues with the grow-from-root (GFR) algorithm. ",
"Consider running with `num_mcmc > 0` and `num_burnin > 0` to improve your model's performance."
)
)
}

if (!is.null(covs_warning_2)) {
warning(
paste0(
"Covariate(s) ",
paste(covs_warning_2, collapse = ", "),
" have fewer than 20 unique values. ",
"This might present some issues with the grow-from-root (GFR) algorithm. ",
"Consider running with `num_mcmc > 0` and `num_burnin > 0` to improve your model's performance."
)
)
}

if (!is.null(covs_warning_3)) {
warning(
paste0(
"Covariates ",
paste(covs_warning_3, collapse = ", "),
" have some observed values with more than ",
2 * max_grid_size,
" repeated observations. ",
"This might present some issues with the grow-from-root (GFR) algorithm. ",
"Consider running with `num_mcmc > 0` and `num_burnin > 0` to improve your model's performance."
)
)
}
}

# Standardize the keep variable lists to numeric indices
if (!is.null(keep_vars_mean)) {
if (is.character(keep_vars_mean)) {
Expand Down
77 changes: 77 additions & 0 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,83 @@ bcf <- function(
}
num_cov_orig <- ncol(X_train)

# Raise a warning if the data have ties and only GFR is being run
if ((num_gfr > 0) && (num_mcmc == 0) && (num_burnin == 0)) {
num_values <- nrow(X_train)
max_grid_size <- floor(num_values / cutpoint_grid_size)
covs_warning_1 <- NULL
covs_warning_2 <- NULL
covs_warning_3 <- NULL
for (i in 1:num_cov_orig) {
# Determine the number of unique values
num_unique_values <- length(unique(X_train[, i]))

# Determine a "name" for the covariate
cov_name <- ifelse(
is.null(colnames(X_train)),
paste0("X", i),
colnames(X_train)[i]
)

# Check for a small relative number of unique values
unique_full_ratio <- num_unique_values / num_values
if (unique_full_ratio < 0.2) {
covs_warning_1 <- c(covs_warning_1, cov_name)
}

# Check for a small absolute number of unique values
if (num_values > 100) {
if (num_unique_values < 20) {
covs_warning_2 <- c(covs_warning_2, cov_name)
}
}

# Check for a large number of duplicates of any individual value
x_j_hist <- table(X_train[, i])
if (any(x_j_hist > 2 * max_grid_size)) {
covs_warning_3 <- c(covs_warning_3, cov_name)
}
}

if (!is.null(covs_warning_1)) {
warning(
paste0(
"Covariate(s) ",
paste(covs_warning_1, collapse = ", "),
" have a ratio of unique to overall observations of less than 0.2. ",
"This might present some issues with the grow-from-root (GFR) algorithm. ",
"Consider running with `num_mcmc > 0` and `num_burnin > 0` to improve your model's performance."
)
)
}

if (!is.null(covs_warning_2)) {
warning(
paste0(
"Covariate(s) ",
paste(covs_warning_2, collapse = ", "),
" have fewer than 20 unique values. ",
"This might present some issues with the grow-from-root (GFR) algorithm. ",
"Consider running with `num_mcmc > 0` and `num_burnin > 0` to improve your model's performance."
)
)
}

if (!is.null(covs_warning_3)) {
warning(
paste0(
"Covariates ",
paste(covs_warning_3, collapse = ", "),
" have some observed values with more than ",
2 * max_grid_size,
" repeated observations. ",
"This might present some issues with the grow-from-root (GFR) algorithm. ",
"Consider running with `num_mcmc > 0` and `num_burnin > 0` to improve your model's performance."
)
)
}
}

# Check delta_max is valid
if ((delta_max <= 0) || (delta_max >= 1)) {
stop("delta_max must be > 0 and < 1")
Expand Down
30 changes: 30 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -1084,3 +1084,33 @@ expand_dims_2d_diag <- function(input, output_size) {
}
return(output)
}


gfr_tie_checks <- function(covariates) {
num_vars <- ncol(covariates)
for (j in 1:num_vars) {
x_j <- covariates[, j]
if (has_few_unique_values(x_j)) {
warning_message <- paste0(
"Covariate column ",
j,
" has relatively few unique values. ",
"This may lead to tied values when sampling split points in BART/BCF, ",
"which can cause errors during model fitting. ",
"Consider adding small amounts of noise to this variable to break ties."
)
warning(warning_message)
}
}
}


has_few_unique_values <- function(
x,
count_threshold = 15
) {
x_unique <- unique(x)
num_unique_values <- length(unique_values)
unique_to_total_count_ratio <- num_unique_values / length(x)
return(num_unique_values <= threshold)
}
Loading
Loading