diff --git a/R/forest.R b/R/forest.R index a6abf982..09f202ff 100644 --- a/R/forest.R +++ b/R/forest.R @@ -785,7 +785,7 @@ Forest <- R6::R6Class( #' Return constant leaf status of trees in a `Forest` object #' @return `TRUE` if leaves are constant, `FALSE` otherwise is_constant_leaf = function() { - return(is_leaf_constant_forest_container_cpp(self$forest_ptr)) + return(is_leaf_constant_active_forest_cpp(self$forest_ptr)) }, #' @description diff --git a/R/kernel.R b/R/kernel.R index f20630b2..381e13bf 100644 --- a/R/kernel.R +++ b/R/kernel.R @@ -32,6 +32,7 @@ #' #' - `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this #' +#' @param propensity (Optional) Propensities used for prediction (BCF-only). #' @param forest_inds (Optional) Indices of the forest sample(s) for which to compute leaf indices. If not provided, #' this function will return leaf indices for every sample of a forest. #' This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on. @@ -46,7 +47,7 @@ #' computeForestLeafIndices(bart_model, X, "mean") #' computeForestLeafIndices(bart_model, X, "mean", 0) #' computeForestLeafIndices(bart_model, X, "mean", c(1,3,9)) -computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) { +computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, propensity=NULL, forest_inds=NULL) { # Extract relevant forest container stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples")))) model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples")) @@ -93,6 +94,21 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, covariates_processed <- covariates } + # Handle BCF propensity covariate + if (model_type == "bcf") { + # Add propensities to covariate set if necessary + if (model_object$model_params$propensity_covariate != "none") { + if (is.null(propensity)) { + if (!model_object$model_params$internal_propensity_model) { + stop("propensity must be provided for this model") + } + # Compute propensity score using the internal bart model + propensity <- rowMeans(predict(model_object$bart_propensity_model, covariates)$y_hat) + } + covariates_processed <- cbind(covariates_processed, propensity) + } + } + # Preprocess forest indices num_forests <- forest_container$num_samples() if (is.null(forest_inds)) { diff --git a/man/computeForestLeafIndices.Rd b/man/computeForestLeafIndices.Rd index 169b1ea8..12c18fff 100644 --- a/man/computeForestLeafIndices.Rd +++ b/man/computeForestLeafIndices.Rd @@ -8,6 +8,7 @@ computeForestLeafIndices( model_object, covariates, forest_type = NULL, + propensity = NULL, forest_inds = NULL ) } @@ -37,6 +38,8 @@ Valid inputs depend on the model type, and whether or not a given forest was sam \item \code{NULL}: It is not necessary to disambiguate when this function is called directly on a \code{ForestSamples} object. This is the default value of this }} +\item{propensity}{(Optional) Propensities used for prediction (BCF-only).} + \item{forest_inds}{(Optional) Indices of the forest sample(s) for which to compute leaf indices. If not provided, this function will return leaf indices for every sample of a forest. This function uses 0-indexing, so the first forest sample corresponds to \code{forest_num = 0}, and so on.} diff --git a/stochtree/kernel.py b/stochtree/kernel.py index ec902303..86d9d8a5 100644 --- a/stochtree/kernel.py +++ b/stochtree/kernel.py @@ -9,7 +9,7 @@ from .forest import ForestContainer -def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestContainer], covariates: Union[np.array, pd.DataFrame], forest_type: str = None, forest_inds: Union[int, np.ndarray] = None): +def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestContainer], covariates: Union[np.array, pd.DataFrame], forest_type: str = None, propensity: np.array = None, forest_inds: Union[int, np.ndarray] = None): """ Compute and return a vector representation of a forest's leaf predictions for every observation in a dataset. @@ -37,6 +37,8 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC * **ForestContainer** * `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this + propensity : `np.array`, optional + Optional test set propensities. Must be provided if propensities were provided when the model was sampled. forest_inds : int or np.ndarray Indices of the forest sample(s) for which to compute leaf indices. If not provided, this function will return leaf indices for every sample of a forest. This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on. @@ -88,6 +90,19 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC else: covariates_processed = covariates covariates_processed = np.asfortranarray(covariates_processed) + + # Handle BCF propensity covariate + if model_type == "bcf": + if model_object.propensity_covariate != "none": + if propensity is None: + if not model_object.internal_propensity_model: + raise ValueError( + "Propensity scores not provided, but no propensity model was trained during sampling" + ) + propensity = np.mean( + model_object.bart_propensity_model.predict(covariates), axis=1, keepdims=True + ) + covariates_processed = np.c_[covariates_processed, propensity] # Preprocess forest indices num_forests = forest_container.num_samples()