Skip to content

Commit

Permalink
Merge pull request #52 from StochasticTree/bcf-feature-subsets-mu-tau
Browse files Browse the repository at this point in the history
Using subsets of features in the prognostic and treatment forests in BCF
  • Loading branch information
andrewherren committed Jun 20, 2024
2 parents 9aac95c + 4b827b4 commit 2b52159
Show file tree
Hide file tree
Showing 11 changed files with 629 additions and 227 deletions.
31 changes: 23 additions & 8 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@
#' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3.
#' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).
#' @param a_leaf Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: 3.
#' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as 0.5/num_trees if not set here.
#' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
#' @param q Quantile used to calibrated `lambda` as in Sparapani et al (2021). Default: 0.9.
#' @param sigma2_init Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.
#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here.
#' @param num_trees Number of trees in the ensemble. Default: 200.
#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.
#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
Expand Down Expand Up @@ -80,10 +81,19 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
cutpoint_grid_size = 100, tau_init = NULL, alpha = 0.95,
beta = 2.0, min_samples_leaf = 5, leaf_model = 0,
nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL,
q = 0.9, sigma2_init = NULL, num_trees = 200, num_gfr = 5,
num_burnin = 0, num_mcmc = 100, sample_sigma = T,
sample_tau = T, random_seed = -1, keep_burnin = F,
keep_gfr = F, verbose = F){
q = 0.9, sigma2_init = NULL, variable_weights = NULL,
num_trees = 200, num_gfr = 5, num_burnin = 0,
num_mcmc = 100, sample_sigma = T, sample_tau = T,
random_seed = -1, keep_burnin = F, keep_gfr = F,
verbose = F){
# Variable weight preprocessing (and initialization if necessary)
if (is.null(variable_weights)) {
variable_weights = rep(1/ncol(X_train), ncol(X_train))
}
if (any(variable_weights < 0)) {
stop("variable_weights cannot have any negative weights")
}

# Preprocess covariates
if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) {
stop("X_train must be a matrix or dataframe")
Expand All @@ -93,12 +103,20 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
stop("X_test must be a matrix or dataframe")
}
}
if (ncol(X_train) != length(variable_weights)) {
stop("length(variable_weights) must equal ncol(X_train)")
}
train_cov_preprocess_list <- preprocessTrainData(X_train)
X_train_metadata <- train_cov_preprocess_list$metadata
X_train <- train_cov_preprocess_list$data
original_var_indices <- X_train_metadata$original_var_indices
feature_types <- X_train_metadata$feature_types
if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata)

# Update variable weights
variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x))
variable_weights <- variable_weights[original_var_indices]*variable_weights_adj

# Convert all input data to matrices if not already converted
if ((is.null(dim(W_train))) && (!is.null(W_train))) {
W_train <- as.matrix(W_train)
Expand Down Expand Up @@ -295,9 +313,6 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
if (sample_sigma) global_var_samples <- rep(0, num_samples)
if (sample_tau) leaf_scale_samples <- rep(0, num_samples)

# Variable selection weights
variable_weights <- rep(1/ncol(X_train), ncol(X_train))

# Run GFR (warm start) if specified
if (num_gfr > 0){
gfr_indices = 1:num_gfr
Expand Down
219 changes: 147 additions & 72 deletions R/bcf.R

Large diffs are not rendered by default.

17 changes: 15 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#' types. Matrices will be passed through assuming all columns are numeric.
#'
#' @param input_data Covariates, provided as either a dataframe or a matrix
#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable
#'
#' @return List with preprocessed (unmodified) data and details on the number of each type
#' of variable, unique categories associated with categorical variables, and the
Expand Down Expand Up @@ -63,6 +64,7 @@ preprocessPredictionData <- function(input_data, metadata) {
#' Returns a list including a matrix of preprocessed covariate values and associated tracking.
#'
#' @param input_matrix Covariate matrix.
#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable
#'
#' @return List with preprocessed (unmodified) data and details on the number of each type
#' of variable, unique categories associated with categorical variables, and the
Expand Down Expand Up @@ -97,7 +99,8 @@ preprocessTrainMatrix <- function(input_matrix) {
num_ordered_cat_vars = num_ordered_cat_vars,
num_unordered_cat_vars = num_unordered_cat_vars,
num_numeric_vars = num_numeric_vars,
numeric_vars = numeric_vars
numeric_vars = numeric_vars,
original_var_indices = 1:num_numeric_vars
)
output <- list(
data = X,
Expand Down Expand Up @@ -139,6 +142,7 @@ preprocessPredictionMatrix <- function(input_matrix, metadata) {
#'
#' @param input_df Dataframe of covariates. Users must pre-process any
#' categorical variables as factors (ordered for ordered categorical).
#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable
#'
#' @return List with preprocessed data and details on the number of each type
#' of variable, unique categories associated with categorical variables, and the
Expand All @@ -164,6 +168,7 @@ preprocessTrainDataFrame <- function(input_df) {
ordered_mask <- sapply(input_df, is.ordered)
ordered_cat_matches <- factor_mask & ordered_mask
ordered_cat_vars <- df_vars[ordered_cat_matches]
ordered_cat_var_inds <- unname(which(ordered_cat_matches))
num_ordered_cat_vars <- length(ordered_cat_vars)
if (num_ordered_cat_vars > 0) ordered_cat_df <- input_df[,ordered_cat_vars,drop=F]

Expand All @@ -173,12 +178,14 @@ preprocessTrainDataFrame <- function(input_df) {
character_mask <- sapply(input_df, is.character)
unordered_cat_matches <- (factor_mask & (!ordered_mask)) | character_mask
unordered_cat_vars <- df_vars[unordered_cat_matches]
unordered_cat_var_inds <- unname(which(unordered_cat_matches))
num_unordered_cat_vars <- length(unordered_cat_vars)
if (num_unordered_cat_vars > 0) unordered_cat_df <- input_df[,unordered_cat_vars,drop=F]

# Numeric variables
numeric_matches <- (!ordered_cat_matches) & (!unordered_cat_matches)
numeric_vars <- df_vars[numeric_matches]
numeric_var_inds <- unname(which(numeric_matches))
num_numeric_vars <- length(numeric_vars)
if (num_numeric_vars > 0) numeric_df <- input_df[,numeric_vars,drop=F]

Expand All @@ -187,6 +194,7 @@ preprocessTrainDataFrame <- function(input_df) {
unordered_unique_levels <- list()
ordered_unique_levels <- list()
feature_types <- integer(0)
original_var_indices <- integer(0)

# First, extract the numeric covariates
if (num_numeric_vars > 0) {
Expand All @@ -197,6 +205,7 @@ preprocessTrainDataFrame <- function(input_df) {
}
X <- cbind(X, unname(Xnum))
feature_types <- c(feature_types, rep(0, ncol(Xnum)))
original_var_indices <- c(original_var_indices, numeric_var_inds)
}

# Next, run some simple preprocessing on the ordered categorical covariates
Expand All @@ -210,6 +219,7 @@ preprocessTrainDataFrame <- function(input_df) {
}
X <- cbind(X, unname(Xordcat))
feature_types <- c(feature_types, rep(1, ncol(Xordcat)))
original_var_indices <- c(original_var_indices, ordered_cat_var_inds)
}

# Finally, one-hot encode the unordered categorical covariates
Expand All @@ -220,6 +230,8 @@ preprocessTrainDataFrame <- function(input_df) {
encode_list <- oneHotInitializeAndEncode(unordered_cat_df[,i])
unordered_unique_levels[[var_name]] <- encode_list$unique_levels
one_hot_mats[[var_name]] <- encode_list$Xtilde
one_hot_var <- rep(unordered_cat_var_inds[i], ncol(encode_list$Xtilde))
original_var_indices <- c(original_var_indices, one_hot_var)
}
Xcat <- do.call(cbind, one_hot_mats)
X <- cbind(X, unname(Xcat))
Expand All @@ -231,7 +243,8 @@ preprocessTrainDataFrame <- function(input_df) {
feature_types = feature_types,
num_ordered_cat_vars = num_ordered_cat_vars,
num_unordered_cat_vars = num_unordered_cat_vars,
num_numeric_vars = num_numeric_vars
num_numeric_vars = num_numeric_vars,
original_var_indices = original_var_indices
)
if (num_ordered_cat_vars > 0) {
metadata[["ordered_cat_vars"]] = ordered_cat_vars
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pip install matplotlib seaborn jupyterlab
The package can be installed in R via

```
remotes::install_github("StochasticTree/stochtree-cpp")
remotes::install_github("StochasticTree/stochtree-cpp", ref="r-dev")
```

# C++ Core
Expand Down
9 changes: 4 additions & 5 deletions include/stochtree/ensemble.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,11 @@ class TreeEnsemble {

inline void PredictInplace(ForestDataset& dataset, std::vector<double> &output,
int tree_begin, int tree_end, data_size_t offset = 0) {
if (dataset.HasBasis()) {
CHECK(!is_leaf_constant_);
PredictInplace(dataset.GetCovariates(), dataset.GetBasis(), output, tree_begin, tree_end, offset);
} else {
CHECK(is_leaf_constant_);
if (is_leaf_constant_) {
PredictInplace(dataset.GetCovariates(), output, tree_begin, tree_end, offset);
} else {
CHECK(dataset.HasBasis());
PredictInplace(dataset.GetCovariates(), dataset.GetBasis(), output, tree_begin, tree_end, offset);
}
}

Expand Down
9 changes: 6 additions & 3 deletions include/stochtree/leaf_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class GaussianConstantLeafModel {
std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id);
void EvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int split_node_id,
std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values, std::vector<FeatureType>& cutpoint_feature_types,
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<FeatureType>& feature_types);
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
std::vector<FeatureType>& feature_types);
double SplitLogMarginalLikelihood(GaussianConstantSuffStat& left_stat, GaussianConstantSuffStat& right_stat, double global_variance);
double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat& suff_stat, double global_variance);
double PosteriorParameterMean(GaussianConstantSuffStat& suff_stat, double global_variance);
Expand Down Expand Up @@ -136,7 +137,8 @@ class GaussianUnivariateRegressionLeafModel {
std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id);
void EvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int split_node_id,
std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values, std::vector<FeatureType>& cutpoint_feature_types,
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<FeatureType>& feature_types);
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
std::vector<FeatureType>& feature_types);
double SplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& left_stat, GaussianUnivariateRegressionSuffStat& right_stat, double global_variance);
double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
Expand Down Expand Up @@ -203,7 +205,8 @@ class GaussianMultivariateRegressionLeafModel {
std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id);
void EvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int split_node_id,
std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values, std::vector<FeatureType>& cutpoint_feature_types,
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<FeatureType>& feature_types);
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
std::vector<FeatureType>& feature_types);
double SplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& left_stat, GaussianMultivariateRegressionSuffStat& right_stat, double global_variance);
double NoSplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance);
Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance);
Expand Down
Loading

0 comments on commit 2b52159

Please sign in to comment.