diff --git a/R/cpp11.R b/R/cpp11.R index a71a7722..d77c7472 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -36,10 +36,26 @@ forest_dataset_update_basis_cpp <- function(dataset_ptr, basis) { invisible(.Call(`_stochtree_forest_dataset_update_basis_cpp`, dataset_ptr, basis)) } +forest_dataset_update_var_weights_cpp <- function(dataset_ptr, weights, exponentiate) { + invisible(.Call(`_stochtree_forest_dataset_update_var_weights_cpp`, dataset_ptr, weights, exponentiate)) +} + forest_dataset_add_weights_cpp <- function(dataset_ptr, weights) { invisible(.Call(`_stochtree_forest_dataset_add_weights_cpp`, dataset_ptr, weights)) } +forest_dataset_get_covariates_cpp <- function(dataset_ptr) { + .Call(`_stochtree_forest_dataset_get_covariates_cpp`, dataset_ptr) +} + +forest_dataset_get_basis_cpp <- function(dataset_ptr) { + .Call(`_stochtree_forest_dataset_get_basis_cpp`, dataset_ptr) +} + +forest_dataset_get_variance_weights_cpp <- function(dataset_ptr) { + .Call(`_stochtree_forest_dataset_get_variance_weights_cpp`, dataset_ptr) +} + create_column_vector_cpp <- function(outcome) { .Call(`_stochtree_create_column_vector_cpp`, outcome) } @@ -68,6 +84,22 @@ create_rfx_dataset_cpp <- function() { .Call(`_stochtree_create_rfx_dataset_cpp`) } +rfx_dataset_update_basis_cpp <- function(dataset_ptr, basis) { + invisible(.Call(`_stochtree_rfx_dataset_update_basis_cpp`, dataset_ptr, basis)) +} + +rfx_dataset_update_var_weights_cpp <- function(dataset_ptr, weights, exponentiate) { + invisible(.Call(`_stochtree_rfx_dataset_update_var_weights_cpp`, dataset_ptr, weights, exponentiate)) +} + +rfx_dataset_update_group_labels_cpp <- function(dataset_ptr, group_labels) { + invisible(.Call(`_stochtree_rfx_dataset_update_group_labels_cpp`, dataset_ptr, group_labels)) +} + +rfx_dataset_num_basis_cpp <- function(dataset) { + .Call(`_stochtree_rfx_dataset_num_basis_cpp`, dataset) +} + rfx_dataset_num_rows_cpp <- function(dataset) { .Call(`_stochtree_rfx_dataset_num_rows_cpp`, dataset) } @@ -96,6 +128,18 @@ rfx_dataset_add_weights_cpp <- function(dataset_ptr, weights) { invisible(.Call(`_stochtree_rfx_dataset_add_weights_cpp`, dataset_ptr, weights)) } +rfx_dataset_get_group_labels_cpp <- function(dataset_ptr) { + .Call(`_stochtree_rfx_dataset_get_group_labels_cpp`, dataset_ptr) +} + +rfx_dataset_get_basis_cpp <- function(dataset_ptr) { + .Call(`_stochtree_rfx_dataset_get_basis_cpp`, dataset_ptr) +} + +rfx_dataset_get_variance_weights_cpp <- function(dataset_ptr) { + .Call(`_stochtree_rfx_dataset_get_variance_weights_cpp`, dataset_ptr) +} + rfx_container_cpp <- function(num_components, num_groups) { .Call(`_stochtree_rfx_container_cpp`, num_components, num_groups) } diff --git a/R/data.R b/R/data.R index 4f35efc0..8bea2823 100644 --- a/R/data.R +++ b/R/data.R @@ -36,7 +36,16 @@ ForestDataset <- R6::R6Class( update_basis = function(basis) { stopifnot(self$has_basis()) forest_dataset_update_basis_cpp(self$data_ptr, basis) - }, + }, + + #' @description + #' Update variance_weights in a dataset + #' @param variance_weights Updated vector of variance weights used to define individual variance / case weights + #' @param exponentiate Whether or not input vector should be exponentiated before being written to the Dataset's variance weights. Default: F. + update_variance_weights = function(variance_weights, exponentiate = F) { + stopifnot(self$has_variance_weights()) + forest_dataset_update_var_weights_cpp(self$data_ptr, variance_weights, exponentiate) + }, #' @description #' Return number of observations in a `ForestDataset` object @@ -59,6 +68,27 @@ ForestDataset <- R6::R6Class( return(dataset_num_basis_cpp(self$data_ptr)) }, + #' @description + #' Return covariates as an R matrix + #' @return Covariate data + get_covariates = function() { + return(forest_dataset_get_covariates_cpp(self$data_ptr)) + }, + + #' @description + #' Return bases as an R matrix + #' @return Basis data + get_basis = function() { + return(forest_dataset_get_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Return variance weights as an R vector + #' @return Variance weight data + get_variance_weights = function() { + return(forest_dataset_get_variance_weights_cpp(self$data_ptr)) + }, + #' @description #' Whether or not a dataset has a basis matrix #' @return True if basis matrix is loaded, false otherwise @@ -190,11 +220,56 @@ RandomEffectsDataset <- R6::R6Class( } }, + #' @description + #' Update basis matrix in a dataset + #' @param basis Updated matrix of bases used to define random slopes / intercepts + update_basis = function(basis) { + stopifnot(self$has_basis()) + rfx_dataset_update_basis_cpp(self$data_ptr, basis) + }, + + #' @description + #' Update variance_weights in a dataset + #' @param variance_weights Updated vector of variance weights used to define individual variance / case weights + #' @param exponentiate Whether or not input vector should be exponentiated before being written to the RandomEffectsDataset's variance weights. Default: F. + update_variance_weights = function(variance_weights, exponentiate = F) { + stopifnot(self$has_variance_weights()) + rfx_dataset_update_var_weights_cpp(self$data_ptr, variance_weights, exponentiate) + }, + #' @description #' Return number of observations in a `RandomEffectsDataset` object #' @return Observation count num_observations = function() { return(rfx_dataset_num_rows_cpp(self$data_ptr)) + }, + + #' @description + #' Return dimension of the basis matrix in a `RandomEffectsDataset` object + #' @return Basis vector count + num_basis = function() { + return(rfx_dataset_num_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Return group labels as an R vector + #' @return Group label data + get_group_labels = function() { + return(rfx_dataset_get_group_labels_cpp(self$data_ptr)) + }, + + #' @description + #' Return bases as an R matrix + #' @return Basis data + get_basis = function() { + return(rfx_dataset_get_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Return variance weights as an R vector + #' @return Variance weight data + get_variance_weights = function() { + return(rfx_dataset_get_variance_weights_cpp(self$data_ptr)) }, #' @description diff --git a/include/stochtree/data.h b/include/stochtree/data.h index cc62ab06..a6061f4b 100644 --- a/include/stochtree/data.h +++ b/include/stochtree/data.h @@ -497,6 +497,7 @@ class RandomEffectsDataset { */ void AddBasis(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) { basis_ = ColumnMatrix(data_ptr, num_row, num_col, is_row_major); + num_basis_ = num_col; has_basis_ = true; } /*! @@ -509,6 +510,64 @@ class RandomEffectsDataset { var_weights_ = ColumnVector(data_ptr, num_row); has_var_weights_ = true; } + /*! + * \brief Update the data in the internal basis matrix to new values stored in a raw double array + * + * \param data_ptr Pointer to first element of a contiguous array of data storing a basis matrix + * \param num_row Number of rows in the basis matrix + * \param num_col Number of columns in the basis matrix + * \param is_row_major Whether or not the data in `data_ptr` are organized in a row-major or column-major fashion + */ + void UpdateBasis(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) { + CHECK(has_basis_); + CHECK_EQ(num_col, num_basis_); + // Copy data from R / Python process memory to Eigen matrix + double temp_value; + for (data_size_t i = 0; i < num_row; ++i) { + for (int j = 0; j < num_col; ++j) { + if (is_row_major){ + // Numpy 2-d arrays are stored in "row major" order + temp_value = static_cast(*(data_ptr + static_cast(num_col) * i + j)); + } else { + // R matrices are stored in "column major" order + temp_value = static_cast(*(data_ptr + static_cast(num_row) * j + i)); + } + basis_.SetElement(i, j, temp_value); + } + } + } + /*! + * \brief Update the data in the internal variance weight vector to new values stored in a raw double array + * + * \param data_ptr Pointer to first element of a contiguous array of data storing a weight vector + * \param num_row Number of rows in the weight vector + * \param exponentiate Whether or not inputs should be exponentiated before being saved to var weight vector + */ + void UpdateVarWeights(double* data_ptr, data_size_t num_row, bool exponentiate = true) { + CHECK(has_var_weights_); + // Copy data from R / Python process memory to Eigen vector + double temp_value; + for (data_size_t i = 0; i < num_row; ++i) { + if (exponentiate) temp_value = std::exp(static_cast(*(data_ptr + i))); + else temp_value = static_cast(*(data_ptr + i)); + var_weights_.SetElement(i, temp_value); + } + } + /*! + * \brief Update a RandomEffectsDataset's group indices + * + * \param data_ptr Pointer to first element of a contiguous array of data storing a weight vector + * \param num_row Number of rows in the weight vector + * \param exponentiate Whether or not inputs should be exponentiated before being saved to var weight vector + */ + void UpdateGroupLabels(std::vector& group_labels, data_size_t num_row) { + CHECK(has_group_labels_); + CHECK_EQ(this->NumObservations(), num_row) + // Copy data from R / Python process memory to internal vector + for (data_size_t i = 0; i < num_row; ++i) { + group_labels_[i] = group_labels[i]; + } + } /*! * \brief Copy / load group indices for random effects * @@ -570,6 +629,7 @@ class RandomEffectsDataset { ColumnMatrix basis_; ColumnVector var_weights_; std::vector group_labels_; + int num_basis_{0}; bool has_basis_{false}; bool has_var_weights_{false}; bool has_group_labels_{false}; diff --git a/man/ForestDataset.Rd b/man/ForestDataset.Rd index 08ec47ad..dfd7760f 100644 --- a/man/ForestDataset.Rd +++ b/man/ForestDataset.Rd @@ -20,9 +20,13 @@ weights are optional. \itemize{ \item \href{#method-ForestDataset-new}{\code{ForestDataset$new()}} \item \href{#method-ForestDataset-update_basis}{\code{ForestDataset$update_basis()}} +\item \href{#method-ForestDataset-update_variance_weights}{\code{ForestDataset$update_variance_weights()}} \item \href{#method-ForestDataset-num_observations}{\code{ForestDataset$num_observations()}} \item \href{#method-ForestDataset-num_covariates}{\code{ForestDataset$num_covariates()}} \item \href{#method-ForestDataset-num_basis}{\code{ForestDataset$num_basis()}} +\item \href{#method-ForestDataset-get_covariates}{\code{ForestDataset$get_covariates()}} +\item \href{#method-ForestDataset-get_basis}{\code{ForestDataset$get_basis()}} +\item \href{#method-ForestDataset-get_variance_weights}{\code{ForestDataset$get_variance_weights()}} \item \href{#method-ForestDataset-has_basis}{\code{ForestDataset$has_basis()}} \item \href{#method-ForestDataset-has_variance_weights}{\code{ForestDataset$has_variance_weights()}} } @@ -69,6 +73,25 @@ Update basis matrix in a dataset } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-update_variance_weights}{}}} +\subsection{Method \code{update_variance_weights()}}{ +Update variance_weights in a dataset +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$update_variance_weights(variance_weights, exponentiate = F)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{variance_weights}}{Updated vector of variance weights used to define individual variance / case weights} + +\item{\code{exponentiate}}{Whether or not input vector should be exponentiated before being written to the Dataset's variance weights. Default: F.} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ForestDataset-num_observations}{}}} \subsection{Method \code{num_observations()}}{ @@ -108,6 +131,45 @@ Basis count } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-get_covariates}{}}} +\subsection{Method \code{get_covariates()}}{ +Return covariates as an R matrix +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$get_covariates()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Covariate data +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-get_basis}{}}} +\subsection{Method \code{get_basis()}}{ +Return bases as an R matrix +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$get_basis()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Basis data +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-get_variance_weights}{}}} +\subsection{Method \code{get_variance_weights()}}{ +Return variance weights as an R vector +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$get_variance_weights()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Variance weight data +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ForestDataset-has_basis}{}}} \subsection{Method \code{has_basis()}}{ diff --git a/man/ForestModel.Rd b/man/ForestModel.Rd index bec10621..3574ffad 100644 --- a/man/ForestModel.Rd +++ b/man/ForestModel.Rd @@ -115,7 +115,7 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR) \item{\code{global_model_config}}{GlobalModelConfig object containing global model parameters and settings} -\item{\code{num_threads}}{Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to \code{1}, otherwise to the maximum number of available threads.} +\item{\code{num_threads}}{Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's system, this will default to \code{1}, otherwise to the maximum number of available threads.} \item{\code{keep_forest}}{(Optional) Whether the updated forest sample should be saved to \code{forest_samples}. Default: \code{TRUE}.} diff --git a/man/RandomEffectsDataset.Rd b/man/RandomEffectsDataset.Rd index 2a516321..e2da0227 100644 --- a/man/RandomEffectsDataset.Rd +++ b/man/RandomEffectsDataset.Rd @@ -18,7 +18,13 @@ bases, and variance weights. Variance weights are optional. \subsection{Public methods}{ \itemize{ \item \href{#method-RandomEffectsDataset-new}{\code{RandomEffectsDataset$new()}} +\item \href{#method-RandomEffectsDataset-update_basis}{\code{RandomEffectsDataset$update_basis()}} +\item \href{#method-RandomEffectsDataset-update_variance_weights}{\code{RandomEffectsDataset$update_variance_weights()}} \item \href{#method-RandomEffectsDataset-num_observations}{\code{RandomEffectsDataset$num_observations()}} +\item \href{#method-RandomEffectsDataset-num_basis}{\code{RandomEffectsDataset$num_basis()}} +\item \href{#method-RandomEffectsDataset-get_group_labels}{\code{RandomEffectsDataset$get_group_labels()}} +\item \href{#method-RandomEffectsDataset-get_basis}{\code{RandomEffectsDataset$get_basis()}} +\item \href{#method-RandomEffectsDataset-get_variance_weights}{\code{RandomEffectsDataset$get_variance_weights()}} \item \href{#method-RandomEffectsDataset-has_group_labels}{\code{RandomEffectsDataset$has_group_labels()}} \item \href{#method-RandomEffectsDataset-has_basis}{\code{RandomEffectsDataset$has_basis()}} \item \href{#method-RandomEffectsDataset-has_variance_weights}{\code{RandomEffectsDataset$has_variance_weights()}} @@ -49,6 +55,45 @@ A new \code{RandomEffectsDataset} object. } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectsDataset-update_basis}{}}} +\subsection{Method \code{update_basis()}}{ +Update basis matrix in a dataset +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectsDataset$update_basis(basis)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{basis}}{Updated matrix of bases used to define random slopes / intercepts} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectsDataset-update_variance_weights}{}}} +\subsection{Method \code{update_variance_weights()}}{ +Update variance_weights in a dataset +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectsDataset$update_variance_weights( + variance_weights, + exponentiate = F +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{variance_weights}}{Updated vector of variance weights used to define individual variance / case weights} + +\item{\code{exponentiate}}{Whether or not input vector should be exponentiated before being written to the RandomEffectsDataset's variance weights. Default: F.} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-RandomEffectsDataset-num_observations}{}}} \subsection{Method \code{num_observations()}}{ @@ -62,6 +107,58 @@ Observation count } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectsDataset-num_basis}{}}} +\subsection{Method \code{num_basis()}}{ +Return dimension of the basis matrix in a \code{RandomEffectsDataset} object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectsDataset$num_basis()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Basis vector count +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectsDataset-get_group_labels}{}}} +\subsection{Method \code{get_group_labels()}}{ +Return group labels as an R vector +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectsDataset$get_group_labels()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Group label data +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectsDataset-get_basis}{}}} +\subsection{Method \code{get_basis()}}{ +Return bases as an R matrix +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectsDataset$get_basis()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Basis data +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectsDataset-get_variance_weights}{}}} +\subsection{Method \code{get_variance_weights()}}{ +Return variance weights as an R vector +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectsDataset$get_variance_weights()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Variance weight data +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-RandomEffectsDataset-has_group_labels}{}}} \subsection{Method \code{has_group_labels()}}{ diff --git a/man/bcf.Rd b/man/bcf.Rd index ed7cd238..01e5fab8 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -97,6 +97,7 @@ that were not in the training set.} \item \code{rfx_group_parameter_prior_cov} Prior covariance matrix for the random effects "group parameters." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. \item \code{rfx_variance_prior_shape} Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. \item \code{rfx_variance_prior_scale} Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. +\item \code{num_threads} Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to \code{1}, otherwise to the maximum number of available threads. }} \item{prognostic_forest_params}{(Optional) A list of prognostic forest model parameters, each of which has a default value processed internally, so this argument list is optional. diff --git a/src/R_data.cpp b/src/R_data.cpp index 0f495436..39b77ab3 100644 --- a/src/R_data.cpp +++ b/src/R_data.cpp @@ -1,3 +1,4 @@ +#include "cpp11/integers.hpp" #include #include #include @@ -84,6 +85,17 @@ void forest_dataset_update_basis_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights, bool exponentiate) { + // Add weights + StochTree::data_size_t n = weights.size(); + double* weight_data_ptr = REAL(PROTECT(weights)); + dataset_ptr->UpdateVarWeights(weight_data_ptr, n, exponentiate); + + // Unprotect pointers to R data + UNPROTECT(1); +} + [[cpp11::register]] void forest_dataset_add_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights) { // Add weights @@ -95,6 +107,47 @@ void forest_dataset_add_weights_cpp(cpp11::external_pointer forest_dataset_get_covariates_cpp(cpp11::external_pointer dataset_ptr) { + // Initialize output matrix + int num_row = dataset_ptr->NumObservations(); + int num_col = dataset_ptr->NumCovariates(); + cpp11::writable::doubles_matrix<> output(num_row, num_col); + + for (int i = 0; i < num_row; i++) { + for (int j = 0; j < num_col; j++) { + output(i, j) = dataset_ptr->CovariateValue(i, j); + } + } + + return output; +} + +[[cpp11::register]] +cpp11::writable::doubles_matrix<> forest_dataset_get_basis_cpp(cpp11::external_pointer dataset_ptr) { + // Initialize output matrix + int num_row = dataset_ptr->NumObservations(); + int num_col = dataset_ptr->NumBasis(); + cpp11::writable::doubles_matrix<> output(num_row, num_col); + for (int i = 0; i < num_row; i++) { + for (int j = 0; j < num_col; j++) { + output(i, j) = dataset_ptr->BasisValue(i, j); + } + } + return output; +} + +[[cpp11::register]] +cpp11::writable::doubles forest_dataset_get_variance_weights_cpp(cpp11::external_pointer dataset_ptr) { + // Initialize output vector + int num_row = dataset_ptr->NumObservations(); + cpp11::writable::doubles output(num_row); + for (int i = 0; i < num_row; i++) { + output.at(i) = dataset_ptr->VarWeightValue(i); + } + return output; +} + [[cpp11::register]] cpp11::external_pointer create_column_vector_cpp(cpp11::doubles outcome) { // Unpack pointers to data and dimensions @@ -180,6 +233,45 @@ cpp11::external_pointer create_rfx_dataset_cpp( return cpp11::external_pointer(dataset_ptr_.release()); } +[[cpp11::register]] +void rfx_dataset_update_basis_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles_matrix<> basis) { + // TODO: add handling code on the R side to ensure matrices are column-major + bool row_major{false}; + + // Add basis + StochTree::data_size_t n = basis.nrow(); + int num_basis = basis.ncol(); + double* basis_data_ptr = REAL(PROTECT(basis)); + dataset_ptr->UpdateBasis(basis_data_ptr, n, num_basis, row_major); + + // Unprotect pointers to R data + UNPROTECT(1); +} + +[[cpp11::register]] +void rfx_dataset_update_var_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights, bool exponentiate) { + // Add weights + StochTree::data_size_t n = weights.size(); + double* weight_data_ptr = REAL(PROTECT(weights)); + dataset_ptr->UpdateVarWeights(weight_data_ptr, n, exponentiate); + + // Unprotect pointers to R data + UNPROTECT(1); +} + +[[cpp11::register]] +void rfx_dataset_update_group_labels_cpp(cpp11::external_pointer dataset_ptr, cpp11::integers group_labels) { + // Update group labels + int n = group_labels.size(); + std::vector group_labels_vec(group_labels.begin(), group_labels.end()); + dataset_ptr->UpdateGroupLabels(group_labels_vec, n); +} + +[[cpp11::register]] +int rfx_dataset_num_basis_cpp(cpp11::external_pointer dataset) { + return dataset->NumBases(); +} + [[cpp11::register]] int rfx_dataset_num_rows_cpp(cpp11::external_pointer dataset) { return dataset->NumObservations(); @@ -232,3 +324,36 @@ void rfx_dataset_add_weights_cpp(cpp11::external_pointer dataset_ptr) { + int num_row = dataset_ptr->NumObservations(); + cpp11::writable::integers output(num_row); + for (int i = 0; i < num_row; i++) { + output.at(i) = dataset_ptr->GroupId(i); + } + return output; +} + +[[cpp11::register]] +cpp11::writable::doubles_matrix<> rfx_dataset_get_basis_cpp(cpp11::external_pointer dataset_ptr) { + int num_row = dataset_ptr->NumObservations(); + int num_col = dataset_ptr->NumBases(); + cpp11::writable::doubles_matrix<> output(num_row, num_col); + for (int i = 0; i < num_row; i++) { + for (int j = 0; j < num_col; j++) { + output(i, j) = dataset_ptr->BasisValue(i, j); + } + } + return output; +} + +[[cpp11::register]] +cpp11::writable::doubles rfx_dataset_get_variance_weights_cpp(cpp11::external_pointer dataset_ptr) { + int num_row = dataset_ptr->NumObservations(); + cpp11::writable::doubles output(num_row); + for (int i = 0; i < num_row; i++) { + output.at(i) = dataset_ptr->VarWeightValue(i); + } + return output; +} diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 67f79ab2..ef98aac0 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -72,6 +72,14 @@ extern "C" SEXP _stochtree_forest_dataset_update_basis_cpp(SEXP dataset_ptr, SEX END_CPP11 } // R_data.cpp +void forest_dataset_update_var_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights, bool exponentiate); +extern "C" SEXP _stochtree_forest_dataset_update_var_weights_cpp(SEXP dataset_ptr, SEXP weights, SEXP exponentiate) { + BEGIN_CPP11 + forest_dataset_update_var_weights_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(weights), cpp11::as_cpp>(exponentiate)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp void forest_dataset_add_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights); extern "C" SEXP _stochtree_forest_dataset_add_weights_cpp(SEXP dataset_ptr, SEXP weights) { BEGIN_CPP11 @@ -80,6 +88,27 @@ extern "C" SEXP _stochtree_forest_dataset_add_weights_cpp(SEXP dataset_ptr, SEXP END_CPP11 } // R_data.cpp +cpp11::writable::doubles_matrix<> forest_dataset_get_covariates_cpp(cpp11::external_pointer dataset_ptr); +extern "C" SEXP _stochtree_forest_dataset_get_covariates_cpp(SEXP dataset_ptr) { + BEGIN_CPP11 + return cpp11::as_sexp(forest_dataset_get_covariates_cpp(cpp11::as_cpp>>(dataset_ptr))); + END_CPP11 +} +// R_data.cpp +cpp11::writable::doubles_matrix<> forest_dataset_get_basis_cpp(cpp11::external_pointer dataset_ptr); +extern "C" SEXP _stochtree_forest_dataset_get_basis_cpp(SEXP dataset_ptr) { + BEGIN_CPP11 + return cpp11::as_sexp(forest_dataset_get_basis_cpp(cpp11::as_cpp>>(dataset_ptr))); + END_CPP11 +} +// R_data.cpp +cpp11::writable::doubles forest_dataset_get_variance_weights_cpp(cpp11::external_pointer dataset_ptr); +extern "C" SEXP _stochtree_forest_dataset_get_variance_weights_cpp(SEXP dataset_ptr) { + BEGIN_CPP11 + return cpp11::as_sexp(forest_dataset_get_variance_weights_cpp(cpp11::as_cpp>>(dataset_ptr))); + END_CPP11 +} +// R_data.cpp cpp11::external_pointer create_column_vector_cpp(cpp11::doubles outcome); extern "C" SEXP _stochtree_create_column_vector_cpp(SEXP outcome) { BEGIN_CPP11 @@ -133,6 +162,37 @@ extern "C" SEXP _stochtree_create_rfx_dataset_cpp() { END_CPP11 } // R_data.cpp +void rfx_dataset_update_basis_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles_matrix<> basis); +extern "C" SEXP _stochtree_rfx_dataset_update_basis_cpp(SEXP dataset_ptr, SEXP basis) { + BEGIN_CPP11 + rfx_dataset_update_basis_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>>(basis)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp +void rfx_dataset_update_var_weights_cpp(cpp11::external_pointer dataset_ptr, cpp11::doubles weights, bool exponentiate); +extern "C" SEXP _stochtree_rfx_dataset_update_var_weights_cpp(SEXP dataset_ptr, SEXP weights, SEXP exponentiate) { + BEGIN_CPP11 + rfx_dataset_update_var_weights_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(weights), cpp11::as_cpp>(exponentiate)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp +void rfx_dataset_update_group_labels_cpp(cpp11::external_pointer dataset_ptr, cpp11::integers group_labels); +extern "C" SEXP _stochtree_rfx_dataset_update_group_labels_cpp(SEXP dataset_ptr, SEXP group_labels) { + BEGIN_CPP11 + rfx_dataset_update_group_labels_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(group_labels)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp +int rfx_dataset_num_basis_cpp(cpp11::external_pointer dataset); +extern "C" SEXP _stochtree_rfx_dataset_num_basis_cpp(SEXP dataset) { + BEGIN_CPP11 + return cpp11::as_sexp(rfx_dataset_num_basis_cpp(cpp11::as_cpp>>(dataset))); + END_CPP11 +} +// R_data.cpp int rfx_dataset_num_rows_cpp(cpp11::external_pointer dataset); extern "C" SEXP _stochtree_rfx_dataset_num_rows_cpp(SEXP dataset) { BEGIN_CPP11 @@ -184,6 +244,27 @@ extern "C" SEXP _stochtree_rfx_dataset_add_weights_cpp(SEXP dataset_ptr, SEXP we return R_NilValue; END_CPP11 } +// R_data.cpp +cpp11::writable::integers rfx_dataset_get_group_labels_cpp(cpp11::external_pointer dataset_ptr); +extern "C" SEXP _stochtree_rfx_dataset_get_group_labels_cpp(SEXP dataset_ptr) { + BEGIN_CPP11 + return cpp11::as_sexp(rfx_dataset_get_group_labels_cpp(cpp11::as_cpp>>(dataset_ptr))); + END_CPP11 +} +// R_data.cpp +cpp11::writable::doubles_matrix<> rfx_dataset_get_basis_cpp(cpp11::external_pointer dataset_ptr); +extern "C" SEXP _stochtree_rfx_dataset_get_basis_cpp(SEXP dataset_ptr) { + BEGIN_CPP11 + return cpp11::as_sexp(rfx_dataset_get_basis_cpp(cpp11::as_cpp>>(dataset_ptr))); + END_CPP11 +} +// R_data.cpp +cpp11::writable::doubles rfx_dataset_get_variance_weights_cpp(cpp11::external_pointer dataset_ptr); +extern "C" SEXP _stochtree_rfx_dataset_get_variance_weights_cpp(SEXP dataset_ptr) { + BEGIN_CPP11 + return cpp11::as_sexp(rfx_dataset_get_variance_weights_cpp(cpp11::as_cpp>>(dataset_ptr))); + END_CPP11 +} // R_random_effects.cpp cpp11::external_pointer rfx_container_cpp(int num_components, int num_groups); extern "C" SEXP _stochtree_rfx_container_cpp(SEXP num_components, SEXP num_groups) { @@ -1540,7 +1621,11 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_forest_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_basis_cpp, 2}, {"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2}, {"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2}, + {"_stochtree_forest_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_basis_cpp, 1}, + {"_stochtree_forest_dataset_get_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_covariates_cpp, 1}, + {"_stochtree_forest_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_variance_weights_cpp, 1}, {"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2}, + {"_stochtree_forest_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_var_weights_cpp, 3}, {"_stochtree_forest_merge_cpp", (DL_FUNC) &_stochtree_forest_merge_cpp, 2}, {"_stochtree_forest_multiply_constant_cpp", (DL_FUNC) &_stochtree_forest_multiply_constant_cpp, 2}, {"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4}, @@ -1659,10 +1744,17 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rfx_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_basis_cpp, 2}, {"_stochtree_rfx_dataset_add_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_group_labels_cpp, 2}, {"_stochtree_rfx_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_weights_cpp, 2}, + {"_stochtree_rfx_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_basis_cpp, 1}, + {"_stochtree_rfx_dataset_get_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_group_labels_cpp, 1}, + {"_stochtree_rfx_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_variance_weights_cpp, 1}, {"_stochtree_rfx_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_basis_cpp, 1}, {"_stochtree_rfx_dataset_has_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_group_labels_cpp, 1}, {"_stochtree_rfx_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_variance_weights_cpp, 1}, + {"_stochtree_rfx_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_basis_cpp, 1}, {"_stochtree_rfx_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_rows_cpp, 1}, + {"_stochtree_rfx_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_basis_cpp, 2}, + {"_stochtree_rfx_dataset_update_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_group_labels_cpp, 2}, + {"_stochtree_rfx_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_var_weights_cpp, 3}, {"_stochtree_rfx_group_ids_from_json_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_cpp, 2}, {"_stochtree_rfx_group_ids_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_string_cpp, 2}, {"_stochtree_rfx_label_mapper_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_cpp, 1}, diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 32bbd707..950caeb8 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -67,6 +67,56 @@ class ForestDatasetCpp { dataset_->AddVarianceWeights(data_ptr, num_row); } + void UpdateVarianceWeights(py::array_t weight_vector, data_size_t num_row, bool exponentiate) { + // Extract pointer to contiguous block of memory + double* data_ptr = static_cast(weight_vector.mutable_data()); + + // Load covariates + dataset_->UpdateVarWeights(data_ptr, num_row, exponentiate); + } + + py::array_t GetCovariates() { + // Initialize n x p numpy array to store the covariates + data_size_t n = dataset_->NumObservations(); + int num_covariates = dataset_->NumCovariates(); + auto result = py::array_t(py::detail::any_container({n, num_covariates})); + auto accessor = result.mutable_unchecked<2>(); + for (size_t i = 0; i < n; i++) { + for (int j = 0; j < num_covariates; j++) { + accessor(i,j) = dataset_->CovariateValue(i,j); + } + } + + return result; + } + + py::array_t GetBasis() { + // Initialize n x k numpy array to store the basis + data_size_t n = dataset_->NumObservations(); + int num_basis = dataset_->NumBasis(); + auto result = py::array_t(py::detail::any_container({n, num_basis})); + auto accessor = result.mutable_unchecked<2>(); + for (size_t i = 0; i < n; i++) { + for (int j = 0; j < num_basis; j++) { + accessor(i,j) = dataset_->BasisValue(i,j); + } + } + + return result; + } + + py::array_t GetVarianceWeights() { + // Initialize n x 1 numpy array to store the variance weights + data_size_t n = dataset_->NumObservations(); + auto result = py::array_t(py::detail::any_container({n})); + auto accessor = result.mutable_unchecked<1>(); + for (size_t i = 0; i < n; i++) { + accessor(i) = dataset_->VarWeightValue(i); + } + + return result; + } + data_size_t NumRows() { return dataset_->NumObservations(); } @@ -1289,6 +1339,52 @@ class RandomEffectsDatasetCpp { double* weight_data_ptr = static_cast(weights.mutable_data()); rfx_dataset_->AddVarianceWeights(weight_data_ptr, num_row); } + void UpdateBasis(py::array_t basis, data_size_t num_row, int num_col, bool row_major) { + double* basis_data_ptr = static_cast(basis.mutable_data()); + rfx_dataset_->UpdateBasis(basis_data_ptr, num_row, num_col, row_major); + } + void UpdateVarianceWeights(py::array_t weights, data_size_t num_row, bool exponentiate) { + double* weight_data_ptr = static_cast(weights.mutable_data()); + rfx_dataset_->UpdateVarWeights(weight_data_ptr, num_row, exponentiate); + } + void UpdateGroupLabels(py::array_t group_labels, data_size_t num_row) { + std::vector group_labels_vec(num_row); + auto accessor = group_labels.mutable_unchecked<1>(); + for (py::ssize_t i = 0; i < num_row; i++) { + group_labels_vec[i] = accessor(i); + } + rfx_dataset_->UpdateGroupLabels(group_labels_vec, num_row); + } + py::array_t GetBasis() { + int num_row = rfx_dataset_->NumObservations(); + int num_col = rfx_dataset_->NumBases(); + auto result = py::array_t(py::detail::any_container({num_row, num_col})); + auto accessor = result.mutable_unchecked<2>(); + for (py::ssize_t i = 0; i < num_row; i++) { + for (int j = 0; j < num_col; j++) { + accessor(i,j) = rfx_dataset_->BasisValue(i,j); + } + } + return result; + } + py::array_t GetVarianceWeights() { + int num_row = rfx_dataset_->NumObservations(); + auto result = py::array_t(py::detail::any_container({num_row})); + auto accessor = result.mutable_unchecked<1>(); + for (py::ssize_t i = 0; i < num_row; i++) { + accessor(i) = rfx_dataset_->VarWeightValue(i); + } + return result; + } + py::array_t GetGroupLabels() { + int num_row = rfx_dataset_->NumObservations(); + auto result = py::array_t(py::detail::any_container({num_row})); + auto accessor = result.mutable_unchecked<1>(); + for (py::ssize_t i = 0; i < num_row; i++) { + accessor(i) = rfx_dataset_->GroupId(i); + } + return result; + } bool HasGroupLabels() {return rfx_dataset_->HasGroupLabels();} bool HasBasis() {return rfx_dataset_->HasBasis();} bool HasVarianceWeights() {return rfx_dataset_->HasVarWeights();} @@ -2043,9 +2139,13 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("AddBasis", &ForestDatasetCpp::AddBasis) .def("UpdateBasis", &ForestDatasetCpp::UpdateBasis) .def("AddVarianceWeights", &ForestDatasetCpp::AddVarianceWeights) + .def("UpdateVarianceWeights", &ForestDatasetCpp::UpdateVarianceWeights) .def("NumRows", &ForestDatasetCpp::NumRows) .def("NumCovariates", &ForestDatasetCpp::NumCovariates) .def("NumBasis", &ForestDatasetCpp::NumBasis) + .def("GetCovariates", &ForestDatasetCpp::GetCovariates) + .def("GetBasis", &ForestDatasetCpp::GetBasis) + .def("GetVarianceWeights", &ForestDatasetCpp::GetVarianceWeights) .def("HasBasis", &ForestDatasetCpp::HasBasis) .def("HasVarianceWeights", &ForestDatasetCpp::HasVarianceWeights); @@ -2176,6 +2276,12 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("AddGroupLabels", &RandomEffectsDatasetCpp::AddGroupLabels) .def("AddBasis", &RandomEffectsDatasetCpp::AddBasis) .def("AddVarianceWeights", &RandomEffectsDatasetCpp::AddVarianceWeights) + .def("UpdateGroupLabels", &RandomEffectsDatasetCpp::UpdateGroupLabels) + .def("UpdateBasis", &RandomEffectsDatasetCpp::UpdateBasis) + .def("UpdateVarianceWeights", &RandomEffectsDatasetCpp::UpdateVarianceWeights) + .def("GetGroupLabels", &RandomEffectsDatasetCpp::GetGroupLabels) + .def("GetBasis", &RandomEffectsDatasetCpp::GetBasis) + .def("GetVarianceWeights", &RandomEffectsDatasetCpp::GetVarianceWeights) .def("HasGroupLabels", &RandomEffectsDatasetCpp::HasGroupLabels) .def("HasBasis", &RandomEffectsDatasetCpp::HasBasis) .def("HasVarianceWeights", &RandomEffectsDatasetCpp::HasVarianceWeights); diff --git a/stochtree/data.py b/stochtree/data.py index 8cbe76e0..4e40a282 100644 --- a/stochtree/data.py +++ b/stochtree/data.py @@ -58,9 +58,22 @@ def update_basis(self, basis: np.array): basis : np.array Numpy array of basis vectors. """ - basis_ = np.expand_dims(basis, 1) if np.ndim(basis) == 1 else basis + if not self.has_basis(): + raise ValueError("This dataset does not have a basis to update. Please use `add_basis` to create and initialize the values in the Dataset's basis matrix.") + if not isinstance(basis, np.ndarray): + raise ValueError("basis must be a numpy array.") + if np.ndim(basis) == 1: + basis_ = np.expand_dims(basis, 1) + elif np.ndim(basis) == 2: + basis_ = basis + else: + raise ValueError("basis must be a numpy array with one or two dimension.") n, p = basis_.shape basis_rowmajor = np.ascontiguousarray(basis_) + if self.num_basis() != p: + raise ValueError(f"The number of columns in the new basis ({p}) must match the number of columns in the existing basis ({self.num_basis()}).") + if self.num_observations() != n: + raise ValueError(f"The number of rows in the new basis ({n}) must match the number of rows in the existing basis ({self.num_observations()}).") self.dataset_cpp.UpdateBasis(basis_rowmajor, n, p, True) def add_variance_weights(self, variance_weights: np.array): @@ -72,8 +85,38 @@ def add_variance_weights(self, variance_weights: np.array): variance_weights : np.array Univariate numpy array of variance weights. """ - n = variance_weights.size - self.dataset_cpp.AddVarianceWeights(variance_weights, n) + if not isinstance(variance_weights, np.ndarray): + raise ValueError("variance_weights must be a numpy array.") + variance_weights_ = np.squeeze(variance_weights) + n = variance_weights_.size + if variance_weights_.ndim != 1: + raise ValueError("variance_weights must be a 1-dimensional numpy array.") + + self.dataset_cpp.AddVarianceWeights(variance_weights_, n) + + def update_variance_weights(self, variance_weights: np.array, exponentiate: bool = False): + """ + Update variance weights in a dataset. Allows users to build an ensemble that depends on + variance weights that are updated throughout the sampler. + + Parameters + ---------- + variance_weights : np.array + Univariate numpy array of variance weights. + exponentiate : bool + Whether to exponentiate the variance weights before storing them in the dataset. + """ + if not self.has_variance_weights(): + raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.") + if not isinstance(variance_weights, np.ndarray): + raise ValueError("variance_weights must be a numpy array.") + variance_weights_ = np.squeeze(variance_weights) + n = variance_weights_.size + if variance_weights_.ndim != 1: + raise ValueError("variance_weights must be a 1-dimensional numpy array.") + if self.num_observations() != n: + raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).") + self.dataset_cpp.UpdateVarianceWeights(variance_weights_, n, exponentiate) def num_observations(self) -> int: """ @@ -107,6 +150,39 @@ def num_basis(self) -> int: Dimension of the basis vector in the dataset, returning 0 if the dataset does not have a basis """ return self.dataset_cpp.NumBasis() + + def get_covariates(self) -> np.array: + """ + Return the covariates in a Dataset as a numpy array + + Returns + ------- + np.array + Covariate data + """ + return self.dataset_cpp.GetCovariates() + + def get_basis(self) -> np.array: + """ + Return the bases in a Dataset as a numpy array + + Returns + ------- + np.array + Basis data + """ + return self.dataset_cpp.GetBasis() + + def get_variance_weights(self) -> np.array: + """ + Return the variance weights in a Dataset as a numpy array + + Returns + ------- + np.array + Variance weights data + """ + return self.dataset_cpp.GetVarianceWeights() def has_basis(self) -> bool: """ diff --git a/stochtree/random_effects.py b/stochtree/random_effects.py index 18144044..6c4093d4 100644 --- a/stochtree/random_effects.py +++ b/stochtree/random_effects.py @@ -40,6 +40,23 @@ def add_group_labels(self, group_labels: np.array): n = group_labels_.shape[0] self.rfx_dataset_cpp.AddGroupLabels(group_labels_, n) + def update_group_labels(self, group_labels: np.array): + """ + Update group labels in a dataset + + Parameters + ---------- + group_labels : np.array + One-dimensional numpy array of group labels. + """ + group_labels_ = np.squeeze(group_labels) + if group_labels_.ndim > 1: + raise ValueError( + "group_labels must be a one-dimensional numpy array of group indices" + ) + n = group_labels_.shape[0] + self.rfx_dataset_cpp.UpdateGroupLabels(group_labels_, n) + def add_basis(self, basis: np.array): """ Add basis matrix to a dataset @@ -93,6 +110,65 @@ def add_variance_weights(self, variance_weights: np.array): ) n = variance_weights_.shape[0] self.rfx_dataset_cpp.AddVarianceWeights(variance_weights_, n) + + def update_variance_weights(self, variance_weights: np.array, exponentiate: bool = False): + """ + Update variance weights in a dataset. Allows users to build an ensemble that depends on + variance weights that are updated throughout the sampler. + + Parameters + ---------- + variance_weights : np.array + Univariate numpy array of variance weights. + exponentiate : bool + Whether to exponentiate the variance weights before storing them in the dataset. + """ + if not self.has_variance_weights(): + raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.") + if not isinstance(variance_weights, np.ndarray): + raise ValueError("variance_weights must be a numpy array.") + variance_weights_ = np.squeeze(variance_weights) + if variance_weights_.ndim > 1: + raise ValueError( + "variance_weights must be a one-dimensional numpy array of group indices" + ) + n = variance_weights_.shape[0] + if self.num_observations() != n: + raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).") + self.rfx_dataset_cpp.UpdateVarianceWeights(variance_weights, n, exponentiate) + + def get_group_labels(self) -> np.array: + """ + Return the group labels in a RandomEffectsDataset as a numpy array + + Returns + ------- + np.array + One-dimensional numpy array of group labels. + """ + return self.rfx_dataset_cpp.GetGroupLabels() + + def get_basis(self) -> np.array: + """ + Return the bases in a RandomEffectsDataset as a numpy array + + Returns + ------- + np.array + Two-dimensional numpy array of basis vectors. + """ + return self.rfx_dataset_cpp.GetBasis() + + def get_variance_weights(self) -> np.array: + """ + Return the variance weights in a RandomEffectsDataset as a numpy array + + Returns + ------- + np.array + One-dimensional numpy array of variance weights. + """ + return self.rfx_dataset_cpp.GetVarianceWeights() def num_observations(self) -> int: """ diff --git a/test/R/testthat/test-dataset.R b/test/R/testthat/test-dataset.R new file mode 100644 index 00000000..80f8d8e0 --- /dev/null +++ b/test/R/testthat/test-dataset.R @@ -0,0 +1,66 @@ +test_that("ForestDataset can be constructed and updated", { + # Generate data + n <- 20 + num_covariates <- 10 + num_basis <- 5 + covariates <- matrix(runif(n * num_covariates), ncol = num_covariates) + basis <- matrix(runif(n * num_basis), ncol = num_basis) + variance_weights <- runif(n) + + # Copy data to a ForestDataset object + forest_dataset <- createForestDataset(covariates, basis, variance_weights) + + # Run first round of expectations + expect_equal(forest_dataset$num_observations(), n) + expect_equal(forest_dataset$num_covariates(), num_covariates) + expect_equal(forest_dataset$num_basis(), num_basis) + expect_equal(forest_dataset$has_variance_weights(), T) + + # Update data + new_basis <- matrix(runif(n * num_basis), ncol = num_basis) + new_variance_weights <- runif(n) + expect_no_error( + forest_dataset$update_basis(new_basis) + ) + expect_no_error( + forest_dataset$update_variance_weights(new_variance_weights) + ) + + # Check that we recover the correct data through get_covariates, get_basis, and get_variance_weights + expect_equal(covariates, forest_dataset$get_covariates()) + expect_equal(new_basis, forest_dataset$get_basis()) + expect_equal(new_variance_weights, forest_dataset$get_variance_weights()) +}) + +test_that("RandomEffectsDataset can be constructed and updated", { + # Generate data + n <- 20 + num_groups <- 4 + num_basis <- 5 + group_ids <- sample(as.integer(1:num_groups), size = n, replace = T) + rfx_basis <- cbind(1, matrix(runif(n*(num_basis-1)), ncol = (num_basis-1))) + variance_weights <- runif(n) + + # Copy data to a RandomEffectsDataset object + rfx_dataset <- createRandomEffectsDataset(group_ids, rfx_basis, variance_weights) + + # Run first round of expectations + expect_equal(rfx_dataset$num_observations(), n) + expect_equal(rfx_dataset$num_basis(), num_basis) + expect_equal(rfx_dataset$has_variance_weights(), T) + + # Update data + new_rfx_basis <- matrix(runif(n * num_basis), ncol = num_basis) + new_variance_weights <- runif(n) + expect_no_error( + rfx_dataset$update_basis(new_rfx_basis) + ) + expect_no_error( + rfx_dataset$update_variance_weights(new_variance_weights) + ) + + # Check that we recover the correct data through get_group_labels, get_basis, and get_variance_weights + expect_equal(group_ids, rfx_dataset$get_group_labels()) + expect_equal(new_rfx_basis, rfx_dataset$get_basis()) + expect_equal(new_variance_weights, rfx_dataset$get_variance_weights()) +}) diff --git a/test/python/test_data.py b/test/python/test_data.py new file mode 100644 index 00000000..09c75154 --- /dev/null +++ b/test/python/test_data.py @@ -0,0 +1,72 @@ +import numpy as np + +from stochtree import Dataset, RandomEffectsDataset + +class TestDataset: + def test_dataset_update(self): + # Generate data + n = 20 + num_covariates = 10 + num_basis = 5 + rng = np.random.default_rng() + covariates = rng.uniform(0, 1, size=(n, num_covariates)) + basis = rng.uniform(0, 1, size=(n, num_basis)) + variance_weights = rng.uniform(0, 1, size=n) + + # Construct dataset + forest_dataset = Dataset() + forest_dataset.add_covariates(covariates) + forest_dataset.add_basis(basis) + forest_dataset.add_variance_weights(variance_weights) + assert forest_dataset.num_observations() == n + assert forest_dataset.num_covariates() == num_covariates + assert forest_dataset.num_basis() == num_basis + assert forest_dataset.has_variance_weights() + + # Update dataset + new_basis = rng.uniform(0, 1, size=(n, num_basis)) + new_variance_weights = rng.uniform(0, 1, size=n) + with np.testing.assert_no_warnings(): + forest_dataset.update_basis(new_basis) + forest_dataset.update_variance_weights(new_variance_weights) + + # Check that we recover the correct data through get_covariates, get_basis, and get_variance_weights + np.testing.assert_array_equal(forest_dataset.get_covariates(), covariates) + np.testing.assert_array_equal(forest_dataset.get_basis(), new_basis) + np.testing.assert_array_equal(forest_dataset.get_variance_weights(), new_variance_weights) + +class TestRFXDataset: + def test_rfx_dataset_update(self): + # Generate data + n = 20 + num_groups = 4 + num_basis = 5 + rng = np.random.default_rng() + group_labels = rng.choice(num_groups, size=n) + basis = np.empty((n, num_basis)) + basis[:, 0] = 1.0 + if num_basis > 1: + basis[:, 1:] = rng.uniform(-1, 1, (n, num_basis - 1)) + variance_weights = rng.uniform(0, 1, size=n) + + # Construct dataset + rfx_dataset = RandomEffectsDataset() + rfx_dataset.add_group_labels(group_labels) + rfx_dataset.add_basis(basis) + rfx_dataset.add_variance_weights(variance_weights) + assert rfx_dataset.num_observations() == n + assert rfx_dataset.num_basis() == num_basis + assert rfx_dataset.has_variance_weights() + + # Update dataset + new_basis = rng.uniform(0, 1, size=(n, num_basis)) + new_variance_weights = rng.uniform(0, 1, size=n) + with np.testing.assert_no_warnings(): + rfx_dataset.update_basis(new_basis) + rfx_dataset.update_variance_weights(new_variance_weights) + + # Check that we recover the correct data through get_group_labels, get_basis, and get_variance_weights + np.testing.assert_array_equal(rfx_dataset.get_group_labels(), group_labels) + np.testing.assert_array_equal(rfx_dataset.get_basis(), new_basis) + np.testing.assert_array_equal(rfx_dataset.get_variance_weights(), new_variance_weights) + diff --git a/tools/debug/dataset_demo.R b/tools/debug/dataset_demo.R new file mode 100644 index 00000000..3510f0ab --- /dev/null +++ b/tools/debug/dataset_demo.R @@ -0,0 +1,39 @@ +# Load libraries +library(stochtree) + +# Generate "forest" data +n <- 20 +num_covariates <- 10 +num_basis <- 5 +covariates <- matrix(runif(n * num_covariates), ncol = num_covariates) +basis <- matrix(runif(n * num_basis), ncol = num_basis) +variance_weights <- runif(n) + +# Create a ForestDataset object +forest_dataset <- createForestDataset(covariates, basis, variance_weights) + +# Update forest_dataset's basis +new_basis <- matrix(runif(n * num_basis), ncol = num_basis) +forest_dataset$update_basis(new_basis) + +# Update forest_dataset's variance_weights +new_variance_weights <- runif(n) +forest_dataset$update_variance_weights(new_variance_weights) + +# Generate RFX data +group_ids <- sample(as.integer(c(1,2)), size = n, replace = T) +rfx_basis <- cbind(1, runif(n)) + +# Create a RandomEffectsDataset object +rfx_dataset <- createRandomEffectsDataset( + group_labels = group_ids, basis = rfx_basis, + variance_weights = variance_weights +) + +# Update rfx_dataset's basis +new_rfx_basis <- cbind(1, runif(n)) +rfx_dataset$update_basis(new_rfx_basis) + +# Update rfx_dataset's variance weights +rfx_dataset$update_variance_weights(new_variance_weights) +