Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ create_column_vector_cpp <- function(outcome) {
.Call(`_stochtree_create_column_vector_cpp`, outcome)
}

add_to_column_vector_cpp <- function(outcome, update_vector) {
invisible(.Call(`_stochtree_add_to_column_vector_cpp`, outcome, update_vector))
}

subtract_from_column_vector_cpp <- function(outcome, update_vector) {
invisible(.Call(`_stochtree_subtract_from_column_vector_cpp`, outcome, update_vector))
}

get_residual_cpp <- function(vector_ptr) {
.Call(`_stochtree_get_residual_cpp`, vector_ptr)
}
Expand Down
34 changes: 34 additions & 0 deletions R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,40 @@ Outcome <- R6::R6Class(
#' @return R vector containing (copy of) the values in `Outcome` object
get_data = function() {
return(get_residual_cpp(self$data_ptr))
},

#' @description
#' Update the current state of the outcome (i.e. partial residual) data by adding the values of `update_vector`
#' @param update_vector Vector to be added to outcome
#' @return NULL
add_vector = function(update_vector) {
if (!is.numeric(update_vector)) {
stop("update_vector must be a numeric vector or 2d matrix")
} else {
dim_vec <- dim(update_vector)
if (!is.null(dim_vec)) {
if (length(dim_vec) > 2) stop("if update_vector is provided as a matrix, it must be 2d")
update_vector <- as.numeric(update_vector)
}
}
add_to_column_vector_cpp(self$data_ptr, update_vector)
},

#' @description
#' Update the current state of the outcome (i.e. partial residual) data by subtracting the values of `update_vector`
#' @param update_vector Vector to be subtracted from outcome
#' @return NULL
subtract_vector = function(update_vector) {
if (!is.numeric(update_vector)) {
stop("update_vector must be a numeric vector or 2d matrix")
} else {
dim_vec <- dim(update_vector)
if (!is.null(dim_vec)) {
if (length(dim_vec) > 2) stop("if update_vector is provided as a matrix, it must be 2d")
update_vector <- as.numeric(update_vector)
}
}
subtract_from_column_vector_cpp(self$data_ptr, update_vector)
}
)
)
Expand Down
3 changes: 3 additions & 0 deletions include/stochtree/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ class ColumnVector {
double GetElement(data_size_t row_num) {return data_(row_num);}
void SetElement(data_size_t row_num, double value) {data_(row_num) = value;}
void LoadData(double* data_ptr, data_size_t num_row);
void AddToData(double* data_ptr, data_size_t num_row);
void SubtractFromData(double* data_ptr, data_size_t num_row);
inline data_size_t NumRows() {return data_.size();}
inline Eigen::VectorXd& GetData() {return data_;}
private:
Eigen::VectorXd data_;
void UpdateData(double* data_ptr, data_size_t num_row, std::function<double(double, double)> op);
};

class ForestDataset {
Expand Down
42 changes: 42 additions & 0 deletions man/Outcome.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/bart.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/bcf.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions src/R_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,32 @@ cpp11::external_pointer<StochTree::ColumnVector> create_column_vector_cpp(cpp11:
return cpp11::external_pointer<StochTree::ColumnVector>(vector_ptr_.release());
}

[[cpp11::register]]
void add_to_column_vector_cpp(cpp11::external_pointer<StochTree::ColumnVector> outcome, cpp11::doubles update_vector) {
// Unpack pointers to data and dimensions
StochTree::data_size_t n = update_vector.size();
double* update_data_ptr = REAL(PROTECT(update_vector));

// Add to the outcome data using the C++ API
outcome->AddToData(update_data_ptr, n);

// Unprotect pointers to R data
UNPROTECT(1);
}

[[cpp11::register]]
void subtract_from_column_vector_cpp(cpp11::external_pointer<StochTree::ColumnVector> outcome, cpp11::doubles update_vector) {
// Unpack pointers to data and dimensions
StochTree::data_size_t n = update_vector.size();
double* update_data_ptr = REAL(PROTECT(update_vector));

// Add to the outcome data using the C++ API
outcome->SubtractFromData(update_data_ptr, n);

// Unprotect pointers to R data
UNPROTECT(1);
}

[[cpp11::register]]
cpp11::writable::doubles get_residual_cpp(cpp11::external_pointer<StochTree::ColumnVector> vector_ptr) {
// Initialize output vector
Expand Down
18 changes: 18 additions & 0 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,22 @@ extern "C" SEXP _stochtree_create_column_vector_cpp(SEXP outcome) {
END_CPP11
}
// R_data.cpp
void add_to_column_vector_cpp(cpp11::external_pointer<StochTree::ColumnVector> outcome, cpp11::doubles update_vector);
extern "C" SEXP _stochtree_add_to_column_vector_cpp(SEXP outcome, SEXP update_vector) {
BEGIN_CPP11
add_to_column_vector_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ColumnVector>>>(outcome), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(update_vector));
return R_NilValue;
END_CPP11
}
// R_data.cpp
void subtract_from_column_vector_cpp(cpp11::external_pointer<StochTree::ColumnVector> outcome, cpp11::doubles update_vector);
extern "C" SEXP _stochtree_subtract_from_column_vector_cpp(SEXP outcome, SEXP update_vector) {
BEGIN_CPP11
subtract_from_column_vector_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ColumnVector>>>(outcome), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(update_vector));
return R_NilValue;
END_CPP11
}
// R_data.cpp
cpp11::writable::doubles get_residual_cpp(cpp11::external_pointer<StochTree::ColumnVector> vector_ptr);
extern "C" SEXP _stochtree_get_residual_cpp(SEXP vector_ptr) {
BEGIN_CPP11
Expand Down Expand Up @@ -881,6 +897,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_stochtree_add_sample_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_forest_container_cpp, 1},
{"_stochtree_add_sample_value_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_value_forest_container_cpp, 2},
{"_stochtree_add_sample_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_vector_forest_container_cpp, 2},
{"_stochtree_add_to_column_vector_cpp", (DL_FUNC) &_stochtree_add_to_column_vector_cpp, 2},
{"_stochtree_adjust_residual_forest_container_cpp", (DL_FUNC) &_stochtree_adjust_residual_forest_container_cpp, 7},
{"_stochtree_all_roots_forest_container_cpp", (DL_FUNC) &_stochtree_all_roots_forest_container_cpp, 2},
{"_stochtree_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_average_max_depth_forest_container_cpp, 1},
Expand Down Expand Up @@ -992,6 +1009,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 5},
{"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2},
{"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2},
{"_stochtree_subtract_from_column_vector_cpp", (DL_FUNC) &_stochtree_subtract_from_column_vector_cpp, 2},
{"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4},
{"_stochtree_update_residual_forest_container_cpp", (DL_FUNC) &_stochtree_update_residual_forest_container_cpp, 5},
{NULL, NULL, 0}
Expand Down
24 changes: 24 additions & 0 deletions src/data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,30 @@ void ColumnVector::LoadData(double* data_ptr, data_size_t num_row) {
}
}

void ColumnVector::AddToData(double* data_ptr, data_size_t num_row) {
data_size_t num_existing_rows = NumRows();
CHECK_EQ(num_row, num_existing_rows);
// std::function<double(double, double)> op = std::plus<double>();
UpdateData(data_ptr, num_row, std::plus<double>());
}

void ColumnVector::SubtractFromData(double* data_ptr, data_size_t num_row) {
data_size_t num_existing_rows = NumRows();
CHECK_EQ(num_row, num_existing_rows);
// std::function<double(double, double)> op = std::minus<double>();
UpdateData(data_ptr, num_row, std::minus<double>());
}

void ColumnVector::UpdateData(double* data_ptr, data_size_t num_row, std::function<double(double, double)> op) {
double ptr_val;
double updated_val;
for (data_size_t i = 0; i < num_row; ++i) {
ptr_val = static_cast<double>(*(data_ptr + i));
updated_val = op(data_(i), ptr_val);
data_(i) = updated_val;
}
}

void LoadData(double* data_ptr, int num_row, int num_col, bool is_row_major, Eigen::MatrixXd& data_matrix) {
data_matrix.resize(num_row, num_col);

Expand Down
Loading