diff --git a/demo/notebooks/causal_inference.ipynb b/demo/notebooks/causal_inference.ipynb index 40ac4725..afe0cc5e 100644 --- a/demo/notebooks/causal_inference.ipynb +++ b/demo/notebooks/causal_inference.ipynb @@ -106,13 +106,13 @@ "bcf_model = BCFModel()\n", "general_params = {\"keep_every\": 5}\n", "bcf_model.sample(\n", - " X_train,\n", - " Z_train,\n", - " y_train,\n", - " pi_train,\n", - " X_test,\n", - " Z_test,\n", - " pi_test,\n", + " X_train=X_train,\n", + " Z_train=Z_train,\n", + " y_train=y_train,\n", + " pi_train=pi_train,\n", + " X_test=X_test,\n", + " Z_test=Z_test,\n", + " pi_test=pi_test,\n", " num_gfr=10,\n", " num_mcmc=100,\n", " general_params=general_params,\n", @@ -236,7 +236,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.17" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/demo/notebooks/causal_inference_feature_subsets.ipynb b/demo/notebooks/causal_inference_feature_subsets.ipynb index 2a0283b2..0cebe960 100644 --- a/demo/notebooks/causal_inference_feature_subsets.ipynb +++ b/demo/notebooks/causal_inference_feature_subsets.ipynb @@ -110,13 +110,13 @@ "source": [ "bcf_model = BCFModel()\n", "bcf_model.sample(\n", - " X_train,\n", - " Z_train,\n", - " y_train,\n", - " pi_train,\n", - " X_test,\n", - " Z_test,\n", - " pi_test,\n", + " X_train=X_train,\n", + " Z_train=Z_train,\n", + " y_train=y_train,\n", + " pi_train=pi_train,\n", + " X_test=X_test,\n", + " Z_test=Z_test,\n", + " pi_test=pi_test,\n", " num_gfr=10,\n", " num_mcmc=100,\n", " general_params={\"keep_every\": 5},\n", @@ -239,13 +239,13 @@ "bcf_model_subset = BCFModel()\n", "tau_params = {\"keep_vars\": [0, 1]}\n", "bcf_model_subset.sample(\n", - " X_train,\n", - " Z_train,\n", - " y_train,\n", - " pi_train,\n", - " X_test,\n", - " Z_test,\n", - " pi_test,\n", + " X_train=X_train,\n", + " Z_train=Z_train,\n", + " y_train=y_train,\n", + " pi_train=pi_train,\n", + " X_test=X_test,\n", + " Z_test=Z_test,\n", + " pi_test=pi_test,\n", " num_gfr=10,\n", " num_mcmc=100,\n", " treatment_effect_forest_params=tau_params,\n", @@ -369,7 +369,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.17" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/demo/notebooks/heteroskedastic_supervised_learning.ipynb b/demo/notebooks/heteroskedastic_supervised_learning.ipynb index 82cb9a23..26ec87e2 100644 --- a/demo/notebooks/heteroskedastic_supervised_learning.ipynb +++ b/demo/notebooks/heteroskedastic_supervised_learning.ipynb @@ -250,7 +250,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.17" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/demo/notebooks/multivariate_treatment_causal_inference.ipynb b/demo/notebooks/multivariate_treatment_causal_inference.ipynb index 5c3d10b9..37cd71d4 100644 --- a/demo/notebooks/multivariate_treatment_causal_inference.ipynb +++ b/demo/notebooks/multivariate_treatment_causal_inference.ipynb @@ -107,13 +107,13 @@ "source": [ "bcf_model = BCFModel()\n", "bcf_model.sample(\n", - " X_train,\n", - " Z_train,\n", - " y_train,\n", - " pi_train,\n", - " X_test,\n", - " Z_test,\n", - " pi_test,\n", + " X_train=X_train,\n", + " Z_train=Z_train,\n", + " y_train=y_train,\n", + " pi_train=pi_train,\n", + " X_test=X_test,\n", + " Z_test=Z_test,\n", + " pi_test=pi_test,\n", " num_gfr=10,\n", " num_mcmc=100,\n", ")" @@ -260,7 +260,7 @@ ], "metadata": { "kernelspec": { - "display_name": "stochtree-dev", + "display_name": "Python 3", "language": "python", "name": "python3" }, diff --git a/demo/notebooks/prototype_interface.ipynb b/demo/notebooks/prototype_interface.ipynb index 90a0c564..ca385e1b 100644 --- a/demo/notebooks/prototype_interface.ipynb +++ b/demo/notebooks/prototype_interface.ipynb @@ -70,7 +70,9 @@ " ForestSampler,\n", " GlobalVarianceModel,\n", " LeafVarianceModel,\n", - " Residual,\n", + " Residual, \n", + " ForestModelConfig, \n", + " GlobalModelConfig,\n", ")" ] }, @@ -98,7 +100,6 @@ "X = rng.uniform(0, 1, (n, p_X))\n", "W = rng.uniform(0, 1, (n, p_W))\n", "\n", - "\n", "# Define the outcome mean function\n", "def outcome_mean(X, W):\n", " return np.where(\n", @@ -138,6 +139,7 @@ "alpha = 0.9\n", "beta = 1.25\n", "min_samples_leaf = 1\n", + "max_depth = -1\n", "num_trees = 100\n", "cutpoint_grid_size = 100\n", "global_variance_init = 1.0\n", @@ -149,7 +151,16 @@ "b_leaf = 0.5\n", "leaf_regression = True\n", "feature_types = np.repeat(0, p_X).astype(int) # 0 = numeric\n", - "var_weights = np.repeat(1 / p_X, p_X)" + "var_weights = np.repeat(1 / p_X, p_X)\n", + "if not leaf_regression:\n", + " leaf_model = 0\n", + " leaf_dimension = 1\n", + "elif leaf_regression and p_W == 1:\n", + " leaf_model = 1\n", + " leaf_dimension = 1\n", + "else:\n", + " leaf_model = 2\n", + " leaf_dimension = p_W" ] }, { @@ -189,12 +200,41 @@ "source": [ "forest_container = ForestContainer(num_trees, W.shape[1], False, False)\n", "active_forest = Forest(num_trees, W.shape[1], False, False)\n", + "global_model_config = GlobalModelConfig(global_error_variance=global_variance_init)\n", + "forest_model_config = ForestModelConfig(\n", + " num_trees=num_trees,\n", + " num_features=p_X,\n", + " num_observations=n,\n", + " feature_types=feature_types,\n", + " variable_weights=var_weights,\n", + " leaf_dimension=leaf_dimension,\n", + " alpha=alpha,\n", + " beta=beta,\n", + " min_samples_leaf=min_samples_leaf,\n", + " max_depth=max_depth,\n", + " leaf_model_type=leaf_model,\n", + " leaf_model_scale=leaf_prior_scale,\n", + " cutpoint_grid_size=cutpoint_grid_size,\n", + ")\n", "forest_sampler = ForestSampler(\n", - " dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf\n", + " dataset, global_model_config, forest_model_config\n", ")\n", "cpp_rng = RNG(random_seed)\n", "global_var_model = GlobalVarianceModel()\n", - "leaf_var_model = LeafVarianceModel()" + "leaf_var_model = LeafVarianceModel()\n", + "\n", + "# Initialize the leaves of each tree in the mean forest\n", + "if leaf_regression:\n", + " forest_init_val = np.repeat(0.0, W.shape[1])\n", + "else:\n", + " forest_init_val = np.array([0.0])\n", + "forest_sampler.prepare_for_sampler(\n", + " dataset,\n", + " residual,\n", + " active_forest,\n", + " leaf_model,\n", + " forest_init_val,\n", + ")" ] }, { @@ -239,17 +279,10 @@ " dataset,\n", " residual,\n", " cpp_rng,\n", - " feature_types,\n", - " cutpoint_grid_size,\n", - " leaf_prior_scale,\n", - " var_weights,\n", - " 0.0,\n", - " 0.0,\n", - " global_var_samples[i],\n", - " 1,\n", + " global_model_config, \n", + " forest_model_config,\n", " True,\n", " True,\n", - " False,\n", " )\n", " global_var_samples[i + 1] = global_var_model.sample_one_iteration(\n", " residual, cpp_rng, a_global, b_global\n", @@ -280,17 +313,10 @@ " dataset,\n", " residual,\n", " cpp_rng,\n", - " feature_types,\n", - " cutpoint_grid_size,\n", - " leaf_prior_scale,\n", - " var_weights,\n", - " 0.0,\n", - " 0.0,\n", - " global_var_samples[i],\n", - " 1,\n", + " global_model_config, \n", + " forest_model_config,\n", " True,\n", " False,\n", - " False,\n", " )\n", " global_var_samples[i + 1] = global_var_model.sample_one_iteration(\n", " residual, cpp_rng, a_global, b_global\n", @@ -474,6 +500,7 @@ "alpha_mu = 0.95\n", "beta_mu = 2.0\n", "min_samples_leaf_mu = 1\n", + "max_depth_mu = -1\n", "num_trees_mu = 200\n", "cutpoint_grid_size_mu = 100\n", "tau_init_mu = 1 / num_trees_mu\n", @@ -481,14 +508,17 @@ "a_leaf_mu = 3.0\n", "b_leaf_mu = 1 / num_trees_mu\n", "leaf_regression_mu = False\n", - "feature_types_mu = np.repeat(0, p_X).astype(int) # 0 = numeric\n", + "feature_types_mu = np.repeat(0, p_X + 1).astype(int) # 0 = numeric\n", "var_weights_mu = np.repeat(1 / (p_X + 1), p_X + 1)\n", + "leaf_model_mu = 0\n", + "leaf_dimension_mu = 1\n", "\n", "# Treatment forest parameters\n", "alpha_tau = 0.75\n", "beta_tau = 3.0\n", "min_samples_leaf_tau = 1\n", - "num_trees_tau = 50\n", + "max_depth_tau = -1\n", + "num_trees_tau = 100\n", "cutpoint_grid_size_tau = 100\n", "tau_init_tau = 1 / num_trees_tau\n", "leaf_prior_scale_tau = np.array([[tau_init_tau]], order=\"C\")\n", @@ -497,6 +527,8 @@ "leaf_regression_tau = True\n", "feature_types_tau = np.repeat(0, p_X).astype(int) # 0 = numeric\n", "var_weights_tau = np.repeat(1 / p_X, p_X)\n", + "leaf_model_tau = 1\n", + "leaf_dimension_tau = 1\n", "\n", "# Global parameters\n", "a_global = 2.0\n", @@ -543,17 +575,33 @@ "metadata": {}, "outputs": [], "source": [ + "# Global classes\n", + "global_model_config = GlobalModelConfig(global_error_variance=global_variance_init)\n", + "cpp_rng = RNG(random_seed)\n", + "global_var_model = GlobalVarianceModel()\n", + "\n", "# Prognostic forest sampling classes\n", "forest_container_mu = ForestContainer(num_trees_mu, 1, True, False)\n", "active_forest_mu = Forest(num_trees_mu, 1, True, False)\n", + "forest_model_config_mu = ForestModelConfig(\n", + " num_trees=num_trees_mu,\n", + " num_features=p_X + 1,\n", + " num_observations=n,\n", + " feature_types=feature_types_mu,\n", + " variable_weights=var_weights_mu,\n", + " leaf_dimension=leaf_dimension_mu,\n", + " alpha=alpha_mu,\n", + " beta=beta_mu,\n", + " min_samples_leaf=min_samples_leaf_mu,\n", + " max_depth=max_depth_mu,\n", + " leaf_model_type=leaf_model_mu,\n", + " leaf_model_scale=leaf_prior_scale_mu,\n", + " cutpoint_grid_size=cutpoint_grid_size_mu,\n", + ")\n", "forest_sampler_mu = ForestSampler(\n", " dataset_mu,\n", - " feature_types_mu,\n", - " num_trees_mu,\n", - " n,\n", - " alpha_mu,\n", - " beta_mu,\n", - " min_samples_leaf_mu,\n", + " global_model_config, \n", + " forest_model_config_mu\n", ")\n", "leaf_var_model_mu = LeafVarianceModel()\n", "\n", @@ -564,20 +612,59 @@ "active_forest_tau = Forest(\n", " num_trees_tau, 1 if np.ndim(Z) == 1 else Z.shape[1], False, False\n", ")\n", + "forest_model_config_tau = ForestModelConfig(\n", + " num_trees=num_trees_tau,\n", + " num_features=p_X,\n", + " num_observations=n,\n", + " feature_types=feature_types_tau,\n", + " variable_weights=var_weights_tau,\n", + " leaf_dimension=leaf_dimension_tau,\n", + " alpha=alpha_tau,\n", + " beta=beta_tau,\n", + " min_samples_leaf=min_samples_leaf_tau,\n", + " max_depth=max_depth_tau,\n", + " leaf_model_type=leaf_model_tau,\n", + " leaf_model_scale=leaf_prior_scale_tau,\n", + " cutpoint_grid_size=cutpoint_grid_size_tau,\n", + ")\n", "forest_sampler_tau = ForestSampler(\n", " dataset_tau,\n", - " feature_types_tau,\n", - " num_trees_tau,\n", - " n,\n", - " alpha_tau,\n", - " beta_tau,\n", - " min_samples_leaf_tau,\n", + " global_model_config, \n", + " forest_model_config_tau\n", + ")\n", + "leaf_var_model_tau = LeafVarianceModel()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Initialize the leaves of the prognostic and treatment forests" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "init_mu = np.array([np.squeeze(np.mean(resid))])\n", + "forest_sampler_mu.prepare_for_sampler(\n", + " dataset_mu,\n", + " residual,\n", + " active_forest_mu,\n", + " leaf_model_mu,\n", + " init_mu,\n", ")\n", - "leaf_var_model_tau = LeafVarianceModel()\n", "\n", - "# Global classes\n", - "cpp_rng = RNG(random_seed)\n", - "global_var_model = GlobalVarianceModel()" + "init_tau = np.array([0.0])\n", + "forest_sampler_tau.prepare_for_sampler(\n", + " dataset_tau,\n", + " residual,\n", + " active_forest_tau,\n", + " leaf_model_tau,\n", + " init_tau,\n", + ")" ] }, { @@ -596,22 +683,16 @@ "num_warmstart = 10\n", "num_mcmc = 100\n", "num_samples = num_warmstart + num_mcmc\n", - "global_var_samples = np.concatenate(\n", - " (np.array([global_variance_init]), np.repeat(0, num_samples))\n", - ")\n", - "leaf_scale_samples_mu = np.concatenate(\n", - " (np.array([tau_init_mu]), np.repeat(0, num_samples))\n", - ")\n", - "leaf_scale_samples_tau = np.concatenate(\n", - " (np.array([tau_init_tau]), np.repeat(0, num_samples))\n", - ")\n", + "global_var_samples = np.empty(num_samples)\n", + "leaf_scale_samples_mu = np.empty(num_samples)\n", + "leaf_scale_samples_tau = np.empty(num_samples)\n", "leaf_prior_scale_mu = np.array([[tau_init_mu]])\n", "leaf_prior_scale_tau = np.array([[tau_init_tau]])\n", - "b_0_init = -0.5\n", - "b_1_init = 0.5\n", - "b_0_samples = np.concatenate((np.array([b_0_init]), np.repeat(0, num_samples)))\n", - "b_1_samples = np.concatenate((np.array([b_1_init]), np.repeat(0, num_samples)))\n", - "tau_basis = (1 - Z) * b_0_init + Z * b_1_init\n", + "current_b0 = -0.5\n", + "current_b1 = 0.5\n", + "b_0_samples = np.empty(num_samples)\n", + "b_1_samples = np.empty(num_samples)\n", + "tau_basis = (1 - Z) * current_b0 + Z * current_b1\n", "dataset_tau.update_basis(tau_basis)" ] }, @@ -636,23 +717,24 @@ " dataset_mu,\n", " residual,\n", " cpp_rng,\n", - " feature_types_mu,\n", - " cutpoint_grid_size_mu,\n", - " leaf_prior_scale_mu,\n", - " var_weights_mu,\n", - " 0.0,\n", - " 0.0,\n", - " global_var_samples[i],\n", - " 0,\n", + " global_model_config, \n", + " forest_model_config_mu,\n", " True,\n", " True,\n", - " False,\n", " )\n", - " leaf_scale_samples_mu[i + 1] = leaf_var_model_mu.sample_one_iteration(\n", + " # Sample global variance\n", + " current_sigma2 = global_var_model.sample_one_iteration(\n", + " residual, cpp_rng, a_global, b_global\n", + " )\n", + " global_model_config.update_global_error_variance(current_sigma2)\n", + " # Sample prognostic forest leaf scale\n", + " leaf_prior_scale_mu[0, 0] = leaf_var_model_mu.sample_one_iteration(\n", " active_forest_mu, cpp_rng, a_leaf_mu, b_leaf_mu\n", " )\n", - " leaf_prior_scale_mu[0, 0] = leaf_scale_samples_mu[i + 1]\n", - " mu_x = active_forest_mu.predict_raw(dataset_mu)\n", + " leaf_scale_samples_mu[i] = leaf_prior_scale_mu[0, 0]\n", + " forest_model_config_mu.update_leaf_model_scale(\n", + " leaf_prior_scale_mu\n", + " )\n", "\n", " # Sample the treatment effect forest\n", " forest_sampler_tau.sample_one_iteration(\n", @@ -661,41 +743,49 @@ " dataset_tau,\n", " residual,\n", " cpp_rng,\n", - " feature_types_tau,\n", - " cutpoint_grid_size_tau,\n", - " leaf_prior_scale_tau,\n", - " var_weights_tau,\n", - " 0.0,\n", - " 0.0,\n", - " global_var_samples[i],\n", - " 1,\n", + " global_model_config, \n", + " forest_model_config_tau,\n", " True,\n", " True,\n", - " False,\n", " )\n", + " \n", + " # Sample adaptive coding parameters\n", + " mu_x = active_forest_mu.predict_raw(dataset_mu)\n", " tau_x = np.squeeze(active_forest_tau.predict_raw(dataset_tau))\n", " s_tt0 = np.sum(tau_x * tau_x * (Z == 0))\n", " s_tt1 = np.sum(tau_x * tau_x * (Z == 1))\n", " partial_resid_mu = resid - np.squeeze(mu_x)\n", " s_ty0 = np.sum(tau_x * partial_resid_mu * (Z == 0))\n", " s_ty1 = np.sum(tau_x * partial_resid_mu * (Z == 1))\n", - " b_0_samples[i + 1] = rng.normal(\n", - " loc=(s_ty0 / (s_tt0 + 2 * global_var_samples[i])),\n", - " scale=np.sqrt(global_var_samples[i] / (s_tt0 + 2 * global_var_samples[i])),\n", + " current_b0 = rng.normal(\n", + " loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)),\n", + " scale=np.sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)),\n", " size=1,\n", - " )\n", - " b_1_samples[i + 1] = rng.normal(\n", - " loc=(s_ty1 / (s_tt1 + 2 * global_var_samples[i])),\n", - " scale=np.sqrt(global_var_samples[i] / (s_tt1 + 2 * global_var_samples[i])),\n", + " )[0]\n", + " current_b1 = rng.normal(\n", + " loc=(s_ty1 / (s_tt1 + 2 * current_sigma2)),\n", + " scale=np.sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)),\n", " size=1,\n", - " )\n", - " tau_basis = (1 - Z) * b_0_samples[i + 1] + Z * b_1_samples[i + 1]\n", + " )[0]\n", + " tau_basis = (1 - Z) * current_b0 + Z * current_b1\n", " dataset_tau.update_basis(tau_basis)\n", " forest_sampler_tau.propagate_basis_update(dataset_tau, residual, active_forest_tau)\n", + " b_0_samples[i] = current_b0\n", + " b_1_samples[i] = current_b1\n", "\n", " # Sample global variance\n", - " global_var_samples[i + 1] = global_var_model.sample_one_iteration(\n", + " current_sigma2 = global_var_model.sample_one_iteration(\n", " residual, cpp_rng, a_global, b_global\n", + " )\n", + " global_model_config.update_global_error_variance(current_sigma2)\n", + " global_var_samples[i] = current_sigma2\n", + " # Sample treatment forest leaf scale\n", + " leaf_prior_scale_tau[0, 0] = leaf_var_model_tau.sample_one_iteration(\n", + " active_forest_tau, cpp_rng, a_leaf_tau, b_leaf_tau\n", + " )\n", + " leaf_scale_samples_tau[i] = leaf_prior_scale_tau[0, 0]\n", + " forest_model_config_tau.update_leaf_model_scale(\n", + " leaf_prior_scale_tau\n", " )" ] }, @@ -720,23 +810,24 @@ " dataset_mu,\n", " residual,\n", " cpp_rng,\n", - " feature_types_mu,\n", - " cutpoint_grid_size_mu,\n", - " leaf_prior_scale_mu,\n", - " var_weights_mu,\n", - " 0.0,\n", - " 0.0,\n", - " global_var_samples[i],\n", - " 0,\n", + " global_model_config, \n", + " forest_model_config_mu,\n", " True,\n", " False,\n", - " False,\n", " )\n", - " leaf_scale_samples_mu[i + 1] = leaf_var_model_mu.sample_one_iteration(\n", + " # Sample global variance\n", + " current_sigma2 = global_var_model.sample_one_iteration(\n", + " residual, cpp_rng, a_global, b_global\n", + " )\n", + " global_model_config.update_global_error_variance(current_sigma2)\n", + " # Sample prognostic forest leaf scale\n", + " leaf_prior_scale_mu[0, 0] = leaf_var_model_mu.sample_one_iteration(\n", " active_forest_mu, cpp_rng, a_leaf_mu, b_leaf_mu\n", " )\n", - " leaf_prior_scale_mu[0, 0] = leaf_scale_samples_mu[i + 1]\n", - " mu_x = active_forest_mu.predict_raw(dataset_mu)\n", + " leaf_scale_samples_mu[i] = leaf_prior_scale_mu[0, 0]\n", + " forest_model_config_mu.update_leaf_model_scale(\n", + " leaf_prior_scale_mu\n", + " )\n", "\n", " # Sample the treatment effect forest\n", " forest_sampler_tau.sample_one_iteration(\n", @@ -745,41 +836,49 @@ " dataset_tau,\n", " residual,\n", " cpp_rng,\n", - " feature_types_tau,\n", - " cutpoint_grid_size_tau,\n", - " leaf_prior_scale_tau,\n", - " var_weights_tau,\n", - " 0.0,\n", - " 0.0,\n", - " global_var_samples[i],\n", - " 1,\n", + " global_model_config, \n", + " forest_model_config_tau,\n", " True,\n", " False,\n", - " False,\n", " )\n", + " \n", + " # Sample adaptive coding parameters\n", + " mu_x = active_forest_mu.predict_raw(dataset_mu)\n", " tau_x = np.squeeze(active_forest_tau.predict_raw(dataset_tau))\n", " s_tt0 = np.sum(tau_x * tau_x * (Z == 0))\n", " s_tt1 = np.sum(tau_x * tau_x * (Z == 1))\n", " partial_resid_mu = resid - np.squeeze(mu_x)\n", " s_ty0 = np.sum(tau_x * partial_resid_mu * (Z == 0))\n", " s_ty1 = np.sum(tau_x * partial_resid_mu * (Z == 1))\n", - " b_0_samples[i + 1] = rng.normal(\n", - " loc=(s_ty0 / (s_tt0 + 2 * global_var_samples[i])),\n", - " scale=np.sqrt(global_var_samples[i] / (s_tt0 + 2 * global_var_samples[i])),\n", + " current_b0 = rng.normal(\n", + " loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)),\n", + " scale=np.sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)),\n", " size=1,\n", - " )\n", - " b_1_samples[i + 1] = rng.normal(\n", - " loc=(s_ty1 / (s_tt1 + 2 * global_var_samples[i])),\n", - " scale=np.sqrt(global_var_samples[i] / (s_tt1 + 2 * global_var_samples[i])),\n", + " )[0]\n", + " current_b1 = rng.normal(\n", + " loc=(s_ty1 / (s_tt1 + 2 * current_sigma2)),\n", + " scale=np.sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)),\n", " size=1,\n", - " )\n", - " tau_basis = (1 - Z) * b_0_samples[i + 1] + Z * b_1_samples[i + 1]\n", + " )[0]\n", + " tau_basis = (1 - Z) * current_b0 + Z * current_b1\n", " dataset_tau.update_basis(tau_basis)\n", " forest_sampler_tau.propagate_basis_update(dataset_tau, residual, active_forest_tau)\n", + " b_0_samples[i] = current_b0\n", + " b_1_samples[i] = current_b1\n", "\n", " # Sample global variance\n", - " global_var_samples[i + 1] = global_var_model.sample_one_iteration(\n", + " current_sigma2 = global_var_model.sample_one_iteration(\n", " residual, cpp_rng, a_global, b_global\n", + " )\n", + " global_model_config.update_global_error_variance(current_sigma2)\n", + " global_var_samples[i] = current_sigma2\n", + " # Sample treatment forest leaf scale\n", + " leaf_prior_scale_tau[0, 0] = leaf_var_model_tau.sample_one_iteration(\n", + " active_forest_tau, cpp_rng, a_leaf_tau, b_leaf_tau\n", + " )\n", + " leaf_scale_samples_tau[i] = leaf_prior_scale_tau[0, 0]\n", + " forest_model_config_tau.update_leaf_model_scale(\n", + " leaf_prior_scale_tau\n", " )" ] }, @@ -800,7 +899,7 @@ "forest_preds_mu = forest_container_mu.predict(dataset_mu) * y_std + y_bar\n", "forest_preds_mu_gfr = forest_preds_mu[:, :num_warmstart]\n", "forest_preds_mu_mcmc = forest_preds_mu[:, num_warmstart:num_samples]\n", - "treatment_coding_samples = b_1_samples[1:] - b_0_samples[1:]\n", + "treatment_coding_samples = b_1_samples - b_0_samples\n", "forest_preds_tau = (\n", " forest_container_tau.predict_raw(dataset_tau)\n", " * y_std\n", @@ -815,10 +914,10 @@ "sigma_samples_mcmc = sigma_samples[num_warmstart:num_samples]\n", "\n", "# Adaptive coding parameters\n", - "b_1_samples_gfr = b_1_samples[1 : (num_warmstart + 1)] * y_std\n", - "b_0_samples_gfr = b_0_samples[1 : (num_warmstart + 1)] * y_std\n", - "b_1_samples_mcmc = b_1_samples[(num_warmstart + 1) :] * y_std\n", - "b_0_samples_mcmc = b_0_samples[(num_warmstart + 1) :] * y_std" + "b_1_samples_gfr = b_1_samples[:num_warmstart] * y_std\n", + "b_0_samples_gfr = b_0_samples[:num_warmstart] * y_std\n", + "b_1_samples_mcmc = b_1_samples[num_warmstart:] * y_std\n", + "b_0_samples_mcmc = b_0_samples[num_warmstart:] * y_std" ] }, { @@ -1016,7 +1115,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.17" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/demo/notebooks/serialization.ipynb b/demo/notebooks/serialization.ipynb index 6cc5b019..f9be5709 100644 --- a/demo/notebooks/serialization.ipynb +++ b/demo/notebooks/serialization.ipynb @@ -395,7 +395,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.17" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/demo/notebooks/supervised_learning.ipynb b/demo/notebooks/supervised_learning.ipynb index f51f78e1..e1067247 100644 --- a/demo/notebooks/supervised_learning.ipynb +++ b/demo/notebooks/supervised_learning.ipynb @@ -385,7 +385,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.17" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/demo/notebooks/tree_inspection.ipynb b/demo/notebooks/tree_inspection.ipynb index 38a9f4ac..a7ad8fe2 100644 --- a/demo/notebooks/tree_inspection.ipynb +++ b/demo/notebooks/tree_inspection.ipynb @@ -372,7 +372,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.17" + "version": "3.12.9" } }, "nbformat": 4,