From 8a9fbbb89acf28cc60588fd2893eb5480e9e7c35 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 19 Jun 2024 20:13:16 -0500 Subject: [PATCH] Updated multivariate treatment python demo to be an observational study --- ...tivariate_treatment_causal_inference.ipynb | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/demo/notebooks/multivariate_treatment_causal_inference.ipynb b/demo/notebooks/multivariate_treatment_causal_inference.ipynb index aae78581..6e175bd5 100644 --- a/demo/notebooks/multivariate_treatment_causal_inference.ipynb +++ b/demo/notebooks/multivariate_treatment_causal_inference.ipynb @@ -48,12 +48,12 @@ "n = 5000\n", "p_X = 5\n", "X = rng.uniform(0, 1, (n, p_X))\n", - "pi_X = 0.25 + 0.5*X[:,0]\n", + "pi_X = np.c_[0.25 + 0.5*X[:,0], 0.75 - 0.5*X[:,1]]\n", "# Z = rng.uniform(0, 1, (n, 2))\n", - "Z = rng.binomial(1, 0.5, (n, 2))\n", + "Z = rng.binomial(1, pi_X, (n, 2))\n", "\n", "# Define the outcome mean functions (prognostic and treatment effects)\n", - "mu_X = pi_X*5 + 2*X[:,2]\n", + "mu_X = pi_X[:,0]*5 + pi_X[:,1]*2 + 2*X[:,2]\n", "tau_X = np.stack((X[:,1], X[:,2]), axis=-1)\n", "\n", "# Generate outcome\n", @@ -129,6 +129,15 @@ "plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.sqrt(np.mean(np.power(y_avg_mcmc - y_test, 2)))" + ] + }, { "cell_type": "code", "execution_count": null, @@ -144,6 +153,21 @@ "plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "treatment_idx = 1\n", + "forest_preds_tau_mcmc = np.squeeze(bcf_model.tau_hat_test[:,:,treatment_idx])\n", + "tau_avg_mcmc = np.squeeze(forest_preds_tau_mcmc).mean(axis = 1, keepdims = True)\n", + "tau_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(tau_test[:,treatment_idx],1), tau_avg_mcmc), axis = 1), columns=[\"True tau\", \"Average estimated tau\"])\n", + "sns.scatterplot(data=tau_df_mcmc, x=\"True tau\", y=\"Average estimated tau\")\n", + "plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n", + "plt.show()" + ] + }, { "cell_type": "code", "execution_count": null,