From 84719bdc1e8967529945be0f01a113f483d88a1d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 18 Oct 2024 23:58:14 -0500 Subject: [PATCH] Update to allow overwriting the outcome (and properly reflecting this in the full/partial residual) in the R prototype interface --- R/cpp11.R | 8 ++++++ R/data.R | 26 +++++++++++++++++++ include/stochtree/data.h | 1 + include/stochtree/tree_sampler.h | 16 ++++++++++-- man/Outcome.Rd | 43 ++++++++++++++++++++++++++++++++ src/R_data.cpp | 21 ++++++++++++++++ src/cpp11.cpp | 18 +++++++++++++ src/data.cpp | 8 ++++++ 8 files changed, 139 insertions(+), 2 deletions(-) diff --git a/R/cpp11.R b/R/cpp11.R index 177b80d5..3a37913a 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -52,6 +52,14 @@ subtract_from_column_vector_cpp <- function(outcome, update_vector) { invisible(.Call(`_stochtree_subtract_from_column_vector_cpp`, outcome, update_vector)) } +overwrite_column_vector_cpp <- function(outcome, new_vector) { + invisible(.Call(`_stochtree_overwrite_column_vector_cpp`, outcome, new_vector)) +} + +propagate_trees_column_vector_cpp <- function(tracker, residual) { + invisible(.Call(`_stochtree_propagate_trees_column_vector_cpp`, tracker, residual)) +} + get_residual_cpp <- function(vector_ptr) { .Call(`_stochtree_get_residual_cpp`, vector_ptr) } diff --git a/R/data.R b/R/data.R index d7c7206d..fe9f7d2d 100644 --- a/R/data.R +++ b/R/data.R @@ -140,6 +140,32 @@ Outcome <- R6::R6Class( } } subtract_from_column_vector_cpp(self$data_ptr, update_vector) + }, + + #' @description + #' Update the current state of the outcome (i.e. partial residual) data by replacing each element with the elements of `new_vector` + #' @param new_vector Vector from which to overwrite the current data + #' @return NULL + update_data = function(new_vector) { + if (!is.numeric(new_vector)) { + stop("update_vector must be a numeric vector or 2d matrix") + } else { + dim_vec <- dim(new_vector) + if (!is.null(dim_vec)) { + if (length(dim_vec) > 2) stop("if update_vector is provided as a matrix, it must be 2d") + new_vector <- as.numeric(new_vector) + } + } + overwrite_column_vector_cpp(self$data_ptr, new_vector) + }, + + #' @description + #' Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree. + #' This function is run after the `update_data` method, which overwrites the partial residual with an entirely new stream of outcome data. + #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling + #' @return NULL + propagate_trees_new_outcome = function(forest_model) { + propagate_trees_column_vector_cpp(forest_model$tracker_ptr, self$data_ptr) } ) ) diff --git a/include/stochtree/data.h b/include/stochtree/data.h index 2d80f56d..fa2d9494 100644 --- a/include/stochtree/data.h +++ b/include/stochtree/data.h @@ -126,6 +126,7 @@ class ColumnVector { void LoadData(double* data_ptr, data_size_t num_row); void AddToData(double* data_ptr, data_size_t num_row); void SubtractFromData(double* data_ptr, data_size_t num_row); + void OverwriteData(double* data_ptr, data_size_t num_row); inline data_size_t NumRows() {return data_.size();} inline Eigen::VectorXd& GetData() {return data_;} private: diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 9e0d0562..42c851e4 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -196,8 +196,6 @@ static inline void UpdateModelVarianceForest(ForestTracker& tracker, ForestDatas tracker.SyncPredictions(); } - - static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest, bool requires_basis, std::function op) { data_size_t n = dataset.GetCovariates().rows(); @@ -225,6 +223,20 @@ static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestData tracker.SyncPredictions(); } +static inline void UpdateResidualNewOutcome(ForestTracker& tracker, ColumnVector& residual) { + data_size_t n = residual.NumRows(); + double pred_value; + double prev_outcome; + double new_resid; + for (data_size_t i = 0; i < n; i++) { + prev_outcome = residual.GetElement(i); + pred_value = tracker.GetSamplePrediction(i); + // Run op (either plus or minus) on the residual and the new prediction + new_resid = prev_outcome - pred_value; + residual.SetElement(i, new_resid); + } +} + static inline void UpdateMeanModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, bool requires_basis, std::function op, bool tree_new) { data_size_t n = dataset.GetCovariates().rows(); diff --git a/man/Outcome.Rd b/man/Outcome.Rd index 37a5d922..4b3c75ca 100644 --- a/man/Outcome.Rd +++ b/man/Outcome.Rd @@ -25,6 +25,8 @@ of the outcome minus the predictions of every other model term \item \href{#method-Outcome-get_data}{\code{Outcome$get_data()}} \item \href{#method-Outcome-add_vector}{\code{Outcome$add_vector()}} \item \href{#method-Outcome-subtract_vector}{\code{Outcome$subtract_vector()}} +\item \href{#method-Outcome-update_data}{\code{Outcome$update_data()}} +\item \href{#method-Outcome-propagate_trees_new_outcome}{\code{Outcome$propagate_trees_new_outcome()}} } } \if{html}{\out{
}} @@ -100,4 +102,45 @@ Update the current state of the outcome (i.e. partial residual) data by subtract NULL } } +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Outcome-update_data}{}}} +\subsection{Method \code{update_data()}}{ +Update the current state of the outcome (i.e. partial residual) data by replacing each element with the elements of \code{new_vector} +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Outcome$update_data(new_vector)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{new_vector}}{Vector from which to overwrite the current data} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +NULL +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Outcome-propagate_trees_new_outcome}{}}} +\subsection{Method \code{propagate_trees_new_outcome()}}{ +Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree. +This function is run after the \code{update_data} method, which overwrites the partial residual with an entirely new stream of outcome data. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Outcome$propagate_trees_new_outcome(forest_model)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{forest_model}}{\code{ForestModel} object storing tracking structures used in training / sampling} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +NULL +} +} } diff --git a/src/R_data.cpp b/src/R_data.cpp index 53b5c748..c6c75c29 100644 --- a/src/R_data.cpp +++ b/src/R_data.cpp @@ -1,6 +1,8 @@ #include #include #include +#include +#include #include #include @@ -136,6 +138,25 @@ void subtract_from_column_vector_cpp(cpp11::external_pointer outcome, cpp11::doubles new_vector) { + // Unpack pointers to data and dimensions + StochTree::data_size_t n = new_vector.size(); + double* update_data_ptr = REAL(PROTECT(new_vector)); + + // Add to the outcome data using the C++ API + outcome->OverwriteData(update_data_ptr, n); + + // Unprotect pointers to R data + UNPROTECT(1); +} + +[[cpp11::register]] +void propagate_trees_column_vector_cpp(cpp11::external_pointer tracker, + cpp11::external_pointer residual) { + StochTree::UpdateResidualNewOutcome(*tracker, *residual); +} + [[cpp11::register]] cpp11::writable::doubles get_residual_cpp(cpp11::external_pointer vector_ptr) { // Initialize output vector diff --git a/src/cpp11.cpp b/src/cpp11.cpp index e9145f26..fc40a21b 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -103,6 +103,22 @@ extern "C" SEXP _stochtree_subtract_from_column_vector_cpp(SEXP outcome, SEXP up END_CPP11 } // R_data.cpp +void overwrite_column_vector_cpp(cpp11::external_pointer outcome, cpp11::doubles new_vector); +extern "C" SEXP _stochtree_overwrite_column_vector_cpp(SEXP outcome, SEXP new_vector) { + BEGIN_CPP11 + overwrite_column_vector_cpp(cpp11::as_cpp>>(outcome), cpp11::as_cpp>(new_vector)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp +void propagate_trees_column_vector_cpp(cpp11::external_pointer tracker, cpp11::external_pointer residual); +extern "C" SEXP _stochtree_propagate_trees_column_vector_cpp(SEXP tracker, SEXP residual) { + BEGIN_CPP11 + propagate_trees_column_vector_cpp(cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(residual)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp cpp11::writable::doubles get_residual_cpp(cpp11::external_pointer vector_ptr); extern "C" SEXP _stochtree_get_residual_cpp(SEXP vector_ptr) { BEGIN_CPP11 @@ -1014,9 +1030,11 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_num_samples_forest_container_cpp", (DL_FUNC) &_stochtree_num_samples_forest_container_cpp, 1}, {"_stochtree_num_trees_forest_container_cpp", (DL_FUNC) &_stochtree_num_trees_forest_container_cpp, 1}, {"_stochtree_output_dimension_forest_container_cpp", (DL_FUNC) &_stochtree_output_dimension_forest_container_cpp, 1}, + {"_stochtree_overwrite_column_vector_cpp", (DL_FUNC) &_stochtree_overwrite_column_vector_cpp, 2}, {"_stochtree_predict_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_cpp, 2}, {"_stochtree_predict_forest_raw_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_cpp, 2}, {"_stochtree_predict_forest_raw_single_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_forest_cpp, 3}, + {"_stochtree_propagate_trees_column_vector_cpp", (DL_FUNC) &_stochtree_propagate_trees_column_vector_cpp, 2}, {"_stochtree_rfx_container_append_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_cpp, 3}, {"_stochtree_rfx_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_string_cpp, 3}, {"_stochtree_rfx_container_cpp", (DL_FUNC) &_stochtree_rfx_container_cpp, 2}, diff --git a/src/data.cpp b/src/data.cpp index 4d62af44..264f212d 100644 --- a/src/data.cpp +++ b/src/data.cpp @@ -110,6 +110,14 @@ void ColumnVector::SubtractFromData(double* data_ptr, data_size_t num_row) { UpdateData(data_ptr, num_row, std::minus()); } +void ColumnVector::OverwriteData(double* data_ptr, data_size_t num_row) { + double ptr_val; + for (data_size_t i = 0; i < num_row; ++i) { + ptr_val = static_cast(*(data_ptr + i)); + data_(i) = ptr_val; + } +} + void ColumnVector::UpdateData(double* data_ptr, data_size_t num_row, std::function op) { double ptr_val; double updated_val;