diff --git a/R/cpp11.R b/R/cpp11.R
index e8fbf36a..b95f07e6 100644
--- a/R/cpp11.R
+++ b/R/cpp11.R
@@ -252,6 +252,14 @@ average_max_depth_forest_container_cpp <- function(forest_samples) {
.Call(`_stochtree_average_max_depth_forest_container_cpp`, forest_samples)
}
+num_leaves_ensemble_forest_container_cpp <- function(forest_samples, forest_num) {
+ .Call(`_stochtree_num_leaves_ensemble_forest_container_cpp`, forest_samples, forest_num)
+}
+
+sum_leaves_squared_ensemble_forest_container_cpp <- function(forest_samples, forest_num) {
+ .Call(`_stochtree_sum_leaves_squared_ensemble_forest_container_cpp`, forest_samples, forest_num)
+}
+
num_trees_forest_container_cpp <- function(forest_samples) {
.Call(`_stochtree_num_trees_forest_container_cpp`, forest_samples)
}
diff --git a/R/forest.R b/R/forest.R
index c2a2b177..beee46aa 100644
--- a/R/forest.R
+++ b/R/forest.R
@@ -323,6 +323,22 @@ ForestSamples <- R6::R6Class(
#' @return Average maximum depth
average_max_depth = function() {
return(average_max_depth_forest_container_cpp(self$forest_container_ptr))
+ },
+
+ #' @description
+ #' Number of leaves in a given ensemble in a `ForestContainer` object
+ #' @param forest_num Index of the ensemble to be queried
+ #' @return Count of leaves in the ensemble stored at `forest_num`
+ num_leaves = function(forest_num) {
+ return(num_leaves_ensemble_forest_container_cpp(self$forest_container_ptr, forest_num))
+ },
+
+ #' @description
+ #' Sum of squared (raw) leaf values in a given ensemble in a `ForestContainer` object
+ #' @param forest_num Index of the ensemble to be queried
+ #' @return Average maximum depth
+ sum_leaves_squared = function(forest_num) {
+ return(sum_leaves_squared_ensemble_forest_container_cpp(self$forest_container_ptr, forest_num))
}
)
)
diff --git a/man/ForestSamples.Rd b/man/ForestSamples.Rd
index 88d21f1f..67810072 100644
--- a/man/ForestSamples.Rd
+++ b/man/ForestSamples.Rd
@@ -43,6 +43,8 @@ Wrapper around a C++ container of tree ensembles
\item \href{#method-ForestSamples-ensemble_tree_max_depth}{\code{ForestSamples$ensemble_tree_max_depth()}}
\item \href{#method-ForestSamples-average_ensemble_max_depth}{\code{ForestSamples$average_ensemble_max_depth()}}
\item \href{#method-ForestSamples-average_max_depth}{\code{ForestSamples$average_max_depth()}}
+\item \href{#method-ForestSamples-num_leaves}{\code{ForestSamples$num_leaves()}}
+\item \href{#method-ForestSamples-sum_leaves_squared}{\code{ForestSamples$sum_leaves_squared()}}
}
}
\if{html}{\out{
}}
@@ -622,6 +624,46 @@ Average the maximum depth of each tree in each ensemble in a \code{ForestContain
\if{html}{\out{}}\preformatted{ForestSamples$average_max_depth()}\if{html}{\out{
}}
}
+\subsection{Returns}{
+Average maximum depth
+}
+}
+\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-ForestSamples-num_leaves}{}}}
+\subsection{Method \code{num_leaves()}}{
+Number of leaves in a given ensemble in a \code{ForestContainer} object
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{ForestSamples$num_leaves(forest_num)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{forest_num}}{Index of the ensemble to be queried}
+}
+\if{html}{\out{
}}
+}
+\subsection{Returns}{
+Count of leaves in the ensemble stored at \code{forest_num}
+}
+}
+\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-ForestSamples-sum_leaves_squared}{}}}
+\subsection{Method \code{sum_leaves_squared()}}{
+Sum of squared (raw) leaf values in a given ensemble in a \code{ForestContainer} object
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{ForestSamples$sum_leaves_squared(forest_num)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{forest_num}}{Index of the ensemble to be queried}
+}
+\if{html}{\out{
}}
+}
\subsection{Returns}{
Average maximum depth
}
diff --git a/src/cpp11.cpp b/src/cpp11.cpp
index 18dc931d..7d7b3a04 100644
--- a/src/cpp11.cpp
+++ b/src/cpp11.cpp
@@ -469,6 +469,20 @@ extern "C" SEXP _stochtree_average_max_depth_forest_container_cpp(SEXP forest_sa
END_CPP11
}
// forest.cpp
+int num_leaves_ensemble_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num);
+extern "C" SEXP _stochtree_num_leaves_ensemble_forest_container_cpp(SEXP forest_samples, SEXP forest_num) {
+ BEGIN_CPP11
+ return cpp11::as_sexp(num_leaves_ensemble_forest_container_cpp(cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>(forest_num)));
+ END_CPP11
+}
+// forest.cpp
+double sum_leaves_squared_ensemble_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num);
+extern "C" SEXP _stochtree_sum_leaves_squared_ensemble_forest_container_cpp(SEXP forest_samples, SEXP forest_num) {
+ BEGIN_CPP11
+ return cpp11::as_sexp(sum_leaves_squared_ensemble_forest_container_cpp(cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>(forest_num)));
+ END_CPP11
+}
+// forest.cpp
int num_trees_forest_container_cpp(cpp11::external_pointer forest_samples);
extern "C" SEXP _stochtree_num_trees_forest_container_cpp(SEXP forest_samples) {
BEGIN_CPP11
@@ -1034,6 +1048,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_stochtree_json_load_string_cpp", (DL_FUNC) &_stochtree_json_load_string_cpp, 2},
{"_stochtree_json_save_file_cpp", (DL_FUNC) &_stochtree_json_save_file_cpp, 2},
{"_stochtree_json_save_forest_container_cpp", (DL_FUNC) &_stochtree_json_save_forest_container_cpp, 2},
+ {"_stochtree_num_leaves_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaves_ensemble_forest_container_cpp, 2},
{"_stochtree_num_samples_forest_container_cpp", (DL_FUNC) &_stochtree_num_samples_forest_container_cpp, 1},
{"_stochtree_num_trees_forest_container_cpp", (DL_FUNC) &_stochtree_num_trees_forest_container_cpp, 1},
{"_stochtree_output_dimension_forest_container_cpp", (DL_FUNC) &_stochtree_output_dimension_forest_container_cpp, 1},
@@ -1089,6 +1104,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2},
{"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2},
{"_stochtree_subtract_from_column_vector_cpp", (DL_FUNC) &_stochtree_subtract_from_column_vector_cpp, 2},
+ {"_stochtree_sum_leaves_squared_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_sum_leaves_squared_ensemble_forest_container_cpp, 2},
{"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4},
{NULL, NULL, 0}
};
diff --git a/src/forest.cpp b/src/forest.cpp
index 87132b01..b0f47c96 100644
--- a/src/forest.cpp
+++ b/src/forest.cpp
@@ -96,6 +96,18 @@ double average_max_depth_forest_container_cpp(cpp11::external_pointerAverageMaxDepth();
}
+[[cpp11::register]]
+int num_leaves_ensemble_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num) {
+ StochTree::TreeEnsemble* forest = forest_samples->GetEnsemble(forest_num);
+ return forest->NumLeaves();
+}
+
+[[cpp11::register]]
+double sum_leaves_squared_ensemble_forest_container_cpp(cpp11::external_pointer forest_samples, int forest_num) {
+ StochTree::TreeEnsemble* forest = forest_samples->GetEnsemble(forest_num);
+ return forest->SumLeafSquared();
+}
+
[[cpp11::register]]
int num_trees_forest_container_cpp(cpp11::external_pointer forest_samples) {
return forest_samples->NumTrees();
diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp
index 95d70831..1240627c 100644
--- a/src/py_stochtree.cpp
+++ b/src/py_stochtree.cpp
@@ -164,6 +164,16 @@ class ForestContainerCpp {
return forest_samples_->NumSamples();
}
+ int NumLeaves(int forest_num) {
+ StochTree::TreeEnsemble* forest = forest_samples_->GetEnsemble(forest_num);
+ return forest->NumLeaves();
+ }
+
+ double SumLeafSquared(int forest_num) {
+ StochTree::TreeEnsemble* forest = forest_samples_->GetEnsemble(forest_num);
+ return forest->SumLeafSquared();
+ }
+
py::array_t Predict(ForestDatasetCpp& dataset) {
// Predict from the forest container
data_size_t n = dataset.NumRows();
@@ -1028,7 +1038,9 @@ PYBIND11_MODULE(stochtree_cpp, m) {
.def("GetTreeSplitCounts", &ForestContainerCpp::GetTreeSplitCounts)
.def("GetForestSplitCounts", &ForestContainerCpp::GetForestSplitCounts)
.def("GetOverallSplitCounts", &ForestContainerCpp::GetOverallSplitCounts)
- .def("GetGranularSplitCounts", &ForestContainerCpp::GetGranularSplitCounts);
+ .def("GetGranularSplitCounts", &ForestContainerCpp::GetGranularSplitCounts)
+ .def("NumLeaves", &ForestContainerCpp::NumLeaves)
+ .def("SumLeafSquared", &ForestContainerCpp::SumLeafSquared);
py::class_(m, "ForestSamplerCpp")
.def(py::init, int, data_size_t, double, double, int, int>())
diff --git a/stochtree/forest.py b/stochtree/forest.py
index 13d91689..866208ee 100644
--- a/stochtree/forest.py
+++ b/stochtree/forest.py
@@ -148,4 +148,22 @@ def get_granular_split_counts(self, num_features: int) -> np.array:
Total number of features in the training set
"""
return self.forest_container_cpp.GetGranularSplitCounts(num_features)
+
+ def num_leaves(self, forest_num: int) -> int:
+ """
+ Return the total number of leaves for a given forest in the ``ForestContainer``
+
+ forest_num : :obj:`int`
+ Index of the forest to be queried
+ """
+ return self.forest_container_cpp.NumLeaves(forest_num)
+
+ def sum_leaves_squared(self, forest_num: int) -> float:
+ """
+ Return the total sum of squared leaf values for a given forest in the ``ForestContainer``
+
+ forest_num : :obj:`int`
+ Index of the forest to be queried
+ """
+ return self.forest_container_cpp.SumLeafSquared(forest_num)
\ No newline at end of file