diff --git a/R/bart.R b/R/bart.R index a8e6ebe2..224432ed 100644 --- a/R/bart.R +++ b/R/bart.R @@ -151,7 +151,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train # 1. General parameters cutpoint_grid_size <- general_params_updated$cutpoint_grid_size standardize <- general_params_updated$standardize - sample_sigma_global <- general_params_updated$sample_sigma2_global + sample_sigma2_global <- general_params_updated$sample_sigma2_global sigma2_init <- general_params_updated$sigma2_global_init a_global <- general_params_updated$sigma2_global_shape b_global <- general_params_updated$sigma2_global_scale @@ -169,8 +169,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train beta_mean <- mean_forest_params_updated$beta min_samples_leaf_mean <- mean_forest_params_updated$min_samples_leaf max_depth_mean <- mean_forest_params_updated$max_depth - sample_sigma_leaf <- mean_forest_params_updated$sample_sigma2_leaf - sigma_leaf_init <- mean_forest_params_updated$sigma2_leaf_init + sample_sigma2_leaf <- mean_forest_params_updated$sample_sigma2_leaf + sigma2_leaf_init <- mean_forest_params_updated$sigma2_leaf_init a_leaf <- mean_forest_params_updated$sigma2_leaf_shape b_leaf <- mean_forest_params_updated$sigma2_leaf_scale keep_vars_mean <- mean_forest_params_updated$keep_vars @@ -212,12 +212,12 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (previous_bart_model$model_params$include_variance_forest) { previous_forest_samples_variance <- previous_bart_model$variance_forests } else previous_forest_samples_variance <- NULL - if (previous_bart_model$model_params$sample_sigma_global) { + if (previous_bart_model$model_params$sample_sigma2_global) { previous_global_var_samples <- previous_bart_model$sigma2_global_samples / ( previous_y_scale*previous_y_scale ) } else previous_global_var_samples <- NULL - if (previous_bart_model$model_params$sample_sigma_leaf) { + if (previous_bart_model$model_params$sample_sigma2_leaf) { previous_leaf_var_samples <- previous_bart_model$sigma2_leaf_samples } else previous_leaf_var_samples <- NULL if (previous_bart_model$model_params$has_rfx) { @@ -254,7 +254,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train } # Override tau sampling if there is no mean forest - if (!include_mean_forest) sample_sigma_leaf <- FALSE + if (!include_mean_forest) sample_sigma2_leaf <- FALSE # Variable weight preprocessing (and initialization if necessary) if (is.null(variable_weights)) { @@ -481,9 +481,9 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (include_variance_forest) { stop("We do not support heteroskedasticity with a probit link") } - if (sample_sigma_global) { + if (sample_sigma2_global) { warning("Global error variance will not be sampled with a probit link as it is fixed at 1") - sample_sigma_global <- F + sample_sigma2_global <- F } } @@ -507,26 +507,26 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train b_leaf <- 1/(num_trees_mean) if (has_basis) { if (ncol(leaf_basis_train) > 1) { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(2/(num_trees_mean), ncol(leaf_basis_train)) - if (!is.matrix(sigma_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train))) + if (is.null(sigma2_leaf_init)) sigma2_leaf_init <- diag(2/(num_trees_mean), ncol(leaf_basis_train)) + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, ncol(leaf_basis_train))) } else { - current_leaf_scale <- sigma_leaf_init + current_leaf_scale <- sigma2_leaf_init } } else { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2/(num_trees_mean)) - if (!is.matrix(sigma_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) + if (is.null(sigma2_leaf_init)) sigma2_leaf_init <- as.matrix(2/(num_trees_mean)) + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) } else { - current_leaf_scale <- sigma_leaf_init + current_leaf_scale <- sigma2_leaf_init } } } else { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2/(num_trees_mean)) - if (!is.matrix(sigma_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) + if (is.null(sigma2_leaf_init)) sigma2_leaf_init <- as.matrix(2/(num_trees_mean)) + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) } else { - current_leaf_scale <- sigma_leaf_init + current_leaf_scale <- sigma2_leaf_init } } current_sigma2 <- sigma2_init @@ -552,26 +552,26 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean) if (has_basis) { if (ncol(leaf_basis_train) > 1) { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(2*var(resid_train)/(num_trees_mean), ncol(leaf_basis_train)) - if (!is.matrix(sigma_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train))) + if (is.null(sigma2_leaf_init)) sigma2_leaf_init <- diag(2*var(resid_train)/(num_trees_mean), ncol(leaf_basis_train)) + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, ncol(leaf_basis_train))) } else { - current_leaf_scale <- sigma_leaf_init + current_leaf_scale <- sigma2_leaf_init } } else { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2*var(resid_train)/(num_trees_mean)) - if (!is.matrix(sigma_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) + if (is.null(sigma2_leaf_init)) sigma2_leaf_init <- as.matrix(2*var(resid_train)/(num_trees_mean)) + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) } else { - current_leaf_scale <- sigma_leaf_init + current_leaf_scale <- sigma2_leaf_init } } } else { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2*var(resid_train)/(num_trees_mean)) - if (!is.matrix(sigma_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) + if (is.null(sigma2_leaf_init)) sigma2_leaf_init <- as.matrix(2*var(resid_train)/(num_trees_mean)) + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) } else { - current_leaf_scale <- sigma_leaf_init + current_leaf_scale <- sigma2_leaf_init } } current_sigma2 <- sigma2_init @@ -603,9 +603,9 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train leaf_dimension = ncol(leaf_basis_train) is_leaf_constant = FALSE leaf_regression = TRUE - if (sample_sigma_leaf) { + if (sample_sigma2_leaf) { warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model.") - sample_sigma_leaf <- FALSE + sample_sigma2_leaf <- FALSE } } @@ -690,8 +690,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train # Delete GFR samples from these containers after the fact if desired # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc num_retained_samples <- num_gfr + ifelse(keep_burnin, num_burnin, 0) + num_mcmc * num_chains - if (sample_sigma_global) global_var_samples <- rep(NA, num_retained_samples) - if (sample_sigma_leaf) leaf_scale_samples <- rep(NA, num_retained_samples) + if (sample_sigma2_global) global_var_samples <- rep(NA, num_retained_samples) + if (sample_sigma2_leaf) leaf_scale_samples <- rep(NA, num_retained_samples) sample_counter <- 0 # Initialize the leaves of each tree in the mean forest @@ -750,12 +750,12 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE ) } - if (sample_sigma_global) { + if (sample_sigma2_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 global_model_config$update_global_error_variance(current_sigma2) } - if (sample_sigma_leaf) { + if (sample_sigma2_leaf) { leaf_scale_double <- sampleLeafVarianceOneIteration(active_forest_mean, rng, a_leaf, b_leaf) current_leaf_scale <- as.matrix(leaf_scale_double) if (keep_sample) leaf_scale_samples[sample_counter] <- leaf_scale_double @@ -776,7 +776,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (include_mean_forest) { resetActiveForest(active_forest_mean, forest_samples_mean, forest_ind) resetForestModel(forest_model_mean, active_forest_mean, forest_dataset_train, outcome_train, TRUE) - if (sample_sigma_leaf) { + if (sample_sigma2_leaf) { leaf_scale_double <- leaf_scale_samples[forest_ind + 1] current_leaf_scale <- as.matrix(leaf_scale_double) forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) @@ -790,7 +790,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train resetRandomEffectsModel(rfx_model, rfx_samples, forest_ind, sigma_alpha_init) resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) } - if (sample_sigma_global) { + if (sample_sigma2_global) { current_sigma2 <- global_var_samples[forest_ind + 1] global_model_config$update_global_error_variance(current_sigma2) } @@ -798,7 +798,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (include_mean_forest) { resetActiveForest(active_forest_mean, previous_forest_samples_mean, previous_model_warmstart_sample_num - 1) resetForestModel(forest_model_mean, active_forest_mean, forest_dataset_train, outcome_train, TRUE) - if (sample_sigma_leaf && (!is.null(previous_leaf_var_samples))) { + if (sample_sigma2_leaf && (!is.null(previous_leaf_var_samples))) { leaf_scale_double <- previous_leaf_var_samples[previous_model_warmstart_sample_num] current_leaf_scale <- as.matrix(leaf_scale_double) forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) @@ -819,7 +819,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) } } - if (sample_sigma_global) { + if (sample_sigma2_global) { if (!is.null(previous_global_var_samples)) { current_sigma2 <- previous_global_var_samples[previous_model_warmstart_sample_num] global_model_config$update_global_error_variance(current_sigma2) @@ -830,8 +830,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train resetActiveForest(active_forest_mean) active_forest_mean$set_root_leaves(init_values_mean_forest / num_trees_mean) resetForestModel(forest_model_mean, active_forest_mean, forest_dataset_train, outcome_train, TRUE) - if (sample_sigma_leaf) { - current_leaf_scale <- as.matrix(sigma_leaf_init) + if (sample_sigma2_leaf) { + current_leaf_scale <- as.matrix(sigma2_leaf_init) forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) } } @@ -845,7 +845,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train sigma_xi_init, sigma_xi_shape, sigma_xi_scale) rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) } - if (sample_sigma_global) { + if (sample_sigma2_global) { current_sigma2 <- sigma2_init global_model_config$update_global_error_variance(current_sigma2) } @@ -903,12 +903,12 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE ) } - if (sample_sigma_global) { + if (sample_sigma2_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 global_model_config$update_global_error_variance(current_sigma2) } - if (sample_sigma_leaf) { + if (sample_sigma2_leaf) { leaf_scale_double <- sampleLeafVarianceOneIteration(active_forest_mean, rng, a_leaf, b_leaf) current_leaf_scale <- as.matrix(leaf_scale_double) if (keep_sample) leaf_scale_samples[sample_counter] <- leaf_scale_double @@ -934,10 +934,10 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train rfx_samples$delete_sample(0) } } - if (sample_sigma_global) { + if (sample_sigma2_global) { global_var_samples <- global_var_samples[(num_gfr+1):length(global_var_samples)] } - if (sample_sigma_leaf) { + if (sample_sigma2_leaf) { leaf_scale_samples <- leaf_scale_samples[(num_gfr+1):length(leaf_scale_samples)] } num_retained_samples <- num_retained_samples - num_gfr @@ -951,8 +951,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train # Variance forest predictions if (include_variance_forest) { - sigma_x_hat_train <- forest_samples_variance$predict(forest_dataset_train) - if (has_test) sigma_x_hat_test <- forest_samples_variance$predict(forest_dataset_test) + sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train) + if (has_test) sigma2_x_hat_test <- forest_samples_variance$predict(forest_dataset_test) } # Random effects predictions @@ -966,26 +966,26 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train } # Global error variance - if (sample_sigma_global) sigma2_samples <- global_var_samples*(y_std_train^2) + if (sample_sigma2_global) sigma2_global_samples <- global_var_samples*(y_std_train^2) # Leaf parameter variance - if (sample_sigma_leaf) tau_samples <- leaf_scale_samples + if (sample_sigma2_leaf) tau_samples <- leaf_scale_samples # Rescale variance forest prediction by global sigma2 (sampled or constant) if (include_variance_forest) { - if (sample_sigma_global) { - sigma_x_hat_train <- sapply(1:num_retained_samples, function(i) sqrt(sigma_x_hat_train[,i]*sigma2_samples[i])) - if (has_test) sigma_x_hat_test <- sapply(1:num_retained_samples, function(i) sqrt(sigma_x_hat_test[,i]*sigma2_samples[i])) + if (sample_sigma2_global) { + sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) sigma2_x_hat_train[,i]*sigma2_global_samples[i]) + if (has_test) sigma2_x_hat_test <- sapply(1:num_retained_samples, function(i) sigma2_x_hat_test[,i]*sigma2_global_samples[i]) } else { - sigma_x_hat_train <- sqrt(sigma_x_hat_train*sigma2_init)*y_std_train - if (has_test) sigma_x_hat_test <- sqrt(sigma_x_hat_test*sigma2_init)*y_std_train + sigma2_x_hat_train <- sigma2_x_hat_train*sigma2_init*y_std_train*y_std_train + if (has_test) sigma2_x_hat_test <- sigma2_x_hat_test*sigma2_init*y_std_train*y_std_train } } # Return results as a list model_params <- list( "sigma2_init" = sigma2_init, - "sigma_leaf_init" = sigma_leaf_init, + "sigma2_leaf_init" = sigma2_leaf_init, "a_global" = a_global, "b_global" = b_global, "a_leaf" = a_leaf, @@ -1011,8 +1011,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train "has_rfx" = has_rfx, "has_rfx_basis" = has_basis_rfx, "num_rfx_basis" = num_basis_rfx, - "sample_sigma_global" = sample_sigma_global, - "sample_sigma_leaf" = sample_sigma_leaf, + "sample_sigma2_global" = sample_sigma2_global, + "sample_sigma2_leaf" = sample_sigma2_leaf, "include_mean_forest" = include_mean_forest, "include_variance_forest" = include_variance_forest, "probit_outcome_model" = probit_outcome_model @@ -1028,11 +1028,11 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train } if (include_variance_forest) { result[["variance_forests"]] = forest_samples_variance - result[["sigma_x_hat_train"]] = sigma_x_hat_train - if (has_test) result[["sigma_x_hat_test"]] = sigma_x_hat_test + result[["sigma2_x_hat_train"]] = sigma2_x_hat_train + if (has_test) result[["sigma2_x_hat_test"]] = sigma2_x_hat_test } - if (sample_sigma_global) result[["sigma2_global_samples"]] = sigma2_samples - if (sample_sigma_leaf) result[["sigma2_leaf_samples"]] = tau_samples + if (sample_sigma2_global) result[["sigma2_global_samples"]] = sigma2_global_samples + if (sample_sigma2_leaf) result[["sigma2_leaf_samples"]] = tau_samples if (has_rfx) { result[["rfx_samples"]] = rfx_samples result[["rfx_preds_train"]] = rfx_preds_train @@ -1170,11 +1170,11 @@ predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL # Scale variance forest predictions if (object$model_params$include_variance_forest) { - if (object$model_params$sample_sigma_global) { - sigma2_samples <- object$sigma2_global_samples - variance_forest_predictions <- sapply(1:num_samples, function(i) sqrt(s_x_raw[,i]*sigma2_samples[i])) + if (object$model_params$sample_sigma2_global) { + sigma2_global_samples <- object$sigma2_global_samples + variance_forest_predictions <- sapply(1:num_samples, function(i) s_x_raw[,i]*sigma2_global_samples[i]) } else { - variance_forest_predictions <- sqrt(s_x_raw*sigma2_init)*y_std + variance_forest_predictions <- s_x_raw*sigma2_init*y_std*y_std } } @@ -1341,8 +1341,8 @@ saveBARTModelToJson <- function(object){ jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) jsonobj$add_boolean("standardize", object$model_params$standardize) jsonobj$add_scalar("sigma2_init", object$model_params$sigma2_init) - jsonobj$add_boolean("sample_sigma_global", object$model_params$sample_sigma_global) - jsonobj$add_boolean("sample_sigma_leaf", object$model_params$sample_sigma_leaf) + jsonobj$add_boolean("sample_sigma2_global", object$model_params$sample_sigma2_global) + jsonobj$add_boolean("sample_sigma2_leaf", object$model_params$sample_sigma2_leaf) jsonobj$add_boolean("include_mean_forest", object$model_params$include_mean_forest) jsonobj$add_boolean("include_variance_forest", object$model_params$include_variance_forest) jsonobj$add_boolean("has_rfx", object$model_params$has_rfx) @@ -1358,10 +1358,10 @@ saveBARTModelToJson <- function(object){ jsonobj$add_scalar("keep_every", object$model_params$keep_every) jsonobj$add_boolean("requires_basis", object$model_params$requires_basis) jsonobj$add_boolean("probit_outcome_model", object$model_params$probit_outcome_model) - if (object$model_params$sample_sigma_global) { + if (object$model_params$sample_sigma2_global) { jsonobj$add_vector("sigma2_global_samples", object$sigma2_global_samples, "parameters") } - if (object$model_params$sample_sigma_leaf) { + if (object$model_params$sample_sigma2_leaf) { jsonobj$add_vector("sigma2_leaf_samples", object$sigma2_leaf_samples, "parameters") } @@ -1533,8 +1533,8 @@ createBARTModelFromJson <- function(json_object){ model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") model_params[["standardize"]] <- json_object$get_boolean("standardize") model_params[["sigma2_init"]] <- json_object$get_scalar("sigma2_init") - model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global") - model_params[["sample_sigma_leaf"]] <- json_object$get_boolean("sample_sigma_leaf") + model_params[["sample_sigma2_global"]] <- json_object$get_boolean("sample_sigma2_global") + model_params[["sample_sigma2_leaf"]] <- json_object$get_boolean("sample_sigma2_leaf") model_params[["include_mean_forest"]] <- include_mean_forest model_params[["include_variance_forest"]] <- include_variance_forest model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") @@ -1554,10 +1554,10 @@ createBARTModelFromJson <- function(json_object){ output[["model_params"]] <- model_params # Unpack sampled parameters - if (model_params[["sample_sigma_global"]]) { + if (model_params[["sample_sigma2_global"]]) { output[["sigma2_global_samples"]] <- json_object$get_vector("sigma2_global_samples", "parameters") } - if (model_params[["sample_sigma_leaf"]]) { + if (model_params[["sample_sigma2_leaf"]]) { output[["sigma2_leaf_samples"]] <- json_object$get_vector("sigma2_leaf_samples", "parameters") } @@ -1743,8 +1743,8 @@ createBARTModelFromCombinedJson <- function(json_object_list){ model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") model_params[["standardize"]] <- json_object_default$get_boolean("standardize") model_params[["sigma2_init"]] <- json_object_default$get_scalar("sigma2_init") - model_params[["sample_sigma_global"]] <- json_object_default$get_boolean("sample_sigma_global") - model_params[["sample_sigma_leaf"]] <- json_object_default$get_boolean("sample_sigma_leaf") + model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean("sample_sigma2_global") + model_params[["sample_sigma2_leaf"]] <- json_object_default$get_boolean("sample_sigma2_leaf") model_params[["include_mean_forest"]] <- include_mean_forest model_params[["include_variance_forest"]] <- include_variance_forest model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") @@ -1776,7 +1776,7 @@ createBARTModelFromCombinedJson <- function(json_object_list){ output[["model_params"]] <- model_params # Unpack sampled parameters - if (model_params[["sample_sigma_global"]]) { + if (model_params[["sample_sigma2_global"]]) { for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { @@ -1786,7 +1786,7 @@ createBARTModelFromCombinedJson <- function(json_object_list){ } } } - if (model_params[["sample_sigma_leaf"]]) { + if (model_params[["sample_sigma2_leaf"]]) { for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { @@ -1897,8 +1897,8 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){ model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") model_params[["standardize"]] <- json_object_default$get_boolean("standardize") model_params[["sigma2_init"]] <- json_object_default$get_scalar("sigma2_init") - model_params[["sample_sigma_global"]] <- json_object_default$get_boolean("sample_sigma_global") - model_params[["sample_sigma_leaf"]] <- json_object_default$get_boolean("sample_sigma_leaf") + model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean("sample_sigma2_global") + model_params[["sample_sigma2_leaf"]] <- json_object_default$get_boolean("sample_sigma2_leaf") model_params[["include_mean_forest"]] <- include_mean_forest model_params[["include_variance_forest"]] <- include_variance_forest model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") @@ -1930,7 +1930,7 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){ output[["model_params"]] <- model_params # Unpack sampled parameters - if (model_params[["sample_sigma_global"]]) { + if (model_params[["sample_sigma2_global"]]) { for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { @@ -1940,7 +1940,7 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){ } } } - if (model_params[["sample_sigma_leaf"]]) { + if (model_params[["sample_sigma2_leaf"]]) { for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { diff --git a/R/bcf.R b/R/bcf.R index 8161fe69..b9842c5d 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -208,7 +208,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # 1. General parameters cutpoint_grid_size <- general_params_updated$cutpoint_grid_size standardize <- general_params_updated$standardize - sample_sigma_global <- general_params_updated$sample_sigma2_global + sample_sigma2_global <- general_params_updated$sample_sigma2_global sigma2_init <- general_params_updated$sigma2_global_init a_global <- general_params_updated$sigma2_global_shape b_global <- general_params_updated$sigma2_global_scale @@ -232,8 +232,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id beta_mu <- prognostic_forest_params_updated$beta min_samples_leaf_mu <- prognostic_forest_params_updated$min_samples_leaf max_depth_mu <- prognostic_forest_params_updated$max_depth - sample_sigma_leaf_mu <- prognostic_forest_params_updated$sample_sigma2_leaf - sigma_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_init + sample_sigma2_leaf_mu <- prognostic_forest_params_updated$sample_sigma2_leaf + sigma2_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_init a_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_shape b_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_scale keep_vars_mu <- prognostic_forest_params_updated$keep_vars @@ -245,8 +245,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id beta_tau <- treatment_effect_forest_params_updated$beta min_samples_leaf_tau <- treatment_effect_forest_params_updated$min_samples_leaf max_depth_tau <- treatment_effect_forest_params_updated$max_depth - sample_sigma_leaf_tau <- treatment_effect_forest_params_updated$sample_sigma2_leaf - sigma_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_init + sample_sigma2_leaf_tau <- treatment_effect_forest_params_updated$sample_sigma2_leaf + sigma2_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_init a_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_shape b_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_scale keep_vars_tau <- treatment_effect_forest_params_updated$keep_vars @@ -287,16 +287,16 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (previous_bcf_model$model_params$include_variance_forest) { previous_forest_samples_variance <- previous_bcf_model$forests_variance } else previous_forest_samples_variance <- NULL - if (previous_bcf_model$model_params$sample_sigma_global) { - previous_global_var_samples <- previous_bcf_model$sigma2_samples / ( + if (previous_bcf_model$model_params$sample_sigma2_global) { + previous_global_var_samples <- previous_bcf_model$sigma2_global_samples / ( previous_y_scale*previous_y_scale ) } else previous_global_var_samples <- NULL - if (previous_bcf_model$model_params$sample_sigma_leaf_mu) { - previous_leaf_var_mu_samples <- previous_bcf_model$sigma_leaf_mu_samples + if (previous_bcf_model$model_params$sample_sigma2_leaf_mu) { + previous_leaf_var_mu_samples <- previous_bcf_model$sigma2_leaf_mu_samples } else previous_leaf_var_mu_samples <- NULL - if (previous_bcf_model$model_params$sample_sigma_leaf_tau) { - previous_leaf_var_tau_samples <- previous_bcf_model$sigma_leaf_tau_samples + if (previous_bcf_model$model_params$sample_sigma2_leaf_tau) { + previous_leaf_var_tau_samples <- previous_bcf_model$sigma2_leaf_tau_samples } else previous_leaf_var_tau_samples <- NULL if (previous_bcf_model$model_params$has_rfx) { previous_rfx_samples <- previous_bcf_model$rfx_samples @@ -697,9 +697,9 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (include_variance_forest) { stop("We do not support heteroskedasticity with a probit link") } - if (sample_sigma_global) { + if (sample_sigma2_global) { warning("Global error variance will not be sampled with a probit link as it is fixed at 1") - sample_sigma_global <- F + sample_sigma2_global <- F } } @@ -716,23 +716,23 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Set initial value for the mu forest init_mu <- 0.0 - # Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau + # Calibrate priors for global sigma^2 and sigma2_leaf_mu / sigma2_leaf_tau # Set sigma2_init to 1, ignoring any defaults provided sigma2_init <- 1.0 # Skip variance_forest_init, since variance forests are not supported with probit link if (is.null(b_leaf_mu)) b_leaf_mu <- 1/num_trees_mu if (is.null(b_leaf_tau)) b_leaf_tau <- 1/(2*num_trees_tau) - if (is.null(sigma_leaf_mu)) { - sigma_leaf_mu <- 2/(num_trees_mu) - current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) + if (is.null(sigma2_leaf_mu)) { + sigma2_leaf_mu <- 2/(num_trees_mu) + current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) } else { - if (!is.matrix(sigma_leaf_mu)) { - current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) + if (!is.matrix(sigma2_leaf_mu)) { + current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) } else { - current_leaf_scale_mu <- sigma_leaf_mu + current_leaf_scale_mu <- sigma2_leaf_mu } } - if (is.null(sigma_leaf_tau)) { + if (is.null(sigma2_leaf_tau)) { # Calibrate prior so that P(abs(tau(X)) < delta_max / dnorm(0)) = p # Use p = 0.9 as an internal default rather than adding another # user-facing "parameter" of the binary outcome BCF prior. @@ -740,15 +740,15 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # treatment_effect_forest_params. p <- 0.6827 q_quantile <- qnorm((p+1)/2) - sigma_leaf_tau <- ((delta_max/(q_quantile*dnorm(0)))^2)/num_trees_tau - current_leaf_scale_tau <- as.matrix(diag(sigma_leaf_tau, ncol(Z_train))) + sigma2_leaf_tau <- ((delta_max/(q_quantile*dnorm(0)))^2)/num_trees_tau + current_leaf_scale_tau <- as.matrix(diag(sigma2_leaf_tau, ncol(Z_train))) } else { - if (!is.matrix(sigma_leaf_tau)) { - current_leaf_scale_tau <- as.matrix(diag(sigma_leaf_tau, ncol(Z_train))) + if (!is.matrix(sigma2_leaf_tau)) { + current_leaf_scale_tau <- as.matrix(diag(sigma2_leaf_tau, ncol(Z_train))) } else { - if (ncol(sigma_leaf_tau) != ncol(Z_train)) stop("sigma_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") - if (nrow(sigma_leaf_tau) != ncol(Z_train)) stop("sigma_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") - current_leaf_scale_tau <- sigma_leaf_tau + if (ncol(sigma2_leaf_tau) != ncol(Z_train)) stop("sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") + if (nrow(sigma2_leaf_tau) != ncol(Z_train)) stop("sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") + current_leaf_scale_tau <- sigma2_leaf_tau } } current_sigma2 <- sigma2_init @@ -768,31 +768,31 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Set initial value for the mu forest init_mu <- mean(resid_train) - # Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau + # Calibrate priors for global sigma^2 and sigma2_leaf_mu / sigma2_leaf_tau if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train) if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train) if (is.null(b_leaf_mu)) b_leaf_mu <- var(resid_train)/(num_trees_mu) if (is.null(b_leaf_tau)) b_leaf_tau <- var(resid_train)/(2*num_trees_tau) - if (is.null(sigma_leaf_mu)) { - sigma_leaf_mu <- 2.0*var(resid_train)/(num_trees_mu) - current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) + if (is.null(sigma2_leaf_mu)) { + sigma2_leaf_mu <- 2.0*var(resid_train)/(num_trees_mu) + current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) } else { - if (!is.matrix(sigma_leaf_mu)) { - current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) + if (!is.matrix(sigma2_leaf_mu)) { + current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) } else { - current_leaf_scale_mu <- sigma_leaf_mu + current_leaf_scale_mu <- sigma2_leaf_mu } } - if (is.null(sigma_leaf_tau)) { - sigma_leaf_tau <- var(resid_train)/(num_trees_tau) - current_leaf_scale_tau <- as.matrix(diag(sigma_leaf_tau, ncol(Z_train))) + if (is.null(sigma2_leaf_tau)) { + sigma2_leaf_tau <- var(resid_train)/(num_trees_tau) + current_leaf_scale_tau <- as.matrix(diag(sigma2_leaf_tau, ncol(Z_train))) } else { - if (!is.matrix(sigma_leaf_tau)) { - current_leaf_scale_tau <- as.matrix(diag(sigma_leaf_tau, ncol(Z_train))) + if (!is.matrix(sigma2_leaf_tau)) { + current_leaf_scale_tau <- as.matrix(diag(sigma2_leaf_tau, ncol(Z_train))) } else { - if (ncol(sigma_leaf_tau) != ncol(Z_train)) stop("sigma_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") - if (nrow(sigma_leaf_tau) != ncol(Z_train)) stop("sigma_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") - current_leaf_scale_tau <- sigma_leaf_tau + if (ncol(sigma2_leaf_tau) != ncol(Z_train)) stop("sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") + if (nrow(sigma2_leaf_tau) != ncol(Z_train)) stop("sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") + current_leaf_scale_tau <- sigma2_leaf_tau } } current_sigma2 <- sigma2_init @@ -800,9 +800,9 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Switch off leaf scale sampling for multivariate treatments if (ncol(Z_train) > 1) { - if (sample_sigma_leaf_tau) { + if (sample_sigma2_leaf_tau) { warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model.") - sample_sigma_leaf_tau <- FALSE + sample_sigma2_leaf_tau <- FALSE } } @@ -863,9 +863,9 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Delete GFR samples from these containers after the fact if desired # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc num_retained_samples <- num_gfr + ifelse(keep_burnin, num_burnin, 0) + num_mcmc * num_chains - if (sample_sigma_global) global_var_samples <- rep(NA, num_retained_samples) - if (sample_sigma_leaf_mu) leaf_scale_mu_samples <- rep(NA, num_retained_samples) - if (sample_sigma_leaf_tau) leaf_scale_tau_samples <- rep(NA, num_retained_samples) + if (sample_sigma2_global) global_var_samples <- rep(NA, num_retained_samples) + if (sample_sigma2_leaf_mu) leaf_scale_mu_samples <- rep(NA, num_retained_samples) + if (sample_sigma2_leaf_tau) leaf_scale_tau_samples <- rep(NA, num_retained_samples) sample_counter <- 0 # Prepare adaptive coding structure @@ -978,11 +978,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id ) # Sample variance parameters (if requested) - if (sample_sigma_global) { + if (sample_sigma2_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) global_model_config$update_global_error_variance(current_sigma2) } - if (sample_sigma_leaf_mu) { + if (sample_sigma2_leaf_mu) { leaf_scale_mu_double <- sampleLeafVarianceOneIteration(active_forest_mu, rng, a_leaf_mu, b_leaf_mu) current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) if (keep_sample) leaf_scale_mu_samples[sample_counter] <- leaf_scale_mu_double @@ -1041,12 +1041,12 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE ) } - if (sample_sigma_global) { + if (sample_sigma2_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 global_model_config$update_global_error_variance(current_sigma2) } - if (sample_sigma_leaf_tau) { + if (sample_sigma2_leaf_tau) { leaf_scale_tau_double <- sampleLeafVarianceOneIteration(active_forest_tau, rng, a_leaf_tau, b_leaf_tau) current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) if (keep_sample) leaf_scale_tau_samples[sample_counter] <- leaf_scale_tau_double @@ -1070,12 +1070,12 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id resetForestModel(forest_model_mu, active_forest_mu, forest_dataset_train, outcome_train, TRUE) resetActiveForest(active_forest_tau, forest_samples_tau, forest_ind) resetForestModel(forest_model_tau, active_forest_tau, forest_dataset_train, outcome_train, TRUE) - if (sample_sigma_leaf_mu) { + if (sample_sigma2_leaf_mu) { leaf_scale_mu_double <- leaf_scale_mu_samples[forest_ind + 1] current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) } - if (sample_sigma_leaf_tau) { + if (sample_sigma2_leaf_tau) { leaf_scale_tau_double <- leaf_scale_tau_samples[forest_ind + 1] current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) @@ -1099,7 +1099,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) } - if (sample_sigma_global) { + if (sample_sigma2_global) { current_sigma2 <- global_var_samples[forest_ind + 1] global_model_config$update_global_error_variance(current_sigma2) } @@ -1112,12 +1112,12 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id resetActiveForest(active_forest_variance, previous_forest_samples_variance, previous_model_warmstart_sample_num - 1) resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE) } - if (sample_sigma_leaf_mu && (!is.null(previous_leaf_var_mu_samples))) { + if (sample_sigma2_leaf_mu && (!is.null(previous_leaf_var_mu_samples))) { leaf_scale_mu_double <- previous_leaf_var_mu_samples[previous_model_warmstart_sample_num] current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) } - if (sample_sigma_leaf_tau && (!is.null(previous_leaf_var_tau_samples))) { + if (sample_sigma2_leaf_tau && (!is.null(previous_leaf_var_tau_samples))) { leaf_scale_tau_double <- previous_leaf_var_tau_samples[previous_model_warmstart_sample_num] current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) @@ -1148,7 +1148,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) } } - if (sample_sigma_global) { + if (sample_sigma2_global) { if (!is.null(previous_global_var_samples)) { current_sigma2 <- previous_global_var_samples[previous_model_warmstart_sample_num] } @@ -1161,12 +1161,12 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id resetActiveForest(active_forest_tau) active_forest_tau$set_root_leaves(init_tau / num_trees_tau) resetForestModel(forest_model_tau, active_forest_tau, forest_dataset_train, outcome_train, TRUE) - if (sample_sigma_leaf_mu) { - current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) + if (sample_sigma2_leaf_mu) { + current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) } - if (sample_sigma_leaf_tau) { - current_leaf_scale_tau <- as.matrix(sigma_leaf_tau) + if (sample_sigma2_leaf_tau) { + current_leaf_scale_tau <- as.matrix(sigma2_leaf_tau) forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) } if (include_variance_forest) { @@ -1190,7 +1190,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) } - if (sample_sigma_global) { + if (sample_sigma2_global) { current_sigma2 <- sigma2_init global_model_config$update_global_error_variance(current_sigma2) } @@ -1244,11 +1244,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id ) # Sample variance parameters (if requested) - if (sample_sigma_global) { + if (sample_sigma2_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) global_model_config$update_global_error_variance(current_sigma2) } - if (sample_sigma_leaf_mu) { + if (sample_sigma2_leaf_mu) { leaf_scale_mu_double <- sampleLeafVarianceOneIteration(active_forest_mu, rng, a_leaf_mu, b_leaf_mu) current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) if (keep_sample) leaf_scale_mu_samples[sample_counter] <- leaf_scale_mu_double @@ -1307,12 +1307,12 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE ) } - if (sample_sigma_global) { + if (sample_sigma2_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 global_model_config$update_global_error_variance(current_sigma2) } - if (sample_sigma_leaf_tau) { + if (sample_sigma2_leaf_tau) { leaf_scale_tau_double <- sampleLeafVarianceOneIteration(active_forest_tau, rng, a_leaf_tau, b_leaf_tau) current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) if (keep_sample) leaf_scale_tau_samples[sample_counter] <- leaf_scale_tau_double @@ -1339,13 +1339,13 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id rfx_samples$delete_sample(0) } } - if (sample_sigma_global) { + if (sample_sigma2_global) { global_var_samples <- global_var_samples[(num_gfr+1):length(global_var_samples)] } - if (sample_sigma_leaf_mu) { + if (sample_sigma2_leaf_mu) { leaf_scale_mu_samples <- leaf_scale_mu_samples[(num_gfr+1):length(leaf_scale_mu_samples)] } - if (sample_sigma_leaf_tau) { + if (sample_sigma2_leaf_tau) { leaf_scale_tau_samples <- leaf_scale_tau_samples[(num_gfr+1):length(leaf_scale_tau_samples)] } if (adaptive_coding) { @@ -1375,8 +1375,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test) } if (include_variance_forest) { - sigma_x_hat_train <- forest_samples_variance$predict(forest_dataset_train) - if (has_test) sigma_x_hat_test <- forest_samples_variance$predict(forest_dataset_test) + sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train) + if (has_test) sigma2_x_hat_test <- forest_samples_variance$predict(forest_dataset_test) } # Random effects predictions @@ -1390,22 +1390,22 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } # Global error variance - if (sample_sigma_global) sigma2_samples <- global_var_samples*(y_std_train^2) + if (sample_sigma2_global) sigma2_global_samples <- global_var_samples*(y_std_train^2) # Leaf parameter variance for prognostic forest - if (sample_sigma_leaf_mu) sigma_leaf_mu_samples <- leaf_scale_mu_samples + if (sample_sigma2_leaf_mu) sigma2_leaf_mu_samples <- leaf_scale_mu_samples # Leaf parameter variance for treatment effect forest - if (sample_sigma_leaf_tau) sigma_leaf_tau_samples <- leaf_scale_tau_samples + if (sample_sigma2_leaf_tau) sigma2_leaf_tau_samples <- leaf_scale_tau_samples # Rescale variance forest prediction by global sigma2 (sampled or constant) if (include_variance_forest) { - if (sample_sigma_global) { - sigma_x_hat_train <- sapply(1:num_retained_samples, function(i) sqrt(sigma_x_hat_train[,i]*sigma2_samples[i])) - if (has_test) sigma_x_hat_test <- sapply(1:num_retained_samples, function(i) sqrt(sigma_x_hat_test[,i]*sigma2_samples[i])) + if (sample_sigma2_global) { + sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) sigma2_x_hat_train[,i]*sigma2_global_samples[i]) + if (has_test) sigma2_x_hat_test <- sapply(1:num_retained_samples, function(i) sigma2_x_hat_test[,i]*sigma2_global_samples[i]) } else { - sigma_x_hat_train <- sqrt(sigma_x_hat_train*sigma2_init)*y_std_train - if (has_test) sigma_x_hat_test <- sqrt(sigma_x_hat_test*sigma2_init)*y_std_train + sigma2_x_hat_train <- sigma2_x_hat_train*sigma2_init*y_std_train*y_std_train + if (has_test) sigma2_x_hat_test <- sigma2_x_hat_test*sigma2_init*y_std_train*y_std_train } } @@ -1417,8 +1417,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } model_params <- list( "initial_sigma2" = sigma2_init, - "initial_sigma_leaf_mu" = sigma_leaf_mu, - "initial_sigma_leaf_tau" = sigma_leaf_tau, + "initial_sigma2_leaf_mu" = sigma2_leaf_mu, + "initial_sigma2_leaf_tau" = sigma2_leaf_tau, "initial_b_0" = b_0, "initial_b_1" = b_1, "a_global" = a_global, @@ -1451,9 +1451,9 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id "has_rfx_basis" = has_basis_rfx, "num_rfx_basis" = num_basis_rfx, "include_variance_forest" = include_variance_forest, - "sample_sigma_global" = sample_sigma_global, - "sample_sigma_leaf_mu" = sample_sigma_leaf_mu, - "sample_sigma_leaf_tau" = sample_sigma_leaf_tau, + "sample_sigma2_global" = sample_sigma2_global, + "sample_sigma2_leaf_mu" = sample_sigma2_leaf_mu, + "sample_sigma2_leaf_tau" = sample_sigma2_leaf_tau, "probit_outcome_model" = probit_outcome_model ) result <- list( @@ -1470,12 +1470,12 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (has_test) result[["y_hat_test"]] = y_hat_test if (include_variance_forest) { result[["forests_variance"]] = forest_samples_variance - result[["sigma_x_hat_train"]] = sigma_x_hat_train - if (has_test) result[["sigma_x_hat_test"]] = sigma_x_hat_test + result[["sigma2_x_hat_train"]] = sigma2_x_hat_train + if (has_test) result[["sigma2_x_hat_test"]] = sigma2_x_hat_test } - if (sample_sigma_global) result[["sigma2_samples"]] = sigma2_samples - if (sample_sigma_leaf_mu) result[["sigma_leaf_mu_samples"]] = sigma_leaf_mu_samples - if (sample_sigma_leaf_tau) result[["sigma_leaf_tau_samples"]] = sigma_leaf_tau_samples + if (sample_sigma2_global) result[["sigma2_global_samples"]] = sigma2_global_samples + if (sample_sigma2_leaf_mu) result[["sigma2_leaf_mu_samples"]] = sigma2_leaf_mu_samples + if (sample_sigma2_leaf_tau) result[["sigma2_leaf_tau_samples"]] = sigma2_leaf_tau_samples if (adaptive_coding) { result[["b_0_samples"]] = b_0_samples result[["b_1_samples"]] = b_1_samples @@ -1636,7 +1636,7 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU tau_hat <- object$forests_tau$predict_raw(forest_dataset_pred)*y_std } if (object$model_params$include_variance_forest) { - s_x_raw <- object$variance_forests$predict(forest_dataset_pred) + s_x_raw <- object$forests_variance$predict(forest_dataset_pred) } # Compute rfx predictions (if needed) @@ -1650,11 +1650,11 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU # Scale variance forest predictions if (object$model_params$include_variance_forest) { - if (object$model_params$sample_sigma_global) { - sigma2_samples <- object$sigma2_global_samples - variance_forest_predictions <- sapply(1:num_samples, function(i) sqrt(s_x_raw[,i]*sigma2_samples[i])) + if (object$model_params$sample_sigma2_global) { + sigma2_global_samples <- object$sigma2_global_samples + variance_forest_predictions <- sapply(1:num_samples, function(i) s_x_raw[,i]*sigma2_global_samples[i]) } else { - variance_forest_predictions <- sqrt(s_x_raw*initial_sigma2)*y_std + variance_forest_predictions <- s_x_raw*initial_sigma2*y_std*y_std } } @@ -1734,8 +1734,8 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' mu_params <- list(sample_sigma_leaf = TRUE) -#' tau_params <- list(sample_sigma_leaf = FALSE) +#' mu_params <- list(sample_sigma2_leaf = TRUE) +#' tau_params <- list(sample_sigma2_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' propensity_train = pi_train, #' rfx_group_ids_train = rfx_group_ids_train, @@ -1827,8 +1827,8 @@ getRandomEffectSamples.bcfmodel <- function(object, ...){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' mu_params <- list(sample_sigma_leaf = TRUE) -#' tau_params <- list(sample_sigma_leaf = FALSE) +#' mu_params <- list(sample_sigma2_leaf = TRUE) +#' tau_params <- list(sample_sigma2_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' propensity_train = pi_train, #' rfx_group_ids_train = rfx_group_ids_train, @@ -1879,9 +1879,9 @@ saveBCFModelToJson <- function(object){ jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) jsonobj$add_boolean("standardize", object$model_params$standardize) jsonobj$add_scalar("initial_sigma2", object$model_params$initial_sigma2) - jsonobj$add_boolean("sample_sigma_global", object$model_params$sample_sigma_global) - jsonobj$add_boolean("sample_sigma_leaf_mu", object$model_params$sample_sigma_leaf_mu) - jsonobj$add_boolean("sample_sigma_leaf_tau", object$model_params$sample_sigma_leaf_tau) + jsonobj$add_boolean("sample_sigma2_global", object$model_params$sample_sigma2_global) + jsonobj$add_boolean("sample_sigma2_leaf_mu", object$model_params$sample_sigma2_leaf_mu) + jsonobj$add_boolean("sample_sigma2_leaf_tau", object$model_params$sample_sigma2_leaf_tau) jsonobj$add_boolean("include_variance_forest", object$model_params$include_variance_forest) jsonobj$add_string("propensity_covariate", object$model_params$propensity_covariate) jsonobj$add_boolean("has_rfx", object$model_params$has_rfx) @@ -1897,14 +1897,14 @@ saveBCFModelToJson <- function(object){ jsonobj$add_scalar("num_chains", object$model_params$num_chains) jsonobj$add_scalar("num_covariates", object$model_params$num_covariates) jsonobj$add_boolean("probit_outcome_model", object$model_params$probit_outcome_model) - if (object$model_params$sample_sigma_global) { - jsonobj$add_vector("sigma2_samples", object$sigma2_samples, "parameters") + if (object$model_params$sample_sigma2_global) { + jsonobj$add_vector("sigma2_global_samples", object$sigma2_global_samples, "parameters") } - if (object$model_params$sample_sigma_leaf_mu) { - jsonobj$add_vector("sigma_leaf_mu_samples", object$sigma_leaf_mu_samples, "parameters") + if (object$model_params$sample_sigma2_leaf_mu) { + jsonobj$add_vector("sigma2_leaf_mu_samples", object$sigma2_leaf_mu_samples, "parameters") } - if (object$model_params$sample_sigma_leaf_tau) { - jsonobj$add_vector("sigma_leaf_tau_samples", object$sigma_leaf_tau_samples, "parameters") + if (object$model_params$sample_sigma2_leaf_tau) { + jsonobj$add_vector("sigma2_leaf_tau_samples", object$sigma2_leaf_tau_samples, "parameters") } if (object$model_params$adaptive_coding) { jsonobj$add_vector("b_1_samples", object$b_1_samples, "parameters") @@ -1995,8 +1995,8 @@ saveBCFModelToJson <- function(object){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' mu_params <- list(sample_sigma_leaf = TRUE) -#' tau_params <- list(sample_sigma_leaf = FALSE) +#' mu_params <- list(sample_sigma2_leaf = TRUE) +#' tau_params <- list(sample_sigma2_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' propensity_train = pi_train, #' rfx_group_ids_train = rfx_group_ids_train, @@ -2077,8 +2077,8 @@ saveBCFModelToJsonFile <- function(object, filename){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' mu_params <- list(sample_sigma_leaf = TRUE) -#' tau_params <- list(sample_sigma_leaf = FALSE) +#' mu_params <- list(sample_sigma2_leaf = TRUE) +#' tau_params <- list(sample_sigma2_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' propensity_train = pi_train, #' rfx_group_ids_train = rfx_group_ids_train, @@ -2159,8 +2159,8 @@ saveBCFModelToJsonString <- function(object){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' mu_params <- list(sample_sigma_leaf = TRUE) -#' tau_params <- list(sample_sigma_leaf = FALSE) +#' mu_params <- list(sample_sigma2_leaf = TRUE) +#' tau_params <- list(sample_sigma2_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' propensity_train = pi_train, #' rfx_group_ids_train = rfx_group_ids_train, @@ -2209,9 +2209,9 @@ createBCFModelFromJson <- function(json_object){ model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") model_params[["standardize"]] <- json_object$get_boolean("standardize") model_params[["initial_sigma2"]] <- json_object$get_scalar("initial_sigma2") - model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global") - model_params[["sample_sigma_leaf_mu"]] <- json_object$get_boolean("sample_sigma_leaf_mu") - model_params[["sample_sigma_leaf_tau"]] <- json_object$get_boolean("sample_sigma_leaf_tau") + model_params[["sample_sigma2_global"]] <- json_object$get_boolean("sample_sigma2_global") + model_params[["sample_sigma2_leaf_mu"]] <- json_object$get_boolean("sample_sigma2_leaf_mu") + model_params[["sample_sigma2_leaf_tau"]] <- json_object$get_boolean("sample_sigma2_leaf_tau") model_params[["include_variance_forest"]] <- include_variance_forest model_params[["propensity_covariate"]] <- json_object$get_string("propensity_covariate") model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") @@ -2228,14 +2228,14 @@ createBCFModelFromJson <- function(json_object){ output[["model_params"]] <- model_params # Unpack sampled parameters - if (model_params[["sample_sigma_global"]]) { - output[["sigma2_samples"]] <- json_object$get_vector("sigma2_samples", "parameters") + if (model_params[["sample_sigma2_global"]]) { + output[["sigma2_global_samples"]] <- json_object$get_vector("sigma2_global_samples", "parameters") } - if (model_params[["sample_sigma_leaf_mu"]]) { - output[["sigma_leaf_mu_samples"]] <- json_object$get_vector("sigma_leaf_mu_samples", "parameters") + if (model_params[["sample_sigma2_leaf_mu"]]) { + output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector("sigma2_leaf_mu_samples", "parameters") } - if (model_params[["sample_sigma_leaf_tau"]]) { - output[["sigma_leaf_tau_samples"]] <- json_object$get_vector("sigma_leaf_tau_samples", "parameters") + if (model_params[["sample_sigma2_leaf_tau"]]) { + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector("sigma2_leaf_tau_samples", "parameters") } if (model_params[["adaptive_coding"]]) { output[["b_1_samples"]] <- json_object$get_vector("b_1_samples", "parameters") @@ -2327,8 +2327,8 @@ createBCFModelFromJson <- function(json_object){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' mu_params <- list(sample_sigma_leaf = TRUE) -#' tau_params <- list(sample_sigma_leaf = FALSE) +#' mu_params <- list(sample_sigma2_leaf = TRUE) +#' tau_params <- list(sample_sigma2_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' propensity_train = pi_train, #' rfx_group_ids_train = rfx_group_ids_train, @@ -2545,9 +2545,9 @@ createBCFModelFromCombinedJson <- function(json_object_list){ model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") model_params[["standardize"]] <- json_object_default$get_boolean("standardize") model_params[["initial_sigma2"]] <- json_object_default$get_scalar("initial_sigma2") - model_params[["sample_sigma_global"]] <- json_object_default$get_boolean("sample_sigma_global") - model_params[["sample_sigma_leaf_mu"]] <- json_object_default$get_boolean("sample_sigma_leaf_mu") - model_params[["sample_sigma_leaf_tau"]] <- json_object_default$get_boolean("sample_sigma_leaf_tau") + model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean("sample_sigma2_global") + model_params[["sample_sigma2_leaf_mu"]] <- json_object_default$get_boolean("sample_sigma2_leaf_mu") + model_params[["sample_sigma2_leaf_tau"]] <- json_object_default$get_boolean("sample_sigma2_leaf_tau") model_params[["include_variance_forest"]] <- include_variance_forest model_params[["propensity_covariate"]] <- json_object_default$get_string("propensity_covariate") model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") @@ -2579,43 +2579,43 @@ createBCFModelFromCombinedJson <- function(json_object_list){ output[["model_params"]] <- model_params # Unpack sampled parameters - if (model_params[["sample_sigma_global"]]) { + if (model_params[["sample_sigma2_global"]]) { for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma2_samples"]] <- json_object$get_vector("sigma2_samples", "parameters") + output[["sigma2_global_samples"]] <- json_object$get_vector("sigma2_global_samples", "parameters") } else { - output[["sigma2_samples"]] <- c(output[["sigma2_samples"]], json_object$get_vector("sigma2_samples", "parameters")) + output[["sigma2_global_samples"]] <- c(output[["sigma2_global_samples"]], json_object$get_vector("sigma2_global_samples", "parameters")) } } } - if (model_params[["sample_sigma_leaf_mu"]]) { + if (model_params[["sample_sigma2_leaf_mu"]]) { for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma_leaf_mu_samples"]] <- json_object$get_vector("sigma_leaf_mu_samples", "parameters") + output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector("sigma2_leaf_mu_samples", "parameters") } else { - output[["sigma_leaf_mu_samples"]] <- c(output[["sigma_leaf_mu_samples"]], json_object$get_vector("sigma_leaf_mu_samples", "parameters")) + output[["sigma2_leaf_mu_samples"]] <- c(output[["sigma2_leaf_mu_samples"]], json_object$get_vector("sigma2_leaf_mu_samples", "parameters")) } } } - if (model_params[["sample_sigma_leaf_tau"]]) { + if (model_params[["sample_sigma2_leaf_tau"]]) { for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma_leaf_tau_samples"]] <- json_object$get_vector("sigma_leaf_tau_samples", "parameters") + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector("sigma2_leaf_tau_samples", "parameters") } else { - output[["sigma_leaf_tau_samples"]] <- c(output[["sigma_leaf_tau_samples"]], json_object$get_vector("sigma_leaf_tau_samples", "parameters")) + output[["sigma2_leaf_tau_samples"]] <- c(output[["sigma2_leaf_tau_samples"]], json_object$get_vector("sigma2_leaf_tau_samples", "parameters")) } } } - if (model_params[["sample_sigma_leaf_tau"]]) { + if (model_params[["sample_sigma2_leaf_tau"]]) { for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma_leaf_tau_samples"]] <- json_object$get_vector("sigma_leaf_tau_samples", "parameters") + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector("sigma2_leaf_tau_samples", "parameters") } else { - output[["sigma_leaf_tau_samples"]] <- c(output[["sigma_leaf_tau_samples"]], json_object$get_vector("sigma_leaf_tau_samples", "parameters")) + output[["sigma2_leaf_tau_samples"]] <- c(output[["sigma2_leaf_tau_samples"]], json_object$get_vector("sigma2_leaf_tau_samples", "parameters")) } } } @@ -2772,9 +2772,9 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){ model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") model_params[["standardize"]] <- json_object_default$get_boolean("standardize") model_params[["initial_sigma2"]] <- json_object_default$get_scalar("initial_sigma2") - model_params[["sample_sigma_global"]] <- json_object_default$get_boolean("sample_sigma_global") - model_params[["sample_sigma_leaf_mu"]] <- json_object_default$get_boolean("sample_sigma_leaf_mu") - model_params[["sample_sigma_leaf_tau"]] <- json_object_default$get_boolean("sample_sigma_leaf_tau") + model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean("sample_sigma2_global") + model_params[["sample_sigma2_leaf_mu"]] <- json_object_default$get_boolean("sample_sigma2_leaf_mu") + model_params[["sample_sigma2_leaf_tau"]] <- json_object_default$get_boolean("sample_sigma2_leaf_tau") model_params[["include_variance_forest"]] <- include_variance_forest model_params[["propensity_covariate"]] <- json_object_default$get_string("propensity_covariate") model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") @@ -2806,43 +2806,43 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){ output[["model_params"]] <- model_params # Unpack sampled parameters - if (model_params[["sample_sigma_global"]]) { + if (model_params[["sample_sigma2_global"]]) { for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma2_samples"]] <- json_object$get_vector("sigma2_samples", "parameters") + output[["sigma2_global_samples"]] <- json_object$get_vector("sigma2_global_samples", "parameters") } else { - output[["sigma2_samples"]] <- c(output[["sigma2_samples"]], json_object$get_vector("sigma2_samples", "parameters")) + output[["sigma2_global_samples"]] <- c(output[["sigma2_global_samples"]], json_object$get_vector("sigma2_global_samples", "parameters")) } } } - if (model_params[["sample_sigma_leaf_mu"]]) { + if (model_params[["sample_sigma2_leaf_mu"]]) { for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma_leaf_mu_samples"]] <- json_object$get_vector("sigma_leaf_mu_samples", "parameters") + output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector("sigma2_leaf_mu_samples", "parameters") } else { - output[["sigma_leaf_mu_samples"]] <- c(output[["sigma_leaf_mu_samples"]], json_object$get_vector("sigma_leaf_mu_samples", "parameters")) + output[["sigma2_leaf_mu_samples"]] <- c(output[["sigma2_leaf_mu_samples"]], json_object$get_vector("sigma2_leaf_mu_samples", "parameters")) } } } - if (model_params[["sample_sigma_leaf_tau"]]) { + if (model_params[["sample_sigma2_leaf_tau"]]) { for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma_leaf_tau_samples"]] <- json_object$get_vector("sigma_leaf_tau_samples", "parameters") + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector("sigma2_leaf_tau_samples", "parameters") } else { - output[["sigma_leaf_tau_samples"]] <- c(output[["sigma_leaf_tau_samples"]], json_object$get_vector("sigma_leaf_tau_samples", "parameters")) + output[["sigma2_leaf_tau_samples"]] <- c(output[["sigma2_leaf_tau_samples"]], json_object$get_vector("sigma2_leaf_tau_samples", "parameters")) } } } - if (model_params[["sample_sigma_leaf_tau"]]) { + if (model_params[["sample_sigma2_leaf_tau"]]) { for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma_leaf_tau_samples"]] <- json_object$get_vector("sigma_leaf_tau_samples", "parameters") + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector("sigma2_leaf_tau_samples", "parameters") } else { - output[["sigma_leaf_tau_samples"]] <- c(output[["sigma_leaf_tau_samples"]], json_object$get_vector("sigma_leaf_tau_samples", "parameters")) + output[["sigma2_leaf_tau_samples"]] <- c(output[["sigma2_leaf_tau_samples"]], json_object$get_vector("sigma2_leaf_tau_samples", "parameters")) } } } diff --git a/R/kernel.R b/R/kernel.R index 3265a1b7..f20630b2 100644 --- a/R/kernel.R +++ b/R/kernel.R @@ -157,7 +157,7 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU if (!model_object$model_params$include_mean_forest) { stop("Mean forest was not sampled in the bart model provided") } - if (!model_object$model_params$sample_sigma_leaf) { + if (!model_object$model_params$sample_sigma2_leaf) { stop("Leaf scale parameter was not sampled for the mean forest in the bart model provided") } leaf_scale_vector <- model_object$sigma2_leaf_samples @@ -170,15 +170,15 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU } else { stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) if (forest_type=="prognostic") { - if (!model_object$model_params$sample_sigma_leaf_mu) { + if (!model_object$model_params$sample_sigma2_leaf_mu) { stop("Leaf scale parameter was not sampled for the prognostic forest in the bcf model provided") } - leaf_scale_vector <- model_object$sigma_leaf_mu_samples + leaf_scale_vector <- model_object$sigma2_leaf_mu_samples } else if (forest_type=="treatment") { - if (!model_object$model_params$sample_sigma_leaf_tau) { + if (!model_object$model_params$sample_sigma2_leaf_tau) { stop("Leaf scale parameter was not sampled for the treatment effect forest in the bcf model provided") } - leaf_scale_vector <- model_object$sigma_leaf_tau_samples + leaf_scale_vector <- model_object$sigma2_leaf_tau_samples } else if (forest_type=="variance") { if (!model_object$model_params$include_variance_forest) { stop("Variance forest was not sampled in the bcf model provided") diff --git a/R/serialization.R b/R/serialization.R index d5d4e046..812b752e 100644 --- a/R/serialization.R +++ b/R/serialization.R @@ -545,7 +545,7 @@ loadRandomEffectSamplesCombinedJsonString <- function(json_string_list, json_rfx #' Load a vector from json #' #' @param json_object Object of class `CppJson` -#' @param json_vector_label Label referring to a particular vector (i.e. "sigma2_samples") in the overall json hierarchy +#' @param json_vector_label Label referring to a particular vector (i.e. "sigma2_global_samples") in the overall json hierarchy #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which vector sits #' #' @return R vector diff --git a/R/stochtree-package.R b/R/stochtree-package.R index 83a5e477..f3fd5c43 100644 --- a/R/stochtree-package.R +++ b/R/stochtree-package.R @@ -1,11 +1,15 @@ ## usethis namespace: start #' @importFrom stats coef +#' @importFrom stats dnorm #' @importFrom stats lm #' @importFrom stats model.matrix #' @importFrom stats predict #' @importFrom stats qgamma +#' @importFrom stats qnorm +#' @importFrom stats pnorm #' @importFrom stats resid #' @importFrom stats rnorm +#' @importFrom stats runif #' @importFrom stats sd #' @importFrom stats sigma #' @importFrom stats var diff --git a/demo/debug/causal_inference.py b/demo/debug/causal_inference.py index fb77367e..4b29eba1 100644 --- a/demo/debug/causal_inference.py +++ b/demo/debug/causal_inference.py @@ -87,8 +87,8 @@ plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3))) plt.show() -sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"]) -sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma") +sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma^2"]) +sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma^2") plt.show() b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.b0_samples,axis=1), np.expand_dims(bcf_model.b1_samples,axis=1)), axis = 1), columns=["Sample", "Beta_0", "Beta_1"]) diff --git a/demo/debug/supervised_learning.py b/demo/debug/supervised_learning.py index 955d83ec..5a039b35 100644 --- a/demo/debug/supervised_learning.py +++ b/demo/debug/supervised_learning.py @@ -66,8 +66,8 @@ def outcome_mean(X, W): plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3))) plt.show() -sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"]) -sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma") +sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma^2"]) +sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma^2") plt.show() # Compute the test set RMSE @@ -89,8 +89,8 @@ def outcome_mean(X, W): plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3))) plt.show() -sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"]) -sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma") +sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma^2"]) +sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma^2") plt.show() # Compute the test set RMSE @@ -110,8 +110,8 @@ def outcome_mean(X, W): plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3))) plt.show() -sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"]) -sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma") +sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma^2"]) +sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma^2") plt.show() # Compute the test set RMSE diff --git a/demo/notebooks/prototype_interface.ipynb b/demo/notebooks/prototype_interface.ipynb index db2d2f22..064ce9b6 100644 --- a/demo/notebooks/prototype_interface.ipynb +++ b/demo/notebooks/prototype_interface.ipynb @@ -346,9 +346,9 @@ "forest_preds_mcmc = forest_preds[:, num_warmstart:num_samples]\n", "\n", "# Global error variance\n", - "sigma_samples = np.sqrt(global_var_samples) * y_std\n", - "sigma_samples_gfr = sigma_samples[:num_warmstart]\n", - "sigma_samples_mcmc = sigma_samples[num_warmstart:num_samples]" + "sigma2_samples = global_var_samples * y_std * y_std\n", + "sigma2_samples_gfr = sigma2_samples[:num_warmstart]\n", + "sigma2_samples_mcmc = sigma2_samples[num_warmstart:num_samples]" ] }, { @@ -384,13 +384,13 @@ " np.concatenate(\n", " (\n", " np.expand_dims(np.arange(num_warmstart), axis=1),\n", - " np.expand_dims(sigma_samples_gfr, axis=1),\n", + " np.expand_dims(sigma2_samples_gfr, axis=1),\n", " ),\n", " axis=1,\n", " ),\n", - " columns=[\"Sample\", \"Sigma\"],\n", + " columns=[\"Sample\", \"Sigma^2\"],\n", ")\n", - "sns.scatterplot(data=sigma_df_gfr, x=\"Sample\", y=\"Sigma\")\n", + "sns.scatterplot(data=sigma_df_gfr, x=\"Sample\", y=\"Sigma^2\")\n", "plt.show()" ] }, @@ -427,13 +427,13 @@ " np.concatenate(\n", " (\n", " np.expand_dims(np.arange(num_samples - num_warmstart), axis=1),\n", - " np.expand_dims(sigma_samples_mcmc, axis=1),\n", + " np.expand_dims(sigma2_samples_mcmc, axis=1),\n", " ),\n", " axis=1,\n", " ),\n", - " columns=[\"Sample\", \"Sigma\"],\n", + " columns=[\"Sample\", \"Sigma^2\"],\n", ")\n", - "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n", + "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n", "plt.show()" ] }, @@ -909,9 +909,9 @@ "forest_preds_tau_mcmc = forest_preds_tau[:, num_warmstart:num_samples]\n", "\n", "# Global error variance\n", - "sigma_samples = np.sqrt(global_var_samples) * y_std\n", - "sigma_samples_gfr = sigma_samples[:num_warmstart]\n", - "sigma_samples_mcmc = sigma_samples[num_warmstart:num_samples]\n", + "sigma2_samples = global_var_samples * y_std * y_std\n", + "sigma2_samples_gfr = sigma2_samples[:num_warmstart]\n", + "sigma2_samples_mcmc = sigma2_samples[num_warmstart:num_samples]\n", "\n", "# Adaptive coding parameters\n", "b_1_samples_gfr = b_1_samples[:num_warmstart] * y_std\n", @@ -969,13 +969,13 @@ " np.concatenate(\n", " (\n", " np.expand_dims(np.arange(num_warmstart), axis=1),\n", - " np.expand_dims(sigma_samples_gfr, axis=1),\n", + " np.expand_dims(sigma2_samples_gfr, axis=1),\n", " ),\n", " axis=1,\n", " ),\n", - " columns=[\"Sample\", \"Sigma\"],\n", + " columns=[\"Sample\", \"Sigma^2\"],\n", ")\n", - "sns.scatterplot(data=sigma_df_gfr, x=\"Sample\", y=\"Sigma\")\n", + "sns.scatterplot(data=sigma_df_gfr, x=\"Sample\", y=\"Sigma^2\")\n", "plt.show()" ] }, @@ -1050,13 +1050,13 @@ " np.concatenate(\n", " (\n", " np.expand_dims(np.arange(num_samples - num_warmstart), axis=1),\n", - " np.expand_dims(sigma_samples_mcmc, axis=1),\n", + " np.expand_dims(sigma2_samples_mcmc, axis=1),\n", " ),\n", " axis=1,\n", " ),\n", - " columns=[\"Sample\", \"Sigma\"],\n", + " columns=[\"Sample\", \"Sigma^2\"],\n", ")\n", - "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n", + "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n", "plt.show()" ] }, diff --git a/demo/notebooks/serialization.ipynb b/demo/notebooks/serialization.ipynb index f9be5709..3e023acf 100644 --- a/demo/notebooks/serialization.ipynb +++ b/demo/notebooks/serialization.ipynb @@ -173,9 +173,9 @@ " ),\n", " axis=1,\n", " ),\n", - " columns=[\"Sample\", \"Sigma\"],\n", + " columns=[\"Sample\", \"Sigma^2\"],\n", ")\n", - "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n", + "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n", "plt.show()" ] }, diff --git a/demo/notebooks/supervised_learning.ipynb b/demo/notebooks/supervised_learning.ipynb index e1067247..cffdce2e 100644 --- a/demo/notebooks/supervised_learning.ipynb +++ b/demo/notebooks/supervised_learning.ipynb @@ -172,9 +172,9 @@ " ),\n", " axis=1,\n", " ),\n", - " columns=[\"Sample\", \"Sigma\"],\n", + " columns=[\"Sample\", \"Sigma^2\"],\n", ")\n", - "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n", + "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n", "plt.show()" ] }, @@ -260,9 +260,9 @@ " ),\n", " axis=1,\n", " ),\n", - " columns=[\"Sample\", \"Sigma\"],\n", + " columns=[\"Sample\", \"Sigma^2\"],\n", ")\n", - "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n", + "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n", "plt.show()" ] }, @@ -346,9 +346,9 @@ " ),\n", " axis=1,\n", " ),\n", - " columns=[\"Sample\", \"Sigma\"],\n", + " columns=[\"Sample\", \"Sigma^2\"],\n", ")\n", - "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n", + "sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n", "plt.show()" ] }, @@ -371,7 +371,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "stochtree-dev", "language": "python", "name": "python3" }, @@ -385,7 +385,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.9" + "version": "3.10.16" } }, "nbformat": 4, diff --git a/stochtree/bart.py b/stochtree/bart.py index 3bbc8d8e..0539129a 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -233,7 +233,7 @@ def sample( # 1. General parameters cutpoint_grid_size = general_params_updated["cutpoint_grid_size"] self.standardize = general_params_updated["standardize"] - sample_sigma_global = general_params_updated["sample_sigma2_global"] + sample_sigma2_global = general_params_updated["sample_sigma2_global"] sigma2_init = general_params_updated["sigma2_init"] a_global = general_params_updated["sigma2_global_shape"] b_global = general_params_updated["sigma2_global_scale"] @@ -251,8 +251,8 @@ def sample( beta_mean = mean_forest_params_updated["beta"] min_samples_leaf_mean = mean_forest_params_updated["min_samples_leaf"] max_depth_mean = mean_forest_params_updated["max_depth"] - sample_sigma_leaf = mean_forest_params_updated["sample_sigma2_leaf"] - sigma_leaf = mean_forest_params_updated["sigma2_leaf_init"] + sample_sigma2_leaf = mean_forest_params_updated["sample_sigma2_leaf"] + sigma2_leaf = mean_forest_params_updated["sigma2_leaf_init"] a_leaf = mean_forest_params_updated["sigma2_leaf_shape"] b_leaf = mean_forest_params_updated["sigma2_leaf_scale"] keep_vars_mean = mean_forest_params_updated["keep_vars"] @@ -662,13 +662,13 @@ def sample( ) else: previous_forest_samples_variance = None - if previous_bart_model.sample_sigma_global: + if previous_bart_model.sample_sigma2_global: previous_global_var_samples = previous_bart_model.global_var_samples / ( previous_y_scale * previous_y_scale ) else: previous_global_var_samples = None - if previous_bart_model.sample_sigma_leaf: + if previous_bart_model.sample_sigma2_leaf: previous_leaf_var_samples = previous_bart_model.leaf_scale_samples else: previous_leaf_var_samples = None @@ -731,11 +731,11 @@ def sample( raise ValueError( "We do not support heteroskedasticity with a probit link" ) - if sample_sigma_global: + if sample_sigma2_global: warnings.warn( "Global error variance will not be sampled with a probit link as it is fixed at 1" ) - sample_sigma_global = False + sample_sigma2_global = False # Handle standardization, prior calibration, and initialization of forest # differently for binary and continuous outcomes @@ -758,7 +758,7 @@ def sample( # Skip variance_forest_init, since variance forests are not supported with probit link b_leaf = 1.0 / num_trees_mean if b_leaf is None else b_leaf if self.has_basis: - if sigma_leaf is None: + if sigma2_leaf is None: current_leaf_scale = np.zeros( (self.num_basis, self.num_basis), dtype=float ) @@ -766,51 +766,51 @@ def sample( current_leaf_scale, 2.0 / num_trees_mean, ) - elif isinstance(sigma_leaf, float): + elif isinstance(sigma2_leaf, float): current_leaf_scale = np.zeros( (self.num_basis, self.num_basis), dtype=float ) - np.fill_diagonal(current_leaf_scale, sigma_leaf) - elif isinstance(sigma_leaf, np.ndarray): - if sigma_leaf.ndim != 2: + np.fill_diagonal(current_leaf_scale, sigma2_leaf) + elif isinstance(sigma2_leaf, np.ndarray): + if sigma2_leaf.ndim != 2: raise ValueError( - "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" ) - if sigma_leaf.shape[0] != sigma_leaf.shape[1]: + if sigma2_leaf.shape[0] != sigma2_leaf.shape[1]: raise ValueError( - "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" ) - if sigma_leaf.shape[0] != self.num_basis: + if sigma2_leaf.shape[0] != self.num_basis: raise ValueError( - "sigma_leaf must be a 2d symmetric numpy array with its dimensionality matching the basis dimension" + "sigma2_leaf must be a 2d symmetric numpy array with its dimensionality matching the basis dimension" ) - current_leaf_scale = sigma_leaf + current_leaf_scale = sigma2_leaf else: raise ValueError( - "sigma_leaf must be either a scalar or a 2d symmetric numpy array" + "sigma2_leaf must be either a scalar or a 2d symmetric numpy array" ) else: - if sigma_leaf is None: + if sigma2_leaf is None: current_leaf_scale = np.array([[2.0 / num_trees_mean]]) - elif isinstance(sigma_leaf, float): - current_leaf_scale = np.array([[sigma_leaf]]) - elif isinstance(sigma_leaf, np.ndarray): - if sigma_leaf.ndim != 2: + elif isinstance(sigma2_leaf, float): + current_leaf_scale = np.array([[sigma2_leaf]]) + elif isinstance(sigma2_leaf, np.ndarray): + if sigma2_leaf.ndim != 2: raise ValueError( - "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" ) - if sigma_leaf.shape[0] != sigma_leaf.shape[1]: + if sigma2_leaf.shape[0] != sigma2_leaf.shape[1]: raise ValueError( - "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" ) - if sigma_leaf.shape[0] != 1: + if sigma2_leaf.shape[0] != 1: raise ValueError( - "sigma_leaf must be a 1x1 numpy array for this leaf model" + "sigma2_leaf must be a 1x1 numpy array for this leaf model" ) - current_leaf_scale = sigma_leaf + current_leaf_scale = sigma2_leaf else: raise ValueError( - "sigma_leaf must be either a scalar or a 2d numpy array" + "sigma2_leaf must be either a scalar or a 2d numpy array" ) else: # Standardize if requested @@ -827,7 +827,7 @@ def sample( # Compute initial value of root nodes in mean forest init_val_mean = np.squeeze(np.mean(resid_train)) - # Calibrate priors for global sigma^2 and sigma_leaf + # Calibrate priors for global sigma^2 and sigma2_leaf if not sigma2_init: sigma2_init = 1.0 * np.var(resid_train) if not variance_forest_leaf_init: @@ -841,7 +841,7 @@ def sample( else b_leaf ) if self.has_basis: - if sigma_leaf is None: + if sigma2_leaf is None: current_leaf_scale = np.zeros( (self.num_basis, self.num_basis), dtype=float ) @@ -849,53 +849,53 @@ def sample( current_leaf_scale, np.squeeze(np.var(resid_train)) / num_trees_mean, ) - elif isinstance(sigma_leaf, float): + elif isinstance(sigma2_leaf, float): current_leaf_scale = np.zeros( (self.num_basis, self.num_basis), dtype=float ) - np.fill_diagonal(current_leaf_scale, sigma_leaf) - elif isinstance(sigma_leaf, np.ndarray): - if sigma_leaf.ndim != 2: + np.fill_diagonal(current_leaf_scale, sigma2_leaf) + elif isinstance(sigma2_leaf, np.ndarray): + if sigma2_leaf.ndim != 2: raise ValueError( - "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" ) - if sigma_leaf.shape[0] != sigma_leaf.shape[1]: + if sigma2_leaf.shape[0] != sigma2_leaf.shape[1]: raise ValueError( - "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" ) - if sigma_leaf.shape[0] != self.num_basis: + if sigma2_leaf.shape[0] != self.num_basis: raise ValueError( - "sigma_leaf must be a 2d symmetric numpy array with its dimensionality matching the basis dimension" + "sigma2_leaf must be a 2d symmetric numpy array with its dimensionality matching the basis dimension" ) - current_leaf_scale = sigma_leaf + current_leaf_scale = sigma2_leaf else: raise ValueError( - "sigma_leaf must be either a scalar or a 2d symmetric numpy array" + "sigma2_leaf must be either a scalar or a 2d symmetric numpy array" ) else: - if sigma_leaf is None: + if sigma2_leaf is None: current_leaf_scale = np.array( [[np.squeeze(np.var(resid_train)) / num_trees_mean]] ) - elif isinstance(sigma_leaf, float): - current_leaf_scale = np.array([[sigma_leaf]]) - elif isinstance(sigma_leaf, np.ndarray): - if sigma_leaf.ndim != 2: + elif isinstance(sigma2_leaf, float): + current_leaf_scale = np.array([[sigma2_leaf]]) + elif isinstance(sigma2_leaf, np.ndarray): + if sigma2_leaf.ndim != 2: raise ValueError( - "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" ) - if sigma_leaf.shape[0] != sigma_leaf.shape[1]: + if sigma2_leaf.shape[0] != sigma2_leaf.shape[1]: raise ValueError( - "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + "sigma2_leaf must be a 2d symmetric numpy array if provided in matrix form" ) - if sigma_leaf.shape[0] != 1: + if sigma2_leaf.shape[0] != 1: raise ValueError( - "sigma_leaf must be a 1x1 numpy array for this leaf model" + "sigma2_leaf must be a 1x1 numpy array for this leaf model" ) - current_leaf_scale = sigma_leaf + current_leaf_scale = sigma2_leaf else: raise ValueError( - "sigma_leaf must be either a scalar or a 2d numpy array" + "sigma2_leaf must be either a scalar or a 2d numpy array" ) else: current_leaf_scale = np.array([[1.0]]) @@ -987,11 +987,11 @@ def sample( if keep_burnin: num_retained_samples += num_burnin * num_chains self.num_samples = num_retained_samples - self.sample_sigma_global = sample_sigma_global - self.sample_sigma_leaf = sample_sigma_leaf - if sample_sigma_global: + self.sample_sigma2_global = sample_sigma2_global + self.sample_sigma2_leaf = sample_sigma2_leaf + if sample_sigma2_global: self.global_var_samples = np.empty(self.num_samples, dtype=np.float64) - if sample_sigma_leaf: + if sample_sigma2_leaf: self.leaf_scale_samples = np.empty(self.num_samples, dtype=np.float64) sample_counter = -1 @@ -1097,9 +1097,9 @@ def sample( active_forest_variance = Forest(num_trees_variance, 1, True, True) # Variance samplers - if self.sample_sigma_global: + if self.sample_sigma2_global: global_var_model = GlobalVarianceModel() - if self.sample_sigma_leaf: + if self.sample_sigma2_leaf: leaf_var_model = LeafVarianceModel() # Initialize the leaves of each tree in the mean forest @@ -1188,14 +1188,14 @@ def sample( ) # Sample variance parameters (if requested) - if self.sample_sigma_global: + if self.sample_sigma2_global: current_sigma2 = global_var_model.sample_one_iteration( residual_train, cpp_rng, a_global, b_global ) global_model_config.update_global_error_variance(current_sigma2) if keep_sample: self.global_var_samples[sample_counter] = current_sigma2 - if self.sample_sigma_leaf: + if self.sample_sigma2_leaf: current_leaf_scale[0, 0] = leaf_var_model.sample_one_iteration( active_forest_mean, cpp_rng, a_leaf, b_leaf ) @@ -1240,7 +1240,7 @@ def sample( residual_train, False, ) - if sample_sigma_global: + if sample_sigma2_global: current_sigma2 = self.global_var_samples[forest_ind] elif has_prev_model: if self.include_mean_forest: @@ -1254,7 +1254,7 @@ def sample( residual_train, True, ) - if sample_sigma_leaf and previous_leaf_var_samples is not None: + if sample_sigma2_leaf and previous_leaf_var_samples is not None: leaf_scale_double = previous_leaf_var_samples[ previous_model_warmstart_sample_num ] @@ -1275,7 +1275,7 @@ def sample( ) # if self.has_rfx: # pass - if self.sample_sigma_global: + if self.sample_sigma2_global: current_sigma2 = previous_global_var_samples[ previous_model_warmstart_sample_num ] @@ -1380,14 +1380,14 @@ def sample( ) # Sample variance parameters (if requested) - if self.sample_sigma_global: + if self.sample_sigma2_global: current_sigma2 = global_var_model.sample_one_iteration( residual_train, cpp_rng, a_global, b_global ) global_model_config.update_global_error_variance(current_sigma2) if keep_sample: self.global_var_samples[sample_counter] = current_sigma2 - if self.sample_sigma_leaf: + if self.sample_sigma2_leaf: current_leaf_scale[0, 0] = leaf_var_model.sample_one_iteration( active_forest_mean, cpp_rng, a_leaf, b_leaf ) @@ -1423,17 +1423,17 @@ def sample( self.forest_container_variance.delete_sample(0) if self.has_rfx: self.rfx_container.delete_sample(0) - if self.sample_sigma_global: + if self.sample_sigma2_global: self.global_var_samples = self.global_var_samples[num_gfr:] - if self.sample_sigma_leaf: + if self.sample_sigma2_leaf: self.leaf_scale_samples = self.leaf_scale_samples[num_gfr:] self.num_samples -= num_gfr # Store predictions - if self.sample_sigma_global: + if self.sample_sigma2_global: self.global_var_samples = self.global_var_samples * self.y_std * self.y_std - if self.sample_sigma_leaf: + if self.sample_sigma2_leaf: self.leaf_scale_samples = self.leaf_scale_samples if self.include_mean_forest: @@ -1468,36 +1468,36 @@ def sample( self.y_hat_test = rfx_preds_test if self.include_variance_forest: - sigma_x_train_raw = ( + sigma2_x_train_raw = ( self.forest_container_variance.forest_container_cpp.Predict( forest_dataset_train.dataset_cpp ) ) - if self.sample_sigma_global: - self.sigma2_x_train = sigma_x_train_raw + if self.sample_sigma2_global: + self.sigma2_x_train = sigma2_x_train_raw for i in range(self.num_samples): self.sigma2_x_train[:, i] = ( - sigma_x_train_raw[:, i] * self.global_var_samples[i] + sigma2_x_train_raw[:, i] * self.global_var_samples[i] ) else: self.sigma2_x_train = ( - sigma_x_train_raw * self.sigma2_init * self.y_std * self.y_std + sigma2_x_train_raw * self.sigma2_init * self.y_std * self.y_std ) if self.has_test: - sigma_x_test_raw = ( + sigma2_x_test_raw = ( self.forest_container_variance.forest_container_cpp.Predict( forest_dataset_test.dataset_cpp ) ) - if self.sample_sigma_global: - self.sigma2_x_test = sigma_x_test_raw + if self.sample_sigma2_global: + self.sigma2_x_test = sigma2_x_test_raw for i in range(self.num_samples): self.sigma2_x_test[:, i] = ( - sigma_x_test_raw[:, i] * self.global_var_samples[i] + sigma2_x_test_raw[:, i] * self.global_var_samples[i] ) else: self.sigma2_x_test = ( - sigma_x_test_raw * self.sigma2_init * self.y_std * self.y_std + sigma2_x_test_raw * self.sigma2_init * self.y_std * self.y_std ) def predict( @@ -1606,7 +1606,7 @@ def predict( pred_dataset.dataset_cpp ) ) - if self.sample_sigma_global: + if self.sample_sigma2_global: variance_pred = variance_pred_raw for i in range(self.num_samples): variance_pred[:, i] = np.sqrt( @@ -1795,7 +1795,7 @@ def predict_variance(self, covariates: np.array) -> np.array: variance_pred_raw = self.forest_container_variance.forest_container_cpp.Predict( pred_dataset.dataset_cpp ) - if self.sample_sigma_global: + if self.sample_sigma2_global: variance_pred = variance_pred_raw for i in range(self.num_samples): variance_pred[:, i] = ( @@ -1843,8 +1843,8 @@ def to_json(self) -> str: bart_json.add_scalar("outcome_mean", self.y_bar) bart_json.add_boolean("standardize", self.standardize) bart_json.add_scalar("sigma2_init", self.sigma2_init) - bart_json.add_boolean("sample_sigma_global", self.sample_sigma_global) - bart_json.add_boolean("sample_sigma_leaf", self.sample_sigma_leaf) + bart_json.add_boolean("sample_sigma2_global", self.sample_sigma2_global) + bart_json.add_boolean("sample_sigma2_leaf", self.sample_sigma2_leaf) bart_json.add_boolean("include_mean_forest", self.include_mean_forest) bart_json.add_boolean("include_variance_forest", self.include_variance_forest) bart_json.add_boolean("has_rfx", self.has_rfx) @@ -1857,11 +1857,11 @@ def to_json(self) -> str: bart_json.add_boolean("probit_outcome_model", self.probit_outcome_model) # Add parameter samples - if self.sample_sigma_global: + if self.sample_sigma2_global: bart_json.add_numeric_vector( "sigma2_global_samples", self.global_var_samples, "parameters" ) - if self.sample_sigma_leaf: + if self.sample_sigma2_leaf: bart_json.add_numeric_vector( "sigma2_leaf_samples", self.leaf_scale_samples, "parameters" ) @@ -1918,8 +1918,8 @@ def from_json(self, json_string: str) -> None: self.y_bar = bart_json.get_scalar("outcome_mean") self.standardize = bart_json.get_boolean("standardize") self.sigma2_init = bart_json.get_scalar("sigma2_init") - self.sample_sigma_global = bart_json.get_boolean("sample_sigma_global") - self.sample_sigma_leaf = bart_json.get_boolean("sample_sigma_leaf") + self.sample_sigma2_global = bart_json.get_boolean("sample_sigma2_global") + self.sample_sigma2_leaf = bart_json.get_boolean("sample_sigma2_leaf") self.num_gfr = bart_json.get_integer("num_gfr") self.num_burnin = bart_json.get_integer("num_burnin") self.num_mcmc = bart_json.get_integer("num_mcmc") @@ -1929,11 +1929,11 @@ def from_json(self, json_string: str) -> None: self.probit_outcome_model = bart_json.get_boolean("probit_outcome_model") # Unpack parameter samples - if self.sample_sigma_global: + if self.sample_sigma2_global: self.global_var_samples = bart_json.get_numeric_vector( "sigma2_global_samples", "parameters" ) - if self.sample_sigma_leaf: + if self.sample_sigma2_leaf: self.leaf_scale_samples = bart_json.get_numeric_vector( "sigma2_leaf_samples", "parameters" ) @@ -2025,10 +2025,10 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.y_bar = json_object_default.get_scalar("outcome_mean") self.standardize = json_object_default.get_boolean("standardize") self.sigma2_init = json_object_default.get_scalar("sigma2_init") - self.sample_sigma_global = json_object_default.get_boolean( - "sample_sigma_global" + self.sample_sigma2_global = json_object_default.get_boolean( + "sample_sigma2_global" ) - self.sample_sigma_leaf = json_object_default.get_boolean("sample_sigma_leaf") + self.sample_sigma2_leaf = json_object_default.get_boolean("sample_sigma2_leaf") self.num_gfr = json_object_default.get_integer("num_gfr") self.num_burnin = json_object_default.get_integer("num_burnin") self.num_mcmc = json_object_default.get_integer("num_mcmc") @@ -2040,7 +2040,7 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: ) # Unpack parameter samples - if self.sample_sigma_global: + if self.sample_sigma2_global: for i in range(len(json_object_list)): if i == 0: self.global_var_samples = json_object_list[i].get_numeric_vector( @@ -2054,7 +2054,7 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: (self.global_var_samples, global_var_samples) ) - if self.sample_sigma_leaf: + if self.sample_sigma2_leaf: for i in range(len(json_object_list)): if i == 0: self.leaf_scale_samples = json_object_list[i].get_numeric_vector( diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 6ee743a5..e97e1a88 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -288,7 +288,7 @@ def sample( # 1. General parameters cutpoint_grid_size = general_params_updated["cutpoint_grid_size"] self.standardize = general_params_updated["standardize"] - sample_sigma_global = general_params_updated["sample_sigma2_global"] + sample_sigma2_global = general_params_updated["sample_sigma2_global"] sigma2_init = general_params_updated["sigma2_global_init"] a_global = general_params_updated["sigma2_global_shape"] b_global = general_params_updated["sigma2_global_scale"] @@ -310,8 +310,8 @@ def sample( beta_mu = prognostic_forest_params_updated["beta"] min_samples_leaf_mu = prognostic_forest_params_updated["min_samples_leaf"] max_depth_mu = prognostic_forest_params_updated["max_depth"] - sample_sigma_leaf_mu = prognostic_forest_params_updated["sample_sigma2_leaf"] - sigma_leaf_mu = prognostic_forest_params_updated["sigma2_leaf_init"] + sample_sigma2_leaf_mu = prognostic_forest_params_updated["sample_sigma2_leaf"] + sigma2_leaf_mu = prognostic_forest_params_updated["sigma2_leaf_init"] a_leaf_mu = prognostic_forest_params_updated["sigma2_leaf_shape"] b_leaf_mu = prognostic_forest_params_updated["sigma2_leaf_scale"] keep_vars_mu = prognostic_forest_params_updated["keep_vars"] @@ -325,10 +325,10 @@ def sample( "min_samples_leaf" ] max_depth_tau = treatment_effect_forest_params_updated["max_depth"] - sample_sigma_leaf_tau = treatment_effect_forest_params_updated[ + sample_sigma2_leaf_tau = treatment_effect_forest_params_updated[ "sample_sigma2_leaf" ] - sigma_leaf_tau = treatment_effect_forest_params_updated["sigma2_leaf_init"] + sigma2_leaf_tau = treatment_effect_forest_params_updated["sigma2_leaf_init"] a_leaf_tau = treatment_effect_forest_params_updated["sigma2_leaf_shape"] b_leaf_tau = treatment_effect_forest_params_updated["sigma2_leaf_scale"] delta_max = treatment_effect_forest_params_updated["delta_max"] @@ -494,29 +494,29 @@ def sample( leaf_model_variance = 3 # Check parameters - if sigma_leaf_tau is not None: - if not isinstance(sigma_leaf_tau, float) and not isinstance( - sigma_leaf_tau, np.ndarray + if sigma2_leaf_tau is not None: + if not isinstance(sigma2_leaf_tau, float) and not isinstance( + sigma2_leaf_tau, np.ndarray ): - raise ValueError("sigma_leaf_tau must be a float or numpy array") + raise ValueError("sigma2_leaf_tau must be a float or numpy array") if self.multivariate_treatment: - if sigma_leaf_tau is not None: - if isinstance(sigma_leaf_tau, np.ndarray): - if sigma_leaf_tau.ndim != 2: + if sigma2_leaf_tau is not None: + if isinstance(sigma2_leaf_tau, np.ndarray): + if sigma2_leaf_tau.ndim != 2: raise ValueError( - "sigma_leaf_tau must be 2-dimensional if passed as a np.array" + "sigma2_leaf_tau must be 2-dimensional if passed as a np.array" ) if ( - self.treatment_dim != sigma_leaf_tau.shape[0] - or self.treatment_dim != sigma_leaf_tau.shape[1] + self.treatment_dim != sigma2_leaf_tau.shape[0] + or self.treatment_dim != sigma2_leaf_tau.shape[1] ): raise ValueError( - "sigma_leaf_tau must have the same number of rows and columns, which must match Z_train.shape[1]" + "sigma2_leaf_tau must have the same number of rows and columns, which must match Z_train.shape[1]" ) - if sigma_leaf_mu is not None: - sigma_leaf_mu = check_scalar( - x=sigma_leaf_mu, - name="sigma_leaf_mu", + if sigma2_leaf_mu is not None: + sigma2_leaf_mu = check_scalar( + x=sigma2_leaf_mu, + name="sigma2_leaf_mu", target_type=float, min_val=0.0, max_val=None, @@ -708,12 +708,12 @@ def sample( max_val=None, include_boundaries="neither", ) - if sample_sigma_leaf_mu is not None: - if not isinstance(sample_sigma_leaf_mu, bool): - raise ValueError("sample_sigma_leaf_mu must be a bool") - if sample_sigma_leaf_tau is not None: - if not isinstance(sample_sigma_leaf_tau, bool): - raise ValueError("sample_sigma_leaf_tau must be a bool") + if sample_sigma2_leaf_mu is not None: + if not isinstance(sample_sigma2_leaf_mu, bool): + raise ValueError("sample_sigma2_leaf_mu must be a bool") + if sample_sigma2_leaf_tau is not None: + if not isinstance(sample_sigma2_leaf_tau, bool): + raise ValueError("sample_sigma2_leaf_tau must be a bool") if propensity_covariate is not None: if propensity_covariate not in ["mu", "tau", "both", "none"]: raise ValueError( @@ -1083,9 +1083,9 @@ def sample( if adaptive_coding and self.multivariate_treatment: self.adaptive_coding = False - # Sampling sigma_leaf_tau will be ignored for multivariate treatments - if sample_sigma_leaf_tau and self.multivariate_treatment: - sample_sigma_leaf_tau = False + # Sampling sigma2_leaf_tau will be ignored for multivariate treatments + if sample_sigma2_leaf_tau and self.multivariate_treatment: + sample_sigma2_leaf_tau = False # Check if user has provided propensities that are needed in the model if pi_train is None and propensity_covariate != "none": @@ -1138,11 +1138,11 @@ def sample( raise ValueError( "We do not support heteroskedasticity with a probit link" ) - if sample_sigma_global: + if sample_sigma2_global: warnings.warn( "Global error variance will not be sampled with a probit link as it is fixed at 1" ) - sample_sigma_global = False + sample_sigma2_global = False # Handle standardization, prior calibration, and initialization of forest # differently for binary and continuous outcomes @@ -1173,15 +1173,15 @@ def sample( if b_leaf_tau is None else b_leaf_tau ) - sigma_leaf_mu = ( + sigma2_leaf_mu = ( 1 / num_trees_mu - if sigma_leaf_mu is None - else sigma_leaf_mu + if sigma2_leaf_mu is None + else sigma2_leaf_mu ) - if isinstance(sigma_leaf_mu, float): - current_leaf_scale_mu = np.array([[sigma_leaf_mu]]) + if isinstance(sigma2_leaf_mu, float): + current_leaf_scale_mu = np.array([[sigma2_leaf_mu]]) else: - raise ValueError("sigma_leaf_mu must be a scalar") + raise ValueError("sigma2_leaf_mu must be a scalar") # Calibrate prior so that P(abs(tau(X)) < delta_max / dnorm(0)) = p # Use p = 0.9 as an internal default rather than adding another # user-facing "parameter" of the binary outcome BCF prior. @@ -1189,38 +1189,38 @@ def sample( # treatment_effect_forest_params. p = 0.6827 q_quantile = norm.ppf((p + 1) / 2.0) - sigma_leaf_tau = ( + sigma2_leaf_tau = ( ((delta_max / (q_quantile*norm.pdf(0)))**2) / num_trees_tau - if sigma_leaf_tau is None - else sigma_leaf_tau + if sigma2_leaf_tau is None + else sigma2_leaf_tau ) if self.multivariate_treatment: - if not isinstance(sigma_leaf_tau, np.ndarray): - sigma_leaf_tau = np.diagflat( - np.repeat(sigma_leaf_tau, self.treatment_dim) + if not isinstance(sigma2_leaf_tau, np.ndarray): + sigma2_leaf_tau = np.diagflat( + np.repeat(sigma2_leaf_tau, self.treatment_dim) ) - if isinstance(sigma_leaf_tau, float): + if isinstance(sigma2_leaf_tau, float): if Z_train.shape[1] > 1: current_leaf_scale_tau = np.zeros((Z_train.shape[1], Z_train.shape[1]), dtype=float) - np.fill_diagonal(current_leaf_scale_tau, sigma_leaf_tau) + np.fill_diagonal(current_leaf_scale_tau, sigma2_leaf_tau) else: - current_leaf_scale_tau = np.array([[sigma_leaf_tau]]) - elif isinstance(sigma_leaf_tau, np.ndarray): - if sigma_leaf_tau.ndim != 2: + current_leaf_scale_tau = np.array([[sigma2_leaf_tau]]) + elif isinstance(sigma2_leaf_tau, np.ndarray): + if sigma2_leaf_tau.ndim != 2: raise ValueError( - "sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" + "sigma2_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" ) - if sigma_leaf_tau.shape[0] != sigma_leaf_tau.shape[1]: + if sigma2_leaf_tau.shape[0] != sigma2_leaf_tau.shape[1]: raise ValueError( - "sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" + "sigma2_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" ) - if sigma_leaf_tau.shape[0] != Z_train.shape[1]: + if sigma2_leaf_tau.shape[0] != Z_train.shape[1]: raise ValueError( - "sigma_leaf_tau must be a 2d numpy array with dimension matching that of the treatment vector" + "sigma2_leaf_tau must be a 2d numpy array with dimension matching that of the treatment vector" ) - current_leaf_scale_tau = sigma_leaf_tau + current_leaf_scale_tau = sigma2_leaf_tau else: - raise ValueError("sigma_leaf_tau must be a scalar or a 2d numpy array") + raise ValueError("sigma2_leaf_tau must be a scalar or a 2d numpy array") else: # Standardize if requested if self.standardize: @@ -1236,7 +1236,7 @@ def sample( # Compute initial value of root nodes in mean forest init_mu = np.squeeze(np.mean(resid_train)) - # Calibrate priors for global sigma^2 and sigma_leaf + # Calibrate priors for global sigma^2 and sigma2_leaf if not sigma2_init: sigma2_init = 1.0 * np.var(resid_train) if not variance_forest_leaf_init: @@ -1253,47 +1253,47 @@ def sample( if b_leaf_tau is None else b_leaf_tau ) - sigma_leaf_mu = ( + sigma2_leaf_mu = ( np.squeeze(2 * np.var(resid_train)) / num_trees_mu - if sigma_leaf_mu is None - else sigma_leaf_mu + if sigma2_leaf_mu is None + else sigma2_leaf_mu ) - if isinstance(sigma_leaf_mu, float): - current_leaf_scale_mu = np.array([[sigma_leaf_mu]]) + if isinstance(sigma2_leaf_mu, float): + current_leaf_scale_mu = np.array([[sigma2_leaf_mu]]) else: - raise ValueError("sigma_leaf_mu must be a scalar") - sigma_leaf_tau = ( + raise ValueError("sigma2_leaf_mu must be a scalar") + sigma2_leaf_tau = ( np.squeeze(np.var(resid_train)) / (num_trees_tau) - if sigma_leaf_tau is None - else sigma_leaf_tau + if sigma2_leaf_tau is None + else sigma2_leaf_tau ) if self.multivariate_treatment: - if not isinstance(sigma_leaf_tau, np.ndarray): - sigma_leaf_tau = np.diagflat( - np.repeat(sigma_leaf_tau, self.treatment_dim) + if not isinstance(sigma2_leaf_tau, np.ndarray): + sigma2_leaf_tau = np.diagflat( + np.repeat(sigma2_leaf_tau, self.treatment_dim) ) - if isinstance(sigma_leaf_tau, float): + if isinstance(sigma2_leaf_tau, float): if Z_train.shape[1] > 1: current_leaf_scale_tau = np.zeros((Z_train.shape[1], Z_train.shape[1]), dtype=float) - np.fill_diagonal(current_leaf_scale_tau, sigma_leaf_tau) + np.fill_diagonal(current_leaf_scale_tau, sigma2_leaf_tau) else: - current_leaf_scale_tau = np.array([[sigma_leaf_tau]]) - elif isinstance(sigma_leaf_tau, np.ndarray): - if sigma_leaf_tau.ndim != 2: + current_leaf_scale_tau = np.array([[sigma2_leaf_tau]]) + elif isinstance(sigma2_leaf_tau, np.ndarray): + if sigma2_leaf_tau.ndim != 2: raise ValueError( - "sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" + "sigma2_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" ) - if sigma_leaf_tau.shape[0] != sigma_leaf_tau.shape[1]: + if sigma2_leaf_tau.shape[0] != sigma2_leaf_tau.shape[1]: raise ValueError( - "sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" + "sigma2_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" ) - if sigma_leaf_tau.shape[0] != Z_train.shape[1]: + if sigma2_leaf_tau.shape[0] != Z_train.shape[1]: raise ValueError( - "sigma_leaf_tau must be a 2d numpy array with dimension matching that of the treatment vector" + "sigma2_leaf_tau must be a 2d numpy array with dimension matching that of the treatment vector" ) - current_leaf_scale_tau = sigma_leaf_tau + current_leaf_scale_tau = sigma2_leaf_tau else: - raise ValueError("sigma_leaf_tau must be a scalar or a 2d numpy array") + raise ValueError("sigma2_leaf_tau must be a scalar or a 2d numpy array") if self.include_variance_forest: if not a_forest: a_forest = num_trees_variance / a_0**2 + 0.5 @@ -1454,14 +1454,14 @@ def sample( if keep_burnin: num_retained_samples += num_burnin self.num_samples = num_retained_samples - self.sample_sigma_global = sample_sigma_global - self.sample_sigma_leaf_mu = sample_sigma_leaf_mu - self.sample_sigma_leaf_tau = sample_sigma_leaf_tau - if sample_sigma_global: + self.sample_sigma2_global = sample_sigma2_global + self.sample_sigma2_leaf_mu = sample_sigma2_leaf_mu + self.sample_sigma2_leaf_tau = sample_sigma2_leaf_tau + if sample_sigma2_global: self.global_var_samples = np.empty(self.num_samples, dtype=np.float64) - if sample_sigma_leaf_mu: + if sample_sigma2_leaf_mu: self.leaf_scale_mu_samples = np.empty(self.num_samples, dtype=np.float64) - if sample_sigma_leaf_tau: + if sample_sigma2_leaf_tau: self.leaf_scale_tau_samples = np.empty(self.num_samples, dtype=np.float64) sample_counter = -1 @@ -1582,11 +1582,11 @@ def sample( active_forest_variance = Forest(num_trees_variance, 1, True, True) # Variance samplers - if self.sample_sigma_global: + if self.sample_sigma2_global: global_var_model = GlobalVarianceModel() - if self.sample_sigma_leaf_mu: + if self.sample_sigma2_leaf_mu: leaf_var_model_mu = LeafVarianceModel() - if self.sample_sigma_leaf_tau: + if self.sample_sigma2_leaf_tau: leaf_var_model_tau = LeafVarianceModel() # Initialize the leaves of each tree in the prognostic forest @@ -1673,12 +1673,12 @@ def sample( ) # Sample variance parameters (if requested) - if self.sample_sigma_global: + if self.sample_sigma2_global: current_sigma2 = global_var_model.sample_one_iteration( residual_train, cpp_rng, a_global, b_global ) global_model_config.update_global_error_variance(current_sigma2) - if self.sample_sigma_leaf_mu: + if self.sample_sigma2_leaf_mu: current_leaf_scale_mu[0, 0] = ( leaf_var_model_mu.sample_one_iteration( active_forest_mu, cpp_rng, a_leaf_mu, b_leaf_mu @@ -1763,14 +1763,14 @@ def sample( ) # Sample variance parameters (if requested) - if self.sample_sigma_global: + if self.sample_sigma2_global: current_sigma2 = global_var_model.sample_one_iteration( residual_train, cpp_rng, a_global, b_global ) global_model_config.update_global_error_variance(current_sigma2) if keep_sample: self.global_var_samples[sample_counter] = current_sigma2 - if self.sample_sigma_leaf_tau: + if self.sample_sigma2_leaf_tau: current_leaf_scale_tau[0, 0] = ( leaf_var_model_tau.sample_one_iteration( active_forest_tau, cpp_rng, a_leaf_tau, b_leaf_tau @@ -1854,12 +1854,12 @@ def sample( ) # Sample variance parameters (if requested) - if self.sample_sigma_global: + if self.sample_sigma2_global: current_sigma2 = global_var_model.sample_one_iteration( residual_train, cpp_rng, a_global, b_global ) global_model_config.update_global_error_variance(current_sigma2) - if self.sample_sigma_leaf_mu: + if self.sample_sigma2_leaf_mu: current_leaf_scale_mu[0, 0] = ( leaf_var_model_mu.sample_one_iteration( active_forest_mu, cpp_rng, a_leaf_mu, b_leaf_mu @@ -1944,14 +1944,14 @@ def sample( ) # Sample variance parameters (if requested) - if self.sample_sigma_global: + if self.sample_sigma2_global: current_sigma2 = global_var_model.sample_one_iteration( residual_train, cpp_rng, a_global, b_global ) global_model_config.update_global_error_variance(current_sigma2) if keep_sample: self.global_var_samples[sample_counter] = current_sigma2 - if self.sample_sigma_leaf_tau: + if self.sample_sigma2_leaf_tau: current_leaf_scale_tau[0, 0] = ( leaf_var_model_tau.sample_one_iteration( active_forest_tau, cpp_rng, a_leaf_tau, b_leaf_tau @@ -1992,11 +1992,11 @@ def sample( if self.adaptive_coding: self.b1_samples = self.b1_samples[num_gfr:] self.b0_samples = self.b0_samples[num_gfr:] - if self.sample_sigma_global: + if self.sample_sigma2_global: self.global_var_samples = self.global_var_samples[num_gfr:] - if self.sample_sigma_leaf_mu: + if self.sample_sigma2_leaf_mu: self.leaf_scale_mu_samples = self.leaf_scale_mu_samples[num_gfr:] - if self.sample_sigma_leaf_tau: + if self.sample_sigma2_leaf_tau: self.leaf_scale_tau_samples = self.leaf_scale_tau_samples[num_gfr:] self.num_samples -= num_gfr @@ -2066,7 +2066,7 @@ def sample( forest_dataset_train.dataset_cpp ) ) - if self.sample_sigma_global: + if self.sample_sigma2_global: self.sigma2_x_train = sigma2_x_train_raw for i in range(self.num_samples): self.sigma2_x_train[:, i] = ( @@ -2082,7 +2082,7 @@ def sample( forest_dataset_test.dataset_cpp ) ) - if self.sample_sigma_global: + if self.sample_sigma2_global: self.sigma2_x_test = sigma2_x_test_raw for i in range(self.num_samples): self.sigma2_x_test[:, i] = ( @@ -2093,13 +2093,13 @@ def sample( sigma2_x_test_raw * self.sigma2_init * self.y_std * self.y_std ) - if self.sample_sigma_global: + if self.sample_sigma2_global: self.global_var_samples = self.global_var_samples * self.y_std * self.y_std - if self.sample_sigma_leaf_mu: + if self.sample_sigma2_leaf_mu: self.leaf_scale_mu_samples = self.leaf_scale_mu_samples - if self.sample_sigma_leaf_tau: + if self.sample_sigma2_leaf_tau: self.leaf_scale_tau_samples = self.leaf_scale_tau_samples if self.adaptive_coding: @@ -2290,7 +2290,7 @@ def predict_variance( variance_pred_raw = self.forest_container_variance.forest_container_cpp.Predict( pred_dataset.dataset_cpp ) - if self.sample_sigma_global: + if self.sample_sigma2_global: variance_pred = variance_pred_raw for i in range(self.num_samples): variance_pred[:, i] = ( @@ -2442,7 +2442,7 @@ def predict( sigma2_x_raw = self.forest_container_variance.forest_container_cpp.Predict( forest_dataset_test.dataset_cpp ) - if self.sample_sigma_global: + if self.sample_sigma2_global: sigma2_x = sigma2_x_raw for i in range(self.num_samples): sigma2_x[:, i] = sigma2_x_raw[:, i] * self.global_var_samples[i] @@ -2494,9 +2494,9 @@ def to_json(self) -> str: bcf_json.add_scalar("outcome_mean", self.y_bar) bcf_json.add_boolean("standardize", self.standardize) bcf_json.add_scalar("sigma2_init", self.sigma2_init) - bcf_json.add_boolean("sample_sigma_global", self.sample_sigma_global) - bcf_json.add_boolean("sample_sigma_leaf_mu", self.sample_sigma_leaf_mu) - bcf_json.add_boolean("sample_sigma_leaf_tau", self.sample_sigma_leaf_tau) + bcf_json.add_boolean("sample_sigma2_global", self.sample_sigma2_global) + bcf_json.add_boolean("sample_sigma2_leaf_mu", self.sample_sigma2_leaf_mu) + bcf_json.add_boolean("sample_sigma2_leaf_tau", self.sample_sigma2_leaf_tau) bcf_json.add_boolean("include_variance_forest", self.include_variance_forest) bcf_json.add_boolean("has_rfx", self.has_rfx) bcf_json.add_scalar("num_gfr", self.num_gfr) @@ -2513,15 +2513,15 @@ def to_json(self) -> str: ) # Add parameter samples - if self.sample_sigma_global: + if self.sample_sigma2_global: bcf_json.add_numeric_vector( "sigma2_global_samples", self.global_var_samples, "parameters" ) - if self.sample_sigma_leaf_mu: + if self.sample_sigma2_leaf_mu: bcf_json.add_numeric_vector( "sigma2_leaf_mu_samples", self.leaf_scale_mu_samples, "parameters" ) - if self.sample_sigma_leaf_tau: + if self.sample_sigma2_leaf_tau: bcf_json.add_numeric_vector( "sigma2_leaf_tau_samples", self.leaf_scale_tau_samples, "parameters" ) @@ -2583,9 +2583,9 @@ def from_json(self, json_string: str) -> None: self.y_bar = bcf_json.get_scalar("outcome_mean") self.standardize = bcf_json.get_boolean("standardize") self.sigma2_init = bcf_json.get_scalar("sigma2_init") - self.sample_sigma_global = bcf_json.get_boolean("sample_sigma_global") - self.sample_sigma_leaf_mu = bcf_json.get_boolean("sample_sigma_leaf_mu") - self.sample_sigma_leaf_tau = bcf_json.get_boolean("sample_sigma_leaf_tau") + self.sample_sigma2_global = bcf_json.get_boolean("sample_sigma2_global") + self.sample_sigma2_leaf_mu = bcf_json.get_boolean("sample_sigma2_leaf_mu") + self.sample_sigma2_leaf_tau = bcf_json.get_boolean("sample_sigma2_leaf_tau") self.num_gfr = int(bcf_json.get_scalar("num_gfr")) self.num_burnin = int(bcf_json.get_scalar("num_burnin")) self.num_mcmc = int(bcf_json.get_scalar("num_mcmc")) @@ -2600,15 +2600,15 @@ def from_json(self, json_string: str) -> None: ) # Unpack parameter samples - if self.sample_sigma_global: + if self.sample_sigma2_global: self.global_var_samples = bcf_json.get_numeric_vector( "sigma2_global_samples", "parameters" ) - if self.sample_sigma_leaf_mu: + if self.sample_sigma2_leaf_mu: self.leaf_scale_mu_samples = bcf_json.get_numeric_vector( "sigma2_leaf_mu_samples", "parameters" ) - if self.sample_sigma_leaf_tau: + if self.sample_sigma2_leaf_tau: self.leaf_scale_tau_samples = bcf_json.get_numeric_vector( "sigma2_leaf_tau_samples", "parameters" ) @@ -2704,14 +2704,14 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.y_bar = json_object_default.get_scalar("outcome_mean") self.standardize = json_object_default.get_boolean("standardize") self.sigma2_init = json_object_default.get_scalar("sigma2_init") - self.sample_sigma_global = json_object_default.get_boolean( - "sample_sigma_global" + self.sample_sigma2_global = json_object_default.get_boolean( + "sample_sigma2_global" ) - self.sample_sigma_leaf_mu = json_object_default.get_boolean( - "sample_sigma_leaf_mu" + self.sample_sigma2_leaf_mu = json_object_default.get_boolean( + "sample_sigma2_leaf_mu" ) - self.sample_sigma_leaf_tau = json_object_default.get_boolean( - "sample_sigma_leaf_tau" + self.sample_sigma2_leaf_tau = json_object_default.get_boolean( + "sample_sigma2_leaf_tau" ) self.num_gfr = json_object_default.get_scalar("num_gfr") self.num_burnin = json_object_default.get_scalar("num_burnin") @@ -2726,7 +2726,7 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: ) # Unpack parameter samples - if self.sample_sigma_global: + if self.sample_sigma2_global: for i in range(len(json_object_list)): if i == 0: self.global_var_samples = json_object_list[i].get_numeric_vector( @@ -2740,7 +2740,7 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: (self.global_var_samples, global_var_samples) ) - if self.sample_sigma_leaf_mu: + if self.sample_sigma2_leaf_mu: for i in range(len(json_object_list)): if i == 0: self.leaf_scale_mu_samples = json_object_list[i].get_numeric_vector( @@ -2754,18 +2754,18 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: (self.leaf_scale_mu_samples, leaf_scale_mu_samples) ) - if self.sample_sigma_leaf_tau: + if self.sample_sigma2_leaf_tau: for i in range(len(json_object_list)): if i == 0: - self.sample_sigma_leaf_tau = json_object_list[i].get_numeric_vector( + self.sample_sigma2_leaf_tau = json_object_list[i].get_numeric_vector( "sigma2_leaf_tau_samples", "parameters" ) else: - sample_sigma_leaf_tau = json_object_list[i].get_numeric_vector( + sample_sigma2_leaf_tau = json_object_list[i].get_numeric_vector( "sigma2_leaf_tau_samples", "parameters" ) - self.sample_sigma_leaf_tau = np.concatenate( - (self.sample_sigma_leaf_tau, sample_sigma_leaf_tau) + self.sample_sigma2_leaf_tau = np.concatenate( + (self.sample_sigma2_leaf_tau, sample_sigma2_leaf_tau) ) # Unpack internal propensity model diff --git a/vignettes/CausalInference.Rmd b/vignettes/CausalInference.Rmd index 2e0dbf89..50f08069 100644 --- a/vignettes/CausalInference.Rmd +++ b/vignettes/CausalInference.Rmd @@ -134,9 +134,9 @@ plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") abline(0,1,col="red",lty=3,lwd=3) sigma_observed <- var(y-E_XZ) -plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)), - max(c(bcf_model_warmstart$sigma2_samples, sigma_observed))) -plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds, +plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_global_samples, sigma_observed)), + max(c(bcf_model_warmstart$sigma2_global_samples, sigma_observed))) +plot(bcf_model_warmstart$sigma2_global_samples, ylim = plot_bounds, ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter") abline(h = sigma_observed, lty=3, lwd = 3, col = "blue") ``` @@ -184,9 +184,9 @@ plot(rowMeans(bcf_model_root$tau_hat_test), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") abline(0,1,col="red",lty=3,lwd=3) sigma_observed <- var(y-E_XZ) -plot_bounds <- c(min(c(bcf_model_root$sigma2_samples, sigma_observed)), - max(c(bcf_model_root$sigma2_samples, sigma_observed))) -plot(bcf_model_root$sigma2_samples, ylim = plot_bounds, +plot_bounds <- c(min(c(bcf_model_root$sigma2_global_samples, sigma_observed)), + max(c(bcf_model_root$sigma2_global_samples, sigma_observed))) +plot(bcf_model_root$sigma2_global_samples, ylim = plot_bounds, ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter") abline(h = sigma_observed, lty=3, lwd = 3, col = "blue") ``` @@ -303,9 +303,9 @@ plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") abline(0,1,col="red",lty=3,lwd=3) sigma_observed <- var(y-E_XZ) -plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)), - max(c(bcf_model_warmstart$sigma2_samples, sigma_observed))) -plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds, +plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_global_samples, sigma_observed)), + max(c(bcf_model_warmstart$sigma2_global_samples, sigma_observed))) +plot(bcf_model_warmstart$sigma2_global_samples, ylim = plot_bounds, ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter") abline(h = sigma_observed, lty=3, lwd = 3, col = "blue") ``` @@ -353,9 +353,9 @@ plot(rowMeans(bcf_model_root$tau_hat_test), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") abline(0,1,col="red",lty=3,lwd=3) sigma_observed <- var(y-E_XZ) -plot_bounds <- c(min(c(bcf_model_root$sigma2_samples, sigma_observed)), - max(c(bcf_model_root$sigma2_samples, sigma_observed))) -plot(bcf_model_root$sigma2_samples, ylim = plot_bounds, +plot_bounds <- c(min(c(bcf_model_root$sigma2_global_samples, sigma_observed)), + max(c(bcf_model_root$sigma2_global_samples, sigma_observed))) +plot(bcf_model_root$sigma2_global_samples, ylim = plot_bounds, ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter") abline(h = sigma_observed, lty=3, lwd = 3, col = "blue") ``` @@ -472,9 +472,9 @@ plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") abline(0,1,col="red",lty=3,lwd=3) sigma_observed <- var(y-E_XZ) -plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)), - max(c(bcf_model_warmstart$sigma2_samples, sigma_observed))) -plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds, +plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_global_samples, sigma_observed)), + max(c(bcf_model_warmstart$sigma2_global_samples, sigma_observed))) +plot(bcf_model_warmstart$sigma2_global_samples, ylim = plot_bounds, ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter") abline(h = sigma_observed, lty=3, lwd = 3, col = "blue") ``` @@ -522,9 +522,9 @@ plot(rowMeans(bcf_model_root$tau_hat_test), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") abline(0,1,col="red",lty=3,lwd=3) sigma_observed <- var(y-E_XZ) -plot_bounds <- c(min(c(bcf_model_root$sigma2_samples, sigma_observed)), - max(c(bcf_model_root$sigma2_samples, sigma_observed))) -plot(bcf_model_root$sigma2_samples, ylim = plot_bounds, +plot_bounds <- c(min(c(bcf_model_root$sigma2_global_samples, sigma_observed)), + max(c(bcf_model_root$sigma2_global_samples, sigma_observed))) +plot(bcf_model_root$sigma2_global_samples, ylim = plot_bounds, ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter") abline(h = sigma_observed, lty=3, lwd = 3, col = "blue") ``` @@ -639,9 +639,9 @@ plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") abline(0,1,col="red",lty=3,lwd=3) sigma_observed <- var(y-E_XZ) -plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)), - max(c(bcf_model_warmstart$sigma2_samples, sigma_observed))) -plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds, +plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_global_samples, sigma_observed)), + max(c(bcf_model_warmstart$sigma2_global_samples, sigma_observed))) +plot(bcf_model_warmstart$sigma2_global_samples, ylim = plot_bounds, ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter") abline(h = sigma_observed, lty=3, lwd = 3, col = "blue") ``` @@ -689,9 +689,9 @@ plot(rowMeans(bcf_model_root$tau_hat_test), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") abline(0,1,col="red",lty=3,lwd=3) sigma_observed <- var(y-E_XZ) -plot_bounds <- c(min(c(bcf_model_root$sigma2_samples, sigma_observed)), - max(c(bcf_model_root$sigma2_samples, sigma_observed))) -plot(bcf_model_root$sigma2_samples, ylim = plot_bounds, +plot_bounds <- c(min(c(bcf_model_root$sigma2_global_samples, sigma_observed)), + max(c(bcf_model_root$sigma2_global_samples, sigma_observed))) +plot(bcf_model_root$sigma2_global_samples, ylim = plot_bounds, ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter") abline(h = sigma_observed, lty=3, lwd = 3, col = "blue") ``` @@ -806,9 +806,9 @@ plot(rowMeans(bcf_model_warmstart$rfx_preds_test), rfx_term_test, xlab = "predicted", ylab = "actual", main = "Random effects terms") abline(0,1,col="red",lty=3,lwd=3) sigma_observed <- var(y-E_XZ-rfx_term) -plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)), - max(c(bcf_model_warmstart$sigma2_samples, sigma_observed))) -plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds, +plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_global_samples, sigma_observed)), + max(c(bcf_model_warmstart$sigma2_global_samples, sigma_observed))) +plot(bcf_model_warmstart$sigma2_global_samples, ylim = plot_bounds, ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter") abline(h = sigma_observed, lty=3, lwd = 3, col = "blue") ``` @@ -919,9 +919,9 @@ plot(rowMeans(bcf_model_mcmc$y_hat_test), y_test, xlab = "predicted", ylab = "actual", main = "Outcome") abline(0,1,col="red",lty=3,lwd=3) sigma_observed <- var(y-E_XZ) -plot_bounds <- c(min(c(bcf_model_mcmc$sigma2_samples, sigma_observed)), - max(c(bcf_model_mcmc$sigma2_samples, sigma_observed))) -plot(bcf_model_mcmc$sigma2_samples, ylim = plot_bounds, +plot_bounds <- c(min(c(bcf_model_mcmc$sigma2_global_samples, sigma_observed)), + max(c(bcf_model_mcmc$sigma2_global_samples, sigma_observed))) +plot(bcf_model_mcmc$sigma2_global_samples, ylim = plot_bounds, ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter") abline(h = sigma_observed, lty=3, lwd = 3, col = "blue") ``` @@ -981,9 +981,9 @@ plot(rowMeans(bcf_model_mcmc$y_hat_test), y_test, xlab = "predicted", ylab = "actual", main = "Outcome") abline(0,1,col="red",lty=3,lwd=3) sigma_observed <- var(y-E_XZ) -plot_bounds <- c(min(c(bcf_model_mcmc$sigma2_samples, sigma_observed)), - max(c(bcf_model_mcmc$sigma2_samples, sigma_observed))) -plot(bcf_model_mcmc$sigma2_samples, ylim = plot_bounds, +plot_bounds <- c(min(c(bcf_model_mcmc$sigma2_global_samples, sigma_observed)), + max(c(bcf_model_mcmc$sigma2_global_samples, sigma_observed))) +plot(bcf_model_mcmc$sigma2_global_samples, ylim = plot_bounds, ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter") abline(h = sigma_observed, lty=3, lwd = 3, col = "blue") ``` @@ -1043,9 +1043,9 @@ plot(rowMeans(bcf_model_warmstart$y_hat_test), y_test, xlab = "predicted", ylab = "actual", main = "Outcome") abline(0,1,col="red",lty=3,lwd=3) sigma_observed <- var(y-E_XZ) -plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)), - max(c(bcf_model_warmstart$sigma2_samples, sigma_observed))) -plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds, +plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_global_samples, sigma_observed)), + max(c(bcf_model_warmstart$sigma2_global_samples, sigma_observed))) +plot(bcf_model_warmstart$sigma2_global_samples, ylim = plot_bounds, ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter") abline(h = sigma_observed, lty=3, lwd = 3, col = "blue") ``` @@ -1105,9 +1105,9 @@ plot(rowMeans(bcf_model_warmstart$y_hat_test), y_test, xlab = "predicted", ylab = "actual", main = "Outcome") abline(0,1,col="red",lty=3,lwd=3) sigma_observed <- var(y-E_XZ) -plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)), - max(c(bcf_model_warmstart$sigma2_samples, sigma_observed))) -plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds, +plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_global_samples, sigma_observed)), + max(c(bcf_model_warmstart$sigma2_global_samples, sigma_observed))) +plot(bcf_model_warmstart$sigma2_global_samples, ylim = plot_bounds, ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter") abline(h = sigma_observed, lty=3, lwd = 3, col = "blue") ``` @@ -1133,6 +1133,193 @@ test_outcome_mean <- rowMeans(bcf_model_warmstart$y_hat_test) sqrt(mean((y_test - test_outcome_mean)^2)) ``` +## Demo 7: Probit Outcome Model, Heterogeneous Treatment Effect + +We consider a modified version of a data generating process from @hahn2020bayesian: + +\begin{equation*} +\begin{aligned} +y &= \mathbb{1}\left(w > 0\right)\\ +w &= \mu(X) + \tau(X) Z + \epsilon\\ +\epsilon &\sim N\left(0,1\right)\\ +\mu(X) &= 1 + g(X) + 6 \lvert X_3 - 1 \rvert\\ +\tau(X) &= 1 + 2 X_2 X_4\\ +g(X) &= \mathbb{I}(X_5=1) \times 2 - \mathbb{I}(X_5=2) \times 1 - \mathbb{I}(X_5=3) \times 4\\ +s_{\mu} &= \sqrt{\mathbb{V}(\mu(X))}\\ +\pi(X) &= 0.8 \phi\left(\frac{3\mu(X)}{s_{\mu}}\right) - \frac{X_1}{2} + \frac{2U+1}{20}\\ +X_1,X_2,X_3 &\sim N\left(0,1\right)\\ +X_4 &\sim \text{Bernoulli}(1/2)\\ +X_5 &\sim \text{Categorical}(1/3,1/3,1/3)\\ +U &\sim \text{Uniform}\left(0,1\right)\\ +Z &\sim \text{Bernoulli}\left(\pi(X)\right) +\end{aligned} +\end{equation*} + +### Simulation + +We draw from the DGP defined above + +```{r} +n <- 2000 +x1 <- rnorm(n) +x2 <- rnorm(n) +x3 <- rnorm(n) +x4 <- as.numeric(rbinom(n,1,0.5)) +x5 <- as.numeric(sample(1:3,n,replace=TRUE)) +X <- cbind(x1,x2,x3,x4,x5) +p <- ncol(X) +mu_x <- mu1(X) +tau_x <- tau2(X) +pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 +Z <- rbinom(n,1,pi_x) +E_XZ <- mu_x + Z*tau_x +w <- E_XZ + rnorm(n, 0, 1) +y <- 1*(w > 0) +delta_x <- pnorm(mu_x + tau_x) - pnorm(mu_x) +X <- as.data.frame(X) +X$x4 <- factor(X$x4, ordered = TRUE) +X$x5 <- factor(X$x5, ordered = TRUE) + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +w_test <- w[test_inds] +w_train <- w[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +delta_x_test <- delta_x[test_inds] +delta_x_train <- delta_x[train_inds] +``` + +### Sampling and Analysis + +#### Warmstart + +We first simulate from an ensemble model of $y \mid X$ using "warm-start" +initialization samples (@krantsevich2023stochastic). This is the default in +`stochtree`. + +```{r} +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +num_samples <- num_gfr + num_burnin + num_mcmc +general_params <- list(keep_every = 5, + probit_outcome_model = T, + sample_sigma2_global = F) +prognostic_forest_params <- list(sample_sigma2_leaf = F) +treatment_effect_forest_params <- list(sample_sigma2_leaf = F) +bcf_model_warmstart <- bcf( + X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, + X_test = X_test, Z_test = Z_test, propensity_test = pi_test, + num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, + general_params = general_params, prognostic_forest_params = prognostic_forest_params, + treatment_effect_forest_params = treatment_effect_forest_params +) +``` + +Inspect the BART samples that were initialized with an XBART warm-start + +```{r} +mu_hat_test <- rowMeans(bcf_model_warmstart$mu_hat_test) +plot(mu_hat_test, mu_test, xlab = "predicted", + ylab = "actual", main = "Prognostic function") +abline(0,1,col="red",lty=3,lwd=3) +tau_hat_test <- rowMeans(bcf_model_warmstart$tau_hat_test) +plot(tau_hat_test, tau_test, xlab = "predicted", + ylab = "actual", main = "Treatment effect") +abline(0,1,col="red",lty=3,lwd=3) +delta_x_hat_test <- pnorm(mu_hat_test+tau_hat_test) - pnorm(mu_hat_test) +plot(delta_x_hat_test, delta_x_test, + xlab = "predicted", ylab = "actual", main = "Distributional treatment\neffect") +abline(0,1,col="red",lty=3,lwd=3) +``` + +Examine test set interval coverage + +```{r} +test_lb <- apply( + pnorm(bcf_model_warmstart$mu_hat_test + bcf_model_warmstart$tau_hat_test) - + pnorm(bcf_model_warmstart$mu_hat_test), 1, quantile, 0.025) +test_ub <- apply( + pnorm(bcf_model_warmstart$mu_hat_test + bcf_model_warmstart$tau_hat_test) - + pnorm(bcf_model_warmstart$mu_hat_test), 1, quantile, 0.975) +cover <- ( + (test_lb <= delta_x_test) & + (test_ub >= delta_x_test) +) +mean(cover) +``` + +#### BART MCMC without Warmstart + +Next, we simulate from this ensemble model without any warm-start initialization. + +```{r} +num_gfr <- 0 +num_burnin <- 2000 +num_mcmc <- 100 +num_samples <- num_gfr + num_burnin + num_mcmc +general_params <- list(keep_every = 5, + probit_outcome_model = T, + sample_sigma2_global = F) +prognostic_forest_params <- list(sample_sigma2_leaf = F) +treatment_effect_forest_params <- list(sample_sigma2_leaf = F) +bcf_model_root <- bcf( + X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, + X_test = X_test, Z_test = Z_test, propensity_test = pi_test, + num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, + general_params = general_params, prognostic_forest_params = prognostic_forest_params, + treatment_effect_forest_params = treatment_effect_forest_params +) +``` + +Inspect the BART samples that were initialized with an XBART warm-start + +```{r} +mu_hat_test <- rowMeans(bcf_model_root$mu_hat_test) +plot(mu_hat_test, mu_test, xlab = "predicted", + ylab = "actual", main = "Prognostic function") +abline(0,1,col="red",lty=3,lwd=3) +tau_hat_test <- rowMeans(bcf_model_root$tau_hat_test) +plot(tau_hat_test, tau_test, xlab = "predicted", + ylab = "actual", main = "Treatment effect") +abline(0,1,col="red",lty=3,lwd=3) +delta_x_hat_test <- pnorm(mu_hat_test+tau_hat_test) - pnorm(mu_hat_test) +plot(delta_x_hat_test, delta_x_test, + xlab = "predicted", ylab = "actual", main = "Distributional treatment\neffect") +abline(0,1,col="red",lty=3,lwd=3) +``` + +Examine test set interval coverage + +```{r} +test_lb <- apply( + pnorm(bcf_model_root$mu_hat_test + bcf_model_root$tau_hat_test) - + pnorm(bcf_model_root$mu_hat_test), 1, quantile, 0.025) +test_ub <- apply( + pnorm(bcf_model_root$mu_hat_test + bcf_model_root$tau_hat_test) - + pnorm(bcf_model_root$mu_hat_test), 1, quantile, 0.975) +cover <- ( + (test_lb <= delta_x_test) & + (test_ub >= delta_x_test) +) +mean(cover) +``` + # Continuous Treatment ## Demo 1: Nonlinear Outcome Model, Heterogeneous Treatment Effect @@ -1230,9 +1417,9 @@ plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") abline(0,1,col="red",lty=3,lwd=3) sigma_observed <- var(y-E_XZ) -plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)), - max(c(bcf_model_warmstart$sigma2_samples, sigma_observed))) -plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds, +plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_global_samples, sigma_observed)), + max(c(bcf_model_warmstart$sigma2_global_samples, sigma_observed))) +plot(bcf_model_warmstart$sigma2_global_samples, ylim = plot_bounds, ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter") abline(h = sigma_observed, lty=3, lwd = 3, col = "blue") ``` @@ -1280,9 +1467,9 @@ plot(rowMeans(bcf_model_root$tau_hat_test), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") abline(0,1,col="red",lty=3,lwd=3) sigma_observed <- var(y-E_XZ) -plot_bounds <- c(min(c(bcf_model_root$sigma2_samples, sigma_observed)), - max(c(bcf_model_root$sigma2_samples, sigma_observed))) -plot(bcf_model_root$sigma2_samples, ylim = plot_bounds, +plot_bounds <- c(min(c(bcf_model_root$sigma2_global_samples, sigma_observed)), + max(c(bcf_model_root$sigma2_global_samples, sigma_observed))) +plot(bcf_model_root$sigma2_global_samples, ylim = plot_bounds, ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter") abline(h = sigma_observed, lty=3, lwd = 3, col = "blue") ``` diff --git a/vignettes/CustomSamplingRoutine.Rmd b/vignettes/CustomSamplingRoutine.Rmd index 22399325..569fc1d1 100644 --- a/vignettes/CustomSamplingRoutine.Rmd +++ b/vignettes/CustomSamplingRoutine.Rmd @@ -1341,7 +1341,7 @@ tau_hat <- t(t(tau_hat_raw) * (b_1_samples - b_0_samples))*y_std y_hat <- mu_hat + tau_hat * as.numeric(Z) # Global error variance -sigma2_samples <- global_var_samples*(y_std^2) +sigma2_global_samples <- global_var_samples*(y_std^2) ``` ## Results @@ -1349,7 +1349,7 @@ sigma2_samples <- global_var_samples*(y_std^2) Inspect the XBART results ```{r bcf_xbcf_plot} -plot(sigma2_samples[1:num_gfr], ylab="sigma^2") +plot(sigma2_global_samples[1:num_gfr], ylab="sigma^2") plot(rowMeans(mu_hat[,1:num_gfr]), mu_x, pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "prognostic term") abline(0,1,col="red",lty=2,lwd=2.5) @@ -1362,7 +1362,7 @@ mean((rowMeans(tau_hat[,1:num_gfr]) - tau_x)^2) Inspect the warm start BART results ```{r bcf_warm_start_plot} -plot(sigma2_samples[(num_gfr+1):num_samples], ylab="sigma^2") +plot(sigma2_global_samples[(num_gfr+1):num_samples], ylab="sigma^2") plot(rowMeans(mu_hat[,(num_gfr+1):num_samples]), mu_x, pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "prognostic term") abline(0,1,col="red",lty=2,lwd=2.5) diff --git a/vignettes/Heteroskedasticity.Rmd b/vignettes/Heteroskedasticity.Rmd index 30350410..e9f84dce 100644 --- a/vignettes/Heteroskedasticity.Rmd +++ b/vignettes/Heteroskedasticity.Rmd @@ -19,27 +19,29 @@ knitr::opts_chunk$set( ``` This vignette demonstrates how to use the `bart()` function for Bayesian -supervised learning (@chipman2010bart), with an additional "variance forest," +supervised learning (@chipman2010bart) and causal inference (@hahn2020bayesian), with an additional "variance forest," for modeling conditional variance (see @murray2021log). To begin, we load the `stochtree` package. ```{r setup} library(stochtree) ``` -# Demo 1: Variance-Only Simulation (simple DGP) +# Section 1: Supervised Learning -## Simulation +## Demo 1: Variance-Only Simulation (simple DGP) + +### Simulation Here, we generate data with a constant (zero) mean and a relatively simple covariate-modified variance function. \begin{equation*} \begin{aligned} -y &= 0 + \sigma^2(X) \epsilon\\ +y &= 0 + \sigma(X) \epsilon\\ \sigma^2(X) &= \begin{cases} -0.25 & X_1 \geq 0 \text{ and } X_1 < 0.25\\ +0.5 & X_1 \geq 0 \text{ and } X_1 < 0.25\\ 1 & X_1 \geq 0.25 \text{ and } X_1 < 0.5\\ -4 & X_1 \geq 0.5 \text{ and } X_1 < 0.75\\ -9 & X_1 \geq 0.75 \text{ and } X_1 < 1\\ +2 & X_1 \geq 0.5 \text{ and } X_1 < 0.75\\ +3 & X_1 \geq 0.75 \text{ and } X_1 < 1\\ \end{cases}\\ X_1,\dots,X_p &\sim \text{U}\left(0,1\right)\\ \epsilon &\sim \mathcal{N}\left(0,1\right) @@ -76,9 +78,9 @@ s_x_test <- s_XW[test_inds] s_x_train <- s_XW[train_inds] ``` -## Sampling and Analysis +### Sampling and Analysis -### Warmstart +#### Warmstart We first sample the $\sigma^2(X)$ ensemble using "warm-start" initialization (@he2023stochastic). This is the default in @@ -105,12 +107,12 @@ bart_model_warmstart <- stochtree::bart( Inspect the MCMC samples ```{r} -plot(rowMeans(bart_model_warmstart$sigma_x_hat_test), s_x_test, - pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "standard deviation function") +plot(rowMeans(bart_model_warmstart$sigma2_x_hat_test), s_x_test^2, + pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "variance function") abline(0,1,col="red",lty=2,lwd=2.5) ``` -### MCMC +#### MCMC We now sample the $\sigma^2(X)$ ensemble using MCMC with root initialization (as in @chipman2010bart). @@ -134,20 +136,20 @@ bart_model_mcmc <- stochtree::bart( Inspect the MCMC samples ```{r} -plot(rowMeans(bart_model_mcmc$sigma_x_hat_test), s_x_test, - pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "standard deviation function") +plot(rowMeans(bart_model_mcmc$sigma2_x_hat_test), s_x_test^2, + pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "variance function") abline(0,1,col="red",lty=2,lwd=2.5) ``` -# Demo 2: Variance-Only Simulation (complex DGP) +## Demo 2: Variance-Only Simulation (complex DGP) -## Simulation +### Simulation Here, we generate data with a constant (zero) mean and a more complex covariate-modified variance function. \begin{equation*} \begin{aligned} -y &= 0 + \sigma^2(X) \epsilon\\ +y &= 0 + \sigma(X) \epsilon\\ \sigma^2(X) &= \begin{cases} 0.25X_3^2 & X_1 \geq 0 \text{ and } X_1 < 0.25\\ 1X_3^2 & X_1 \geq 0.25 \text{ and } X_1 < 0.5\\ @@ -189,9 +191,9 @@ s_x_test <- s_XW[test_inds] s_x_train <- s_XW[train_inds] ``` -## Sampling and Analysis +### Sampling and Analysis -### Warmstart +#### Warmstart We first sample the $\sigma^2(X)$ ensemble using "warm-start" initialization (@he2023stochastic). This is the default in @@ -219,12 +221,12 @@ bart_model_warmstart <- stochtree::bart( Inspect the MCMC samples ```{r} -plot(rowMeans(bart_model_warmstart$sigma_x_hat_test), s_x_test, - pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "standard deviation function") +plot(rowMeans(bart_model_warmstart$sigma2_x_hat_test), s_x_test^2, + pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "variance function") abline(0,1,col="red",lty=2,lwd=2.5) ``` -### MCMC +#### MCMC We now sample the $\sigma^2(X)$ ensemble using MCMC with root initialization (as in @chipman2010bart). @@ -250,20 +252,20 @@ bart_model_mcmc <- stochtree::bart( Inspect the MCMC samples ```{r} -plot(rowMeans(bart_model_mcmc$sigma_x_hat_test), s_x_test, - pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "standard deviation function") +plot(rowMeans(bart_model_mcmc$sigma2_x_hat_test), s_x_test^2, + pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "variance function") abline(0,1,col="red",lty=2,lwd=2.5) ``` -# Demo 3: Mean and Variance Simulation (simple DGP) +## Demo 3: Mean and Variance Simulation (simple DGP) -## Simulation +### Simulation Here, we generate data with (relatively simple) covariate-modified mean and variance functions. \begin{equation*} \begin{aligned} -y &= f(X) + \sigma^2(X) \epsilon\\ +y &= f(X) + \sigma(X) \epsilon\\ f(X) &= \begin{cases} -6 & X_2 \geq 0 \text{ and } X_2 < 0.25\\ -2 & X_2 \geq 0.25 \text{ and } X_2 < 0.5\\ @@ -316,9 +318,9 @@ s_x_test <- s_XW[test_inds] s_x_train <- s_XW[train_inds] ``` -## Sampling and Analysis +### Sampling and Analysis -### Warmstart +#### Warmstart We first sample the $\sigma^2(X)$ ensemble using "warm-start" initialization (@he2023stochastic). This is the default in @@ -348,12 +350,12 @@ Inspect the MCMC samples plot(rowMeans(bart_model_warmstart$y_hat_test), y_test, pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "mean function") abline(0,1,col="red",lty=2,lwd=2.5) -plot(rowMeans(bart_model_warmstart$sigma_x_hat_test), s_x_test, - pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "standard deviation function") +plot(rowMeans(bart_model_warmstart$sigma2_x_hat_test), s_x_test^2, + pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "variance function") abline(0,1,col="red",lty=2,lwd=2.5) ``` -### MCMC +#### MCMC We now sample the $\sigma^2(X)$ ensemble using MCMC with root initialization (as in @chipman2010bart). @@ -383,20 +385,20 @@ plot(rowMeans(bart_model_mcmc$y_hat_test), y_test, pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "mean function") abline(0,1,col="red",lty=2,lwd=2.5) -plot(rowMeans(bart_model_mcmc$sigma_x_hat_test), s_x_test, - pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "standard deviation function") +plot(rowMeans(bart_model_mcmc$sigma2_x_hat_test), s_x_test^2, + pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "variance function") abline(0,1,col="red",lty=2,lwd=2.5) ``` -# Demo 4: Mean and Variance Simulation (complex DGP) +## Demo 4: Mean and Variance Simulation (complex DGP) -## Simulation +### Simulation Here, we generate data with more complex covariate-modified mean and variance functions. \begin{equation*} \begin{aligned} -y &= f(X) + \sigma^2(X) \epsilon\\ +y &= f(X) + \sigma(X) \epsilon\\ f(X) &= \begin{cases} -6X_4 & X_2 \geq 0 \text{ and } X_2 < 0.25\\ -2X_4 & X_2 \geq 0.25 \text{ and } X_2 < 0.5\\ @@ -449,9 +451,9 @@ s_x_test <- s_XW[test_inds] s_x_train <- s_XW[train_inds] ``` -## Sampling and Analysis +### Sampling and Analysis -### Warmstart +#### Warmstart We first sample the $\sigma^2(X)$ ensemble using "warm-start" initialization (@he2023stochastic). This is the default in @@ -481,12 +483,12 @@ Inspect the MCMC samples plot(rowMeans(bart_model_warmstart$y_hat_test), y_test, pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "mean function") abline(0,1,col="red",lty=2,lwd=2.5) -plot(rowMeans(bart_model_warmstart$sigma_x_hat_test), s_x_test, - pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "standard deviation function") +plot(rowMeans(bart_model_warmstart$sigma2_x_hat_test), s_x_test^2, + pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "variance function") abline(0,1,col="red",lty=2,lwd=2.5) ``` -### MCMC +#### MCMC We now sample the $\sigma^2(X)$ ensemble using MCMC with root initialization (as in @chipman2010bart). @@ -516,8 +518,161 @@ plot(rowMeans(bart_model_mcmc$y_hat_test), y_test, pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "mean function") abline(0,1,col="red",lty=2,lwd=2.5) -plot(rowMeans(bart_model_mcmc$sigma_x_hat_test), s_x_test, - pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "standard deviation function") +plot(rowMeans(bart_model_mcmc$sigma2_x_hat_test), s_x_test^2, + pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "variance function") +abline(0,1,col="red",lty=2,lwd=2.5) +``` + +# Section 2: Causal Inference + +## Demo 1: Heterogeneous Treatment Effect, Continuous Treatment, Heteroskedastic Errors + +We consider the following data generating process: + +\begin{equation*} +\begin{aligned} +y &= \mu(X) + \tau(X) Z + \sigma(X) \epsilon\\ +\sigma^2(X) &= \begin{cases} +0.25 & X_1 \geq 0 \text{ and } X_1 < 0.25\\ +1 & X_1 \geq 0.25 \text{ and } X_1 < 0.5\\ +4 & X_1 \geq 0.5 \text{ and } X_1 < 0.75\\ +9 & X_1 \geq 0.75 \text{ and } X_1 < 1\\ +\end{cases}\\ +\epsilon &\sim N\left(0,\sigma^2\right)\\ +\mu(X) &= 1 + 2 X_1 - \mathbb{1}\left(X_2 < 0\right) \times 4 + \mathbb{1}\left(X_2 \geq 0\right) \times 4 + 3 \left(\lvert X_3 \rvert - \sqrt{\frac{2}{\pi}} \right)\\ +\tau(X) &= 1 + 2 X_4\\ +X_1,X_2,X_3,X_4,X_5 &\sim N\left(0,1\right)\\ +U &\sim \text{Uniform}\left(0,1\right)\\ +\pi(X) &= \frac{\mu(X) - 1}{2} + 4 \left(U - \frac{1}{2}\right)\\ +Z &\sim \mathcal{N}\left(\pi(X), 1\right) +\end{aligned} +\end{equation*} + +### Simulation + +We draw from the DGP defined above + +```{r} +n <- 2000 +x1 <- rnorm(n) +x2 <- rnorm(n) +x3 <- rnorm(n) +x4 <- rnorm(n) +x5 <- rnorm(n) +X <- cbind(x1,x2,x3,x4,x5) +p <- ncol(X) +mu_x <- 1 + 2*x1 - 4*(x2 < 0) + 4*(x2 >= 0) + 3*(abs(x3) - sqrt(2/pi)) +tau_x <- 1 + 2*x4 +u <- runif(n) +pi_x <- ((mu_x-1)/4) + 4*(u-0.5) +Z <- pi_x + rnorm(n,0,1) +E_XZ <- mu_x + Z*tau_x +s_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (1) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (3) +) +y <- E_XZ + rnorm(n, 0, 1)*s_X +X <- as.data.frame(X) + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +s_x_test <- s_X[test_inds] +s_x_train <- s_X[train_inds] +``` + +### Sampling and Analysis + +#### Warmstart + +We first simulate from an ensemble model of $y \mid X$ using "warm-start" +initialization samples (@krantsevich2023stochastic). This is the default in +`stochtree`. + +```{r} +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +num_samples <- num_gfr + num_burnin + num_mcmc +general_params <- list(keep_every = 5) +prognostic_forest_params <- list(sample_sigma2_leaf = F) +treatment_effect_forest_params <- list(sample_sigma2_leaf = F) +variance_forest_params <- list(num_trees = num_trees) +bcf_model_warmstart <- bcf( + X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, + X_test = X_test, Z_test = Z_test, propensity_test = pi_test, + num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, + general_params = general_params, prognostic_forest_params = prognostic_forest_params, + treatment_effect_forest_params = treatment_effect_forest_params, + variance_forest_params = variance_forest_params +) +``` + +Inspect the BART samples that were initialized with an XBART warm-start + +```{r} +plot(rowMeans(bcf_model_warmstart$mu_hat_test), mu_test, + xlab = "predicted", ylab = "actual", main = "Prognostic function") +abline(0,1,col="red",lty=3,lwd=3) +plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test, + xlab = "predicted", ylab = "actual", main = "Treatment effect") +abline(0,1,col="red",lty=3,lwd=3) +plot(rowMeans(bcf_model_warmstart$sigma2_x_hat_test), s_x_test^2, + pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "variance function") +abline(0,1,col="red",lty=2,lwd=2.5) +``` + +#### BART MCMC without Warmstart + +Next, we simulate from this ensemble model without any warm-start initialization. + +```{r} +num_gfr <- 0 +num_burnin <- 2000 +num_mcmc <- 100 +num_samples <- num_gfr + num_burnin + num_mcmc +general_params <- list(keep_every = 5) +prognostic_forest_params <- list(sample_sigma2_leaf = F) +treatment_effect_forest_params <- list(sample_sigma2_leaf = F) +variance_forest_params <- list(num_trees = num_trees) +bcf_model_root <- bcf( + X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, + X_test = X_test, Z_test = Z_test, propensity_test = pi_test, + num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, + general_params = general_params, prognostic_forest_params = prognostic_forest_params, + treatment_effect_forest_params = treatment_effect_forest_params, + variance_forest_params = variance_forest_params +) +``` + +Inspect the BART samples after burnin + +```{r} +plot(rowMeans(bcf_model_root$mu_hat_test), mu_test, + xlab = "predicted", ylab = "actual", main = "Prognostic function") +abline(0,1,col="red",lty=3,lwd=3) +plot(rowMeans(bcf_model_root$tau_hat_test), tau_test, + xlab = "predicted", ylab = "actual", main = "Treatment effect") +abline(0,1,col="red",lty=3,lwd=3) +plot(rowMeans(bcf_model_root$sigma2_x_hat_test), s_x_test^2, + pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "variance function") abline(0,1,col="red",lty=2,lwd=2.5) ``` diff --git a/vignettes/ModelSerialization.Rmd b/vignettes/ModelSerialization.Rmd index 9e80aa80..0fedb736 100644 --- a/vignettes/ModelSerialization.Rmd +++ b/vignettes/ModelSerialization.Rmd @@ -225,7 +225,7 @@ bart_preds_reload <- predict(bart_model_reload, X_train) plot(rowMeans(bart_model$y_hat_train), rowMeans(bart_preds_reload$y_hat), xlab = "Original", ylab = "Deserialized", main = "Conditional Mean Estimates") abline(0,1,col="red",lwd=3,lty=3) -plot(rowMeans(bart_model$sigma_x_hat_train), rowMeans(bart_preds_reload$variance_forest_predictions), +plot(rowMeans(bart_model$sigma2_x_hat_train), rowMeans(bart_preds_reload$variance_forest_predictions), xlab = "Original", ylab = "Deserialized", main = "Conditional Variance Estimates") abline(0,1,col="red",lwd=3,lty=3) ```