Skip to content

Commit

Permalink
Merge pull request #51 from StochasticTree/multivariate_example_fix
Browse files Browse the repository at this point in the history
Updated multivariate treatment python demo to be an observational study
  • Loading branch information
andrewherren committed Jun 20, 2024
2 parents 99b33b6 + 8a9fbbb commit 9aac95c
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions demo/notebooks/multivariate_treatment_causal_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@
"n = 5000\n",
"p_X = 5\n",
"X = rng.uniform(0, 1, (n, p_X))\n",
"pi_X = 0.25 + 0.5*X[:,0]\n",
"pi_X = np.c_[0.25 + 0.5*X[:,0], 0.75 - 0.5*X[:,1]]\n",
"# Z = rng.uniform(0, 1, (n, 2))\n",
"Z = rng.binomial(1, 0.5, (n, 2))\n",
"Z = rng.binomial(1, pi_X, (n, 2))\n",
"\n",
"# Define the outcome mean functions (prognostic and treatment effects)\n",
"mu_X = pi_X*5 + 2*X[:,2]\n",
"mu_X = pi_X[:,0]*5 + pi_X[:,1]*2 + 2*X[:,2]\n",
"tau_X = np.stack((X[:,1], X[:,2]), axis=-1)\n",
"\n",
"# Generate outcome\n",
Expand Down Expand Up @@ -129,6 +129,15 @@
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.sqrt(np.mean(np.power(y_avg_mcmc - y_test, 2)))"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -144,6 +153,21 @@
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"treatment_idx = 1\n",
"forest_preds_tau_mcmc = np.squeeze(bcf_model.tau_hat_test[:,:,treatment_idx])\n",
"tau_avg_mcmc = np.squeeze(forest_preds_tau_mcmc).mean(axis = 1, keepdims = True)\n",
"tau_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(tau_test[:,treatment_idx],1), tau_avg_mcmc), axis = 1), columns=[\"True tau\", \"Average estimated tau\"])\n",
"sns.scatterplot(data=tau_df_mcmc, x=\"True tau\", y=\"Average estimated tau\")\n",
"plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 9aac95c

Please sign in to comment.