diff --git a/demo/notebooks/serialization.ipynb b/demo/notebooks/serialization.ipynb new file mode 100644 index 00000000..a29018fe --- /dev/null +++ b/demo/notebooks/serialization.ipynb @@ -0,0 +1,268 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Serialization Demo Notebook" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demo 1: Supervised Learning" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load necessary libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "from stochtree import BARTModel\n", + "from sklearn.model_selection import train_test_split" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate sample data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# RNG\n", + "random_seed = 1234\n", + "rng = np.random.default_rng(random_seed)\n", + "\n", + "# Generate covariates and basis\n", + "n = 1000\n", + "p_X = 10\n", + "p_W = 1\n", + "X = rng.uniform(0, 1, (n, p_X))\n", + "W = rng.uniform(0, 1, (n, p_W))\n", + "\n", + "# Define the outcome mean function\n", + "def outcome_mean(X, W):\n", + " return np.where(\n", + " (X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * W[:,0], \n", + " np.where(\n", + " (X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * W[:,0], \n", + " np.where(\n", + " (X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * W[:,0], \n", + " 7.5 * W[:,0]\n", + " )\n", + " )\n", + " )\n", + "\n", + "# Generate outcome\n", + "epsilon = rng.normal(0, 1, n)\n", + "y = outcome_mean(X, W) + epsilon\n", + "\n", + "# Standardize outcome\n", + "y_bar = np.mean(y)\n", + "y_std = np.std(y)\n", + "resid = (y-y_bar)/y_std" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Test-train split" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sample_inds = np.arange(n)\n", + "train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)\n", + "X_train = X[train_inds,:]\n", + "X_test = X[test_inds,:]\n", + "basis_train = W[train_inds,:]\n", + "basis_test = W[test_inds,:]\n", + "y_train = y[train_inds]\n", + "y_test = y[test_inds]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Run BART" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bart_model = BARTModel()\n", + "bart_model.sample(X_train=X_train, y_train=y_train, basis_train=basis_train, X_test=X_test, basis_test=basis_test, num_gfr=10, num_mcmc=100)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Inspect the MCMC (BART) samples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "forest_preds_y_mcmc = bart_model.y_hat_test\n", + "y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True)\n", + "y_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y_test,1), y_avg_mcmc), axis = 1), columns=[\"True outcome\", \"Average estimated outcome\"])\n", + "sns.scatterplot(data=y_df_mcmc, x=\"Average estimated outcome\", y=\"True outcome\")\n", + "plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples - bart_model.num_gfr),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Sigma\"])\n", + "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compute the test set RMSE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.sqrt(np.mean(np.power(y_test - np.squeeze(y_avg_mcmc),2)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Serialize the BART model to JSON" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bart_json_string = bart_model.to_json()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Deserialize BART model from JSON string" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bart_model_deserialized = BARTModel()\n", + "bart_model_deserialized.from_json(bart_json_string)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compare predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_hat_deserialized = bart_model_deserialized.predict(X_test, basis_test)\n", + "y_avg_mcmc_deserialized = np.squeeze(y_hat_deserialized).mean(axis = 1, keepdims = True)\n", + "y_df = pd.DataFrame(np.concatenate((y_avg_mcmc, y_avg_mcmc_deserialized), axis = 1), columns=[\"Original model\", \"Deserialized model\"])\n", + "sns.scatterplot(data=y_df, x=\"Original model\", y=\"Deserialized model\")\n", + "plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compare parameter samples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sigma2_df = pd.DataFrame(np.c_[bart_model.global_var_samples, bart_model_deserialized.global_var_samples], columns=[\"Original model\", \"Deserialized model\"])\n", + "sns.scatterplot(data=sigma2_df, x=\"Original model\", y=\"Deserialized model\")\n", + "plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.17" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 1240627c..0420e396 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -714,6 +714,10 @@ class JsonCpp { output_file << *json_ << std::endl; } + void LoadFromString(std::string& json_string) { + *json_ = nlohmann::json::parse(json_string); + } + std::string DumpJson() { return json_->dump(); } @@ -973,6 +977,7 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def(py::init<>()) .def("LoadFile", &JsonCpp::LoadFile) .def("SaveFile", &JsonCpp::SaveFile) + .def("LoadFromString", &JsonCpp::LoadFromString) .def("DumpJson", &JsonCpp::DumpJson) .def("AddDouble", &JsonCpp::AddDouble) .def("AddDoubleSubfolder", &JsonCpp::AddDoubleSubfolder) diff --git a/stochtree/bart.py b/stochtree/bart.py index 5b79b9cd..44483e54 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -8,6 +8,7 @@ from .forest import ForestContainer from .preprocessing import CovariateTransformer, _preprocess_bart_params from .sampler import ForestSampler, RNG, GlobalVarianceModel, LeafVarianceModel +from .serialization import JSONSerializer from .utils import NotSampledError class BARTModel: @@ -29,9 +30,9 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N Parameters ---------- - X_train : np.array + X_train : :obj:`np.array` Training set covariates on which trees may be partitioned. - y_train : np.array + y_train : :obj:`np.array` Training set outcome. basis_train : :obj:`np.array`, optional Optional training set basis vector used to define a regression to be run in the leaves of each tree. @@ -438,7 +439,7 @@ def predict(self, covariates: np.array, basis: np.array = None) -> np.array: Parameters ---------- - covariates : np.array + covariates : :obj:`np.array` Test set covariates. basis_train : :obj:`np.array`, optional Optional test set basis vector, must be provided if the model was trained with a leaf regression basis. @@ -495,7 +496,7 @@ def predict_mean(self, covariates: np.array, basis: np.array = None) -> np.array Parameters ---------- - covariates : np.array + covariates : :obj:`np.array` Test set covariates. basis_train : :obj:`np.array`, optional Optional test set basis vector, must be provided if the model was trained with a leaf regression basis. @@ -594,3 +595,107 @@ def predict_variance(self, covariates: np.array, basis: np.array = None) -> np.a variance_pred = np.sqrt(variance_pred_raw*self.sigma2_init)*self.y_std/np.sqrt(self.variance_scale) return variance_pred + + def to_json(self) -> str: + """ + Converts a sampled BART model to JSON string representation (which can then be saved to a file or + processed using the ``json`` library) + + Returns + ------- + :obj:`str` + JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests + """ + if not self.is_sampled: + msg = ( + "This BARTModel instance has not yet been sampled. " + "Call 'fit' with appropriate arguments before using this model." + ) + raise NotSampledError(msg) + + # Initialize JSONSerializer object + bart_json = JSONSerializer() + + # Add the forests + if self.include_mean_forest: + bart_json.add_forest(self.forest_container_mean) + if self.include_variance_forest: + bart_json.add_forest(self.forest_container_variance) + + # Add global parameters + bart_json.add_scalar("variance_scale", self.variance_scale) + bart_json.add_scalar("outcome_scale", self.y_std) + bart_json.add_scalar("outcome_mean", self.y_bar) + bart_json.add_scalar("sigma2_init", self.sigma2_init) + bart_json.add_boolean("sample_sigma_global", self.sample_sigma_global) + bart_json.add_boolean("sample_sigma_leaf", self.sample_sigma_leaf) + bart_json.add_boolean("include_mean_forest", self.include_mean_forest) + bart_json.add_boolean("include_variance_forest", self.include_variance_forest) + bart_json.add_scalar("num_gfr", self.num_gfr) + bart_json.add_scalar("num_burnin", self.num_burnin) + bart_json.add_scalar("num_mcmc", self.num_mcmc) + bart_json.add_scalar("num_samples", self.num_samples) + bart_json.add_scalar("num_basis", self.num_basis) + bart_json.add_boolean("requires_basis", self.has_basis) + bart_json.add_numeric_vector("keep_indices", self.keep_indices) + + # Add parameter samples + if self.sample_sigma_global: + bart_json.add_numeric_vector("sigma2_global_samples", self.global_var_samples, "parameters") + if self.sample_sigma_global: + bart_json.add_numeric_vector("sigma2_leaf_samples", self.leaf_scale_samples, "parameters") + + return bart_json.return_json_string() + + def from_json(self, json_string: str) -> None: + """ + Converts a JSON string to an in-memory BART model. + + Parameters + ---------- + json_string : :obj:`str` + JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests + """ + # Parse string to a JSON object in C++ + bart_json = JSONSerializer() + bart_json.load_from_json_string(json_string) + + # Unpack forests + self.include_mean_forest = bart_json.get_boolean("include_mean_forest") + self.include_variance_forest = bart_json.get_boolean("include_variance_forest") + if self.include_mean_forest: + # TODO: don't just make this a placeholder that we overwrite + self.forest_container_mean = ForestContainer(0, 0, False, False) + self.forest_container_mean.forest_container_cpp.LoadFromJson(bart_json.json_cpp, "forest_0") + if self.include_variance_forest: + # TODO: don't just make this a placeholder that we overwrite + self.forest_container_variance = ForestContainer(0, 0, False, False) + self.forest_container_variance.forest_container_cpp.LoadFromJson(bart_json.json_cpp, "forest_1") + else: + # TODO: don't just make this a placeholder that we overwrite + self.forest_container_variance = ForestContainer(0, 0, False, False) + self.forest_container_variance.forest_container_cpp.LoadFromJson(bart_json.json_cpp, "forest_0") + + # Unpack global parameters + self.variance_scale = bart_json.get_scalar("variance_scale") + self.y_std = bart_json.get_scalar("outcome_scale") + self.y_bar = bart_json.get_scalar("outcome_mean") + self.sigma2_init = bart_json.get_scalar("sigma2_init") + self.sample_sigma_global = bart_json.get_boolean("sample_sigma_global") + self.sample_sigma_leaf = bart_json.get_boolean("sample_sigma_leaf") + self.num_gfr = bart_json.get_scalar("num_gfr") + self.num_burnin = bart_json.get_scalar("num_burnin") + self.num_mcmc = bart_json.get_scalar("num_mcmc") + self.num_samples = bart_json.get_scalar("num_samples") + self.num_basis = bart_json.get_scalar("num_basis") + self.has_basis = bart_json.get_boolean("requires_basis") + self.keep_indices = bart_json.get_numeric_vector("keep_indices").astype(int) + + # Unpack parameter samples + if self.sample_sigma_global: + self.global_var_samples = bart_json.get_numeric_vector("sigma2_global_samples", "parameters") + if self.sample_sigma_global: + self.leaf_scale_samples = bart_json.get_numeric_vector("sigma2_leaf_samples", "parameters") + + # Mark the deserialized model as "sampled" + self.sampled = True diff --git a/stochtree/serialization.py b/stochtree/serialization.py index a9f9e0d1..22ee1176 100644 --- a/stochtree/serialization.py +++ b/stochtree/serialization.py @@ -15,6 +15,28 @@ def __init__(self) -> None: self.num_forests = 0 self.forest_labels = [] + def return_json_string(self) -> str: + """ + Convert JSON object to in-memory string + + Returns + ------- + :obj:`str` + JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests + """ + return self.json_cpp.DumpJson() + + def load_from_json_string(self, json_string: str) -> None: + """ + Parse in-memory JSON string to ``JsonCpp`` object + + Parameters + ------- + json_string : :obj:`str` + JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests + """ + self.json_cpp.LoadFromString(json_string) + def add_forest(self, forest_samples: ForestContainer) -> None: """Adds a container of forest samples to a json object diff --git a/test/python/test_json.py b/test/python/test_json.py index 71cdc563..262e4ffa 100644 --- a/test/python/test_json.py +++ b/test/python/test_json.py @@ -162,3 +162,46 @@ def outcome_mean(X, W): forest_container_reloaded.load_from_json_string(forest_json_string) y_hat_reloaded = forest_container_reloaded.predict(dataset) np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded) + + def test_bart_string(self): + # RNG + random_seed = 1234 + rng = np.random.default_rng(random_seed) + + # Generate covariates and basis + n = 1000 + p_X = 10 + p_W = 1 + X = rng.uniform(0, 1, (n, p_X)) + W = rng.uniform(0, 1, (n, p_W)) + + # Define the outcome mean function + def outcome_mean(X, W): + return np.where( + (X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * W[:,0], + np.where( + (X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * W[:,0], + np.where( + (X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * W[:,0], + 7.5 * W[:,0] + ) + ) + ) + + # Generate outcome + epsilon = rng.normal(0, 1, n) + y = outcome_mean(X, W) + epsilon + + # Run BART + bart_orig = BARTModel() + bart_orig.sample(X_train=X, y_train=y, basis_train=W, num_gfr=10, num_mcmc=10) + + # Extract predictions from the sampler + y_hat_orig = bart_orig.predict(X, W) + + # "Round-trip" the model to JSON string and back and check that the predictions agree + bart_json_string = bart_orig.to_json() + bart_reloaded = BARTModel() + bart_reloaded.from_json(bart_json_string) + y_hat_reloaded = bart_reloaded.predict(X, W) + np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded)