From 1134f621717099b50a19df962de332511f83a75e Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 8 Nov 2024 15:20:10 -0600 Subject: [PATCH 1/3] Enable access to forest leaf values and counts in R --- R/cpp11.R | 8 ++++++++ R/forest.R | 16 ++++++++++++++++ man/ForestSamples.Rd | 42 ++++++++++++++++++++++++++++++++++++++++++ src/cpp11.cpp | 16 ++++++++++++++++ src/forest.cpp | 12 ++++++++++++ 5 files changed, 94 insertions(+) diff --git a/R/cpp11.R b/R/cpp11.R index e8fbf36a..b95f07e6 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -252,6 +252,14 @@ average_max_depth_forest_container_cpp <- function(forest_samples) { .Call(`_stochtree_average_max_depth_forest_container_cpp`, forest_samples) } +num_leaves_ensemble_forest_container_cpp <- function(forest_samples, forest_num) { + .Call(`_stochtree_num_leaves_ensemble_forest_container_cpp`, forest_samples, forest_num) +} + +sum_leaves_squared_ensemble_forest_container_cpp <- function(forest_samples, forest_num) { + .Call(`_stochtree_sum_leaves_squared_ensemble_forest_container_cpp`, forest_samples, forest_num) +} + num_trees_forest_container_cpp <- function(forest_samples) { .Call(`_stochtree_num_trees_forest_container_cpp`, forest_samples) } diff --git a/R/forest.R b/R/forest.R index c2a2b177..beee46aa 100644 --- a/R/forest.R +++ b/R/forest.R @@ -323,6 +323,22 @@ ForestSamples <- R6::R6Class( #' @return Average maximum depth average_max_depth = function() { return(average_max_depth_forest_container_cpp(self$forest_container_ptr)) + }, + + #' @description + #' Number of leaves in a given ensemble in a `ForestContainer` object + #' @param forest_num Index of the ensemble to be queried + #' @return Count of leaves in the ensemble stored at `forest_num` + num_leaves = function(forest_num) { + return(num_leaves_ensemble_forest_container_cpp(self$forest_container_ptr, forest_num)) + }, + + #' @description + #' Sum of squared (raw) leaf values in a given ensemble in a `ForestContainer` object + #' @param forest_num Index of the ensemble to be queried + #' @return Average maximum depth + sum_leaves_squared = function(forest_num) { + return(sum_leaves_squared_ensemble_forest_container_cpp(self$forest_container_ptr, forest_num)) } ) ) diff --git a/man/ForestSamples.Rd b/man/ForestSamples.Rd index 88d21f1f..67810072 100644 --- a/man/ForestSamples.Rd +++ b/man/ForestSamples.Rd @@ -43,6 +43,8 @@ Wrapper around a C++ container of tree ensembles \item \href{#method-ForestSamples-ensemble_tree_max_depth}{\code{ForestSamples$ensemble_tree_max_depth()}} \item \href{#method-ForestSamples-average_ensemble_max_depth}{\code{ForestSamples$average_ensemble_max_depth()}} \item \href{#method-ForestSamples-average_max_depth}{\code{ForestSamples$average_max_depth()}} +\item \href{#method-ForestSamples-num_leaves}{\code{ForestSamples$num_leaves()}} +\item \href{#method-ForestSamples-sum_leaves_squared}{\code{ForestSamples$sum_leaves_squared()}} } } \if{html}{\out{
}} @@ -622,6 +624,46 @@ Average the maximum depth of each tree in each ensemble in a \code{ForestContain \if{html}{\out{
}}\preformatted{ForestSamples$average_max_depth()}\if{html}{\out{
}} } +\subsection{Returns}{ +Average maximum depth +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestSamples-num_leaves}{}}} +\subsection{Method \code{num_leaves()}}{ +Number of leaves in a given ensemble in a \code{ForestContainer} object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestSamples$num_leaves(forest_num)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{forest_num}}{Index of the ensemble to be queried} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +Count of leaves in the ensemble stored at \code{forest_num} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestSamples-sum_leaves_squared}{}}} +\subsection{Method \code{sum_leaves_squared()}}{ +Sum of squared (raw) leaf values in a given ensemble in a \code{ForestContainer} object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestSamples$sum_leaves_squared(forest_num)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{forest_num}}{Index of the ensemble to be queried} +} +\if{html}{\out{
}} +} \subsection{Returns}{ Average maximum depth } diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 18dc931d..7d7b3a04 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -469,6 +469,20 @@ extern "C" SEXP _stochtree_average_max_depth_forest_container_cpp(SEXP forest_sa END_CPP11 } // forest.cpp +int num_leaves_ensemble_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num); +extern "C" SEXP _stochtree_num_leaves_ensemble_forest_container_cpp(SEXP forest_samples, SEXP forest_num) { + BEGIN_CPP11 + return cpp11::as_sexp(num_leaves_ensemble_forest_container_cpp(cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>(forest_num))); + END_CPP11 +} +// forest.cpp +double sum_leaves_squared_ensemble_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num); +extern "C" SEXP _stochtree_sum_leaves_squared_ensemble_forest_container_cpp(SEXP forest_samples, SEXP forest_num) { + BEGIN_CPP11 + return cpp11::as_sexp(sum_leaves_squared_ensemble_forest_container_cpp(cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>(forest_num))); + END_CPP11 +} +// forest.cpp int num_trees_forest_container_cpp(cpp11::external_pointer forest_samples); extern "C" SEXP _stochtree_num_trees_forest_container_cpp(SEXP forest_samples) { BEGIN_CPP11 @@ -1034,6 +1048,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_json_load_string_cpp", (DL_FUNC) &_stochtree_json_load_string_cpp, 2}, {"_stochtree_json_save_file_cpp", (DL_FUNC) &_stochtree_json_save_file_cpp, 2}, {"_stochtree_json_save_forest_container_cpp", (DL_FUNC) &_stochtree_json_save_forest_container_cpp, 2}, + {"_stochtree_num_leaves_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaves_ensemble_forest_container_cpp, 2}, {"_stochtree_num_samples_forest_container_cpp", (DL_FUNC) &_stochtree_num_samples_forest_container_cpp, 1}, {"_stochtree_num_trees_forest_container_cpp", (DL_FUNC) &_stochtree_num_trees_forest_container_cpp, 1}, {"_stochtree_output_dimension_forest_container_cpp", (DL_FUNC) &_stochtree_output_dimension_forest_container_cpp, 1}, @@ -1089,6 +1104,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2}, {"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2}, {"_stochtree_subtract_from_column_vector_cpp", (DL_FUNC) &_stochtree_subtract_from_column_vector_cpp, 2}, + {"_stochtree_sum_leaves_squared_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_sum_leaves_squared_ensemble_forest_container_cpp, 2}, {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4}, {NULL, NULL, 0} }; diff --git a/src/forest.cpp b/src/forest.cpp index 87132b01..b0f47c96 100644 --- a/src/forest.cpp +++ b/src/forest.cpp @@ -96,6 +96,18 @@ double average_max_depth_forest_container_cpp(cpp11::external_pointerAverageMaxDepth(); } +[[cpp11::register]] +int num_leaves_ensemble_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num) { + StochTree::TreeEnsemble* forest = forest_samples->GetEnsemble(forest_num); + return forest->NumLeaves(); +} + +[[cpp11::register]] +double sum_leaves_squared_ensemble_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num) { + StochTree::TreeEnsemble* forest = forest_samples->GetEnsemble(forest_num); + return forest->SumLeafSquared(); +} + [[cpp11::register]] int num_trees_forest_container_cpp(cpp11::external_pointer forest_samples) { return forest_samples->NumTrees(); From b3f44f08a2d640dcbe5813a8c658377a9a731900 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 8 Nov 2024 20:11:54 -0600 Subject: [PATCH 2/3] Added same methods to python interface --- src/py_stochtree.cpp | 14 +++++++++++++- stochtree/forest.py | 18 ++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 95d70831..2adf5632 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -164,6 +164,16 @@ class ForestContainerCpp { return forest_samples_->NumSamples(); } + int NumLeaves(int forest_num) { + StochTree::TreeEnsemble* forest = forest_samples_->GetEnsemble(forest_num); + return forest->NumLeaves(); + } + + double SumLeafSquared(int forest_num) { + StochTree::TreeEnsemble* forest = forest_samples_->GetEnsemble(forest_num); + return forest->SumLeafSquared(); + } + py::array_t Predict(ForestDatasetCpp& dataset) { // Predict from the forest container data_size_t n = dataset.NumRows(); @@ -1028,7 +1038,9 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("GetTreeSplitCounts", &ForestContainerCpp::GetTreeSplitCounts) .def("GetForestSplitCounts", &ForestContainerCpp::GetForestSplitCounts) .def("GetOverallSplitCounts", &ForestContainerCpp::GetOverallSplitCounts) - .def("GetGranularSplitCounts", &ForestContainerCpp::GetGranularSplitCounts); + .def("GetGranularSplitCounts", &ForestContainerCpp::GetGranularSplitCounts) + .det("NumLeaves", &ForestContainerCpp::NumLeaves) + .det("SumLeafSquared", &ForestContainerCpp::SumLeafSquared); py::class_(m, "ForestSamplerCpp") .def(py::init, int, data_size_t, double, double, int, int>()) diff --git a/stochtree/forest.py b/stochtree/forest.py index 13d91689..866208ee 100644 --- a/stochtree/forest.py +++ b/stochtree/forest.py @@ -148,4 +148,22 @@ def get_granular_split_counts(self, num_features: int) -> np.array: Total number of features in the training set """ return self.forest_container_cpp.GetGranularSplitCounts(num_features) + + def num_leaves(self, forest_num: int) -> int: + """ + Return the total number of leaves for a given forest in the ``ForestContainer`` + + forest_num : :obj:`int` + Index of the forest to be queried + """ + return self.forest_container_cpp.NumLeaves(forest_num) + + def sum_leaves_squared(self, forest_num: int) -> float: + """ + Return the total sum of squared leaf values for a given forest in the ``ForestContainer`` + + forest_num : :obj:`int` + Index of the forest to be queried + """ + return self.forest_container_cpp.SumLeafSquared(forest_num) \ No newline at end of file From 454c28cc4ca6ce1bf4ff654eb9e5066ed03afa70 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 8 Nov 2024 23:31:23 -0600 Subject: [PATCH 3/3] Fixed typo --- src/py_stochtree.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 2adf5632..1240627c 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -1039,8 +1039,8 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("GetForestSplitCounts", &ForestContainerCpp::GetForestSplitCounts) .def("GetOverallSplitCounts", &ForestContainerCpp::GetOverallSplitCounts) .def("GetGranularSplitCounts", &ForestContainerCpp::GetGranularSplitCounts) - .det("NumLeaves", &ForestContainerCpp::NumLeaves) - .det("SumLeafSquared", &ForestContainerCpp::SumLeafSquared); + .def("NumLeaves", &ForestContainerCpp::NumLeaves) + .def("SumLeafSquared", &ForestContainerCpp::SumLeafSquared); py::class_(m, "ForestSamplerCpp") .def(py::init, int, data_size_t, double, double, int, int>())