Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion R/forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ ForestSamples <- R6::R6Class(
#' @param forest_dataset `ForestDataset` R class
#' @param forest_num Index of the forest sample within the container
#' @return matrix of predictions with as many rows as in forest_dataset
#' and as many columns as samples in the `ForestContainer`
#' and as many columns as dimensions in the leaves of trees in `ForestContainer`
predict_raw_single_forest = function(forest_dataset, forest_num) {
stopifnot(!is.null(forest_dataset$data_ptr))
# Unpack dimensions
Expand All @@ -113,6 +113,21 @@ ForestSamples <- R6::R6Class(
return(output)
},

#' @description
#' Predict "raw" leaf values (without being multiplied by basis) for a specific tree in a specific forest on every observation in `forest_dataset`
#' @param forest_dataset `ForestDataset` R class
#' @param forest_num Index of the forest sample within the container
#' @param tree_num Index of the tree to be queried
#' @return matrix of predictions with as many rows as in `forest_dataset`
#' and as many columns as dimensions in the leaves of trees in `ForestContainer`
predict_raw_single_tree = function(forest_dataset, forest_num, tree_num) {
stopifnot(!is.null(forest_dataset$data_ptr))

# Predict leaf values from forest
output <- predict_forest_raw_single_tree_cpp(self$forest_container_ptr, forest_dataset$data_ptr, forest_num, tree_num)
return(output)
},

#' @description
#' Set a constant predicted value for every tree in the ensemble.
#' Stops program if any tree is more than a root node.
Expand Down
Loading