From b2e940e24b4ee2b34ca0a8dfa2845d62f3585f74 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 11 Nov 2024 12:27:36 -0600 Subject: [PATCH] Added demo for json file roundtrip --- demo/notebooks/serialization.ipynb | 79 ++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/demo/notebooks/serialization.ipynb b/demo/notebooks/serialization.ipynb index a29018fe..b5672b27 100644 --- a/demo/notebooks/serialization.ipynb +++ b/demo/notebooks/serialization.ipynb @@ -27,6 +27,7 @@ "metadata": {}, "outputs": [], "source": [ + "import json\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", @@ -242,6 +243,84 @@ "plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n", "plt.show()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Save to JSON file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open('bart.json', 'w') as f:\n", + " bart_json_python = json.loads(bart_json_string)\n", + " json.dump(bart_json_python, f)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Reload from JSON file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open('bart.json', 'r') as f:\n", + " bart_json_python_reload = json.load(f)\n", + "bart_json_string_reload = json.dumps(bart_json_python_reload)\n", + "bart_model_file_deserialized = BARTModel()\n", + "bart_model_file_deserialized.from_json(bart_json_string_reload)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compare predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_hat_file_deserialized = bart_model_file_deserialized.predict(X_test, basis_test)\n", + "y_avg_mcmc_file_deserialized = np.squeeze(y_hat_file_deserialized).mean(axis = 1, keepdims = True)\n", + "y_df = pd.DataFrame(np.concatenate((y_avg_mcmc, y_avg_mcmc_file_deserialized), axis = 1), columns=[\"Original model\", \"Deserialized model\"])\n", + "sns.scatterplot(data=y_df, x=\"Original model\", y=\"Deserialized model\")\n", + "plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compare parameter samples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sigma2_df = pd.DataFrame(np.c_[bart_model.global_var_samples, bart_model_file_deserialized.global_var_samples], columns=[\"Original model\", \"Deserialized model\"])\n", + "sns.scatterplot(data=sigma2_df, x=\"Original model\", y=\"Deserialized model\")\n", + "plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n", + "plt.show()" + ] } ], "metadata": {