From 9713f43c6b4dcd7f192a88e183f3d2822b09772f Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 22 Sep 2025 15:48:19 -0500 Subject: [PATCH 1/8] Added ability to update variance weights --- R/cpp11.R | 4 ++++ man/ForestModel.Rd | 2 +- man/bcf.Rd | 1 + src/R_data.cpp | 11 +++++++++++ src/cpp11.cpp | 9 +++++++++ src/py_stochtree.cpp | 8 ++++++++ stochtree/data.py | 31 +++++++++++++++++++++++++++++++ 7 files changed, 65 insertions(+), 1 deletion(-) diff --git a/R/cpp11.R b/R/cpp11.R index a71a7722..9504dca3 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -36,6 +36,10 @@ forest_dataset_update_basis_cpp <- function(dataset_ptr, basis) { invisible(.Call(`_stochtree_forest_dataset_update_basis_cpp`, dataset_ptr, basis)) } +forest_dataset_update_var_weights_cpp <- function(dataset_ptr, weights) { + invisible(.Call(`_stochtree_forest_dataset_update_var_weights_cpp`, dataset_ptr, weights)) +} + forest_dataset_add_weights_cpp <- function(dataset_ptr, weights) { invisible(.Call(`_stochtree_forest_dataset_add_weights_cpp`, dataset_ptr, weights)) } diff --git a/man/ForestModel.Rd b/man/ForestModel.Rd index bec10621..3574ffad 100644 --- a/man/ForestModel.Rd +++ b/man/ForestModel.Rd @@ -115,7 +115,7 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR) \item{\code{global_model_config}}{GlobalModelConfig object containing global model parameters and settings} -\item{\code{num_threads}}{Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to \code{1}, otherwise to the maximum number of available threads.} +\item{\code{num_threads}}{Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's system, this will default to \code{1}, otherwise to the maximum number of available threads.} \item{\code{keep_forest}}{(Optional) Whether the updated forest sample should be saved to \code{forest_samples}. Default: \code{TRUE}.} diff --git a/man/bcf.Rd b/man/bcf.Rd index ed7cd238..01e5fab8 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -97,6 +97,7 @@ that were not in the training set.} \item \code{rfx_group_parameter_prior_cov} Prior covariance matrix for the random effects "group parameters." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. \item \code{rfx_variance_prior_shape} Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. \item \code{rfx_variance_prior_scale} Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. +\item \code{num_threads} Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to \code{1}, otherwise to the maximum number of available threads. }} \item{prognostic_forest_params}{(Optional) A list of prognostic forest model parameters, each of which has a default value processed internally, so this argument list is optional. diff --git a/src/R_data.cpp b/src/R_data.cpp index 0f495436..7d2268e0 100644 --- a/src/R_data.cpp +++ b/src/R_data.cpp @@ -84,6 +84,17 @@ void forest_dataset_update_basis_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights) { + // Add weights + StochTree::data_size_t n = weights.size(); + double* weight_data_ptr = REAL(PROTECT(weights)); + dataset_ptr->AddVarianceWeights(weight_data_ptr, n); + + // Unprotect pointers to R data + UNPROTECT(1); +} + [[cpp11::register]] void forest_dataset_add_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights) { // Add weights diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 67f79ab2..9a3f8a24 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -72,6 +72,14 @@ extern "C" SEXP _stochtree_forest_dataset_update_basis_cpp(SEXP dataset_ptr, SEX END_CPP11 } // R_data.cpp +void forest_dataset_update_var_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights); +extern "C" SEXP _stochtree_forest_dataset_update_var_weights_cpp(SEXP dataset_ptr, SEXP weights) { + BEGIN_CPP11 + forest_dataset_update_var_weights_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(weights)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp void forest_dataset_add_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights); extern "C" SEXP _stochtree_forest_dataset_add_weights_cpp(SEXP dataset_ptr, SEXP weights) { BEGIN_CPP11 @@ -1541,6 +1549,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2}, {"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2}, {"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2}, + {"_stochtree_forest_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_var_weights_cpp, 2}, {"_stochtree_forest_merge_cpp", (DL_FUNC) &_stochtree_forest_merge_cpp, 2}, {"_stochtree_forest_multiply_constant_cpp", (DL_FUNC) &_stochtree_forest_multiply_constant_cpp, 2}, {"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4}, diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 32bbd707..1bd6da9e 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -67,6 +67,14 @@ class ForestDatasetCpp { dataset_->AddVarianceWeights(data_ptr, num_row); } + void UpdateVarianceWeights(py::array_t weight_vector, data_size_t num_row) { + // Extract pointer to contiguous block of memory + double* data_ptr = static_cast(weight_vector.mutable_data()); + + // Load covariates + dataset_->AddVarianceWeights(data_ptr, num_row); + } + data_size_t NumRows() { return dataset_->NumObservations(); } diff --git a/stochtree/data.py b/stochtree/data.py index 8cbe76e0..597299c6 100644 --- a/stochtree/data.py +++ b/stochtree/data.py @@ -61,6 +61,16 @@ def update_basis(self, basis: np.array): basis_ = np.expand_dims(basis, 1) if np.ndim(basis) == 1 else basis n, p = basis_.shape basis_rowmajor = np.ascontiguousarray(basis_) + if not self.has_basis(): + raise ValueError("This dataset does not have a basis to update. Please use `add_basis` to create and initialize the values in the Dataset's basis matrix.") + if not isinstance(basis, np.ndarray): + raise ValueError("basis must be a numpy array.") + if basis.ndim != 2: + raise ValueError("basis must be a 2-dimensional numpy array.") + if self.num_basis() != p: + raise ValueError(f"The number of columns in the new basis ({p}) must match the number of columns in the existing basis ({self.num_basis()}).") + if self.num_observations() != n: + raise ValueError(f"The number of rows in the new basis ({n}) must match the number of rows in the existing basis ({self.num_observations()}).") self.dataset_cpp.UpdateBasis(basis_rowmajor, n, p, True) def add_variance_weights(self, variance_weights: np.array): @@ -74,6 +84,27 @@ def add_variance_weights(self, variance_weights: np.array): """ n = variance_weights.size self.dataset_cpp.AddVarianceWeights(variance_weights, n) + + def update_variance_weights(self, variance_weights: np.array): + """ + Update variance weights in a dataset. Allows users to build an ensemble that depends on + variance weights that are updated throughout the sampler. + + Parameters + ---------- + variance_weights : np.array + Univariate numpy array of variance weights. + """ + n = variance_weights.size + if not self.has_variance_weights(): + raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.") + if not isinstance(variance_weights, np.ndarray): + raise ValueError("variance_weights must be a numpy array.") + if variance_weights.ndim != 1: + raise ValueError("variance_weights must be a 1-dimensional numpy array.") + if self.num_observations() != n: + raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).") + self.dataset_cpp.AddVarianceWeights(variance_weights, n) def num_observations(self) -> int: """ From 50bcc41bb888617c77fdbaa5329520f891f8cb30 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 22 Sep 2025 15:57:12 -0500 Subject: [PATCH 2/8] Fixed basis update python code --- stochtree/data.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/stochtree/data.py b/stochtree/data.py index 597299c6..4743269d 100644 --- a/stochtree/data.py +++ b/stochtree/data.py @@ -58,15 +58,18 @@ def update_basis(self, basis: np.array): basis : np.array Numpy array of basis vectors. """ - basis_ = np.expand_dims(basis, 1) if np.ndim(basis) == 1 else basis - n, p = basis_.shape - basis_rowmajor = np.ascontiguousarray(basis_) if not self.has_basis(): raise ValueError("This dataset does not have a basis to update. Please use `add_basis` to create and initialize the values in the Dataset's basis matrix.") if not isinstance(basis, np.ndarray): raise ValueError("basis must be a numpy array.") - if basis.ndim != 2: - raise ValueError("basis must be a 2-dimensional numpy array.") + if np.ndim(basis) == 1: + basis_ = np.expand_dims(basis, 1) + elif np.ndim(basis) == 2: + basis_ = basis + else: + raise ValueError("basis must be a numpy array with one or two dimension.") + n, p = basis_.shape + basis_rowmajor = np.ascontiguousarray(basis_) if self.num_basis() != p: raise ValueError(f"The number of columns in the new basis ({p}) must match the number of columns in the existing basis ({self.num_basis()}).") if self.num_observations() != n: From 4d1baa92d3baa0a60dc69682570f72634cb593a6 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 22 Sep 2025 17:17:11 -0500 Subject: [PATCH 3/8] Added update methods to both ForestDataset and RandomEffectsDataset in R --- R/cpp11.R | 16 +++++++++-- R/data.R | 37 +++++++++++++++++++++++-- include/stochtree/data.h | 45 ++++++++++++++++++++++++++++++ man/ForestDataset.Rd | 20 ++++++++++++++ man/RandomEffectsDataset.Rd | 55 +++++++++++++++++++++++++++++++++++++ src/R_data.cpp | 35 +++++++++++++++++++++-- src/cpp11.cpp | 34 ++++++++++++++++++++--- tools/debug/dataset_demo.R | 39 ++++++++++++++++++++++++++ 8 files changed, 271 insertions(+), 10 deletions(-) create mode 100644 tools/debug/dataset_demo.R diff --git a/R/cpp11.R b/R/cpp11.R index 9504dca3..29a819e2 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -36,8 +36,8 @@ forest_dataset_update_basis_cpp <- function(dataset_ptr, basis) { invisible(.Call(`_stochtree_forest_dataset_update_basis_cpp`, dataset_ptr, basis)) } -forest_dataset_update_var_weights_cpp <- function(dataset_ptr, weights) { - invisible(.Call(`_stochtree_forest_dataset_update_var_weights_cpp`, dataset_ptr, weights)) +forest_dataset_update_var_weights_cpp <- function(dataset_ptr, weights, exponentiate) { + invisible(.Call(`_stochtree_forest_dataset_update_var_weights_cpp`, dataset_ptr, weights, exponentiate)) } forest_dataset_add_weights_cpp <- function(dataset_ptr, weights) { @@ -72,6 +72,18 @@ create_rfx_dataset_cpp <- function() { .Call(`_stochtree_create_rfx_dataset_cpp`) } +rfx_dataset_update_basis_cpp <- function(dataset_ptr, basis) { + invisible(.Call(`_stochtree_rfx_dataset_update_basis_cpp`, dataset_ptr, basis)) +} + +rfx_dataset_update_var_weights_cpp <- function(dataset_ptr, weights, exponentiate) { + invisible(.Call(`_stochtree_rfx_dataset_update_var_weights_cpp`, dataset_ptr, weights, exponentiate)) +} + +rfx_dataset_num_basis_cpp <- function(dataset) { + .Call(`_stochtree_rfx_dataset_num_basis_cpp`, dataset) +} + rfx_dataset_num_rows_cpp <- function(dataset) { .Call(`_stochtree_rfx_dataset_num_rows_cpp`, dataset) } diff --git a/R/data.R b/R/data.R index 4f35efc0..8e5dda03 100644 --- a/R/data.R +++ b/R/data.R @@ -36,7 +36,16 @@ ForestDataset <- R6::R6Class( update_basis = function(basis) { stopifnot(self$has_basis()) forest_dataset_update_basis_cpp(self$data_ptr, basis) - }, + }, + + #' @description + #' Update variance_weights in a dataset + #' @param variance_weights Updated vector of variance weights used to define individual variance / case weights + #' @param exponentiate Whether or not input vector should be exponentiated before being written to the Dataset's variance weights. Default: F. + update_variance_weights = function(variance_weights, exponentiate = F) { + stopifnot(self$has_variance_weights()) + forest_dataset_update_var_weights_cpp(self$data_ptr, variance_weights, exponentiate) + }, #' @description #' Return number of observations in a `ForestDataset` object @@ -190,12 +199,36 @@ RandomEffectsDataset <- R6::R6Class( } }, + #' @description + #' Update basis matrix in a dataset + #' @param basis Updated matrix of bases used to define random slopes / intercepts + update_basis = function(basis) { + stopifnot(self$has_basis()) + rfx_dataset_update_basis_cpp(self$data_ptr, basis) + }, + + #' @description + #' Update variance_weights in a dataset + #' @param variance_weights Updated vector of variance weights used to define individual variance / case weights + #' @param exponentiate Whether or not input vector should be exponentiated before being written to the RandomEffectsDataset's variance weights. Default: F. + update_variance_weights = function(variance_weights, exponentiate = F) { + stopifnot(self$has_variance_weights()) + rfx_dataset_update_var_weights_cpp(self$data_ptr, variance_weights, exponentiate) + }, + #' @description #' Return number of observations in a `RandomEffectsDataset` object #' @return Observation count num_observations = function() { return(rfx_dataset_num_rows_cpp(self$data_ptr)) - }, + }, + + #' @description + #' Return dimension of the basis matrix in a `RandomEffectsDataset` object + #' @return Basis vector count + num_basis = function() { + return(rfx_dataset_num_basis_cpp(self$data_ptr)) + }, #' @description #' Whether or not a dataset has group label indices diff --git a/include/stochtree/data.h b/include/stochtree/data.h index cc62ab06..5eb3d50c 100644 --- a/include/stochtree/data.h +++ b/include/stochtree/data.h @@ -497,6 +497,7 @@ class RandomEffectsDataset { */ void AddBasis(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) { basis_ = ColumnMatrix(data_ptr, num_row, num_col, is_row_major); + num_basis_ = num_col; has_basis_ = true; } /*! @@ -509,6 +510,49 @@ class RandomEffectsDataset { var_weights_ = ColumnVector(data_ptr, num_row); has_var_weights_ = true; } + /*! + * \brief Update the data in the internal basis matrix to new values stored in a raw double array + * + * \param data_ptr Pointer to first element of a contiguous array of data storing a basis matrix + * \param num_row Number of rows in the basis matrix + * \param num_col Number of columns in the basis matrix + * \param is_row_major Whether or not the data in `data_ptr` are organized in a row-major or column-major fashion + */ + void UpdateBasis(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) { + CHECK(has_basis_); + CHECK_EQ(num_col, num_basis_); + // Copy data from R / Python process memory to Eigen matrix + double temp_value; + for (data_size_t i = 0; i < num_row; ++i) { + for (int j = 0; j < num_col; ++j) { + if (is_row_major){ + // Numpy 2-d arrays are stored in "row major" order + temp_value = static_cast(*(data_ptr + static_cast(num_col) * i + j)); + } else { + // R matrices are stored in "column major" order + temp_value = static_cast(*(data_ptr + static_cast(num_row) * j + i)); + } + basis_.SetElement(i, j, temp_value); + } + } + } + /*! + * \brief Update the data in the internal variance weight vector to new values stored in a raw double array + * + * \param data_ptr Pointer to first element of a contiguous array of data storing a weight vector + * \param num_row Number of rows in the weight vector + * \param exponentiate Whether or not inputs should be exponentiated before being saved to var weight vector + */ + void UpdateVarWeights(double* data_ptr, data_size_t num_row, bool exponentiate = true) { + CHECK(has_var_weights_); + // Copy data from R / Python process memory to Eigen vector + double temp_value; + for (data_size_t i = 0; i < num_row; ++i) { + if (exponentiate) temp_value = std::exp(static_cast(*(data_ptr + i))); + else temp_value = static_cast(*(data_ptr + i)); + var_weights_.SetElement(i, temp_value); + } + } /*! * \brief Copy / load group indices for random effects * @@ -570,6 +614,7 @@ class RandomEffectsDataset { ColumnMatrix basis_; ColumnVector var_weights_; std::vector group_labels_; + int num_basis_{0}; bool has_basis_{false}; bool has_var_weights_{false}; bool has_group_labels_{false}; diff --git a/man/ForestDataset.Rd b/man/ForestDataset.Rd index 08ec47ad..a560f350 100644 --- a/man/ForestDataset.Rd +++ b/man/ForestDataset.Rd @@ -20,6 +20,7 @@ weights are optional. \itemize{ \item \href{#method-ForestDataset-new}{\code{ForestDataset$new()}} \item \href{#method-ForestDataset-update_basis}{\code{ForestDataset$update_basis()}} +\item \href{#method-ForestDataset-update_variance_weights}{\code{ForestDataset$update_variance_weights()}} \item \href{#method-ForestDataset-num_observations}{\code{ForestDataset$num_observations()}} \item \href{#method-ForestDataset-num_covariates}{\code{ForestDataset$num_covariates()}} \item \href{#method-ForestDataset-num_basis}{\code{ForestDataset$num_basis()}} @@ -69,6 +70,25 @@ Update basis matrix in a dataset } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-update_variance_weights}{}}} +\subsection{Method \code{update_variance_weights()}}{ +Update variance_weights in a dataset +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$update_variance_weights(variance_weights, exponentiate = F)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{variance_weights}}{Updated vector of variance weights used to define individual variance / case weights} + +\item{\code{exponentiate}}{Whether or not input vector should be exponentiated before being written to the Dataset's variance weights. Default: F.} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ForestDataset-num_observations}{}}} \subsection{Method \code{num_observations()}}{ diff --git a/man/RandomEffectsDataset.Rd b/man/RandomEffectsDataset.Rd index 2a516321..4bb4fdaa 100644 --- a/man/RandomEffectsDataset.Rd +++ b/man/RandomEffectsDataset.Rd @@ -18,7 +18,10 @@ bases, and variance weights. Variance weights are optional. \subsection{Public methods}{ \itemize{ \item \href{#method-RandomEffectsDataset-new}{\code{RandomEffectsDataset$new()}} +\item \href{#method-RandomEffectsDataset-update_basis}{\code{RandomEffectsDataset$update_basis()}} +\item \href{#method-RandomEffectsDataset-update_variance_weights}{\code{RandomEffectsDataset$update_variance_weights()}} \item \href{#method-RandomEffectsDataset-num_observations}{\code{RandomEffectsDataset$num_observations()}} +\item \href{#method-RandomEffectsDataset-num_basis}{\code{RandomEffectsDataset$num_basis()}} \item \href{#method-RandomEffectsDataset-has_group_labels}{\code{RandomEffectsDataset$has_group_labels()}} \item \href{#method-RandomEffectsDataset-has_basis}{\code{RandomEffectsDataset$has_basis()}} \item \href{#method-RandomEffectsDataset-has_variance_weights}{\code{RandomEffectsDataset$has_variance_weights()}} @@ -49,6 +52,45 @@ A new \code{RandomEffectsDataset} object. } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectsDataset-update_basis}{}}} +\subsection{Method \code{update_basis()}}{ +Update basis matrix in a dataset +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectsDataset$update_basis(basis)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{basis}}{Updated matrix of bases used to define random slopes / intercepts} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectsDataset-update_variance_weights}{}}} +\subsection{Method \code{update_variance_weights()}}{ +Update variance_weights in a dataset +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectsDataset$update_variance_weights( + variance_weights, + exponentiate = F +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{variance_weights}}{Updated vector of variance weights used to define individual variance / case weights} + +\item{\code{exponentiate}}{Whether or not input vector should be exponentiated before being written to the RandomEffectsDataset's variance weights. Default: F.} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-RandomEffectsDataset-num_observations}{}}} \subsection{Method \code{num_observations()}}{ @@ -62,6 +104,19 @@ Observation count } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectsDataset-num_basis}{}}} +\subsection{Method \code{num_basis()}}{ +Return dimension of the basis matrix in a \code{RandomEffectsDataset} object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectsDataset$num_basis()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Basis vector count +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-RandomEffectsDataset-has_group_labels}{}}} \subsection{Method \code{has_group_labels()}}{ diff --git a/src/R_data.cpp b/src/R_data.cpp index 7d2268e0..021be76a 100644 --- a/src/R_data.cpp +++ b/src/R_data.cpp @@ -85,11 +85,11 @@ void forest_dataset_update_basis_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights) { +void forest_dataset_update_var_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights, bool exponentiate) { // Add weights StochTree::data_size_t n = weights.size(); double* weight_data_ptr = REAL(PROTECT(weights)); - dataset_ptr->AddVarianceWeights(weight_data_ptr, n); + dataset_ptr->UpdateVarWeights(weight_data_ptr, n, exponentiate); // Unprotect pointers to R data UNPROTECT(1); @@ -191,6 +191,37 @@ cpp11::external_pointer create_rfx_dataset_cpp( return cpp11::external_pointer(dataset_ptr_.release()); } +[[cpp11::register]] +void rfx_dataset_update_basis_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles_matrix<> basis) { + // TODO: add handling code on the R side to ensure matrices are column-major + bool row_major{false}; + + // Add basis + StochTree::data_size_t n = basis.nrow(); + int num_basis = basis.ncol(); + double* basis_data_ptr = REAL(PROTECT(basis)); + dataset_ptr->UpdateBasis(basis_data_ptr, n, num_basis, row_major); + + // Unprotect pointers to R data + UNPROTECT(1); +} + +[[cpp11::register]] +void rfx_dataset_update_var_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights, bool exponentiate) { + // Add weights + StochTree::data_size_t n = weights.size(); + double* weight_data_ptr = REAL(PROTECT(weights)); + dataset_ptr->UpdateVarWeights(weight_data_ptr, n, exponentiate); + + // Unprotect pointers to R data + UNPROTECT(1); +} + +[[cpp11::register]] +int rfx_dataset_num_basis_cpp(cpp11::external_pointer dataset) { + return dataset->NumBases(); +} + [[cpp11::register]] int rfx_dataset_num_rows_cpp(cpp11::external_pointer dataset) { return dataset->NumObservations(); diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 9a3f8a24..5a64afc0 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -72,10 +72,10 @@ extern "C" SEXP _stochtree_forest_dataset_update_basis_cpp(SEXP dataset_ptr, SEX END_CPP11 } // R_data.cpp -void forest_dataset_update_var_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights); -extern "C" SEXP _stochtree_forest_dataset_update_var_weights_cpp(SEXP dataset_ptr, SEXP weights) { +void forest_dataset_update_var_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights, bool exponentiate); +extern "C" SEXP _stochtree_forest_dataset_update_var_weights_cpp(SEXP dataset_ptr, SEXP weights, SEXP exponentiate) { BEGIN_CPP11 - forest_dataset_update_var_weights_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(weights)); + forest_dataset_update_var_weights_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(weights), cpp11::as_cpp>(exponentiate)); return R_NilValue; END_CPP11 } @@ -141,6 +141,29 @@ extern "C" SEXP _stochtree_create_rfx_dataset_cpp() { END_CPP11 } // R_data.cpp +void rfx_dataset_update_basis_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles_matrix<> basis); +extern "C" SEXP _stochtree_rfx_dataset_update_basis_cpp(SEXP dataset_ptr, SEXP basis) { + BEGIN_CPP11 + rfx_dataset_update_basis_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>>(basis)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp +void rfx_dataset_update_var_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights, bool exponentiate); +extern "C" SEXP _stochtree_rfx_dataset_update_var_weights_cpp(SEXP dataset_ptr, SEXP weights, SEXP exponentiate) { + BEGIN_CPP11 + rfx_dataset_update_var_weights_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(weights), cpp11::as_cpp>(exponentiate)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp +int rfx_dataset_num_basis_cpp(cpp11::external_pointer dataset); +extern "C" SEXP _stochtree_rfx_dataset_num_basis_cpp(SEXP dataset) { + BEGIN_CPP11 + return cpp11::as_sexp(rfx_dataset_num_basis_cpp(cpp11::as_cpp>>(dataset))); + END_CPP11 +} +// R_data.cpp int rfx_dataset_num_rows_cpp(cpp11::external_pointer dataset); extern "C" SEXP _stochtree_rfx_dataset_num_rows_cpp(SEXP dataset) { BEGIN_CPP11 @@ -1549,7 +1572,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2}, {"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2}, {"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2}, - {"_stochtree_forest_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_var_weights_cpp, 2}, + {"_stochtree_forest_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_var_weights_cpp, 3}, {"_stochtree_forest_merge_cpp", (DL_FUNC) &_stochtree_forest_merge_cpp, 2}, {"_stochtree_forest_multiply_constant_cpp", (DL_FUNC) &_stochtree_forest_multiply_constant_cpp, 2}, {"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4}, @@ -1671,7 +1694,10 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rfx_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_basis_cpp, 1}, {"_stochtree_rfx_dataset_has_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_group_labels_cpp, 1}, {"_stochtree_rfx_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_variance_weights_cpp, 1}, + {"_stochtree_rfx_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_basis_cpp, 1}, {"_stochtree_rfx_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_rows_cpp, 1}, + {"_stochtree_rfx_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_basis_cpp, 2}, + {"_stochtree_rfx_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_var_weights_cpp, 3}, {"_stochtree_rfx_group_ids_from_json_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_cpp, 2}, {"_stochtree_rfx_group_ids_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_string_cpp, 2}, {"_stochtree_rfx_label_mapper_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_cpp, 1}, diff --git a/tools/debug/dataset_demo.R b/tools/debug/dataset_demo.R new file mode 100644 index 00000000..3510f0ab --- /dev/null +++ b/tools/debug/dataset_demo.R @@ -0,0 +1,39 @@ +# Load libraries +library(stochtree) + +# Generate "forest" data +n <- 20 +num_covariates <- 10 +num_basis <- 5 +covariates <- matrix(runif(n * num_covariates), ncol = num_covariates) +basis <- matrix(runif(n * num_basis), ncol = num_basis) +variance_weights <- runif(n) + +# Create a ForestDataset object +forest_dataset <- createForestDataset(covariates, basis, variance_weights) + +# Update forest_dataset's basis +new_basis <- matrix(runif(n * num_basis), ncol = num_basis) +forest_dataset$update_basis(new_basis) + +# Update forest_dataset's variance_weights +new_variance_weights <- runif(n) +forest_dataset$update_variance_weights(new_variance_weights) + +# Generate RFX data +group_ids <- sample(as.integer(c(1,2)), size = n, replace = T) +rfx_basis <- cbind(1, runif(n)) + +# Create a RandomEffectsDataset object +rfx_dataset <- createRandomEffectsDataset( + group_labels = group_ids, basis = rfx_basis, + variance_weights = variance_weights +) + +# Update rfx_dataset's basis +new_rfx_basis <- cbind(1, runif(n)) +rfx_dataset$update_basis(new_rfx_basis) + +# Update rfx_dataset's variance weights +rfx_dataset$update_variance_weights(new_variance_weights) + From 8cf4beddd1b494e75896dbe8b65077d88af829c6 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 22 Sep 2025 17:43:44 -0500 Subject: [PATCH 4/8] Updated python and R data interfaces --- include/stochtree/data.h | 15 ++++++++++++++ src/R_data.cpp | 8 ++++++++ src/py_stochtree.cpp | 18 +++++++++++++++- stochtree/data.py | 17 ++++++++++----- stochtree/random_effects.py | 41 +++++++++++++++++++++++++++++++++++++ test/python/test_data.py | 31 ++++++++++++++++++++++++++++ 6 files changed, 124 insertions(+), 6 deletions(-) create mode 100644 test/python/test_data.py diff --git a/include/stochtree/data.h b/include/stochtree/data.h index 5eb3d50c..a6061f4b 100644 --- a/include/stochtree/data.h +++ b/include/stochtree/data.h @@ -553,6 +553,21 @@ class RandomEffectsDataset { var_weights_.SetElement(i, temp_value); } } + /*! + * \brief Update a RandomEffectsDataset's group indices + * + * \param data_ptr Pointer to first element of a contiguous array of data storing a weight vector + * \param num_row Number of rows in the weight vector + * \param exponentiate Whether or not inputs should be exponentiated before being saved to var weight vector + */ + void UpdateGroupLabels(std::vector& group_labels, data_size_t num_row) { + CHECK(has_group_labels_); + CHECK_EQ(this->NumObservations(), num_row) + // Copy data from R / Python process memory to internal vector + for (data_size_t i = 0; i < num_row; ++i) { + group_labels_[i] = group_labels[i]; + } + } /*! * \brief Copy / load group indices for random effects * diff --git a/src/R_data.cpp b/src/R_data.cpp index 021be76a..596516f3 100644 --- a/src/R_data.cpp +++ b/src/R_data.cpp @@ -217,6 +217,14 @@ void rfx_dataset_update_var_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::integers group_labels) { + // Update group labels + int n = group_labels.size(); + std::vector group_labels_vec(group_labels.begin(), group_labels.end()); + dataset_ptr->UpdateGroupLabels(group_labels_vec, n); +} + [[cpp11::register]] int rfx_dataset_num_basis_cpp(cpp11::external_pointer dataset) { return dataset->NumBases(); diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 1bd6da9e..09f6c259 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -72,7 +72,7 @@ class ForestDatasetCpp { double* data_ptr = static_cast(weight_vector.mutable_data()); // Load covariates - dataset_->AddVarianceWeights(data_ptr, num_row); + dataset_->UpdateVarWeights(data_ptr, num_row); } data_size_t NumRows() { @@ -1297,6 +1297,22 @@ class RandomEffectsDatasetCpp { double* weight_data_ptr = static_cast(weights.mutable_data()); rfx_dataset_->AddVarianceWeights(weight_data_ptr, num_row); } + void UpdateBasis(py::array_t basis, data_size_t num_row, int num_col, bool row_major) { + double* basis_data_ptr = static_cast(basis.mutable_data()); + rfx_dataset_->UpdateBasis(basis_data_ptr, num_row, num_col, row_major); + } + void UpdateVarianceWeights(py::array_t weights, data_size_t num_row, bool exponentiate) { + double* weight_data_ptr = static_cast(weights.mutable_data()); + rfx_dataset_->UpdateVarWeights(weight_data_ptr, num_row, exponentiate); + } + void UpdateGroupLabels(py::array_t group_labels, data_size_t num_row) { + std::vector group_labels_vec(num_row); + auto accessor = group_labels.mutable_unchecked<1>(); + for (py::ssize_t i = 0; i < num_row; i++) { + group_labels_vec[i] = accessor(i); + } + rfx_dataset_->UpdateGroupLabels(group_labels_vec, num_row); + } bool HasGroupLabels() {return rfx_dataset_->HasGroupLabels();} bool HasBasis() {return rfx_dataset_->HasBasis();} bool HasVarianceWeights() {return rfx_dataset_->HasVarWeights();} diff --git a/stochtree/data.py b/stochtree/data.py index 4743269d..29c8280c 100644 --- a/stochtree/data.py +++ b/stochtree/data.py @@ -85,8 +85,14 @@ def add_variance_weights(self, variance_weights: np.array): variance_weights : np.array Univariate numpy array of variance weights. """ - n = variance_weights.size - self.dataset_cpp.AddVarianceWeights(variance_weights, n) + if not isinstance(variance_weights, np.ndarray): + raise ValueError("variance_weights must be a numpy array.") + variance_weights_ = np.squeeze(variance_weights) + n = variance_weights_.size + if variance_weights_.ndim != 1: + raise ValueError("variance_weights must be a 1-dimensional numpy array.") + + self.dataset_cpp.AddVarianceWeights(variance_weights_, n) def update_variance_weights(self, variance_weights: np.array): """ @@ -98,16 +104,17 @@ def update_variance_weights(self, variance_weights: np.array): variance_weights : np.array Univariate numpy array of variance weights. """ - n = variance_weights.size if not self.has_variance_weights(): raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.") if not isinstance(variance_weights, np.ndarray): raise ValueError("variance_weights must be a numpy array.") - if variance_weights.ndim != 1: + variance_weights_ = np.squeeze(variance_weights) + n = variance_weights_.size + if variance_weights_.ndim != 1: raise ValueError("variance_weights must be a 1-dimensional numpy array.") if self.num_observations() != n: raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).") - self.dataset_cpp.AddVarianceWeights(variance_weights, n) + self.dataset_cpp.UpdateVarianceWeights(variance_weights_, n) def num_observations(self) -> int: """ diff --git a/stochtree/random_effects.py b/stochtree/random_effects.py index 18144044..1597ff57 100644 --- a/stochtree/random_effects.py +++ b/stochtree/random_effects.py @@ -40,6 +40,23 @@ def add_group_labels(self, group_labels: np.array): n = group_labels_.shape[0] self.rfx_dataset_cpp.AddGroupLabels(group_labels_, n) + def update_group_labels(self, group_labels: np.array): + """ + Update group labels in a dataset + + Parameters + ---------- + group_labels : np.array + One-dimensional numpy array of group labels. + """ + group_labels_ = np.squeeze(group_labels) + if group_labels_.ndim > 1: + raise ValueError( + "group_labels must be a one-dimensional numpy array of group indices" + ) + n = group_labels_.shape[0] + self.rfx_dataset_cpp.UpdateGroupLabels(group_labels_, n) + def add_basis(self, basis: np.array): """ Add basis matrix to a dataset @@ -93,6 +110,30 @@ def add_variance_weights(self, variance_weights: np.array): ) n = variance_weights_.shape[0] self.rfx_dataset_cpp.AddVarianceWeights(variance_weights_, n) + + def update_variance_weights(self, variance_weights: np.array): + """ + Update variance weights in a dataset. Allows users to build an ensemble that depends on + variance weights that are updated throughout the sampler. + + Parameters + ---------- + variance_weights : np.array + Univariate numpy array of variance weights. + """ + if not self.has_variance_weights(): + raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.") + if not isinstance(variance_weights, np.ndarray): + raise ValueError("variance_weights must be a numpy array.") + variance_weights_ = np.squeeze(variance_weights) + if variance_weights_.ndim > 1: + raise ValueError( + "variance_weights must be a one-dimensional numpy array of group indices" + ) + n = variance_weights_.shape[0] + if self.num_observations() != n: + raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).") + self.rfx_dataset_cpp.UpdateVarianceWeights(variance_weights, n) def num_observations(self) -> int: """ diff --git a/test/python/test_data.py b/test/python/test_data.py new file mode 100644 index 00000000..7b012f24 --- /dev/null +++ b/test/python/test_data.py @@ -0,0 +1,31 @@ +import numpy as np + +from stochtree import Dataset + +class TestDataset: + def test_dataset_update(self): + # Generate data + n = 20 + num_covariates = 10 + num_basis = 5 + rng = np.random.default_rng() + covariates = rng.uniform(0, 1, size=(n, num_covariates)) + basis = rng.uniform(0, 1, size=(n, num_basis)) + variance_weights = rng.uniform(0, 1, size=n) + + # Construct dataset + forest_dataset = Dataset() + forest_dataset.add_covariates(covariates) + forest_dataset.add_basis(basis) + forest_dataset.add_variance_weights(variance_weights) + assert forest_dataset.num_observations() == n + assert forest_dataset.num_covariates() == num_covariates + assert forest_dataset.num_basis() == num_basis + assert forest_dataset.has_variance_weights() + + # Update dataset + new_basis = rng.uniform(0, 1, size=(n, num_basis)) + new_variance_weights = rng.uniform(0, 1, size=n) + with np.testing.assert_no_warnings(): + forest_dataset.update_basis(new_basis) + forest_dataset.update_variance_weights(new_variance_weights) From 69d60aa42875f619e94d9ed55c0f233920a36c24 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 22 Sep 2025 18:01:35 -0500 Subject: [PATCH 5/8] Updated interfaces and added unit tests --- R/cpp11.R | 4 +++ src/cpp11.cpp | 9 ++++++ src/py_stochtree.cpp | 8 +++-- stochtree/data.py | 6 ++-- stochtree/random_effects.py | 6 ++-- test/R/testthat/test-dataset.R | 56 ++++++++++++++++++++++++++++++++++ test/python/test_data.py | 33 +++++++++++++++++++- 7 files changed, 115 insertions(+), 7 deletions(-) create mode 100644 test/R/testthat/test-dataset.R diff --git a/R/cpp11.R b/R/cpp11.R index 29a819e2..f7e844b2 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -80,6 +80,10 @@ rfx_dataset_update_var_weights_cpp <- function(dataset_ptr, weights, exponentiat invisible(.Call(`_stochtree_rfx_dataset_update_var_weights_cpp`, dataset_ptr, weights, exponentiate)) } +rfx_dataset_update_group_labels_cpp <- function(dataset_ptr, group_labels) { + invisible(.Call(`_stochtree_rfx_dataset_update_group_labels_cpp`, dataset_ptr, group_labels)) +} + rfx_dataset_num_basis_cpp <- function(dataset) { .Call(`_stochtree_rfx_dataset_num_basis_cpp`, dataset) } diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 5a64afc0..9d454531 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -157,6 +157,14 @@ extern "C" SEXP _stochtree_rfx_dataset_update_var_weights_cpp(SEXP dataset_ptr, END_CPP11 } // R_data.cpp +void rfx_dataset_update_group_labels_cpp(cpp11::external_pointer dataset_ptr, cpp11::integers group_labels); +extern "C" SEXP _stochtree_rfx_dataset_update_group_labels_cpp(SEXP dataset_ptr, SEXP group_labels) { + BEGIN_CPP11 + rfx_dataset_update_group_labels_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(group_labels)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp int rfx_dataset_num_basis_cpp(cpp11::external_pointer dataset); extern "C" SEXP _stochtree_rfx_dataset_num_basis_cpp(SEXP dataset) { BEGIN_CPP11 @@ -1697,6 +1705,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rfx_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_basis_cpp, 1}, {"_stochtree_rfx_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_rows_cpp, 1}, {"_stochtree_rfx_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_basis_cpp, 2}, + {"_stochtree_rfx_dataset_update_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_group_labels_cpp, 2}, {"_stochtree_rfx_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_var_weights_cpp, 3}, {"_stochtree_rfx_group_ids_from_json_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_cpp, 2}, {"_stochtree_rfx_group_ids_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_string_cpp, 2}, diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 09f6c259..6ac7ee2e 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -67,12 +67,12 @@ class ForestDatasetCpp { dataset_->AddVarianceWeights(data_ptr, num_row); } - void UpdateVarianceWeights(py::array_t weight_vector, data_size_t num_row) { + void UpdateVarianceWeights(py::array_t weight_vector, data_size_t num_row, bool exponentiate) { // Extract pointer to contiguous block of memory double* data_ptr = static_cast(weight_vector.mutable_data()); // Load covariates - dataset_->UpdateVarWeights(data_ptr, num_row); + dataset_->UpdateVarWeights(data_ptr, num_row, exponentiate); } data_size_t NumRows() { @@ -2067,6 +2067,7 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("AddBasis", &ForestDatasetCpp::AddBasis) .def("UpdateBasis", &ForestDatasetCpp::UpdateBasis) .def("AddVarianceWeights", &ForestDatasetCpp::AddVarianceWeights) + .def("UpdateVarianceWeights", &ForestDatasetCpp::UpdateVarianceWeights) .def("NumRows", &ForestDatasetCpp::NumRows) .def("NumCovariates", &ForestDatasetCpp::NumCovariates) .def("NumBasis", &ForestDatasetCpp::NumBasis) @@ -2200,6 +2201,9 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("AddGroupLabels", &RandomEffectsDatasetCpp::AddGroupLabels) .def("AddBasis", &RandomEffectsDatasetCpp::AddBasis) .def("AddVarianceWeights", &RandomEffectsDatasetCpp::AddVarianceWeights) + .def("UpdateGroupLabels", &RandomEffectsDatasetCpp::UpdateGroupLabels) + .def("UpdateBasis", &RandomEffectsDatasetCpp::UpdateBasis) + .def("UpdateVarianceWeights", &RandomEffectsDatasetCpp::UpdateVarianceWeights) .def("HasGroupLabels", &RandomEffectsDatasetCpp::HasGroupLabels) .def("HasBasis", &RandomEffectsDatasetCpp::HasBasis) .def("HasVarianceWeights", &RandomEffectsDatasetCpp::HasVarianceWeights); diff --git a/stochtree/data.py b/stochtree/data.py index 29c8280c..424767d3 100644 --- a/stochtree/data.py +++ b/stochtree/data.py @@ -94,7 +94,7 @@ def add_variance_weights(self, variance_weights: np.array): self.dataset_cpp.AddVarianceWeights(variance_weights_, n) - def update_variance_weights(self, variance_weights: np.array): + def update_variance_weights(self, variance_weights: np.array, exponentiate: bool = False): """ Update variance weights in a dataset. Allows users to build an ensemble that depends on variance weights that are updated throughout the sampler. @@ -103,6 +103,8 @@ def update_variance_weights(self, variance_weights: np.array): ---------- variance_weights : np.array Univariate numpy array of variance weights. + exponentiate : bool + Whether to exponentiate the variance weights before storing them in the dataset. """ if not self.has_variance_weights(): raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.") @@ -114,7 +116,7 @@ def update_variance_weights(self, variance_weights: np.array): raise ValueError("variance_weights must be a 1-dimensional numpy array.") if self.num_observations() != n: raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).") - self.dataset_cpp.UpdateVarianceWeights(variance_weights_, n) + self.dataset_cpp.UpdateVarianceWeights(variance_weights_, n, exponentiate) def num_observations(self) -> int: """ diff --git a/stochtree/random_effects.py b/stochtree/random_effects.py index 1597ff57..c74de5db 100644 --- a/stochtree/random_effects.py +++ b/stochtree/random_effects.py @@ -111,7 +111,7 @@ def add_variance_weights(self, variance_weights: np.array): n = variance_weights_.shape[0] self.rfx_dataset_cpp.AddVarianceWeights(variance_weights_, n) - def update_variance_weights(self, variance_weights: np.array): + def update_variance_weights(self, variance_weights: np.array, exponentiate: bool = False): """ Update variance weights in a dataset. Allows users to build an ensemble that depends on variance weights that are updated throughout the sampler. @@ -120,6 +120,8 @@ def update_variance_weights(self, variance_weights: np.array): ---------- variance_weights : np.array Univariate numpy array of variance weights. + exponentiate : bool + Whether to exponentiate the variance weights before storing them in the dataset. """ if not self.has_variance_weights(): raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.") @@ -133,7 +135,7 @@ def update_variance_weights(self, variance_weights: np.array): n = variance_weights_.shape[0] if self.num_observations() != n: raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).") - self.rfx_dataset_cpp.UpdateVarianceWeights(variance_weights, n) + self.rfx_dataset_cpp.UpdateVarianceWeights(variance_weights, n, exponentiate) def num_observations(self) -> int: """ diff --git a/test/R/testthat/test-dataset.R b/test/R/testthat/test-dataset.R new file mode 100644 index 00000000..2753a3f9 --- /dev/null +++ b/test/R/testthat/test-dataset.R @@ -0,0 +1,56 @@ +test_that("ForestDataset can be constructed and updated", { + # Generate data + n <- 20 + num_covariates <- 10 + num_basis <- 5 + covariates <- matrix(runif(n * num_covariates), ncol = num_covariates) + basis <- matrix(runif(n * num_basis), ncol = num_basis) + variance_weights <- runif(n) + + # Copy data to a ForestDataset object + forest_dataset <- createForestDataset(covariates, basis, variance_weights) + + # Run first round of expectations + expect_equal(forest_dataset$num_observations(), n) + expect_equal(forest_dataset$num_covariates(), num_covariates) + expect_equal(forest_dataset$num_basis(), num_basis) + expect_equal(forest_dataset$has_variance_weights(), T) + + # Update data + new_basis <- matrix(runif(n * num_basis), ncol = num_basis) + new_variance_weights <- runif(n) + expect_no_error( + forest_dataset$update_basis(new_basis) + ) + expect_no_error( + forest_dataset$update_variance_weights(new_variance_weights) + ) +}) + +test_that("RandomEffectsDataset can be constructed and updated", { + # Generate data + n <- 20 + num_groups <- 4 + num_basis <- 5 + group_ids <- sample(as.integer(1:num_groups), size = n, replace = T) + rfx_basis <- cbind(1, matrix(runif(n*(num_basis-1)), ncol = (num_basis-1))) + variance_weights <- runif(n) + + # Copy data to a RandomEffectsDataset object + rfx_dataset <- createRandomEffectsDataset(group_ids, rfx_basis, variance_weights) + + # Run first round of expectations + expect_equal(rfx_dataset$num_observations(), n) + expect_equal(rfx_dataset$num_basis(), num_basis) + expect_equal(rfx_dataset$has_variance_weights(), T) + + # Update data + new_rfx_basis <- matrix(runif(n * num_basis), ncol = num_basis) + new_variance_weights <- runif(n) + expect_no_error( + rfx_dataset$update_basis(new_basis) + ) + expect_no_error( + rfx_dataset$update_variance_weights(new_variance_weights) + ) +}) diff --git a/test/python/test_data.py b/test/python/test_data.py index 7b012f24..0a964637 100644 --- a/test/python/test_data.py +++ b/test/python/test_data.py @@ -1,6 +1,6 @@ import numpy as np -from stochtree import Dataset +from stochtree import Dataset, RandomEffectsDataset class TestDataset: def test_dataset_update(self): @@ -29,3 +29,34 @@ def test_dataset_update(self): with np.testing.assert_no_warnings(): forest_dataset.update_basis(new_basis) forest_dataset.update_variance_weights(new_variance_weights) + +class TestRFXDataset: + def test_rfx_dataset_update(self): + # Generate data + n = 20 + num_groups = 4 + num_basis = 5 + rng = np.random.default_rng() + group_labels = rng.choice(num_groups, size=n) + basis = np.empty((n, num_basis)) + basis[:, 0] = 1.0 + if num_basis > 1: + basis[:, 1:] = rng.uniform(-1, 1, (n, num_basis - 1)) + variance_weights = rng.uniform(0, 1, size=n) + + # Construct dataset + rfx_dataset = RandomEffectsDataset() + rfx_dataset.add_group_labels(group_labels) + rfx_dataset.add_basis(basis) + rfx_dataset.add_variance_weights(variance_weights) + assert rfx_dataset.num_observations() == n + assert rfx_dataset.num_basis() == num_basis + assert rfx_dataset.has_variance_weights() + + # Update dataset + new_basis = rng.uniform(0, 1, size=(n, num_basis)) + new_variance_weights = rng.uniform(0, 1, size=n) + with np.testing.assert_no_warnings(): + rfx_dataset.update_basis(new_basis) + rfx_dataset.update_variance_weights(new_variance_weights) + From 7423c76ddbea43760a75d3477d720c136cf2fed4 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 22 Sep 2025 18:16:26 -0500 Subject: [PATCH 6/8] Added methods to query data from a C++ object back to python as a numpy array --- src/py_stochtree.cpp | 78 +++++++++++++++++++++++++++++++++++++ stochtree/data.py | 33 ++++++++++++++++ stochtree/random_effects.py | 33 ++++++++++++++++ test/python/test_data.py | 10 +++++ 4 files changed, 154 insertions(+) diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 6ac7ee2e..950caeb8 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -75,6 +75,48 @@ class ForestDatasetCpp { dataset_->UpdateVarWeights(data_ptr, num_row, exponentiate); } + py::array_t GetCovariates() { + // Initialize n x p numpy array to store the covariates + data_size_t n = dataset_->NumObservations(); + int num_covariates = dataset_->NumCovariates(); + auto result = py::array_t(py::detail::any_container({n, num_covariates})); + auto accessor = result.mutable_unchecked<2>(); + for (size_t i = 0; i < n; i++) { + for (int j = 0; j < num_covariates; j++) { + accessor(i,j) = dataset_->CovariateValue(i,j); + } + } + + return result; + } + + py::array_t GetBasis() { + // Initialize n x k numpy array to store the basis + data_size_t n = dataset_->NumObservations(); + int num_basis = dataset_->NumBasis(); + auto result = py::array_t(py::detail::any_container({n, num_basis})); + auto accessor = result.mutable_unchecked<2>(); + for (size_t i = 0; i < n; i++) { + for (int j = 0; j < num_basis; j++) { + accessor(i,j) = dataset_->BasisValue(i,j); + } + } + + return result; + } + + py::array_t GetVarianceWeights() { + // Initialize n x 1 numpy array to store the variance weights + data_size_t n = dataset_->NumObservations(); + auto result = py::array_t(py::detail::any_container({n})); + auto accessor = result.mutable_unchecked<1>(); + for (size_t i = 0; i < n; i++) { + accessor(i) = dataset_->VarWeightValue(i); + } + + return result; + } + data_size_t NumRows() { return dataset_->NumObservations(); } @@ -1313,6 +1355,36 @@ class RandomEffectsDatasetCpp { } rfx_dataset_->UpdateGroupLabels(group_labels_vec, num_row); } + py::array_t GetBasis() { + int num_row = rfx_dataset_->NumObservations(); + int num_col = rfx_dataset_->NumBases(); + auto result = py::array_t(py::detail::any_container({num_row, num_col})); + auto accessor = result.mutable_unchecked<2>(); + for (py::ssize_t i = 0; i < num_row; i++) { + for (int j = 0; j < num_col; j++) { + accessor(i,j) = rfx_dataset_->BasisValue(i,j); + } + } + return result; + } + py::array_t GetVarianceWeights() { + int num_row = rfx_dataset_->NumObservations(); + auto result = py::array_t(py::detail::any_container({num_row})); + auto accessor = result.mutable_unchecked<1>(); + for (py::ssize_t i = 0; i < num_row; i++) { + accessor(i) = rfx_dataset_->VarWeightValue(i); + } + return result; + } + py::array_t GetGroupLabels() { + int num_row = rfx_dataset_->NumObservations(); + auto result = py::array_t(py::detail::any_container({num_row})); + auto accessor = result.mutable_unchecked<1>(); + for (py::ssize_t i = 0; i < num_row; i++) { + accessor(i) = rfx_dataset_->GroupId(i); + } + return result; + } bool HasGroupLabels() {return rfx_dataset_->HasGroupLabels();} bool HasBasis() {return rfx_dataset_->HasBasis();} bool HasVarianceWeights() {return rfx_dataset_->HasVarWeights();} @@ -2071,6 +2143,9 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("NumRows", &ForestDatasetCpp::NumRows) .def("NumCovariates", &ForestDatasetCpp::NumCovariates) .def("NumBasis", &ForestDatasetCpp::NumBasis) + .def("GetCovariates", &ForestDatasetCpp::GetCovariates) + .def("GetBasis", &ForestDatasetCpp::GetBasis) + .def("GetVarianceWeights", &ForestDatasetCpp::GetVarianceWeights) .def("HasBasis", &ForestDatasetCpp::HasBasis) .def("HasVarianceWeights", &ForestDatasetCpp::HasVarianceWeights); @@ -2204,6 +2279,9 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("UpdateGroupLabels", &RandomEffectsDatasetCpp::UpdateGroupLabels) .def("UpdateBasis", &RandomEffectsDatasetCpp::UpdateBasis) .def("UpdateVarianceWeights", &RandomEffectsDatasetCpp::UpdateVarianceWeights) + .def("GetGroupLabels", &RandomEffectsDatasetCpp::GetGroupLabels) + .def("GetBasis", &RandomEffectsDatasetCpp::GetBasis) + .def("GetVarianceWeights", &RandomEffectsDatasetCpp::GetVarianceWeights) .def("HasGroupLabels", &RandomEffectsDatasetCpp::HasGroupLabels) .def("HasBasis", &RandomEffectsDatasetCpp::HasBasis) .def("HasVarianceWeights", &RandomEffectsDatasetCpp::HasVarianceWeights); diff --git a/stochtree/data.py b/stochtree/data.py index 424767d3..4e40a282 100644 --- a/stochtree/data.py +++ b/stochtree/data.py @@ -150,6 +150,39 @@ def num_basis(self) -> int: Dimension of the basis vector in the dataset, returning 0 if the dataset does not have a basis """ return self.dataset_cpp.NumBasis() + + def get_covariates(self) -> np.array: + """ + Return the covariates in a Dataset as a numpy array + + Returns + ------- + np.array + Covariate data + """ + return self.dataset_cpp.GetCovariates() + + def get_basis(self) -> np.array: + """ + Return the bases in a Dataset as a numpy array + + Returns + ------- + np.array + Basis data + """ + return self.dataset_cpp.GetBasis() + + def get_variance_weights(self) -> np.array: + """ + Return the variance weights in a Dataset as a numpy array + + Returns + ------- + np.array + Variance weights data + """ + return self.dataset_cpp.GetVarianceWeights() def has_basis(self) -> bool: """ diff --git a/stochtree/random_effects.py b/stochtree/random_effects.py index c74de5db..6c4093d4 100644 --- a/stochtree/random_effects.py +++ b/stochtree/random_effects.py @@ -136,6 +136,39 @@ def update_variance_weights(self, variance_weights: np.array, exponentiate: bool if self.num_observations() != n: raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).") self.rfx_dataset_cpp.UpdateVarianceWeights(variance_weights, n, exponentiate) + + def get_group_labels(self) -> np.array: + """ + Return the group labels in a RandomEffectsDataset as a numpy array + + Returns + ------- + np.array + One-dimensional numpy array of group labels. + """ + return self.rfx_dataset_cpp.GetGroupLabels() + + def get_basis(self) -> np.array: + """ + Return the bases in a RandomEffectsDataset as a numpy array + + Returns + ------- + np.array + Two-dimensional numpy array of basis vectors. + """ + return self.rfx_dataset_cpp.GetBasis() + + def get_variance_weights(self) -> np.array: + """ + Return the variance weights in a RandomEffectsDataset as a numpy array + + Returns + ------- + np.array + One-dimensional numpy array of variance weights. + """ + return self.rfx_dataset_cpp.GetVarianceWeights() def num_observations(self) -> int: """ diff --git a/test/python/test_data.py b/test/python/test_data.py index 0a964637..09c75154 100644 --- a/test/python/test_data.py +++ b/test/python/test_data.py @@ -29,6 +29,11 @@ def test_dataset_update(self): with np.testing.assert_no_warnings(): forest_dataset.update_basis(new_basis) forest_dataset.update_variance_weights(new_variance_weights) + + # Check that we recover the correct data through get_covariates, get_basis, and get_variance_weights + np.testing.assert_array_equal(forest_dataset.get_covariates(), covariates) + np.testing.assert_array_equal(forest_dataset.get_basis(), new_basis) + np.testing.assert_array_equal(forest_dataset.get_variance_weights(), new_variance_weights) class TestRFXDataset: def test_rfx_dataset_update(self): @@ -59,4 +64,9 @@ def test_rfx_dataset_update(self): with np.testing.assert_no_warnings(): rfx_dataset.update_basis(new_basis) rfx_dataset.update_variance_weights(new_variance_weights) + + # Check that we recover the correct data through get_group_labels, get_basis, and get_variance_weights + np.testing.assert_array_equal(rfx_dataset.get_group_labels(), group_labels) + np.testing.assert_array_equal(rfx_dataset.get_basis(), new_basis) + np.testing.assert_array_equal(rfx_dataset.get_variance_weights(), new_variance_weights) From 9e9de2ae98d0616cd62fa61529b2e7e9f6273934 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 22 Sep 2025 18:30:25 -0500 Subject: [PATCH 7/8] Added data retrieval methods and unit tests in R --- R/cpp11.R | 24 +++++++++++ R/data.R | 42 +++++++++++++++++++ man/ForestDataset.Rd | 42 +++++++++++++++++++ man/RandomEffectsDataset.Rd | 42 +++++++++++++++++++ src/R_data.cpp | 75 ++++++++++++++++++++++++++++++++++ src/cpp11.cpp | 48 ++++++++++++++++++++++ test/R/testthat/test-dataset.R | 10 +++++ 7 files changed, 283 insertions(+) diff --git a/R/cpp11.R b/R/cpp11.R index f7e844b2..d77c7472 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -44,6 +44,18 @@ forest_dataset_add_weights_cpp <- function(dataset_ptr, weights) { invisible(.Call(`_stochtree_forest_dataset_add_weights_cpp`, dataset_ptr, weights)) } +forest_dataset_get_covariates_cpp <- function(dataset_ptr) { + .Call(`_stochtree_forest_dataset_get_covariates_cpp`, dataset_ptr) +} + +forest_dataset_get_basis_cpp <- function(dataset_ptr) { + .Call(`_stochtree_forest_dataset_get_basis_cpp`, dataset_ptr) +} + +forest_dataset_get_variance_weights_cpp <- function(dataset_ptr) { + .Call(`_stochtree_forest_dataset_get_variance_weights_cpp`, dataset_ptr) +} + create_column_vector_cpp <- function(outcome) { .Call(`_stochtree_create_column_vector_cpp`, outcome) } @@ -116,6 +128,18 @@ rfx_dataset_add_weights_cpp <- function(dataset_ptr, weights) { invisible(.Call(`_stochtree_rfx_dataset_add_weights_cpp`, dataset_ptr, weights)) } +rfx_dataset_get_group_labels_cpp <- function(dataset_ptr) { + .Call(`_stochtree_rfx_dataset_get_group_labels_cpp`, dataset_ptr) +} + +rfx_dataset_get_basis_cpp <- function(dataset_ptr) { + .Call(`_stochtree_rfx_dataset_get_basis_cpp`, dataset_ptr) +} + +rfx_dataset_get_variance_weights_cpp <- function(dataset_ptr) { + .Call(`_stochtree_rfx_dataset_get_variance_weights_cpp`, dataset_ptr) +} + rfx_container_cpp <- function(num_components, num_groups) { .Call(`_stochtree_rfx_container_cpp`, num_components, num_groups) } diff --git a/R/data.R b/R/data.R index 8e5dda03..8bea2823 100644 --- a/R/data.R +++ b/R/data.R @@ -68,6 +68,27 @@ ForestDataset <- R6::R6Class( return(dataset_num_basis_cpp(self$data_ptr)) }, + #' @description + #' Return covariates as an R matrix + #' @return Covariate data + get_covariates = function() { + return(forest_dataset_get_covariates_cpp(self$data_ptr)) + }, + + #' @description + #' Return bases as an R matrix + #' @return Basis data + get_basis = function() { + return(forest_dataset_get_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Return variance weights as an R vector + #' @return Variance weight data + get_variance_weights = function() { + return(forest_dataset_get_variance_weights_cpp(self$data_ptr)) + }, + #' @description #' Whether or not a dataset has a basis matrix #' @return True if basis matrix is loaded, false otherwise @@ -230,6 +251,27 @@ RandomEffectsDataset <- R6::R6Class( return(rfx_dataset_num_basis_cpp(self$data_ptr)) }, + #' @description + #' Return group labels as an R vector + #' @return Group label data + get_group_labels = function() { + return(rfx_dataset_get_group_labels_cpp(self$data_ptr)) + }, + + #' @description + #' Return bases as an R matrix + #' @return Basis data + get_basis = function() { + return(rfx_dataset_get_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Return variance weights as an R vector + #' @return Variance weight data + get_variance_weights = function() { + return(rfx_dataset_get_variance_weights_cpp(self$data_ptr)) + }, + #' @description #' Whether or not a dataset has group label indices #' @return True if group label vector is loaded, false otherwise diff --git a/man/ForestDataset.Rd b/man/ForestDataset.Rd index a560f350..dfd7760f 100644 --- a/man/ForestDataset.Rd +++ b/man/ForestDataset.Rd @@ -24,6 +24,9 @@ weights are optional. \item \href{#method-ForestDataset-num_observations}{\code{ForestDataset$num_observations()}} \item \href{#method-ForestDataset-num_covariates}{\code{ForestDataset$num_covariates()}} \item \href{#method-ForestDataset-num_basis}{\code{ForestDataset$num_basis()}} +\item \href{#method-ForestDataset-get_covariates}{\code{ForestDataset$get_covariates()}} +\item \href{#method-ForestDataset-get_basis}{\code{ForestDataset$get_basis()}} +\item \href{#method-ForestDataset-get_variance_weights}{\code{ForestDataset$get_variance_weights()}} \item \href{#method-ForestDataset-has_basis}{\code{ForestDataset$has_basis()}} \item \href{#method-ForestDataset-has_variance_weights}{\code{ForestDataset$has_variance_weights()}} } @@ -128,6 +131,45 @@ Basis count } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-get_covariates}{}}} +\subsection{Method \code{get_covariates()}}{ +Return covariates as an R matrix +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$get_covariates()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Covariate data +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-get_basis}{}}} +\subsection{Method \code{get_basis()}}{ +Return bases as an R matrix +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$get_basis()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Basis data +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-get_variance_weights}{}}} +\subsection{Method \code{get_variance_weights()}}{ +Return variance weights as an R vector +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$get_variance_weights()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Variance weight data +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ForestDataset-has_basis}{}}} \subsection{Method \code{has_basis()}}{ diff --git a/man/RandomEffectsDataset.Rd b/man/RandomEffectsDataset.Rd index 4bb4fdaa..e2da0227 100644 --- a/man/RandomEffectsDataset.Rd +++ b/man/RandomEffectsDataset.Rd @@ -22,6 +22,9 @@ bases, and variance weights. Variance weights are optional. \item \href{#method-RandomEffectsDataset-update_variance_weights}{\code{RandomEffectsDataset$update_variance_weights()}} \item \href{#method-RandomEffectsDataset-num_observations}{\code{RandomEffectsDataset$num_observations()}} \item \href{#method-RandomEffectsDataset-num_basis}{\code{RandomEffectsDataset$num_basis()}} +\item \href{#method-RandomEffectsDataset-get_group_labels}{\code{RandomEffectsDataset$get_group_labels()}} +\item \href{#method-RandomEffectsDataset-get_basis}{\code{RandomEffectsDataset$get_basis()}} +\item \href{#method-RandomEffectsDataset-get_variance_weights}{\code{RandomEffectsDataset$get_variance_weights()}} \item \href{#method-RandomEffectsDataset-has_group_labels}{\code{RandomEffectsDataset$has_group_labels()}} \item \href{#method-RandomEffectsDataset-has_basis}{\code{RandomEffectsDataset$has_basis()}} \item \href{#method-RandomEffectsDataset-has_variance_weights}{\code{RandomEffectsDataset$has_variance_weights()}} @@ -117,6 +120,45 @@ Basis vector count } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectsDataset-get_group_labels}{}}} +\subsection{Method \code{get_group_labels()}}{ +Return group labels as an R vector +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectsDataset$get_group_labels()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Group label data +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectsDataset-get_basis}{}}} +\subsection{Method \code{get_basis()}}{ +Return bases as an R matrix +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectsDataset$get_basis()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Basis data +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectsDataset-get_variance_weights}{}}} +\subsection{Method \code{get_variance_weights()}}{ +Return variance weights as an R vector +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectsDataset$get_variance_weights()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Variance weight data +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-RandomEffectsDataset-has_group_labels}{}}} \subsection{Method \code{has_group_labels()}}{ diff --git a/src/R_data.cpp b/src/R_data.cpp index 596516f3..39b77ab3 100644 --- a/src/R_data.cpp +++ b/src/R_data.cpp @@ -1,3 +1,4 @@ +#include "cpp11/integers.hpp" #include #include #include @@ -106,6 +107,47 @@ void forest_dataset_add_weights_cpp(cpp11::external_pointer forest_dataset_get_covariates_cpp(cpp11::external_pointer dataset_ptr) { + // Initialize output matrix + int num_row = dataset_ptr->NumObservations(); + int num_col = dataset_ptr->NumCovariates(); + cpp11::writable::doubles_matrix<> output(num_row, num_col); + + for (int i = 0; i < num_row; i++) { + for (int j = 0; j < num_col; j++) { + output(i, j) = dataset_ptr->CovariateValue(i, j); + } + } + + return output; +} + +[[cpp11::register]] +cpp11::writable::doubles_matrix<> forest_dataset_get_basis_cpp(cpp11::external_pointer dataset_ptr) { + // Initialize output matrix + int num_row = dataset_ptr->NumObservations(); + int num_col = dataset_ptr->NumBasis(); + cpp11::writable::doubles_matrix<> output(num_row, num_col); + for (int i = 0; i < num_row; i++) { + for (int j = 0; j < num_col; j++) { + output(i, j) = dataset_ptr->BasisValue(i, j); + } + } + return output; +} + +[[cpp11::register]] +cpp11::writable::doubles forest_dataset_get_variance_weights_cpp(cpp11::external_pointer dataset_ptr) { + // Initialize output vector + int num_row = dataset_ptr->NumObservations(); + cpp11::writable::doubles output(num_row); + for (int i = 0; i < num_row; i++) { + output.at(i) = dataset_ptr->VarWeightValue(i); + } + return output; +} + [[cpp11::register]] cpp11::external_pointer create_column_vector_cpp(cpp11::doubles outcome) { // Unpack pointers to data and dimensions @@ -282,3 +324,36 @@ void rfx_dataset_add_weights_cpp(cpp11::external_pointer dataset_ptr) { + int num_row = dataset_ptr->NumObservations(); + cpp11::writable::integers output(num_row); + for (int i = 0; i < num_row; i++) { + output.at(i) = dataset_ptr->GroupId(i); + } + return output; +} + +[[cpp11::register]] +cpp11::writable::doubles_matrix<> rfx_dataset_get_basis_cpp(cpp11::external_pointer dataset_ptr) { + int num_row = dataset_ptr->NumObservations(); + int num_col = dataset_ptr->NumBases(); + cpp11::writable::doubles_matrix<> output(num_row, num_col); + for (int i = 0; i < num_row; i++) { + for (int j = 0; j < num_col; j++) { + output(i, j) = dataset_ptr->BasisValue(i, j); + } + } + return output; +} + +[[cpp11::register]] +cpp11::writable::doubles rfx_dataset_get_variance_weights_cpp(cpp11::external_pointer dataset_ptr) { + int num_row = dataset_ptr->NumObservations(); + cpp11::writable::doubles output(num_row); + for (int i = 0; i < num_row; i++) { + output.at(i) = dataset_ptr->VarWeightValue(i); + } + return output; +} diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 9d454531..ef98aac0 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -88,6 +88,27 @@ extern "C" SEXP _stochtree_forest_dataset_add_weights_cpp(SEXP dataset_ptr, SEXP END_CPP11 } // R_data.cpp +cpp11::writable::doubles_matrix<> forest_dataset_get_covariates_cpp(cpp11::external_pointer dataset_ptr); +extern "C" SEXP _stochtree_forest_dataset_get_covariates_cpp(SEXP dataset_ptr) { + BEGIN_CPP11 + return cpp11::as_sexp(forest_dataset_get_covariates_cpp(cpp11::as_cpp>>(dataset_ptr))); + END_CPP11 +} +// R_data.cpp +cpp11::writable::doubles_matrix<> forest_dataset_get_basis_cpp(cpp11::external_pointer dataset_ptr); +extern "C" SEXP _stochtree_forest_dataset_get_basis_cpp(SEXP dataset_ptr) { + BEGIN_CPP11 + return cpp11::as_sexp(forest_dataset_get_basis_cpp(cpp11::as_cpp>>(dataset_ptr))); + END_CPP11 +} +// R_data.cpp +cpp11::writable::doubles forest_dataset_get_variance_weights_cpp(cpp11::external_pointer dataset_ptr); +extern "C" SEXP _stochtree_forest_dataset_get_variance_weights_cpp(SEXP dataset_ptr) { + BEGIN_CPP11 + return cpp11::as_sexp(forest_dataset_get_variance_weights_cpp(cpp11::as_cpp>>(dataset_ptr))); + END_CPP11 +} +// R_data.cpp cpp11::external_pointer create_column_vector_cpp(cpp11::doubles outcome); extern "C" SEXP _stochtree_create_column_vector_cpp(SEXP outcome) { BEGIN_CPP11 @@ -223,6 +244,27 @@ extern "C" SEXP _stochtree_rfx_dataset_add_weights_cpp(SEXP dataset_ptr, SEXP we return R_NilValue; END_CPP11 } +// R_data.cpp +cpp11::writable::integers rfx_dataset_get_group_labels_cpp(cpp11::external_pointer dataset_ptr); +extern "C" SEXP _stochtree_rfx_dataset_get_group_labels_cpp(SEXP dataset_ptr) { + BEGIN_CPP11 + return cpp11::as_sexp(rfx_dataset_get_group_labels_cpp(cpp11::as_cpp>>(dataset_ptr))); + END_CPP11 +} +// R_data.cpp +cpp11::writable::doubles_matrix<> rfx_dataset_get_basis_cpp(cpp11::external_pointer dataset_ptr); +extern "C" SEXP _stochtree_rfx_dataset_get_basis_cpp(SEXP dataset_ptr) { + BEGIN_CPP11 + return cpp11::as_sexp(rfx_dataset_get_basis_cpp(cpp11::as_cpp>>(dataset_ptr))); + END_CPP11 +} +// R_data.cpp +cpp11::writable::doubles rfx_dataset_get_variance_weights_cpp(cpp11::external_pointer dataset_ptr); +extern "C" SEXP _stochtree_rfx_dataset_get_variance_weights_cpp(SEXP dataset_ptr) { + BEGIN_CPP11 + return cpp11::as_sexp(rfx_dataset_get_variance_weights_cpp(cpp11::as_cpp>>(dataset_ptr))); + END_CPP11 +} // R_random_effects.cpp cpp11::external_pointer rfx_container_cpp(int num_components, int num_groups); extern "C" SEXP _stochtree_rfx_container_cpp(SEXP num_components, SEXP num_groups) { @@ -1579,6 +1621,9 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_forest_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_basis_cpp, 2}, {"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2}, {"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2}, + {"_stochtree_forest_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_basis_cpp, 1}, + {"_stochtree_forest_dataset_get_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_covariates_cpp, 1}, + {"_stochtree_forest_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_variance_weights_cpp, 1}, {"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2}, {"_stochtree_forest_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_var_weights_cpp, 3}, {"_stochtree_forest_merge_cpp", (DL_FUNC) &_stochtree_forest_merge_cpp, 2}, @@ -1699,6 +1744,9 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rfx_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_basis_cpp, 2}, {"_stochtree_rfx_dataset_add_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_group_labels_cpp, 2}, {"_stochtree_rfx_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_weights_cpp, 2}, + {"_stochtree_rfx_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_basis_cpp, 1}, + {"_stochtree_rfx_dataset_get_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_group_labels_cpp, 1}, + {"_stochtree_rfx_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_variance_weights_cpp, 1}, {"_stochtree_rfx_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_basis_cpp, 1}, {"_stochtree_rfx_dataset_has_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_group_labels_cpp, 1}, {"_stochtree_rfx_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_variance_weights_cpp, 1}, diff --git a/test/R/testthat/test-dataset.R b/test/R/testthat/test-dataset.R index 2753a3f9..1a833d97 100644 --- a/test/R/testthat/test-dataset.R +++ b/test/R/testthat/test-dataset.R @@ -25,6 +25,11 @@ test_that("ForestDataset can be constructed and updated", { expect_no_error( forest_dataset$update_variance_weights(new_variance_weights) ) + + # Check that we recover the correct data through get_covariates, get_basis, and get_variance_weights + expect_equal(covariates, forest_dataset$get_covariates()) + expect_equal(new_basis, forest_dataset$get_basis()) + expect_equal(new_variance_weights, forest_dataset$get_variance_weights()) }) test_that("RandomEffectsDataset can be constructed and updated", { @@ -53,4 +58,9 @@ test_that("RandomEffectsDataset can be constructed and updated", { expect_no_error( rfx_dataset$update_variance_weights(new_variance_weights) ) + + # Check that we recover the correct data through get_group_labels, get_basis, and get_variance_weights + expect_equal(group_ids, rfx_dataset$get_group_labels()) + expect_equal(new_basis, rfx_dataset$get_basis()) + expect_equal(new_variance_weights, rfx_dataset$get_variance_weights()) }) From a05c9afe4620f196960706bcafd132d4c74e9c7a Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 22 Sep 2025 22:30:16 -0500 Subject: [PATCH 8/8] Fixed R unit tests --- test/R/testthat/test-dataset.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/R/testthat/test-dataset.R b/test/R/testthat/test-dataset.R index 1a833d97..80f8d8e0 100644 --- a/test/R/testthat/test-dataset.R +++ b/test/R/testthat/test-dataset.R @@ -53,7 +53,7 @@ test_that("RandomEffectsDataset can be constructed and updated", { new_rfx_basis <- matrix(runif(n * num_basis), ncol = num_basis) new_variance_weights <- runif(n) expect_no_error( - rfx_dataset$update_basis(new_basis) + rfx_dataset$update_basis(new_rfx_basis) ) expect_no_error( rfx_dataset$update_variance_weights(new_variance_weights) @@ -61,6 +61,6 @@ test_that("RandomEffectsDataset can be constructed and updated", { # Check that we recover the correct data through get_group_labels, get_basis, and get_variance_weights expect_equal(group_ids, rfx_dataset$get_group_labels()) - expect_equal(new_basis, rfx_dataset$get_basis()) + expect_equal(new_rfx_basis, rfx_dataset$get_basis()) expect_equal(new_variance_weights, rfx_dataset$get_variance_weights()) })