From bad7dffe0213aa5dd130f083f1b3069bdd1cba6b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 31 Oct 2024 02:18:01 -0400 Subject: [PATCH] Added R method --- R/forest.R | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/R/forest.R b/R/forest.R index 3843192d..c2a2b177 100644 --- a/R/forest.R +++ b/R/forest.R @@ -101,7 +101,7 @@ ForestSamples <- R6::R6Class( #' @param forest_dataset `ForestDataset` R class #' @param forest_num Index of the forest sample within the container #' @return matrix of predictions with as many rows as in forest_dataset - #' and as many columns as samples in the `ForestContainer` + #' and as many columns as dimensions in the leaves of trees in `ForestContainer` predict_raw_single_forest = function(forest_dataset, forest_num) { stopifnot(!is.null(forest_dataset$data_ptr)) # Unpack dimensions @@ -113,6 +113,21 @@ ForestSamples <- R6::R6Class( return(output) }, + #' @description + #' Predict "raw" leaf values (without being multiplied by basis) for a specific tree in a specific forest on every observation in `forest_dataset` + #' @param forest_dataset `ForestDataset` R class + #' @param forest_num Index of the forest sample within the container + #' @param tree_num Index of the tree to be queried + #' @return matrix of predictions with as many rows as in `forest_dataset` + #' and as many columns as dimensions in the leaves of trees in `ForestContainer` + predict_raw_single_tree = function(forest_dataset, forest_num, tree_num) { + stopifnot(!is.null(forest_dataset$data_ptr)) + + # Predict leaf values from forest + output <- predict_forest_raw_single_tree_cpp(self$forest_container_ptr, forest_dataset$data_ptr, forest_num, tree_num) + return(output) + }, + #' @description #' Set a constant predicted value for every tree in the ensemble. #' Stops program if any tree is more than a root node.