Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 268 additions & 0 deletions demo/notebooks/serialization.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Serialization Demo Notebook"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Demo 1: Supervised Learning"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load necessary libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"from stochtree import BARTModel\n",
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generate sample data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# RNG\n",
"random_seed = 1234\n",
"rng = np.random.default_rng(random_seed)\n",
"\n",
"# Generate covariates and basis\n",
"n = 1000\n",
"p_X = 10\n",
"p_W = 1\n",
"X = rng.uniform(0, 1, (n, p_X))\n",
"W = rng.uniform(0, 1, (n, p_W))\n",
"\n",
"# Define the outcome mean function\n",
"def outcome_mean(X, W):\n",
" return np.where(\n",
" (X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * W[:,0], \n",
" np.where(\n",
" (X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * W[:,0], \n",
" np.where(\n",
" (X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * W[:,0], \n",
" 7.5 * W[:,0]\n",
" )\n",
" )\n",
" )\n",
"\n",
"# Generate outcome\n",
"epsilon = rng.normal(0, 1, n)\n",
"y = outcome_mean(X, W) + epsilon\n",
"\n",
"# Standardize outcome\n",
"y_bar = np.mean(y)\n",
"y_std = np.std(y)\n",
"resid = (y-y_bar)/y_std"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Test-train split"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sample_inds = np.arange(n)\n",
"train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)\n",
"X_train = X[train_inds,:]\n",
"X_test = X[test_inds,:]\n",
"basis_train = W[train_inds,:]\n",
"basis_test = W[test_inds,:]\n",
"y_train = y[train_inds]\n",
"y_test = y[test_inds]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run BART"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bart_model = BARTModel()\n",
"bart_model.sample(X_train=X_train, y_train=y_train, basis_train=basis_train, X_test=X_test, basis_test=basis_test, num_gfr=10, num_mcmc=100)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Inspect the MCMC (BART) samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"forest_preds_y_mcmc = bart_model.y_hat_test\n",
"y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True)\n",
"y_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y_test,1), y_avg_mcmc), axis = 1), columns=[\"True outcome\", \"Average estimated outcome\"])\n",
"sns.scatterplot(data=y_df_mcmc, x=\"Average estimated outcome\", y=\"True outcome\")\n",
"plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples - bart_model.num_gfr),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=[\"Sample\", \"Sigma\"])\n",
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compute the test set RMSE"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.sqrt(np.mean(np.power(y_test - np.squeeze(y_avg_mcmc),2)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Serialize the BART model to JSON"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bart_json_string = bart_model.to_json()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Deserialize BART model from JSON string"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bart_model_deserialized = BARTModel()\n",
"bart_model_deserialized.from_json(bart_json_string)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compare predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_hat_deserialized = bart_model_deserialized.predict(X_test, basis_test)\n",
"y_avg_mcmc_deserialized = np.squeeze(y_hat_deserialized).mean(axis = 1, keepdims = True)\n",
"y_df = pd.DataFrame(np.concatenate((y_avg_mcmc, y_avg_mcmc_deserialized), axis = 1), columns=[\"Original model\", \"Deserialized model\"])\n",
"sns.scatterplot(data=y_df, x=\"Original model\", y=\"Deserialized model\")\n",
"plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compare parameter samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sigma2_df = pd.DataFrame(np.c_[bart_model.global_var_samples, bart_model_deserialized.global_var_samples], columns=[\"Original model\", \"Deserialized model\"])\n",
"sns.scatterplot(data=sigma2_df, x=\"Original model\", y=\"Deserialized model\")\n",
"plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
5 changes: 5 additions & 0 deletions src/py_stochtree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,10 @@ class JsonCpp {
output_file << *json_ << std::endl;
}

void LoadFromString(std::string& json_string) {
*json_ = nlohmann::json::parse(json_string);
}

std::string DumpJson() {
return json_->dump();
}
Expand Down Expand Up @@ -973,6 +977,7 @@ PYBIND11_MODULE(stochtree_cpp, m) {
.def(py::init<>())
.def("LoadFile", &JsonCpp::LoadFile)
.def("SaveFile", &JsonCpp::SaveFile)
.def("LoadFromString", &JsonCpp::LoadFromString)
.def("DumpJson", &JsonCpp::DumpJson)
.def("AddDouble", &JsonCpp::AddDouble)
.def("AddDoubleSubfolder", &JsonCpp::AddDoubleSubfolder)
Expand Down
Loading