diff --git a/R/bart.R b/R/bart.R index 6a246a91..47afeede 100644 --- a/R/bart.R +++ b/R/bart.R @@ -962,6 +962,7 @@ convertBARTModelToJson <- function(object){ jsonobj$add_scalar("variance_scale", object$model_params$variance_scale) jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) + jsonobj$add_scalar("sigma2_init", object$model_params$sigma2_init) jsonobj$add_boolean("sample_sigma_global", object$model_params$sample_sigma_global) jsonobj$add_boolean("sample_sigma_leaf", object$model_params$sample_sigma_leaf) jsonobj$add_boolean("include_mean_forest", object$model_params$include_mean_forest) @@ -1141,6 +1142,7 @@ createBARTModelFromJson <- function(json_object){ model_params[["variance_scale"]] <- json_object$get_scalar("variance_scale") model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") + model_params[["sigma2_init"]] <- json_object$get_scalar("sigma2_init") model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global") model_params[["sample_sigma_leaf"]] <- json_object$get_boolean("sample_sigma_leaf") model_params[["include_mean_forest"]] <- include_mean_forest @@ -1336,6 +1338,7 @@ createBARTModelFromCombinedJson <- function(json_object_list){ model_params = list() model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale") model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") + model_params[["sigma2_init"]] <- json_object_default$get_scalar("sigma2_init") model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global") model_params[["sample_sigma_leaf"]] <- json_object$get_boolean("sample_sigma_leaf") model_params[["include_mean_forest"]] <- include_mean_forest @@ -1486,6 +1489,7 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){ model_params[["variance_scale"]] <- json_object_default$get_scalar("variance_scale") model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale") model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") + model_params[["sigma2_init"]] <- json_object_default$get_scalar("sigma2_init") model_params[["sample_sigma_global"]] <- json_object_default$get_boolean("sample_sigma_global") model_params[["sample_sigma_leaf"]] <- json_object_default$get_boolean("sample_sigma_leaf") model_params[["include_mean_forest"]] <- include_mean_forest diff --git a/R/bcf.R b/R/bcf.R index 99866c1b..c84a919d 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -1187,7 +1187,7 @@ predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL, group_ids_test = NU # Compute forest predictions y_std <- bcf$model_params$outcome_scale y_bar <- bcf$model_params$outcome_mean - sigma2_init <- bcf$model_params$initial_sigma2 + initial_sigma2 <- bcf$model_params$initial_sigma2 mu_hat_test <- bcf$forests_mu$predict(prediction_dataset_mu)*y_std + y_bar if (bcf$model_params$adaptive_coding) { tau_hat_test_raw <- bcf$forests_tau$predict_raw(prediction_dataset_tau) @@ -1224,7 +1224,7 @@ predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL, group_ids_test = NU sigma2_samples <- bcf$sigma2_global_samples variance_forest_predictions <- sapply(1:length(keep_indices), function(i) sqrt(s_x_raw[,i]*sigma2_samples[i])) } else { - variance_forest_predictions <- sqrt(s_x_raw*sigma2_init)*y_std + variance_forest_predictions <- sqrt(s_x_raw*initial_sigma2)*y_std } } @@ -1406,6 +1406,9 @@ convertBCFModelToJson <- function(object){ # Add the forests jsonobj$add_forest(object$forests_mu) jsonobj$add_forest(object$forests_tau) + if (object$model_params$include_variance_forest) { + jsonobj$add_forest(object$forests_variance) + } # Add metadata jsonobj$add_scalar("num_numeric_vars", object$train_set_metadata$num_numeric_vars) @@ -1426,9 +1429,11 @@ convertBCFModelToJson <- function(object){ # Add global parameters jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) + jsonobj$add_scalar("initial_sigma2", object$model_params$initial_sigma2) jsonobj$add_boolean("sample_sigma_global", object$model_params$sample_sigma_global) jsonobj$add_boolean("sample_sigma_leaf_mu", object$model_params$sample_sigma_leaf_mu) jsonobj$add_boolean("sample_sigma_leaf_tau", object$model_params$sample_sigma_leaf_tau) + jsonobj$add_boolean("include_variance_forest", object$model_params$include_variance_forest) jsonobj$add_string("propensity_covariate", object$model_params$propensity_covariate) jsonobj$add_boolean("has_rfx", object$model_params$has_rfx) jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis) @@ -1686,6 +1691,10 @@ createBCFModelFromJson <- function(json_object){ # Unpack the forests output[["forests_mu"]] <- loadForestContainerJson(json_object, "forest_0") output[["forests_tau"]] <- loadForestContainerJson(json_object, "forest_1") + include_variance_forest <- json_object$get_boolean("include_variance_forest") + if (include_variance_forest) { + output[["forests_variance"]] <- loadForestContainerJson(json_object, "forest_2") + } # Unpack metadata train_set_metadata = list() @@ -1710,9 +1719,11 @@ createBCFModelFromJson <- function(json_object){ model_params = list() model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") + model_params[["initial_sigma2"]] <- json_object$get_scalar("initial_sigma2") model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global") model_params[["sample_sigma_leaf_mu"]] <- json_object$get_boolean("sample_sigma_leaf_mu") model_params[["sample_sigma_leaf_tau"]] <- json_object$get_boolean("sample_sigma_leaf_tau") + model_params[["include_variance_forest"]] <- include_variance_forest model_params[["propensity_covariate"]] <- json_object$get_string("propensity_covariate") model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") diff --git a/_pkgdown.yml b/_pkgdown.yml index 9ce0f9d7..8845ca48 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -86,6 +86,8 @@ reference: - CppRNG - createRNG - calibrate_inverse_gamma_error_variance + - preprocessBartParams + - preprocessBcfParams - subtitle: Random Effects desc: > @@ -118,6 +120,7 @@ articles: contents: - BayesianSupervisedLearning - CausalInference + - Heteroskedasticity - title: Advanced Model Interface navbar: Advanced Model Interface diff --git a/vignettes/CustomSamplingRoutine.Rmd b/vignettes/CustomSamplingRoutine.Rmd index 89fe20fc..f78659ee 100644 --- a/vignettes/CustomSamplingRoutine.Rmd +++ b/vignettes/CustomSamplingRoutine.Rmd @@ -161,7 +161,7 @@ for (i in 1:num_warmstart) { forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, rng, feature_types, outcome_model_type, leaf_prior_scale, var_weights, - global_var_samples[i], cutpoint_grid_size, gfr = T + 1, 1, global_var_samples[i], cutpoint_grid_size, gfr = T ) # Sample global variance parameter @@ -186,7 +186,7 @@ for (i in (num_warmstart+1):num_samples) { forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, rng, feature_types, outcome_model_type, leaf_prior_scale, var_weights, - global_var_samples[i], cutpoint_grid_size, gfr = F + 1, 1, global_var_samples[i], cutpoint_grid_size, gfr = F ) # Sample global variance parameter @@ -370,7 +370,7 @@ for (i in 1:num_warmstart) { forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, rng, feature_types, outcome_model_type, leaf_prior_scale, var_weights, - global_var_samples[i], cutpoint_grid_size, gfr = T + 1, 1, global_var_samples[i], cutpoint_grid_size, gfr = T ) # Sample global variance parameter @@ -398,7 +398,7 @@ for (i in (num_warmstart+1):num_samples) { forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, rng, feature_types, outcome_model_type, leaf_prior_scale, var_weights, - global_var_samples[i], cutpoint_grid_size, gfr = F + 1, 1, global_var_samples[i], cutpoint_grid_size, gfr = F ) # Sample global variance parameter @@ -599,7 +599,7 @@ for (i in 1:num_warmstart) { forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, rng, feature_types, outcome_model_type, leaf_prior_scale, var_weights, - global_var_samples[i], cutpoint_grid_size, gfr = T + 1, 1, global_var_samples[i], cutpoint_grid_size, gfr = T ) # Sample global variance parameter @@ -627,7 +627,7 @@ for (i in (num_warmstart+1):num_samples) { forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, rng, feature_types, outcome_model_type, leaf_prior_scale, var_weights, - global_var_samples[i], cutpoint_grid_size, gfr = F + 1, 1, global_var_samples[i], cutpoint_grid_size, gfr = F ) # Sample global variance parameter @@ -824,12 +824,12 @@ for (i in 1:num_warmstart) { forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, rng, feature_types, outcome_model_type, leaf_prior_scale, var_weights, - sigma2, cutpoint_grid_size, gfr = T + 1, 1, sigma2, cutpoint_grid_size, gfr = T ) # Sample global variance parameter global_var_samples[i+1] <- sample_sigma2_one_iteration( - outcome, rng, nu, lambda + outcome, forest_dataset, rng, nu, lambda ) } ``` @@ -862,12 +862,12 @@ for (i in (num_warmstart+1):num_samples) { forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, rng, feature_types, outcome_model_type, leaf_prior_scale, var_weights, - global_var_samples[i], cutpoint_grid_size, gfr = F + 1, 1, global_var_samples[i], cutpoint_grid_size, gfr = F ) # Sample global variance parameter global_var_samples[i+1] <- sample_sigma2_one_iteration( - outcome, rng, nu, lambda + outcome, forest_dataset, rng, nu, lambda ) } ``` @@ -1150,7 +1150,7 @@ if (num_gfr > 0){ forest_model_mu$sample_one_iteration( forest_dataset_mu, outcome, forest_samples_mu, rng, feature_types_mu, 0, current_leaf_scale_mu, variable_weights_mu, - current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T + 1, 1, current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T ) # Sample variance parameters (if requested) @@ -1163,7 +1163,7 @@ if (num_gfr > 0){ forest_model_tau$sample_one_iteration( forest_dataset_tau, outcome, forest_samples_tau, rng, feature_types_tau, 1, current_leaf_scale_tau, variable_weights_tau, - current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T + 1, 1, current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T ) # Sample adaptive coding parameters @@ -1198,7 +1198,7 @@ if (num_burnin + num_mcmc > 0) { # Sample the prognostic forest forest_model_mu$sample_one_iteration( forest_dataset_mu, outcome, forest_samples_mu, rng, feature_types_mu, - 0, current_leaf_scale_mu, variable_weights_mu, current_sigma2, + 0, current_leaf_scale_mu, variable_weights_mu, 1, 1, current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = T ) @@ -1209,7 +1209,7 @@ if (num_burnin + num_mcmc > 0) { # Sample the treatment forest forest_model_tau$sample_one_iteration( forest_dataset_tau, outcome, forest_samples_tau, rng, feature_types_tau, - 1, current_leaf_scale_tau, variable_weights_tau, current_sigma2, + 1, current_leaf_scale_tau, variable_weights_tau, 1, 1, current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = T ) diff --git a/vignettes/ModelSerialization.Rmd b/vignettes/ModelSerialization.Rmd index cc1af62b..e22b199b 100644 --- a/vignettes/ModelSerialization.Rmd +++ b/vignettes/ModelSerialization.Rmd @@ -100,13 +100,14 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc +bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F) bcf_model <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F + params = bcf_params ) ``` @@ -189,14 +190,15 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc +bart_params <- list(num_trees_mean = 100, num_trees_variance = 50, + alpha_mean = 0.95, beta_mean = 2, min_samples_leaf_mean = 5, + alpha_variance = 0.95, beta_variance = 1.25, + min_samples_leaf_variance = 1, + sample_sigma_global = F, sample_sigma_leaf = F) bart_model <- stochtree::bart( X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - num_trees_mean = 0, num_trees_variance = 50, - alpha_mean = 0.95, beta_mean = 2, min_samples_leaf_mean = 5, - alpha_variance = 0.95, beta_variance = 1.25, - min_samples_leaf_variance = 1, - sample_sigma_global = F, sample_sigma_leaf = F + params = bart_params ) ```