Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 51 additions & 23 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -423,37 +423,53 @@ bart <- function(
floor(num_values / cutpoint_grid_size),
1
)
x_is_df <- is.data.frame(X_train)
covs_warning_1 <- NULL
covs_warning_2 <- NULL
covs_warning_3 <- NULL
covs_warning_4 <- NULL
for (i in 1:num_cov_orig) {
# Determine the number of unique values
num_unique_values <- length(unique(X_train[, i]))

# Determine a "name" for the covariate
cov_name <- ifelse(
is.null(colnames(X_train)),
paste0("X", i),
colnames(X_train)[i]
)

# Check for a small relative number of unique values
unique_full_ratio <- num_unique_values / num_values
if (unique_full_ratio < 0.2) {
covs_warning_1 <- c(covs_warning_1, cov_name)
# Skip check for variables that are treated as categorical
x_numeric <- T
if (x_is_df) {
if (is.factor(X_train[, i])) {
x_numeric <- F
}
}
if (x_numeric) {
# Determine the number of unique values
num_unique_values <- length(unique(X_train[, i]))

# Determine a "name" for the covariate
cov_name <- ifelse(
is.null(colnames(X_train)),
paste0("X", i),
colnames(X_train)[i]
)

# Check for a small absolute number of unique values
if (num_values > 100) {
if (num_unique_values < 20) {
covs_warning_2 <- c(covs_warning_2, cov_name)
# Check for a small relative number of unique values
unique_full_ratio <- num_unique_values / num_values
if (unique_full_ratio < 0.2) {
covs_warning_1 <- c(covs_warning_1, cov_name)
}

# Check for a small absolute number of unique values
if (num_values > 100) {
if (num_unique_values < 20) {
covs_warning_2 <- c(covs_warning_2, cov_name)
}
}

# Check for a large number of duplicates of any individual value
x_j_hist <- table(X_train[, i])
if (any(x_j_hist > 2 * max_grid_size)) {
covs_warning_3 <- c(covs_warning_3, cov_name)
}
}

# Check for a large number of duplicates of any individual value
x_j_hist <- table(X_train[, i])
if (any(x_j_hist > 2 * max_grid_size)) {
covs_warning_3 <- c(covs_warning_3, cov_name)
# Check for binary variables
if (num_unique_values == 2) {
covs_warning_4 <- c(covs_warning_4, cov_name)
}
}
}

Expand Down Expand Up @@ -494,6 +510,18 @@ bart <- function(
)
)
}

if (!is.null(covs_warning_4)) {
warning(
paste0(
"Covariates ",
paste(covs_warning_4, collapse = ", "),
" appear to be binary but are currently treated by stochtree as continuous. ",
"This might present some issues with the grow-from-root (GFR) algorithm. ",
"Consider converting binary variables to ordered factor (i.e. `factor(..., ordered = T)`."
)
)
}
}

# Standardize the keep variable lists to numeric indices
Expand Down
75 changes: 52 additions & 23 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -527,37 +527,54 @@ bcf <- function(
floor(num_values / cutpoint_grid_size),
1
)
x_is_df <- is.data.frame(X_train)
covs_warning_1 <- NULL
covs_warning_2 <- NULL
covs_warning_3 <- NULL
covs_warning_4 <- NULL
for (i in 1:num_cov_orig) {
# Determine the number of unique values
num_unique_values <- length(unique(X_train[, i]))

# Determine a "name" for the covariate
cov_name <- ifelse(
is.null(colnames(X_train)),
paste0("X", i),
colnames(X_train)[i]
)

# Check for a small relative number of unique values
unique_full_ratio <- num_unique_values / num_values
if (unique_full_ratio < 0.2) {
covs_warning_1 <- c(covs_warning_1, cov_name)
# Skip check for variables that are treated as categorical
x_numeric <- T
if (x_is_df) {
if (is.factor(X_train[, i])) {
x_numeric <- F
}
}

# Check for a small absolute number of unique values
if (num_values > 100) {
if (num_unique_values < 20) {
covs_warning_2 <- c(covs_warning_2, cov_name)
if (x_numeric) {
# Determine the number of unique values
num_unique_values <- length(unique(X_train[, i]))

# Determine a "name" for the covariate
cov_name <- ifelse(
is.null(colnames(X_train)),
paste0("X", i),
colnames(X_train)[i]
)

# Check for a small relative number of unique values
unique_full_ratio <- num_unique_values / num_values
if (unique_full_ratio < 0.2) {
covs_warning_1 <- c(covs_warning_1, cov_name)
}
}

# Check for a large number of duplicates of any individual value
x_j_hist <- table(X_train[, i])
if (any(x_j_hist > 2 * max_grid_size)) {
covs_warning_3 <- c(covs_warning_3, cov_name)
# Check for a small absolute number of unique values
if (num_values > 100) {
if (num_unique_values < 20) {
covs_warning_2 <- c(covs_warning_2, cov_name)
}
}

# Check for a large number of duplicates of any individual value
x_j_hist <- table(X_train[, i])
if (any(x_j_hist > 2 * max_grid_size)) {
covs_warning_3 <- c(covs_warning_3, cov_name)
}

# Check for binary variables
if (num_unique_values == 2) {
covs_warning_4 <- c(covs_warning_4, cov_name)
}
}
}

Expand Down Expand Up @@ -598,6 +615,18 @@ bcf <- function(
)
)
}

if (!is.null(covs_warning_4)) {
warning(
paste0(
"Covariates ",
paste(covs_warning_4, collapse = ", "),
" appear to be binary but are currently treated by stochtree as continuous. ",
"This might present some issues with the grow-from-root (GFR) algorithm. ",
"Consider converting binary variables to ordered factor (i.e. `factor(..., ordered = T)`."
)
)
}
}

# Check delta_max is valid
Expand Down
58 changes: 39 additions & 19 deletions stochtree/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,32 +456,45 @@ def sample(
if (num_gfr > 0) and (num_burnin == 0) and (num_mcmc == 0):
num_values, num_cov_orig = X_train.shape
max_grid_size = floor(num_values / cutpoint_grid_size)
x_is_df = isinstance(X_train, pd.DataFrame)
covs_warning_1 = []
covs_warning_2 = []
covs_warning_3 = []
covs_warning_4 = []
for i in range(num_cov_orig):
# Determine the number of unique covariate values and a name for the covariate
if isinstance(X_train, np.ndarray):
x_j_hist = np.unique_counts(X_train[:, i]).counts
cov_name = f"X{i + 1}"
else:
x_j_hist = (X_train.iloc[:, i]).value_counts()
cov_name = X_train.columns[i]
# Skip check for variables that are treated as categorical
x_numeric = True
if x_is_df:
if isinstance(X_train.iloc[:,i].dtype, pd.CategoricalDtype):
x_numeric = False

if x_numeric:
# Determine the number of unique covariate values and a name for the covariate
if isinstance(X_train, np.ndarray):
x_j_hist = np.unique_counts(X_train[:, i]).counts
cov_name = f"X{i + 1}"
else:
x_j_hist = (X_train.iloc[:, i]).value_counts()
cov_name = X_train.columns[i]

# Check for a small relative number of unique values
num_unique_values = len(x_j_hist)
unique_full_ratio = num_unique_values / num_values
if unique_full_ratio < 0.2:
covs_warning_1.append(cov_name)
# Check for a small relative number of unique values
num_unique_values = len(x_j_hist)
unique_full_ratio = num_unique_values / num_values
if unique_full_ratio < 0.2:
covs_warning_1.append(cov_name)

# Check for a small absolute number of unique values
if num_values > 100:
if num_unique_values < 20:
covs_warning_2.append(cov_name)
# Check for a small absolute number of unique values
if num_values > 100:
if num_unique_values < 20:
covs_warning_2.append(cov_name)

# Check for a large number of duplicates of any individual value
if np.any(x_j_hist > 2 * max_grid_size):
covs_warning_3.append(cov_name)
# Check for a large number of duplicates of any individual value
if np.any(x_j_hist > 2 * max_grid_size):
covs_warning_3.append(cov_name)

# Check for binary variables
if num_unique_values == 2:
covs_warning_4.append(cov_name)

if covs_warning_1:
warnings.warn(
Expand All @@ -505,6 +518,13 @@ def sample(
"Consider running with `num_mcmc > 0` and `num_burnin > 0` to improve your model's performance."
)

if covs_warning_4:
warnings.warn(
f"Covariates {', '.join(covs_warning_4)} appear to be binary but are currently treated by stochtree as continuous. "
"This might present some issues with the grow-from-root (GFR) algorithm. "
"Consider converting binary variables to ordered categorical (i.e. `pd.Categorical(..., ordered = True)`."
)

# Variable weight preprocessing (and initialization if necessary)
p = X_train.shape[1]
if variable_weights is None:
Expand Down
Loading
Loading