From c956b7943fa111e03a8a0bed472d5a4d3e7a9065 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 17 Nov 2025 22:28:51 -0600 Subject: [PATCH] Updated Python data interface for parity with the R interface --- src/py_stochtree.cpp | 18 +++++++++++++++++- stochtree/data.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index b2fc4ded..05bcc876 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -182,6 +182,20 @@ class ResidualCpp { residual_->OverwriteData(data_ptr, num_row); } + void AddToData(py::array_t update_vector, data_size_t num_row) { + // Extract pointer to contiguous block of memory + double* data_ptr = static_cast(update_vector.mutable_data()); + // Add to data in residual_ + residual_->AddToData(data_ptr, num_row); + } + + void SubtractFromData(py::array_t update_vector, data_size_t num_row) { + // Extract pointer to contiguous block of memory + double* data_ptr = static_cast(update_vector.mutable_data()); + // Subtract from data in residual_ + residual_->SubtractFromData(data_ptr, num_row); + } + private: std::unique_ptr residual_; }; @@ -2224,7 +2238,9 @@ PYBIND11_MODULE(stochtree_cpp, m) { py::class_(m, "ResidualCpp") .def(py::init,data_size_t>()) .def("GetResidualArray", &ResidualCpp::GetResidualArray) - .def("ReplaceData", &ResidualCpp::ReplaceData); + .def("ReplaceData", &ResidualCpp::ReplaceData) + .def("AddToData", &ResidualCpp::AddToData) + .def("SubtractFromData", &ResidualCpp::SubtractFromData); py::class_(m, "RngCpp") .def(py::init()); diff --git a/stochtree/data.py b/stochtree/data.py index da3ca735..3aaf740c 100644 --- a/stochtree/data.py +++ b/stochtree/data.py @@ -264,3 +264,37 @@ def update_data(self, new_vector: np.array) -> None: """ n = new_vector.size self.residual_cpp.ReplaceData(new_vector, n) + + def add_vector(self, update_vector: np.array) -> None: + """ + Update the current state of the outcome (i.e. partial residual) data by adding each element of `update_vector` + + Parameters + ---------- + update_vector : np.array + Univariate numpy array of values to add to the current residual. + """ + if not isinstance(update_vector, np.ndarray): + raise ValueError("update_vector must be a numpy array.") + update_vector_ = np.squeeze(update_vector) + if not update_vector_.ndim == 1: + raise ValueError("update_vector must be a 1-dimensional numpy array.") + n = update_vector_.size + self.residual_cpp.AddToData(update_vector_, n) + + def subtract_vector(self, update_vector: np.array) -> None: + """ + Update the current state of the outcome (i.e. partial residual) data by subtracting each element of `update_vector` + + Parameters + ---------- + update_vector : np.array + Univariate numpy array of values to subtracted from the current residual. + """ + if not isinstance(update_vector, np.ndarray): + raise ValueError("update_vector must be a numpy array.") + update_vector_ = np.squeeze(update_vector) + if not update_vector_.ndim == 1: + raise ValueError("update_vector must be a 1-dimensional numpy array.") + n = update_vector_.size + self.residual_cpp.SubtractFromData(update_vector_, n)