Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R/forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ Forest <- R6::R6Class(
#' Return constant leaf status of trees in a `Forest` object
#' @return `TRUE` if leaves are constant, `FALSE` otherwise
is_constant_leaf = function() {
return(is_leaf_constant_forest_container_cpp(self$forest_ptr))
return(is_leaf_constant_active_forest_cpp(self$forest_ptr))
},

#' @description
Expand Down
18 changes: 17 additions & 1 deletion R/kernel.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#'
#' - `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this
#'
#' @param propensity (Optional) Propensities used for prediction (BCF-only).
#' @param forest_inds (Optional) Indices of the forest sample(s) for which to compute leaf indices. If not provided,
#' this function will return leaf indices for every sample of a forest.
#' This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on.
Expand All @@ -46,7 +47,7 @@
#' computeForestLeafIndices(bart_model, X, "mean")
#' computeForestLeafIndices(bart_model, X, "mean", 0)
#' computeForestLeafIndices(bart_model, X, "mean", c(1,3,9))
computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) {
computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, propensity=NULL, forest_inds=NULL) {
# Extract relevant forest container
stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples"))))
model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples"))
Expand Down Expand Up @@ -93,6 +94,21 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL,
covariates_processed <- covariates
}

# Handle BCF propensity covariate
if (model_type == "bcf") {
# Add propensities to covariate set if necessary
if (model_object$model_params$propensity_covariate != "none") {
if (is.null(propensity)) {
if (!model_object$model_params$internal_propensity_model) {
stop("propensity must be provided for this model")
}
# Compute propensity score using the internal bart model
propensity <- rowMeans(predict(model_object$bart_propensity_model, covariates)$y_hat)
}
covariates_processed <- cbind(covariates_processed, propensity)
}
}

# Preprocess forest indices
num_forests <- forest_container$num_samples()
if (is.null(forest_inds)) {
Expand Down
3 changes: 3 additions & 0 deletions man/computeForestLeafIndices.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 16 additions & 1 deletion stochtree/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .forest import ForestContainer


def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestContainer], covariates: Union[np.array, pd.DataFrame], forest_type: str = None, forest_inds: Union[int, np.ndarray] = None):
def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestContainer], covariates: Union[np.array, pd.DataFrame], forest_type: str = None, propensity: np.array = None, forest_inds: Union[int, np.ndarray] = None):
"""
Compute and return a vector representation of a forest's leaf predictions for every observation in a dataset.

Expand Down Expand Up @@ -37,6 +37,8 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC
* **ForestContainer**
* `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this

propensity : `np.array`, optional
Optional test set propensities. Must be provided if propensities were provided when the model was sampled.
forest_inds : int or np.ndarray
Indices of the forest sample(s) for which to compute leaf indices. If not provided, this function will return leaf indices for every sample of a forest.
This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on.
Expand Down Expand Up @@ -88,6 +90,19 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC
else:
covariates_processed = covariates
covariates_processed = np.asfortranarray(covariates_processed)

# Handle BCF propensity covariate
if model_type == "bcf":
if model_object.propensity_covariate != "none":
if propensity is None:
if not model_object.internal_propensity_model:
raise ValueError(
"Propensity scores not provided, but no propensity model was trained during sampling"
)
propensity = np.mean(
model_object.bart_propensity_model.predict(covariates), axis=1, keepdims=True
)
covariates_processed = np.c_[covariates_processed, propensity]

# Preprocess forest indices
num_forests = forest_container.num_samples()
Expand Down
Loading