From eb619875a7ef44500159d9637510607479d645ba Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 29 Aug 2024 22:48:30 -0500 Subject: [PATCH 1/4] Refactoring outcome error variance sampler to default to Jeffreys' prior --- R/bart.R | 27 ++++++++++-------------- R/bcf.R | 19 ++++++----------- R/variance.R | 10 ++++----- debug/api_debug.cpp | 8 +++---- demo/notebooks/prototype_interface.ipynb | 16 +++++++------- include/stochtree/variance_model.h | 16 ++++++-------- src/py_stochtree.cpp | 4 ++-- src/sampler.cpp | 4 ++-- stochtree/sampler.py | 8 +++---- 9 files changed, 50 insertions(+), 62 deletions(-) diff --git a/R/bart.R b/R/bart.R index 08f79549..eddea63c 100644 --- a/R/bart.R +++ b/R/bart.R @@ -36,7 +36,8 @@ #' @param a_leaf Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: 3. #' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. #' @param q Quantile used to calibrated `lambda` as in Sparapani et al (2021). Default: 0.9. -#' @param sigma2_init Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here. +#' @param sigma2_init Starting value of global error variance parameter. Calibrated internally as `pct_var_sigma2_init*var((y-mean(y))/sd(y))` if not set. +#' @param pct_var_sigma2_init Percentage of standardized outcome variance used to initialize global error variance parameter. Default: 0.25. Superseded by `sigma2_init`. #' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. #' @param num_trees Number of trees in the ensemble. Default: 200. #' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5. @@ -81,12 +82,12 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, group_ids_test = NULL, rfx_basis_test = NULL, cutpoint_grid_size = 100, tau_init = NULL, alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = 10, leaf_model = 0, - nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL, - q = 0.9, sigma2_init = NULL, variable_weights = NULL, - num_trees = 200, num_gfr = 5, num_burnin = 0, - num_mcmc = 100, sample_sigma = T, sample_tau = T, - random_seed = -1, keep_burnin = F, keep_gfr = F, - verbose = F){ + a_global = 0, b_global = 0, a_leaf = 3, b_leaf = NULL, + q = 0.9, sigma2_init = NULL, pct_var_sigma2_init = 0.25, + variable_weights = NULL, num_trees = 200, num_gfr = 5, + num_burnin = 0, num_mcmc = 100, sample_sigma = T, + sample_tau = T, random_seed = -1, keep_burnin = F, + keep_gfr = F, verbose = F) { # Variable weight preprocessing (and initialization if necessary) if (is.null(variable_weights)) { variable_weights = rep(1/ncol(X_train), ncol(X_train)) @@ -216,13 +217,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, resid_train <- (y_train-y_bar_train)/y_std_train # Calibrate priors for sigma^2 and tau - reg_basis <- cbind(W_train, X_train) - sigma2hat <- (sigma(lm(resid_train~reg_basis)))^2 - quantile_cutoff <- 0.9 - if (is.null(lambda)) { - lambda <- (sigma2hat*qgamma(1-quantile_cutoff,nu))/nu - } - if (is.null(sigma2_init)) sigma2_init <- sigma2hat + if (is.null(sigma2_init)) sigma2_init <- pct_var_sigma2_init*var(resid_train) if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees) if (is.null(tau_init)) tau_init <- var(resid_train)/(num_trees) current_leaf_scale <- as.matrix(tau_init) @@ -331,7 +326,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = F ) if (sample_sigma) { - global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda) + global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, a_global, b_global) current_sigma2 <- global_var_samples[i] } if (sample_tau) { @@ -373,7 +368,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = F ) if (sample_sigma) { - global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda) + global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, a_global, b_global) current_sigma2 <- global_var_samples[i] } if (sample_tau) { diff --git a/R/bcf.R b/R/bcf.R index 8e2287e2..b1760bad 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -39,7 +39,8 @@ #' @param b_leaf_mu Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model for the prognostic forest. Calibrated internally as 0.5/num_trees if not set here. #' @param b_leaf_tau Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model for the treatment effect forest. Calibrated internally as 0.5/num_trees if not set here. #' @param q Quantile used to calibrated `lambda` as in Sparapani et al (2021). Default: 0.9. -#' @param sigma2 Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here. +#' @param sigma2 Starting value of global error variance parameter. Calibrated internally as `pct_var_sigma2_init*var((y-mean(y))/sd(y))` if not set. +#' @param pct_var_sigma2_init Percentage of standardized outcome variance used to initialize global error variance parameter. Default: 0.25. Superseded by `sigma2`. #' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to `1/ncol(X_train)`. A workaround if you wish to provide a custom weight for the propensity score is to include it as a column in `X_train` and then set `propensity_covariate` to `'none'` adjust `keep_vars_mu` and `keep_vars_tau` accordingly. #' @param keep_vars_mu Vector of variable names or column indices denoting variables that should be included in the prognostic (`mu(X)`) forest. Default: NULL. #' @param drop_vars_mu Vector of variable names or column indices denoting variables that should be excluded from the prognostic (`mu(X)`) forest. Default: NULL. If both `drop_vars_mu` and `keep_vars_mu` are set, `drop_vars_mu` will be ignored. @@ -119,10 +120,10 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU sigma_leaf_mu = NULL, sigma_leaf_tau = NULL, alpha_mu = 0.95, alpha_tau = 0.25, beta_mu = 2.0, beta_tau = 3.0, min_samples_leaf_mu = 5, min_samples_leaf_tau = 5, max_depth_mu = 10, max_depth_tau = 5, nu = 3, lambda = NULL, a_leaf_mu = 3, a_leaf_tau = 3, - b_leaf_mu = NULL, b_leaf_tau = NULL, q = 0.9, sigma2 = NULL, variable_weights = NULL, - keep_vars_mu = NULL, drop_vars_mu = NULL, keep_vars_tau = NULL, drop_vars_tau = NULL, - num_trees_mu = 250, num_trees_tau = 50, num_gfr = 5, num_burnin = 0, num_mcmc = 100, - sample_sigma_global = T, sample_sigma_leaf_mu = T, sample_sigma_leaf_tau = F, + b_leaf_mu = NULL, b_leaf_tau = NULL, q = 0.9, sigma2 = NULL, pct_var_sigma2_init = 0.25, + variable_weights = NULL, keep_vars_mu = NULL, drop_vars_mu = NULL, keep_vars_tau = NULL, + drop_vars_tau = NULL, num_trees_mu = 250, num_trees_tau = 50, num_gfr = 5, num_burnin = 0, + num_mcmc = 100, sample_sigma_global = T, sample_sigma_leaf_mu = T, sample_sigma_leaf_tau = F, propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5, b_1 = 0.5, rfx_prior_var = NULL, random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F) { # Variable weight preprocessing (and initialization if necessary) @@ -413,13 +414,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU resid_train <- (y_train-y_bar_train)/y_std_train # Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau - reg_basis <- X_train - sigma2hat <- mean(resid(lm(resid_train~reg_basis))^2) - quantile_cutoff <- 0.9 - if (is.null(lambda)) { - lambda <- (sigma2hat*qgamma(1-quantile_cutoff,nu))/nu - } - if (is.null(sigma2)) sigma2 <- sigma2hat + if (is.null(sigma2)) sigma2 <- pct_var_sigma2_init*var(resid_train) if (is.null(b_leaf_mu)) b_leaf_mu <- var(resid_train)/(num_trees_mu) if (is.null(b_leaf_tau)) b_leaf_tau <- var(resid_train)/(2*num_trees_tau) if (is.null(sigma_leaf_mu)) sigma_leaf_mu <- var(resid_train)/(num_trees_mu) diff --git a/R/variance.R b/R/variance.R index 8c6b0814..c8933b3d 100644 --- a/R/variance.R +++ b/R/variance.R @@ -1,13 +1,13 @@ -#' Sample one iteration of the global variance model +#' Sample one iteration of the (inverse gamma) global variance model #' #' @param residual Outcome class #' @param rng C++ random number generator -#' @param nu Global variance shape parameter -#' @param lambda Constitutes the scale parameter for the global variance along with nu (i.e. scale is nu*lambda) +#' @param a Global variance shape parameter +#' @param b Global variance scale parameter #' #' @export -sample_sigma2_one_iteration <- function(residual, rng, nu, lambda) { - return(sample_sigma2_one_iteration_cpp(residual$data_ptr, rng$rng_ptr, nu, lambda)) +sample_sigma2_one_iteration <- function(residual, rng, a, b) { + return(sample_sigma2_one_iteration_cpp(residual$data_ptr, rng$rng_ptr, a, b)) } #' Sample one iteration of the leaf parameter variance model (only for univariate basis and constant leaf!) diff --git a/debug/api_debug.cpp b/debug/api_debug.cpp index d827d8cb..1f276fd6 100644 --- a/debug/api_debug.cpp +++ b/debug/api_debug.cpp @@ -418,8 +418,8 @@ void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int double b_rfx = 1.; double a_leaf = 2.; double b_leaf = 0.5; - double nu = 4.; - double lamb = 0.5; + double a_global = 4.; + double b_global = 2.; // Set leaf model parameters double leaf_scale_init = 1.; @@ -473,7 +473,7 @@ void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int leaf_variance_samples.push_back(leaf_var_model.SampleVarianceParameter(forest_samples.GetEnsemble(i), a_leaf, b_leaf, gen)); // Sample global variance - global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), nu, nu * lamb, gen)); + global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), a_global, b_global, gen)); } } @@ -503,7 +503,7 @@ void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int leaf_variance_samples.push_back(leaf_var_model.SampleVarianceParameter(forest_samples.GetEnsemble(i), a_leaf, b_leaf, gen)); // Sample global variance - global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), nu, nu * lamb, gen)); + global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), a_global, b_global, gen)); } } diff --git a/demo/notebooks/prototype_interface.ipynb b/demo/notebooks/prototype_interface.ipynb index 79e0586e..854cd484 100644 --- a/demo/notebooks/prototype_interface.ipynb +++ b/demo/notebooks/prototype_interface.ipynb @@ -157,8 +157,8 @@ "global_variance_init = 1.\n", "tau_init = 0.5\n", "leaf_prior_scale = np.array([[tau_init]], order='C')\n", - "nu = 4.\n", - "lamb = 0.5\n", + "a_global = 4.\n", + "b_global = 2.\n", "a_leaf = 2.\n", "b_leaf = 0.5\n", "leaf_regression = True\n", @@ -243,7 +243,7 @@ "source": [ "for i in range(num_warmstart):\n", " forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, global_var_samples[i], 1, True, False)\n", - " global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, nu, lamb)\n", + " global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global)\n", " leaf_scale_samples[i+1] = leaf_var_model.sample_one_iteration(forest_container, cpp_rng, a_leaf, b_leaf, i)\n", " leaf_prior_scale[0,0] = leaf_scale_samples[i+1]" ] @@ -263,7 +263,7 @@ "source": [ "for i in range(num_warmstart, num_samples):\n", " forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, global_var_samples[i], 1, False, False)\n", - " global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, nu, lamb)\n", + " global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global)\n", " leaf_scale_samples[i+1] = leaf_var_model.sample_one_iteration(forest_container, cpp_rng, a_leaf, b_leaf, i)\n", " leaf_prior_scale[0,0] = leaf_scale_samples[i+1]" ] @@ -442,8 +442,8 @@ "var_weights_tau = np.repeat(1/p_X, p_X)\n", "\n", "# Global parameters\n", - "nu = 2.\n", - "lamb = 0.5\n", + "a_global = 2.\n", + "b_global = 1.\n", "global_variance_init = 1." ] }, @@ -566,7 +566,7 @@ " dataset_tau.update_basis(tau_basis)\n", " \n", " # Sample global variance\n", - " global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, nu, lamb)" + " global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global)" ] }, { @@ -605,7 +605,7 @@ " dataset_tau.update_basis(tau_basis)\n", " \n", " # Sample global variance\n", - " global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, nu, lamb)" + " global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global)" ] }, { diff --git a/include/stochtree/variance_model.h b/include/stochtree/variance_model.h index 53140dd5..003d164e 100644 --- a/include/stochtree/variance_model.h +++ b/include/stochtree/variance_model.h @@ -24,29 +24,27 @@ class GlobalHomoskedasticVarianceModel { public: GlobalHomoskedasticVarianceModel() {ig_sampler_ = InverseGammaSampler();} ~GlobalHomoskedasticVarianceModel() {} - double PosteriorShape(Eigen::VectorXd& residuals, double nu, double lambda) { + double PosteriorShape(Eigen::VectorXd& residuals, double a, double b) { data_size_t n = residuals.rows(); - return (nu/2.0) + n; + return (a/2.0) + n; } - double PosteriorScale(Eigen::VectorXd& residuals, double nu, double lambda) { + double PosteriorScale(Eigen::VectorXd& residuals, double a, double b) { data_size_t n = residuals.rows(); - double nu_lambda = nu*lambda; double sum_sq_resid = 0.; for (data_size_t i = 0; i < n; i++) { sum_sq_resid += std::pow(residuals(i, 0), 2); } - return (nu_lambda/2.0) + sum_sq_resid; + return (b/2.0) + sum_sq_resid; } - double SampleVarianceParameter(Eigen::VectorXd& residuals, double nu, double lambda, std::mt19937& gen) { - double ig_shape = PosteriorShape(residuals, nu, lambda); - double ig_scale = PosteriorScale(residuals, nu, lambda); + double SampleVarianceParameter(Eigen::VectorXd& residuals, double a, double b, std::mt19937& gen) { + double ig_shape = PosteriorShape(residuals, a, b); + double ig_scale = PosteriorScale(residuals, a, b); return ig_sampler_.Sample(ig_shape, ig_scale, gen); } private: InverseGammaSampler ig_sampler_; }; - /*! \brief Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model */ class LeafNodeHomoskedasticVarianceModel { public: diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 3c8ca606..88739990 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -551,10 +551,10 @@ class GlobalVarianceModelCpp { } ~GlobalVarianceModelCpp() {} - double SampleOneIteration(ResidualCpp& residual, RngCpp& rng, double nu, double lamb) { + double SampleOneIteration(ResidualCpp& residual, RngCpp& rng, double a, double b) { StochTree::ColumnVector* residual_ptr = residual.GetData(); std::mt19937* rng_ptr = rng.GetRng(); - return var_model_.SampleVarianceParameter(residual_ptr->GetData(), nu, lamb, *rng_ptr); + return var_model_.SampleVarianceParameter(residual_ptr->GetData(), a, b, *rng_ptr); } private: diff --git a/src/sampler.cpp b/src/sampler.cpp index 0edf6a7a..482fb781 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -141,11 +141,11 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer residual, cpp11::external_pointer rng, - double nu, double lambda + double a, double b ) { // Run one iteration of the sampler StochTree::GlobalHomoskedasticVarianceModel var_model = StochTree::GlobalHomoskedasticVarianceModel(); - return var_model.SampleVarianceParameter(residual->GetData(), nu, lambda, *rng); + return var_model.SampleVarianceParameter(residual->GetData(), a, b, *rng); } [[cpp11::register]] diff --git a/stochtree/sampler.py b/stochtree/sampler.py index 21977619..d7c1a01d 100644 --- a/stochtree/sampler.py +++ b/stochtree/sampler.py @@ -53,11 +53,11 @@ def __init__(self) -> None: # Initialize a GlobalVarianceModelCpp object self.variance_model_cpp = GlobalVarianceModelCpp() - def sample_one_iteration(self, residual: Residual, rng: RNG, nu: float, lamb: float) -> float: + def sample_one_iteration(self, residual: Residual, rng: RNG, a: float, b: float) -> float: """ - Sample one iteration of a forest using the specified model and tree sampling algorithm + Sample one iteration of a global error variance parameter """ - return self.variance_model_cpp.SampleOneIteration(residual.residual_cpp, rng.rng_cpp, nu, lamb) + return self.variance_model_cpp.SampleOneIteration(residual.residual_cpp, rng.rng_cpp, a, b) class LeafVarianceModel: @@ -67,6 +67,6 @@ def __init__(self) -> None: def sample_one_iteration(self, forest_container: ForestContainer, rng: RNG, a: float, b: float, sample_num: int) -> float: """ - Sample one iteration of a forest using the specified model and tree sampling algorithm + Sample one iteration of a forest leaf model's variance parameter (assuming a location-scale leaf model, most commonly ``N(0, tau)``) """ return self.variance_model_cpp.SampleOneIteration(forest_container.forest_container_cpp, rng.rng_cpp, a, b, sample_num) From 1e7b2a2fc55a50a8dacfd12092d8d61ed7dd1576 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 29 Aug 2024 22:51:19 -0500 Subject: [PATCH 2/4] Cleaned up commented code --- R/bcf.R | 6 ------ 1 file changed, 6 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index b1760bad..bfee3d7b 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -501,16 +501,10 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU # Initialize the leaves of each tree in the prognostic forest forest_samples_mu$set_root_leaves(0, mean(resid_train) / num_trees_mu) forest_samples_mu$adjust_residual(forest_dataset_train, outcome_train, forest_model_mu, F, 0, F) - # adjust_residual_forest_container_cpp(forest_dataset_train$data_ptr, outcome_train$data_ptr, - # forest_samples_mu$forest_container_ptr, forest_model_mu$tracker_ptr, - # F, 0, F) # Initialize the leaves of each tree in the treatment effect forest forest_samples_tau$set_root_leaves(0, 0.) forest_samples_tau$adjust_residual(forest_dataset_train, outcome_train, forest_model_tau, T, 0, F) - # adjust_residual_forest_container_cpp(forest_dataset_train$data_ptr, outcome_train$data_ptr, - # forest_samples_tau$forest_container_ptr, forest_model_tau$tracker_ptr, - # T, 0, F) # Run GFR (warm start) if specified if (num_gfr > 0){ From 376eb849cb5802bc165e5a03946795310c9480b7 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 30 Aug 2024 02:10:17 -0500 Subject: [PATCH 3/4] Updated interface and added vignette to demonstrate calibrating lambda --- NAMESPACE | 1 + R/bart.R | 14 +-- R/bcf.R | 20 ++-- R/calibration.R | 34 ++++++ R/cpp11.R | 4 +- _pkgdown.yml | 12 +- include/stochtree/variance_model.h | 8 +- man/bart.Rd | 13 ++- man/bcf.Rd | 13 ++- man/calibrate_inverse_gamma_error_variance.Rd | 44 +++++++ man/sample_sigma2_one_iteration.Rd | 10 +- src/cpp11.cpp | 6 +- vignettes/PriorCalibration.Rmd | 110 ++++++++++++++++++ 13 files changed, 245 insertions(+), 44 deletions(-) create mode 100644 R/calibration.R create mode 100644 man/calibrate_inverse_gamma_error_variance.Rd create mode 100644 vignettes/PriorCalibration.Rmd diff --git a/NAMESPACE b/NAMESPACE index ab87b7b9..aa8ee365 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,6 +6,7 @@ S3method(predict,bartmodel) S3method(predict,bcf) export(bart) export(bcf) +export(calibrate_inverse_gamma_error_variance) export(computeForestKernels) export(computeForestLeafIndices) export(convertBCFModelToJson) diff --git a/R/bart.R b/R/bart.R index eddea63c..1b498115 100644 --- a/R/bart.R +++ b/R/bart.R @@ -31,8 +31,8 @@ #' @param leaf_model Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0. #' @param min_samples_leaf Minimum allowable size of a leaf, in terms of training samples. Default: 5. #' @param max_depth Maximum depth of any tree in the ensemble. Default: 10. Can be overriden with ``-1`` which does not enforce any depth limits on trees. -#' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3. -#' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021). +#' @param a_global Shape parameter in the `IG(a_global, b_global)` global error variance model. Default: 0. +#' @param b_global Scale parameter in the `IG(a_global, b_global)` global error variance model. Default: 0. #' @param a_leaf Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: 3. #' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. #' @param q Quantile used to calibrated `lambda` as in Sparapani et al (2021). Default: 0.9. @@ -43,7 +43,7 @@ #' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5. #' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0. #' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100. -#' @param sample_sigma Whether or not to update the `sigma^2` global error variance parameter based on `IG(nu, nu*lambda)`. Default: T. +#' @param sample_sigma Whether or not to update the `sigma^2` global error variance parameter based on `IG(a_globa, b_global)`. Default: T. #' @param sample_tau Whether or not to update the `tau` leaf scale variance parameter based on `IG(a_leaf, b_leaf)`. Cannot (currently) be set to true if `ncol(W_train)>1`. Default: T. #' @param random_seed Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`. #' @param keep_burnin Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0. @@ -437,11 +437,11 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, # Return results as a list model_params <- list( "sigma2_init" = sigma2_init, - "nu" = nu, - "lambda" = lambda, + "a_global" = a_global, + "b_global" = b_global, "tau_init" = tau_init, - "a" = a_leaf, - "b" = b_leaf, + "a_leaf" = a_leaf, + "b_leaf" = b_leaf, "outcome_mean" = y_bar_train, "outcome_scale" = y_std_train, "output_dimension" = output_dimension, diff --git a/R/bcf.R b/R/bcf.R index bfee3d7b..f99c3df2 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -32,8 +32,8 @@ #' @param min_samples_leaf_tau Minimum allowable size of a leaf, in terms of training samples, for the treatment effect forest. Default: 5. #' @param max_depth_mu Maximum depth of any tree in the mu ensemble. Default: 10. Can be overriden with ``-1`` which does not enforce any depth limits on trees. #' @param max_depth_tau Maximum depth of any tree in the tau ensemble. Default: 5. Can be overriden with ``-1`` which does not enforce any depth limits on trees. -#' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3. -#' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021). +#' @param a_global Shape parameter in the `IG(a_global, b_global)` global error variance model. Default: 0. +#' @param b_global Scale parameter in the `IG(a_global, b_global)` global error variance model. Default: 0. #' @param a_leaf_mu Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model for the prognostic forest. Default: 3. #' @param a_leaf_tau Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model for the treatment effect forest. Default: 3. #' @param b_leaf_mu Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model for the prognostic forest. Calibrated internally as 0.5/num_trees if not set here. @@ -51,7 +51,7 @@ #' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5. #' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0. #' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100. -#' @param sample_sigma_global Whether or not to update the `sigma^2` global error variance parameter based on `IG(nu, nu*lambda)`. Default: T. +#' @param sample_sigma_global Whether or not to update the `sigma^2` global error variance parameter based on `IG(a_global, b_global)`. Default: T. #' @param sample_sigma_leaf_mu Whether or not to update the `sigma_leaf_mu` leaf scale variance parameter in the prognostic forest based on `IG(a_leaf_mu, b_leaf_mu)`. Default: T. #' @param sample_sigma_leaf_tau Whether or not to update the `sigma_leaf_tau` leaf scale variance parameter in the treatment effect forest based on `IG(a_leaf_tau, b_leaf_tau)`. Default: T. #' @param propensity_covariate Whether to include the propensity score as a covariate in either or both of the forests. Enter "none" for neither, "mu" for the prognostic forest, "tau" for the treatment forest, and "both" for both forests. If this is not "none" and a propensity score is not provided, it will be estimated from (`X_train`, `Z_train`) using `stochtree::bart()`. Default: "mu". @@ -119,7 +119,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU group_ids_test = NULL, rfx_basis_test = NULL, cutpoint_grid_size = 100, sigma_leaf_mu = NULL, sigma_leaf_tau = NULL, alpha_mu = 0.95, alpha_tau = 0.25, beta_mu = 2.0, beta_tau = 3.0, min_samples_leaf_mu = 5, min_samples_leaf_tau = 5, - max_depth_mu = 10, max_depth_tau = 5, nu = 3, lambda = NULL, a_leaf_mu = 3, a_leaf_tau = 3, + max_depth_mu = 10, max_depth_tau = 5, a_global = 0, b_global = 0, a_leaf_mu = 3, a_leaf_tau = 3, b_leaf_mu = NULL, b_leaf_tau = NULL, q = 0.9, sigma2 = NULL, pct_var_sigma2_init = 0.25, variable_weights = NULL, keep_vars_mu = NULL, drop_vars_mu = NULL, keep_vars_tau = NULL, drop_vars_tau = NULL, num_trees_mu = 250, num_trees_tau = 50, num_gfr = 5, num_burnin = 0, @@ -526,7 +526,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU # Sample variance parameters (if requested) if (sample_sigma_global) { - global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda) + global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, a_global, b_global) current_sigma2 <- global_var_samples[i] } if (sample_sigma_leaf_mu) { @@ -578,7 +578,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU # Sample variance parameters (if requested) if (sample_sigma_global) { - global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda) + global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, a_global, b_global) current_sigma2 <- global_var_samples[i] } if (sample_sigma_leaf_tau) { @@ -625,7 +625,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU # Sample variance parameters (if requested) if (sample_sigma_global) { - global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda) + global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, a_global, b_global) current_sigma2 <- global_var_samples[i] } if (sample_sigma_leaf_mu) { @@ -677,7 +677,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU # Sample variance parameters (if requested) if (sample_sigma_global) { - global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda) + global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, a_global, b_global) current_sigma2 <- global_var_samples[i] } if (sample_sigma_leaf_tau) { @@ -775,8 +775,8 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU "initial_sigma_leaf_tau" = sigma_leaf_tau, "initial_b_0" = b_0, "initial_b_1" = b_1, - "nu" = nu, - "lambda" = lambda, + "a_global" = a_global, + "b_global" = b_global, "a_leaf_mu" = a_leaf_mu, "b_leaf_mu" = b_leaf_mu, "a_leaf_tau" = a_leaf_tau, diff --git a/R/calibration.R b/R/calibration.R new file mode 100644 index 00000000..998ec419 --- /dev/null +++ b/R/calibration.R @@ -0,0 +1,34 @@ +#' Calibrate the scale parameter on an inverse gamma prior for the global error variance as in Chipman et al (2022) [1] +#' +#' [1] Chipman, H., George, E., Hahn, R., McCulloch, R., Pratola, M. and Sparapani, R. (2022). Bayesian Additive Regression Trees, Computational Approaches. In Wiley StatsRef: Statistics Reference Online (eds N. Balakrishnan, T. Colton, B. Everitt, W. Piegorsch, F. Ruggeri and J.L. Teugels). https://doi.org/10.1002/9781118445112.stat08288 +#' +#' @param y Outcome to be modeled using BART, BCF or another nonparametric ensemble method. +#' @param X Covariates to be used to partition trees in an ensemble or series of ensemble. +#' @param W [Optional] Basis used to define a "leaf regression" model for each decision tree. The "classic" BART model assumes a constant leaf parameter, which is equivalent to a "leaf regression" on a basis of all ones, though it is not necessary to pass a vector of ones, here or to the BART function. Default: `NULL`. +#' @param nu The shape parameter for the global error variance's IG prior. The scale parameter in the Sparapani et al (2021) parameterization is defined as `nu*lambda` where `lambda` is the output of this function. Default: `3`. +#' @param quant [Optional] Quantile of the inverse gamma prior distribution represented by a linear-regression-based overestimate of `sigma^2`. Default: `0.9`. +#' @param standardize [Optional] Whether or not outcome should be standardized (`(y-mean(y))/sd(y)`) before calibration of `lambda`. Default: `TRUE`. +#' +#' @return Value of `lambda` which determines the scale parameter of the global error variance prior (`sigma^2 ~ IG(nu,nu*lambda)`) +#' @export +#' +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' y <- 10*X[,1] - 20*X[,2] + rnorm(n) +#' nu <- 3 +#' lambda <- calibrate_inverse_gamma_error_variance(y, X, nu = nu) +#' sigma2hat <- mean(resid(lm(y~X))^2) +#' mean(var(y)/rgamma(100000, nu, rate = nu*lambda) < sigma2hat) +calibrate_inverse_gamma_error_variance <- function(y, X, W = NULL, nu = 3, quant = 0.9, standardize = TRUE) { + # Compute regression basis + if (!is.null(W)) basis <- cbind(X, W) + else basis <- X + # Standardize outcome if requested + if (standardize) y <- (y-mean(y))/sd(y) + # Compute the "regression-based" overestimate of sigma^2 + sigma2hat <- mean(resid(lm(y~basis))^2) + # Calibrate lambda based on the implied quantile of sigma2hat + return((sigma2hat*qgamma(1-quant,nu))/nu) +} diff --git a/R/cpp11.R b/R/cpp11.R index 16d8b449..c718ca1c 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -332,8 +332,8 @@ sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, tracke invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, global_variance, leaf_model_int, pre_initialized)) } -sample_sigma2_one_iteration_cpp <- function(residual, rng, nu, lambda) { - .Call(`_stochtree_sample_sigma2_one_iteration_cpp`, residual, rng, nu, lambda) +sample_sigma2_one_iteration_cpp <- function(residual, rng, a, b) { + .Call(`_stochtree_sample_sigma2_one_iteration_cpp`, residual, rng, a, b) } sample_tau_one_iteration_cpp <- function(forest_samples, rng, a, b, sample_num) { diff --git a/_pkgdown.yml b/_pkgdown.yml index bffe900a..48e74c21 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -70,6 +70,7 @@ reference: - createForestKernel - CppRNG - createRNG + - calibrate_inverse_gamma_error_variance - subtitle: Random Effects desc: > @@ -97,15 +98,20 @@ reference: - stochtree-package articles: -- title: High-Level Interface - navbar: High-Level Interface +- title: High-Level Model Fitting + navbar: High-Level Model Fitting contents: - BayesianSupervisedLearning - CausalInference + +- title: Advanced Model Interface + navbar: Advanced Model Interface + contents: - ModelSerialization + - PriorCalibration + - EnsembleKernel - title: Prototype Interface navbar: Prototype Interface contents: - CustomSamplingRoutine - - EnsembleKernel diff --git a/include/stochtree/variance_model.h b/include/stochtree/variance_model.h index 003d164e..e74b2964 100644 --- a/include/stochtree/variance_model.h +++ b/include/stochtree/variance_model.h @@ -26,7 +26,7 @@ class GlobalHomoskedasticVarianceModel { ~GlobalHomoskedasticVarianceModel() {} double PosteriorShape(Eigen::VectorXd& residuals, double a, double b) { data_size_t n = residuals.rows(); - return (a/2.0) + n; + return (a/2.0) + (n/2.0); } double PosteriorScale(Eigen::VectorXd& residuals, double a, double b) { data_size_t n = residuals.rows(); @@ -34,7 +34,7 @@ class GlobalHomoskedasticVarianceModel { for (data_size_t i = 0; i < n; i++) { sum_sq_resid += std::pow(residuals(i, 0), 2); } - return (b/2.0) + sum_sq_resid; + return (b/2.0) + (sum_sq_resid/2.0); } double SampleVarianceParameter(Eigen::VectorXd& residuals, double a, double b, std::mt19937& gen) { double ig_shape = PosteriorShape(residuals, a, b); @@ -52,11 +52,11 @@ class LeafNodeHomoskedasticVarianceModel { ~LeafNodeHomoskedasticVarianceModel() {} double PosteriorShape(TreeEnsemble* ensemble, double a, double b) { data_size_t num_leaves = ensemble->NumLeaves(); - return (a/2.0) + num_leaves; + return (a/2.0) + (num_leaves/2.0); } double PosteriorScale(TreeEnsemble* ensemble, double a, double b) { double mu_sq = ensemble->SumLeafSquared(); - return (b/2.0) + mu_sq; + return (b/2.0) + (mu_sq/2.0); } double SampleVarianceParameter(TreeEnsemble* ensemble, double a, double b, std::mt19937& gen) { double ig_shape = PosteriorShape(ensemble, a, b); diff --git a/man/bart.Rd b/man/bart.Rd index f274891b..eb4dfe59 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -21,12 +21,13 @@ bart( min_samples_leaf = 5, max_depth = 10, leaf_model = 0, - nu = 3, - lambda = NULL, + a_global = 0, + b_global = 0, a_leaf = 3, b_leaf = NULL, q = 0.9, sigma2_init = NULL, + pct_var_sigma2_init = 0.25, variable_weights = NULL, num_trees = 200, num_gfr = 5, @@ -88,9 +89,9 @@ that were not in the training set.} \item{leaf_model}{Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.} -\item{nu}{Shape parameter in the \code{IG(nu, nu*lambda)} global error variance model. Default: 3.} +\item{a_global}{Shape parameter in the \code{IG(a_global, b_global)} global error variance model. Default: 0.} -\item{lambda}{Component of the scale parameter in the \code{IG(nu, nu*lambda)} global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).} +\item{b_global}{Scale parameter in the \code{IG(a_global, b_global)} global error variance model. Default: 0.} \item{a_leaf}{Shape parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model. Default: 3.} @@ -98,7 +99,9 @@ that were not in the training set.} \item{q}{Quantile used to calibrated \code{lambda} as in Sparapani et al (2021). Default: 0.9.} -\item{sigma2_init}{Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.} +\item{sigma2_init}{Starting value of global error variance parameter. Calibrated internally as \code{pct_var_sigma2_init*var((y-mean(y))/sd(y))} if not set.} + +\item{pct_var_sigma2_init}{Percentage of standardized outcome variance used to initialize global error variance parameter. Default: 0.25. Superseded by \code{sigma2_init}.} \item{variable_weights}{Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here.} diff --git a/man/bcf.Rd b/man/bcf.Rd index 9dac1939..af2fbf33 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -27,14 +27,15 @@ bcf( min_samples_leaf_tau = 5, max_depth_mu = 10, max_depth_tau = 5, - nu = 3, - lambda = NULL, + a_global = 0, + b_global = 0, a_leaf_mu = 3, a_leaf_tau = 3, b_leaf_mu = NULL, b_leaf_tau = NULL, q = 0.9, sigma2 = NULL, + pct_var_sigma2_init = 0.25, variable_weights = NULL, keep_vars_mu = NULL, drop_vars_mu = NULL, @@ -114,9 +115,9 @@ that were not in the training set.} \item{max_depth_tau}{Maximum depth of any tree in the tau ensemble. Default: 5. Can be overriden with \code{-1} which does not enforce any depth limits on trees.} -\item{nu}{Shape parameter in the \code{IG(nu, nu*lambda)} global error variance model. Default: 3.} +\item{a_global}{Shape parameter in the \code{IG(a_global, b_global)} global error variance model. Default: 0.} -\item{lambda}{Component of the scale parameter in the \code{IG(nu, nu*lambda)} global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).} +\item{b_global}{Scale parameter in the \code{IG(a_global, b_global)} global error variance model. Default: 0.} \item{a_leaf_mu}{Shape parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model for the prognostic forest. Default: 3.} @@ -128,7 +129,9 @@ that were not in the training set.} \item{q}{Quantile used to calibrated \code{lambda} as in Sparapani et al (2021). Default: 0.9.} -\item{sigma2}{Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.} +\item{sigma2}{Starting value of global error variance parameter. Calibrated internally as \code{pct_var_sigma2_init*var((y-mean(y))/sd(y))} if not set.} + +\item{pct_var_sigma2_init}{Percentage of standardized outcome variance used to initialize global error variance parameter. Default: 0.25. Superseded by \code{sigma2}.} \item{variable_weights}{Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to \code{1/ncol(X_train)}. A workaround if you wish to provide a custom weight for the propensity score is to include it as a column in \code{X_train} and then set \code{propensity_covariate} to \code{'none'} adjust \code{keep_vars_mu} and \code{keep_vars_tau} accordingly.} diff --git a/man/calibrate_inverse_gamma_error_variance.Rd b/man/calibrate_inverse_gamma_error_variance.Rd new file mode 100644 index 00000000..9d4b2713 --- /dev/null +++ b/man/calibrate_inverse_gamma_error_variance.Rd @@ -0,0 +1,44 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/calibration.R +\name{calibrate_inverse_gamma_error_variance} +\alias{calibrate_inverse_gamma_error_variance} +\title{Calibrate the scale parameter on an inverse gamma prior for the global error variance as in Chipman et al (2022) \link{1}} +\usage{ +calibrate_inverse_gamma_error_variance( + y, + X, + W = NULL, + nu = 3, + quant = 0.9, + standardize = TRUE +) +} +\arguments{ +\item{y}{Outcome to be modeled using BART, BCF or another nonparametric ensemble method.} + +\item{X}{Covariates to be used to partition trees in an ensemble or series of ensemble.} + +\item{W}{\link{Optional} Basis used to define a "leaf regression" model for each decision tree. The "classic" BART model assumes a constant leaf parameter, which is equivalent to a "leaf regression" on a basis of all ones, though it is not necessary to pass a vector of ones, here or to the BART function. Default: \code{NULL}.} + +\item{nu}{The shape parameter for the global error variance's IG prior. The scale parameter in the Sparapani et al (2021) parameterization is defined as \code{nu*lambda} where \code{lambda} is the output of this function. Default: \code{3}.} + +\item{quant}{\link{Optional} Quantile of the inverse gamma prior distribution represented by a linear-regression-based overestimate of \code{sigma^2}. Default: \code{0.9}.} + +\item{standardize}{\link{Optional} Whether or not outcome should be standardized (\code{(y-mean(y))/sd(y)}) before calibration of \code{lambda}. Default: \code{TRUE}.} +} +\value{ +Value of \code{lambda} which determines the scale parameter of the global error variance prior (\code{sigma^2 ~ IG(nu,nu*lambda)}) +} +\description{ +\link{1} Chipman, H., George, E., Hahn, R., McCulloch, R., Pratola, M. and Sparapani, R. (2022). Bayesian Additive Regression Trees, Computational Approaches. In Wiley StatsRef: Statistics Reference Online (eds N. Balakrishnan, T. Colton, B. Everitt, W. Piegorsch, F. Ruggeri and J.L. Teugels). https://doi.org/10.1002/9781118445112.stat08288 +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +y <- 10*X[,1] - 20*X[,2] + rnorm(n) +nu <- 3 +lambda <- calibrate_inverse_gamma_error_variance(y, X, nu = nu) +sigma2hat <- mean(resid(lm(y~X))^2) +mean(var(y)/rgamma(100000, nu, rate = nu*lambda) < sigma2hat) +} diff --git a/man/sample_sigma2_one_iteration.Rd b/man/sample_sigma2_one_iteration.Rd index eba62ee5..559ccc66 100644 --- a/man/sample_sigma2_one_iteration.Rd +++ b/man/sample_sigma2_one_iteration.Rd @@ -2,19 +2,19 @@ % Please edit documentation in R/variance.R \name{sample_sigma2_one_iteration} \alias{sample_sigma2_one_iteration} -\title{Sample one iteration of the global variance model} +\title{Sample one iteration of the (inverse gamma) global variance model} \usage{ -sample_sigma2_one_iteration(residual, rng, nu, lambda) +sample_sigma2_one_iteration(residual, rng, a, b) } \arguments{ \item{residual}{Outcome class} \item{rng}{C++ random number generator} -\item{nu}{Global variance shape parameter} +\item{a}{Global variance shape parameter} -\item{lambda}{Constitutes the scale parameter for the global variance along with nu (i.e. scale is nu*lambda)} +\item{b}{Global variance scale parameter} } \description{ -Sample one iteration of the global variance model +Sample one iteration of the (inverse gamma) global variance model } diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 53423c30..d20046bb 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -616,10 +616,10 @@ extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residua END_CPP11 } // sampler.cpp -double sample_sigma2_one_iteration_cpp(cpp11::external_pointer residual, cpp11::external_pointer rng, double nu, double lambda); -extern "C" SEXP _stochtree_sample_sigma2_one_iteration_cpp(SEXP residual, SEXP rng, SEXP nu, SEXP lambda) { +double sample_sigma2_one_iteration_cpp(cpp11::external_pointer residual, cpp11::external_pointer rng, double a, double b); +extern "C" SEXP _stochtree_sample_sigma2_one_iteration_cpp(SEXP residual, SEXP rng, SEXP a, SEXP b) { BEGIN_CPP11 - return cpp11::as_sexp(sample_sigma2_one_iteration_cpp(cpp11::as_cpp>>(residual), cpp11::as_cpp>>(rng), cpp11::as_cpp>(nu), cpp11::as_cpp>(lambda))); + return cpp11::as_sexp(sample_sigma2_one_iteration_cpp(cpp11::as_cpp>>(residual), cpp11::as_cpp>>(rng), cpp11::as_cpp>(a), cpp11::as_cpp>(b))); END_CPP11 } // sampler.cpp diff --git a/vignettes/PriorCalibration.Rmd b/vignettes/PriorCalibration.Rmd new file mode 100644 index 00000000..f2ae5d59 --- /dev/null +++ b/vignettes/PriorCalibration.Rmd @@ -0,0 +1,110 @@ +--- +title: "Prior Calibration Approaches for Parametric Components of Stochastic Tree Ensembles" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{Ensemble-Kernel} + %\VignetteEncoding{UTF-8} + %\VignetteEngine{knitr::rmarkdown} +bibliography: vignettes.bib +editor_options: + markdown: + wrap: 72 +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) +``` + +# Background + +The "classic" BART model of @chipman2010bart + +\begin{equation*} +\begin{aligned} +y &= f(X) + \epsilon\\ +f(X) &\sim \text{BART}\left(\alpha, \beta\right)\\ +\epsilon &\sim \mathcal{N}\left(0,\sigma^2\right)\\ +\sigma^2 &\sim \text{IG}\left(a,b\right) +\end{aligned} +\end{equation*} + +is semiparametric, with a nonparametric tree ensemble $f(X)$ and a homoskedastic error variance parameter $\sigma^2$. Note that in @chipman2010bart, $a$ and $b$ are parameterized with $a = \frac{\nu}{2}$ and $b = \frac{\nu\lambda}{2}$. + +# Setting Priors on Variance Parameters in `stochtree` + +By default, `stochtree` employs a Jeffreys' prior for $\sigma^2$ +\begin{equation*} +\begin{aligned} +\sigma^2 &\propto \frac{1}{\sigma^2} +\end{aligned} +\end{equation*} +which corresponds to an improper prior with $a = 0$ and $b = 0$. + +We provide convenience functions for users wishing to set the $\sigma^2$ prior as in @chipman2010bart. +In this case, $\nu$ is set by default to 3 and $\lambda$ is calibrated as follows: + +1. An "overestimate," $\hat{\sigma}^2$, of $\sigma^2$ is obtained via simple linear regression of $y$ on $X$ +2. $\lambda$ is chosen to ensure that $p(\sigma^2 < \hat{\sigma}^2) = q$ for some value $q$, typically set to a default value of 0.9. + +This is done in `stochtree` via the `calibrate_inverse_gamma_error_variance` function. + +```{r} +# Load library +library(stochtree) + +# Generate data +n <- 1000 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) + +# Test/train split +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] + +# Calibrate the scale parameter for the variance term as in Chipman et al (2010) +nu <- 3 +lambda <- calibrate_inverse_gamma_error_variance(y_train, X_train, nu = nu) +``` + +Now we run a BART model with this variance parameterization + +```{r} +bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + a_global = nu, b_global = nu*lambda, num_gfr = 0, + num_burnin = 500, num_mcmc = 100) +``` + +Inspect the out-of-sample predictions of the model + +```{r} +plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual") +abline(0,1,col="red",lty=3,lwd=3) +``` + +Inspect the posterior samples of $\sigma^2$ + +```{r} +plot(bart_model$sigma2_samples, ylab = "sigma^2", xlab = "iteration") +abline(h = noise_sd^2, col = "red", lty = 3, lwd = 3) +``` + + +# References From 9a4b46b8b17e5e31402e547de9e8463979a05630 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 30 Aug 2024 02:39:57 -0500 Subject: [PATCH 4/4] Updated python initialization and prior setting for global error variance --- stochtree/bart.py | 30 ++++----- stochtree/bcf.py | 51 ++++++++-------- stochtree/calibration.py | 104 +++++++++++++++----------------- test/python/test_calibration.py | 26 ++++---- 4 files changed, 104 insertions(+), 107 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index df88fb01..e1553eaa 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -27,9 +27,9 @@ def is_sampled(self) -> bool: def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = None, X_test: np.array = None, basis_test: np.array = None, cutpoint_grid_size = 100, sigma_leaf: float = None, alpha: float = 0.95, beta: float = 2.0, min_samples_leaf: int = 5, max_depth: int = 10, - nu: float = 3, lamb: float = None, a_leaf: float = 3, b_leaf: float = None, q: float = 0.9, sigma2: float = None, - num_trees: int = 200, num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, sample_sigma_global: bool = True, - sample_sigma_leaf: bool = True, random_seed: int = -1, keep_burnin: bool = False, keep_gfr: bool = False) -> None: + a_global: float = 0, b_global: float = 0, a_leaf: float = 3, b_leaf: float = None, q: float = 0.9, sigma2: float = None, + pct_var_sigma2_init: float = 0.25, num_trees: int = 200, num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, + sample_sigma_global: bool = True, sample_sigma_leaf: bool = True, random_seed: int = -1, keep_burnin: bool = False, keep_gfr: bool = False) -> None: """Runs a BART sampler on provided training set. Predictions will be cached for the training set and (if provided) the test set. Does not require a leaf regression basis. @@ -60,10 +60,10 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N Minimum allowable size of a leaf, in terms of training samples. Defaults to ``5``. max_depth : :obj:`int`, optional Maximum depth of any tree in the ensemble. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. - nu : :obj:`float`, optional - Shape parameter in the ``IG(nu, nu*lamb)`` global error variance model. Defaults to ``3``. - lamb : :obj:`float`, optional - Component of the scale parameter in the ``IG(nu, nu*lambda)`` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021). + a_global : :obj:`float`, optional + Shape parameter in the ``IG(a_global, b_global)`` global error variance model. Defaults to ``0``. + b_global : :obj:`float`, optional + Component of the scale parameter in the ``IG(a_global, b_global)`` global error variance prior. Defaults to ``0``. a_leaf : :obj:`float`, optional Shape parameter in the ``IG(a_leaf, b_leaf)`` leaf node parameter variance model. Defaults to ``3``. b_leaf : :obj:`float`, optional @@ -71,7 +71,9 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N q : :obj:`float`, optional Quantile used to calibrated ``lamb`` as in Sparapani et al (2021). Defaults to ``0.9``. sigma2 : :obj:`float`, optional - Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here. + Starting value of global variance parameter. Set internally as a percentage of the standardized outcome variance if not set here. + pct_var_sigma2_init : :obj:`float`, optional + Percentage of standardized outcome variance used to initialize global error variance parameter. Superseded by ``sigma2``. Defaults to ``0.25``. num_trees : :obj:`int`, optional Number of trees in the ensemble. Defaults to ``200``. num_gfr : :obj:`int`, optional @@ -81,7 +83,7 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N num_mcmc : :obj:`int`, optional Number of "retained" iterations of the MCMC sampler. Defaults to ``100``. If this is set to 0, GFR (XBART) samples will be retained. sample_sigma_global : :obj:`bool`, optional - Whether or not to update the ``sigma^2`` global error variance parameter based on ``IG(nu, nu*lambda)``. Defaults to ``True``. + Whether or not to update the ``sigma^2`` global error variance parameter based on ``IG(a_global, b_global)``. Defaults to ``True``. sample_sigma_leaf : :obj:`bool`, optional Whether or not to update the ``tau`` leaf scale variance parameter based on ``IG(a_leaf, b_leaf)``. Cannot (currently) be set to true if ``basis_train`` has more than one column. Defaults to ``True``. random_seed : :obj:`int`, optional @@ -176,10 +178,8 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N resid_train = (y_train-self.y_bar)/self.y_std # Calibrate priors for global sigma^2 and sigma_leaf (don't use regression initializer for warm-start or XBART) - if num_gfr > 0: - sigma2, lamb = calibrate_global_error_variance(X_train_processed, np.squeeze(resid_train), sigma2, nu, lamb, q, False) - else: - sigma2, lamb = calibrate_global_error_variance(X_train_processed, np.squeeze(resid_train), sigma2, nu, lamb, q, True) + if not sigma2: + sigma2 = pct_var_sigma2_init*np.var(resid_train) b_leaf = np.squeeze(np.var(resid_train)) / num_trees if b_leaf is None else b_leaf sigma_leaf = np.squeeze(np.var(resid_train)) / num_trees if sigma_leaf is None else sigma_leaf current_sigma2 = sigma2 @@ -254,7 +254,7 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N # Sample variance parameters (if requested) if self.sample_sigma_global: - current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, nu, lamb) + current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global) self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std if self.sample_sigma_leaf: self.leaf_scale_samples[i] = leaf_var_model.sample_one_iteration(self.forest_container, cpp_rng, a_leaf, b_leaf, i) @@ -275,7 +275,7 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N # Sample variance parameters (if requested) if self.sample_sigma_global: - current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, nu, lamb) + current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global) self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std if self.sample_sigma_leaf: self.leaf_scale_samples[i] = leaf_var_model.sample_one_iteration(self.forest_container, cpp_rng, a_leaf, b_leaf, i) diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 96f3ae84..72ea2d3c 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -34,8 +34,9 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr cutpoint_grid_size = 100, sigma_leaf_mu: float = None, sigma_leaf_tau: float = None, alpha_mu: float = 0.95, alpha_tau: float = 0.25, beta_mu: float = 2.0, beta_tau: float = 3.0, min_samples_leaf_mu: int = 5, min_samples_leaf_tau: int = 5, max_depth_mu: int = 10, max_depth_tau: int = 5, - nu: float = 3, lamb: float = None, a_leaf_mu: float = 3, a_leaf_tau: float = 3, b_leaf_mu: float = None, b_leaf_tau: float = None, - q: float = 0.9, sigma2: float = None, variable_weights: np.array = None, + a_global: float = 0, b_global: float = 0, a_leaf_mu: float = 3, a_leaf_tau: float = 3, + b_leaf_mu: float = None, b_leaf_tau: float = None, q: float = 0.9, sigma2: float = None, + pct_var_sigma2_init: float = 0.25, variable_weights: np.array = None, keep_vars_mu: Union[list, np.array] = None, drop_vars_mu: Union[list, np.array] = None, keep_vars_tau: Union[list, np.array] = None, drop_vars_tau: Union[list, np.array] = None, num_trees_mu: int = 200, num_trees_tau: int = 50, num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, @@ -93,10 +94,10 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr Maximum depth of any tree in the mu ensemble. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. max_depth_tau : :obj:`int`, optional Maximum depth of any tree in the tau ensemble. Defaults to ``5``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. - nu : :obj:`float`, optional - Shape parameter in the ``IG(nu, nu*lamb)`` global error variance model. Defaults to ``3``. - lamb : :obj:`float`, optional - Component of the scale parameter in the ``IG(nu, nu*lambda)`` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021). + a_global : :obj:`float`, optional + Shape parameter in the ``IG(a_global, b_global)`` global error variance model. Defaults to ``0``. + b_global : :obj:`float`, optional + Component of the scale parameter in the ``IG(a_global, b_global)`` global error variance prior. Defaults to ``0``. a_leaf_mu : :obj:`float`, optional Shape parameter in the ``IG(a_leaf, b_leaf)`` leaf node parameter variance model for the prognostic forest. Defaults to ``3``. a_leaf_tau : :obj:`float`, optional @@ -109,6 +110,8 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr Quantile used to calibrated ``lamb`` as in Sparapani et al (2021). Defaults to ``0.9``. sigma2 : :obj:`float`, optional Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here. + pct_var_sigma2_init : :obj:`float`, optional + Percentage of standardized outcome variance used to initialize global error variance parameter. Superseded by ``sigma2``. Defaults to ``0.25``. variable_weights : :obj:`np.array`, optional Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to ``np.repeat(1/X_train.shape[1], X_train.shape[1])`` if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to ``1/X_train.shape[1]``. A workaround if you wish to provide a custom weight for the propensity score is to include it as a column in ``X_train`` and then set ``propensity_covariate`` to ``'none'`` and adjust ``keep_vars_mu`` and ``keep_vars_tau`` accordingly. keep_vars_mu : obj:`list` or :obj:`np.array`, optional @@ -130,7 +133,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr num_mcmc : :obj:`int`, optional Number of "retained" iterations of the MCMC sampler. Defaults to ``100``. If this is set to 0, GFR (XBART) samples will be retained. sample_sigma_global : :obj:`bool`, optional - Whether or not to update the ``sigma^2`` global error variance parameter based on ``IG(nu, nu*lambda)``. Defaults to ``True``. + Whether or not to update the ``sigma^2`` global error variance parameter based on ``IG(a_global, b_global)``. Defaults to ``True``. sample_sigma_leaf_mu : :obj:`bool`, optional Whether or not to update the ``tau`` leaf scale variance parameter based on ``IG(a_leaf, b_leaf)`` for the prognostic forest. Cannot (currently) be set to true if ``basis_train`` has more than one column. Defaults to ``True``. @@ -294,24 +297,24 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr if beta_tau is not None: beta_tau = check_scalar(x=beta_tau, name="beta_tau", target_type=(float,int), min_val=1, max_val=None, include_boundaries="left") - if nu is not None: - nu = check_scalar(x=nu, name="nu", target_type=(float,int), - min_val=0, max_val=None, include_boundaries="neither") - if lamb is not None: - lamb = check_scalar(x=lamb, name="lamb", target_type=(float,int), - min_val=0, max_val=None, include_boundaries="neither") + if a_global is not None: + a_global = check_scalar(x=a_global, name="a_global", target_type=(float,int), + min_val=0, max_val=None, include_boundaries="left") + if b_global is not None: + b_global = check_scalar(x=b_global, name="b_global", target_type=(float,int), + min_val=0, max_val=None, include_boundaries="left") if a_leaf_mu is not None: a_leaf_mu = check_scalar(x=a_leaf_mu, name="a_leaf_mu", target_type=(float,int), - min_val=0, max_val=None, include_boundaries="neither") + min_val=0, max_val=None, include_boundaries="left") if a_leaf_tau is not None: a_leaf_tau = check_scalar(x=a_leaf_tau, name="a_leaf_tau", target_type=(float,int), - min_val=0, max_val=None, include_boundaries="neither") + min_val=0, max_val=None, include_boundaries="left") if b_leaf_mu is not None: b_leaf_mu = check_scalar(x=b_leaf_mu, name="b_leaf_mu", target_type=(float,int), - min_val=0, max_val=None, include_boundaries="neither") + min_val=0, max_val=None, include_boundaries="left") if b_leaf_tau is not None: b_leaf_tau = check_scalar(x=b_leaf_tau, name="b_leaf_tau", target_type=(float,int), - min_val=0, max_val=None, include_boundaries="neither") + min_val=0, max_val=None, include_boundaries="left") if q is not None: q = check_scalar(x=q, name="q", target_type=float, min_val=0, max_val=1, include_boundaries="neither") @@ -512,10 +515,8 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr resid_train = (y_train-self.y_bar)/self.y_std # Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau (don't use regression initializer for warm-start or XBART) - if num_gfr > 0: - sigma2, lamb = calibrate_global_error_variance(X_train_processed, np.squeeze(resid_train), sigma2, nu, lamb, q, False) - else: - sigma2, lamb = calibrate_global_error_variance(X_train_processed, np.squeeze(resid_train), sigma2, nu, lamb, q, True) + if not sigma2: + sigma2 = pct_var_sigma2_init*np.var(resid_train) b_leaf_mu = np.squeeze(np.var(resid_train)) / num_trees_mu if b_leaf_mu is None else b_leaf_mu b_leaf_tau = np.squeeze(np.var(resid_train)) / (2*num_trees_tau) if b_leaf_tau is None else b_leaf_tau sigma_leaf_mu = np.squeeze(np.var(resid_train)) / num_trees_mu if sigma_leaf_mu is None else sigma_leaf_mu @@ -657,7 +658,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr # Sample variance parameters (if requested) if self.sample_sigma_global: - current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, nu, lamb) + current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global) self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std if self.sample_sigma_leaf_mu: self.leaf_scale_mu_samples[i] = leaf_var_model_mu.sample_one_iteration(self.forest_container_mu, cpp_rng, a_leaf_mu, b_leaf_mu, i) @@ -671,7 +672,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr # Sample variance parameters (if requested) if self.sample_sigma_global: - current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, nu, lamb) + current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global) self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std if self.sample_sigma_leaf_tau: self.leaf_scale_tau_samples[i] = leaf_var_model_tau.sample_one_iteration(self.forest_container_tau, cpp_rng, a_leaf_tau, b_leaf_tau, i) @@ -716,7 +717,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr # Sample variance parameters (if requested) if self.sample_sigma_global: - current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, nu, lamb) + current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global) self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std if self.sample_sigma_leaf_mu: self.leaf_scale_mu_samples[i] = leaf_var_model_mu.sample_one_iteration(self.forest_container_mu, cpp_rng, a_leaf_mu, b_leaf_mu, i) @@ -730,7 +731,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr # Sample variance parameters (if requested) if self.sample_sigma_global: - current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, nu, lamb) + current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global) self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std if self.sample_sigma_leaf_tau: self.leaf_scale_tau_samples[i] = leaf_var_model_tau.sample_one_iteration(self.forest_container_tau, cpp_rng, a_leaf_tau, b_leaf_tau, i) diff --git a/stochtree/calibration.py b/stochtree/calibration.py index 20a3a27f..3e4d1fbe 100644 --- a/stochtree/calibration.py +++ b/stochtree/calibration.py @@ -6,9 +6,9 @@ from scipy.stats import gamma -def calibrate_global_error_variance(X: np.array, y: np.array, sigma2: float = None, nu: float = 3, lamb: float = None, q: float = 0.9, lm_calibrate: bool = True) -> tuple: - """Calibrates global error variance model by setting an initial value of sigma^2 (the parameter itself) and setting a value of lambda, part of the scale parameter in the - ``sigma2 ~ IG(nu/2, (nu*lambda)/2)`` prior. +def calibrate_global_error_variance(X: np.array, y: np.array, nu: float = 3, q: float = 0.9, standardize: bool = True) -> float: + """Calibrates scale parameter of global error variance model as in Chipman et al (2010) by setting a value of lambda, + part of the scale parameter in the ``sigma2 ~ IG(nu/2, (nu*lambda)/2)`` prior. Parameters ---------- @@ -16,69 +16,65 @@ def calibrate_global_error_variance(X: np.array, y: np.array, sigma2: float = No Covariates to be used as split candidates for constructing trees. y : :obj:`np.array` Outcome to be used as target for constructing trees. - sigma2 : :obj:`float`, optional - Starting value of global variance parameter. Calibrated internally if not set here. nu : :obj:`float`, optional Shape parameter in the ``IG(nu, nu*lamb)`` global error variance model. Defaults to ``3``. - lamb : :obj:`float`, optional - Component of the scale parameter in the ``IG(nu, nu*lambda)`` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021). q : :obj:`float`, optional Quantile used to calibrated ``lamb`` as in Sparapani et al (2021). Defaults to ``0.9``. - lm_calibrate : :obj:`bool`, optional - Whether or not to calibrate sigma2 based on a linear model of ``y`` given ``X``. If ``True``, uses the linear model calibration technique in Sparapani et al (2021), otherwise uses `np.var(y)`. Defaults to ``True``. + standardize : :obj:`bool`, optional + Whether or not ``y`` should be standardized before calibration. Defaults to ``True``. Returns ------- - (sigma2, lamb) : :obj:`tuple` of :obj:`float` - Tuple containing an initial value of sigma^2 (global error variance) and lambda (part of scale parameter of global error variance model) + lamb : :obj:`float` + Part of scale parameter of global error variance model """ - # Initialize sigma if no initial value is provided - var_y = np.var(y) - if sigma2 is None: - if lm_calibrate: - # Convert X and y to expected dimensions - if X.ndim == 2: - X_processed = X - elif X.ndim == 1: - X_processed = np.expand_dims(X, 1) - else: - raise ValueError("X must be a 1 or 2 dimensional numpy array") - n, p = X_processed.shape + # Convert X and y to expected dimensions + if X.ndim == 2: + X_processed = X + elif X.ndim == 1: + X_processed = np.expand_dims(X, 1) + else: + raise ValueError("X must be a 1 or 2 dimensional numpy array") + n, p = X_processed.shape + + if y.ndim == 2: + y_processed = np.squeeze(y) + elif y.ndim == 1: + y_processed = y + else: + raise ValueError("y must be a 1 or 2 dimensional numpy array") - if y.ndim == 2: - y_processed = np.squeeze(y) - elif y.ndim == 1: - y_processed = y - else: - raise ValueError("y must be a 1 or 2 dimensional numpy array") - - # Fit a linear model of y ~ X - lm_calibrator = linear_model.LinearRegression() - lm_calibrator.fit(X_processed, y_processed) - - # Compute MSE - y_hat_processed = lm_calibrator.predict(X_processed) - mse = mean_squared_error(y_processed, y_hat_processed) - - # Check for overdetermination, revert to variance of y if model is overdetermined - eps = np.finfo("double").eps - if _is_model_overdetermined(lm_calibrator, n, mse, eps): - sigma2 = var_y - warnings.warn("Default calibration method for global error variance failed; covariate dimension exceeds number of samples. " - "Initializing global error variance based on the variance of the standardized outcome.", UserWarning) - else: - sigma2 = mse - if _is_model_rank_deficient(lm_calibrator, p): - warnings.warn("Default calibration method for global error variance detected rank deficiency in covariate matrix. " - "This should not impact the calibrated values, but may indicate the presence of duplicated covariates.", UserWarning) - else: - sigma2 = var_y + # Standardize outcome if necessary + var_y = np.var(y) + sd_y = np.std(y) + mean_y = np.mean(y) + if standardize: + y_processed = (y_processed - mean_y) / sd_y + + # Fit a linear model of y ~ X + lm_calibrator = linear_model.LinearRegression() + lm_calibrator.fit(X_processed, y_processed) + + # Compute MSE + y_hat_processed = lm_calibrator.predict(X_processed) + mse = mean_squared_error(y_processed, y_hat_processed) + + # Check for overdetermination, revert to variance of y if model is overdetermined + eps = np.finfo("double").eps + if _is_model_overdetermined(lm_calibrator, n, mse, eps): + sigma2hat = var_y + warnings.warn("Default calibration method for global error variance failed; covariate dimension exceeds number of samples. " + "Initializing global error variance scale parameter based on the variance of the standardized outcome.", UserWarning) + else: + sigma2hat = mse + if _is_model_rank_deficient(lm_calibrator, p): + warnings.warn("Default calibration method for global error variance detected rank deficiency in covariate matrix. " + "This should not impact the calibrated values, but may indicate the presence of duplicated covariates.", UserWarning) # Calibrate lamb if no initial value is provided - if lamb is None: - lamb = (sigma2*gamma.ppf(1-q,nu))/nu + lamb = (sigma2hat*gamma.ppf(1-q,nu))/nu - return (sigma2, lamb) + return lamb def _is_model_overdetermined(reg_model: linear_model.LinearRegression, n: int, mse: float, eps: float) -> bool: diff --git a/test/python/test_calibration.py b/test/python/test_calibration.py index 99292cce..312b9632 100644 --- a/test/python/test_calibration.py +++ b/test/python/test_calibration.py @@ -15,11 +15,11 @@ def test_full_rank(self): q = 0.9 X = np.random.uniform(size=(n,p)) y = 1 + X[:,0]*0.1 - X[:,1]*0.2 + np.random.normal(size=n) + y_std = (y - np.mean(y)) / np.std(y) reg_model = linear_model.LinearRegression() - reg_model.fit(X, y) - mse = mean_squared_error(y, reg_model.predict(X)) - sigma2, lamb = calibrate_global_error_variance(X, y, None, nu, None, q, True) - assert sigma2 == pytest.approx(mse) + reg_model.fit(X, y_std) + mse = mean_squared_error(y_std, reg_model.predict(X)) + lamb = calibrate_global_error_variance(X = X, y = y, nu = nu, q = q, standardize = True) assert lamb == pytest.approx((mse*gamma.ppf(1-q,nu))/nu) def test_rank_deficient(self): @@ -30,15 +30,15 @@ def test_rank_deficient(self): X = np.random.uniform(size=(n,p)) X[:,4] = X[:,2] y = 1 + X[:,0]*0.1 - X[:,1]*0.2 + np.random.normal(size=n) + y_std = (y - np.mean(y)) / np.std(y) reg_model = linear_model.LinearRegression() - reg_model.fit(X, y) - mse = mean_squared_error(y, reg_model.predict(X)) + reg_model.fit(X, y_std) + mse = mean_squared_error(y_std, reg_model.predict(X)) if reg_model.rank_ < p: with pytest.warns(UserWarning): - sigma2, lamb = calibrate_global_error_variance(X, y, None, nu, None, q, True) + lamb = calibrate_global_error_variance(X = X, y = y, nu = nu, q = q, standardize = True) else: - sigma2, lamb = calibrate_global_error_variance(X, y, None, nu, None, q, True) - assert sigma2 == pytest.approx(mse) + lamb = calibrate_global_error_variance(X = X, y = y, nu = nu, q = q, standardize = True) assert lamb == pytest.approx((mse*gamma.ppf(1-q,nu))/nu) def test_overdetermined(self): @@ -48,10 +48,10 @@ def test_overdetermined(self): q = 0.9 X = np.random.uniform(size=(n,p)) y = 1 + X[:,0]*0.1 - X[:,1]*0.2 + np.random.normal(size=n) + y_std = (y - np.mean(y)) / np.std(y) reg_model = linear_model.LinearRegression() - reg_model.fit(X, y) - mse = mean_squared_error(y, reg_model.predict(X)) + reg_model.fit(X, y_std) + mse = mean_squared_error(y_std, reg_model.predict(X)) with pytest.warns(UserWarning): - sigma2, lamb = calibrate_global_error_variance(X, y, None, nu, None, q, True) - assert sigma2 == pytest.approx(np.var(y)) + lamb = calibrate_global_error_variance(X = X, y = y, nu = nu, q = q, standardize = True) assert lamb == pytest.approx(np.var(y)*(gamma.ppf(1-q,nu))/nu)