Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,26 @@ forest_dataset_update_basis_cpp <- function(dataset_ptr, basis) {
invisible(.Call(`_stochtree_forest_dataset_update_basis_cpp`, dataset_ptr, basis))
}

forest_dataset_update_var_weights_cpp <- function(dataset_ptr, weights, exponentiate) {
invisible(.Call(`_stochtree_forest_dataset_update_var_weights_cpp`, dataset_ptr, weights, exponentiate))
}

forest_dataset_add_weights_cpp <- function(dataset_ptr, weights) {
invisible(.Call(`_stochtree_forest_dataset_add_weights_cpp`, dataset_ptr, weights))
}

forest_dataset_get_covariates_cpp <- function(dataset_ptr) {
.Call(`_stochtree_forest_dataset_get_covariates_cpp`, dataset_ptr)
}

forest_dataset_get_basis_cpp <- function(dataset_ptr) {
.Call(`_stochtree_forest_dataset_get_basis_cpp`, dataset_ptr)
}

forest_dataset_get_variance_weights_cpp <- function(dataset_ptr) {
.Call(`_stochtree_forest_dataset_get_variance_weights_cpp`, dataset_ptr)
}

create_column_vector_cpp <- function(outcome) {
.Call(`_stochtree_create_column_vector_cpp`, outcome)
}
Expand Down Expand Up @@ -68,6 +84,22 @@ create_rfx_dataset_cpp <- function() {
.Call(`_stochtree_create_rfx_dataset_cpp`)
}

rfx_dataset_update_basis_cpp <- function(dataset_ptr, basis) {
invisible(.Call(`_stochtree_rfx_dataset_update_basis_cpp`, dataset_ptr, basis))
}

rfx_dataset_update_var_weights_cpp <- function(dataset_ptr, weights, exponentiate) {
invisible(.Call(`_stochtree_rfx_dataset_update_var_weights_cpp`, dataset_ptr, weights, exponentiate))
}

rfx_dataset_update_group_labels_cpp <- function(dataset_ptr, group_labels) {
invisible(.Call(`_stochtree_rfx_dataset_update_group_labels_cpp`, dataset_ptr, group_labels))
}

rfx_dataset_num_basis_cpp <- function(dataset) {
.Call(`_stochtree_rfx_dataset_num_basis_cpp`, dataset)
}

rfx_dataset_num_rows_cpp <- function(dataset) {
.Call(`_stochtree_rfx_dataset_num_rows_cpp`, dataset)
}
Expand Down Expand Up @@ -96,6 +128,18 @@ rfx_dataset_add_weights_cpp <- function(dataset_ptr, weights) {
invisible(.Call(`_stochtree_rfx_dataset_add_weights_cpp`, dataset_ptr, weights))
}

rfx_dataset_get_group_labels_cpp <- function(dataset_ptr) {
.Call(`_stochtree_rfx_dataset_get_group_labels_cpp`, dataset_ptr)
}

rfx_dataset_get_basis_cpp <- function(dataset_ptr) {
.Call(`_stochtree_rfx_dataset_get_basis_cpp`, dataset_ptr)
}

rfx_dataset_get_variance_weights_cpp <- function(dataset_ptr) {
.Call(`_stochtree_rfx_dataset_get_variance_weights_cpp`, dataset_ptr)
}

rfx_container_cpp <- function(num_components, num_groups) {
.Call(`_stochtree_rfx_container_cpp`, num_components, num_groups)
}
Expand Down
77 changes: 76 additions & 1 deletion R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,16 @@ ForestDataset <- R6::R6Class(
update_basis = function(basis) {
stopifnot(self$has_basis())
forest_dataset_update_basis_cpp(self$data_ptr, basis)
},
},

#' @description
#' Update variance_weights in a dataset
#' @param variance_weights Updated vector of variance weights used to define individual variance / case weights
#' @param exponentiate Whether or not input vector should be exponentiated before being written to the Dataset's variance weights. Default: F.
update_variance_weights = function(variance_weights, exponentiate = F) {
stopifnot(self$has_variance_weights())
forest_dataset_update_var_weights_cpp(self$data_ptr, variance_weights, exponentiate)
},

#' @description
#' Return number of observations in a `ForestDataset` object
Expand All @@ -59,6 +68,27 @@ ForestDataset <- R6::R6Class(
return(dataset_num_basis_cpp(self$data_ptr))
},

#' @description
#' Return covariates as an R matrix
#' @return Covariate data
get_covariates = function() {
return(forest_dataset_get_covariates_cpp(self$data_ptr))
},

#' @description
#' Return bases as an R matrix
#' @return Basis data
get_basis = function() {
return(forest_dataset_get_basis_cpp(self$data_ptr))
},

#' @description
#' Return variance weights as an R vector
#' @return Variance weight data
get_variance_weights = function() {
return(forest_dataset_get_variance_weights_cpp(self$data_ptr))
},

#' @description
#' Whether or not a dataset has a basis matrix
#' @return True if basis matrix is loaded, false otherwise
Expand Down Expand Up @@ -190,11 +220,56 @@ RandomEffectsDataset <- R6::R6Class(
}
},

#' @description
#' Update basis matrix in a dataset
#' @param basis Updated matrix of bases used to define random slopes / intercepts
update_basis = function(basis) {
stopifnot(self$has_basis())
rfx_dataset_update_basis_cpp(self$data_ptr, basis)
},

#' @description
#' Update variance_weights in a dataset
#' @param variance_weights Updated vector of variance weights used to define individual variance / case weights
#' @param exponentiate Whether or not input vector should be exponentiated before being written to the RandomEffectsDataset's variance weights. Default: F.
update_variance_weights = function(variance_weights, exponentiate = F) {
stopifnot(self$has_variance_weights())
rfx_dataset_update_var_weights_cpp(self$data_ptr, variance_weights, exponentiate)
},

#' @description
#' Return number of observations in a `RandomEffectsDataset` object
#' @return Observation count
num_observations = function() {
return(rfx_dataset_num_rows_cpp(self$data_ptr))
},

#' @description
#' Return dimension of the basis matrix in a `RandomEffectsDataset` object
#' @return Basis vector count
num_basis = function() {
return(rfx_dataset_num_basis_cpp(self$data_ptr))
},

#' @description
#' Return group labels as an R vector
#' @return Group label data
get_group_labels = function() {
return(rfx_dataset_get_group_labels_cpp(self$data_ptr))
},

#' @description
#' Return bases as an R matrix
#' @return Basis data
get_basis = function() {
return(rfx_dataset_get_basis_cpp(self$data_ptr))
},

#' @description
#' Return variance weights as an R vector
#' @return Variance weight data
get_variance_weights = function() {
return(rfx_dataset_get_variance_weights_cpp(self$data_ptr))
},

#' @description
Expand Down
60 changes: 60 additions & 0 deletions include/stochtree/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ class RandomEffectsDataset {
*/
void AddBasis(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) {
basis_ = ColumnMatrix(data_ptr, num_row, num_col, is_row_major);
num_basis_ = num_col;
has_basis_ = true;
}
/*!
Expand All @@ -509,6 +510,64 @@ class RandomEffectsDataset {
var_weights_ = ColumnVector(data_ptr, num_row);
has_var_weights_ = true;
}
/*!
* \brief Update the data in the internal basis matrix to new values stored in a raw double array
*
* \param data_ptr Pointer to first element of a contiguous array of data storing a basis matrix
* \param num_row Number of rows in the basis matrix
* \param num_col Number of columns in the basis matrix
* \param is_row_major Whether or not the data in `data_ptr` are organized in a row-major or column-major fashion
*/
void UpdateBasis(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) {
CHECK(has_basis_);
CHECK_EQ(num_col, num_basis_);
// Copy data from R / Python process memory to Eigen matrix
double temp_value;
for (data_size_t i = 0; i < num_row; ++i) {
for (int j = 0; j < num_col; ++j) {
if (is_row_major){
// Numpy 2-d arrays are stored in "row major" order
temp_value = static_cast<double>(*(data_ptr + static_cast<data_size_t>(num_col) * i + j));
} else {
// R matrices are stored in "column major" order
temp_value = static_cast<double>(*(data_ptr + static_cast<data_size_t>(num_row) * j + i));
}
basis_.SetElement(i, j, temp_value);
}
}
}
/*!
* \brief Update the data in the internal variance weight vector to new values stored in a raw double array
*
* \param data_ptr Pointer to first element of a contiguous array of data storing a weight vector
* \param num_row Number of rows in the weight vector
* \param exponentiate Whether or not inputs should be exponentiated before being saved to var weight vector
*/
void UpdateVarWeights(double* data_ptr, data_size_t num_row, bool exponentiate = true) {
CHECK(has_var_weights_);
// Copy data from R / Python process memory to Eigen vector
double temp_value;
for (data_size_t i = 0; i < num_row; ++i) {
if (exponentiate) temp_value = std::exp(static_cast<double>(*(data_ptr + i)));
else temp_value = static_cast<double>(*(data_ptr + i));
var_weights_.SetElement(i, temp_value);
}
}
/*!
* \brief Update a RandomEffectsDataset's group indices
*
* \param data_ptr Pointer to first element of a contiguous array of data storing a weight vector
* \param num_row Number of rows in the weight vector
* \param exponentiate Whether or not inputs should be exponentiated before being saved to var weight vector
*/
void UpdateGroupLabels(std::vector<int32_t>& group_labels, data_size_t num_row) {
CHECK(has_group_labels_);
CHECK_EQ(this->NumObservations(), num_row)
// Copy data from R / Python process memory to internal vector
for (data_size_t i = 0; i < num_row; ++i) {
group_labels_[i] = group_labels[i];
}
}
/*!
* \brief Copy / load group indices for random effects
*
Expand Down Expand Up @@ -570,6 +629,7 @@ class RandomEffectsDataset {
ColumnMatrix basis_;
ColumnVector var_weights_;
std::vector<int32_t> group_labels_;
int num_basis_{0};
bool has_basis_{false};
bool has_var_weights_{false};
bool has_group_labels_{false};
Expand Down
62 changes: 62 additions & 0 deletions man/ForestDataset.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/ForestModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading