Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,7 @@ convertBARTModelToJson <- function(object){
jsonobj$add_scalar("variance_scale", object$model_params$variance_scale)
jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale)
jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean)
jsonobj$add_scalar("sigma2_init", object$model_params$sigma2_init)
jsonobj$add_boolean("sample_sigma_global", object$model_params$sample_sigma_global)
jsonobj$add_boolean("sample_sigma_leaf", object$model_params$sample_sigma_leaf)
jsonobj$add_boolean("include_mean_forest", object$model_params$include_mean_forest)
Expand Down Expand Up @@ -1141,6 +1142,7 @@ createBARTModelFromJson <- function(json_object){
model_params[["variance_scale"]] <- json_object$get_scalar("variance_scale")
model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale")
model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean")
model_params[["sigma2_init"]] <- json_object$get_scalar("sigma2_init")
model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global")
model_params[["sample_sigma_leaf"]] <- json_object$get_boolean("sample_sigma_leaf")
model_params[["include_mean_forest"]] <- include_mean_forest
Expand Down Expand Up @@ -1336,6 +1338,7 @@ createBARTModelFromCombinedJson <- function(json_object_list){
model_params = list()
model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale")
model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean")
model_params[["sigma2_init"]] <- json_object_default$get_scalar("sigma2_init")
model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global")
model_params[["sample_sigma_leaf"]] <- json_object$get_boolean("sample_sigma_leaf")
model_params[["include_mean_forest"]] <- include_mean_forest
Expand Down Expand Up @@ -1486,6 +1489,7 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){
model_params[["variance_scale"]] <- json_object_default$get_scalar("variance_scale")
model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale")
model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean")
model_params[["sigma2_init"]] <- json_object_default$get_scalar("sigma2_init")
model_params[["sample_sigma_global"]] <- json_object_default$get_boolean("sample_sigma_global")
model_params[["sample_sigma_leaf"]] <- json_object_default$get_boolean("sample_sigma_leaf")
model_params[["include_mean_forest"]] <- include_mean_forest
Expand Down
15 changes: 13 additions & 2 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,7 @@ predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL, group_ids_test = NU
# Compute forest predictions
y_std <- bcf$model_params$outcome_scale
y_bar <- bcf$model_params$outcome_mean
sigma2_init <- bcf$model_params$initial_sigma2
initial_sigma2 <- bcf$model_params$initial_sigma2
mu_hat_test <- bcf$forests_mu$predict(prediction_dataset_mu)*y_std + y_bar
if (bcf$model_params$adaptive_coding) {
tau_hat_test_raw <- bcf$forests_tau$predict_raw(prediction_dataset_tau)
Expand Down Expand Up @@ -1224,7 +1224,7 @@ predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL, group_ids_test = NU
sigma2_samples <- bcf$sigma2_global_samples
variance_forest_predictions <- sapply(1:length(keep_indices), function(i) sqrt(s_x_raw[,i]*sigma2_samples[i]))
} else {
variance_forest_predictions <- sqrt(s_x_raw*sigma2_init)*y_std
variance_forest_predictions <- sqrt(s_x_raw*initial_sigma2)*y_std
}
}

Expand Down Expand Up @@ -1406,6 +1406,9 @@ convertBCFModelToJson <- function(object){
# Add the forests
jsonobj$add_forest(object$forests_mu)
jsonobj$add_forest(object$forests_tau)
if (object$model_params$include_variance_forest) {
jsonobj$add_forest(object$forests_variance)
}

# Add metadata
jsonobj$add_scalar("num_numeric_vars", object$train_set_metadata$num_numeric_vars)
Expand All @@ -1426,9 +1429,11 @@ convertBCFModelToJson <- function(object){
# Add global parameters
jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale)
jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean)
jsonobj$add_scalar("initial_sigma2", object$model_params$initial_sigma2)
jsonobj$add_boolean("sample_sigma_global", object$model_params$sample_sigma_global)
jsonobj$add_boolean("sample_sigma_leaf_mu", object$model_params$sample_sigma_leaf_mu)
jsonobj$add_boolean("sample_sigma_leaf_tau", object$model_params$sample_sigma_leaf_tau)
jsonobj$add_boolean("include_variance_forest", object$model_params$include_variance_forest)
jsonobj$add_string("propensity_covariate", object$model_params$propensity_covariate)
jsonobj$add_boolean("has_rfx", object$model_params$has_rfx)
jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis)
Expand Down Expand Up @@ -1686,6 +1691,10 @@ createBCFModelFromJson <- function(json_object){
# Unpack the forests
output[["forests_mu"]] <- loadForestContainerJson(json_object, "forest_0")
output[["forests_tau"]] <- loadForestContainerJson(json_object, "forest_1")
include_variance_forest <- json_object$get_boolean("include_variance_forest")
if (include_variance_forest) {
output[["forests_variance"]] <- loadForestContainerJson(json_object, "forest_2")
}

# Unpack metadata
train_set_metadata = list()
Expand All @@ -1710,9 +1719,11 @@ createBCFModelFromJson <- function(json_object){
model_params = list()
model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale")
model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean")
model_params[["initial_sigma2"]] <- json_object$get_scalar("initial_sigma2")
model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global")
model_params[["sample_sigma_leaf_mu"]] <- json_object$get_boolean("sample_sigma_leaf_mu")
model_params[["sample_sigma_leaf_tau"]] <- json_object$get_boolean("sample_sigma_leaf_tau")
model_params[["include_variance_forest"]] <- include_variance_forest
model_params[["propensity_covariate"]] <- json_object$get_string("propensity_covariate")
model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx")
model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis")
Expand Down
3 changes: 3 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ reference:
- CppRNG
- createRNG
- calibrate_inverse_gamma_error_variance
- preprocessBartParams
- preprocessBcfParams

- subtitle: Random Effects
desc: >
Expand Down Expand Up @@ -118,6 +120,7 @@ articles:
contents:
- BayesianSupervisedLearning
- CausalInference
- Heteroskedasticity

- title: Advanced Model Interface
navbar: Advanced Model Interface
Expand Down
28 changes: 14 additions & 14 deletions vignettes/CustomSamplingRoutine.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ for (i in 1:num_warmstart) {
forest_model$sample_one_iteration(
forest_dataset, outcome, forest_samples, rng, feature_types,
outcome_model_type, leaf_prior_scale, var_weights,
global_var_samples[i], cutpoint_grid_size, gfr = T
1, 1, global_var_samples[i], cutpoint_grid_size, gfr = T
)

# Sample global variance parameter
Expand All @@ -186,7 +186,7 @@ for (i in (num_warmstart+1):num_samples) {
forest_model$sample_one_iteration(
forest_dataset, outcome, forest_samples, rng, feature_types,
outcome_model_type, leaf_prior_scale, var_weights,
global_var_samples[i], cutpoint_grid_size, gfr = F
1, 1, global_var_samples[i], cutpoint_grid_size, gfr = F
)

# Sample global variance parameter
Expand Down Expand Up @@ -370,7 +370,7 @@ for (i in 1:num_warmstart) {
forest_model$sample_one_iteration(
forest_dataset, outcome, forest_samples, rng, feature_types,
outcome_model_type, leaf_prior_scale, var_weights,
global_var_samples[i], cutpoint_grid_size, gfr = T
1, 1, global_var_samples[i], cutpoint_grid_size, gfr = T
)

# Sample global variance parameter
Expand Down Expand Up @@ -398,7 +398,7 @@ for (i in (num_warmstart+1):num_samples) {
forest_model$sample_one_iteration(
forest_dataset, outcome, forest_samples, rng, feature_types,
outcome_model_type, leaf_prior_scale, var_weights,
global_var_samples[i], cutpoint_grid_size, gfr = F
1, 1, global_var_samples[i], cutpoint_grid_size, gfr = F
)

# Sample global variance parameter
Expand Down Expand Up @@ -599,7 +599,7 @@ for (i in 1:num_warmstart) {
forest_model$sample_one_iteration(
forest_dataset, outcome, forest_samples, rng, feature_types,
outcome_model_type, leaf_prior_scale, var_weights,
global_var_samples[i], cutpoint_grid_size, gfr = T
1, 1, global_var_samples[i], cutpoint_grid_size, gfr = T
)

# Sample global variance parameter
Expand Down Expand Up @@ -627,7 +627,7 @@ for (i in (num_warmstart+1):num_samples) {
forest_model$sample_one_iteration(
forest_dataset, outcome, forest_samples, rng, feature_types,
outcome_model_type, leaf_prior_scale, var_weights,
global_var_samples[i], cutpoint_grid_size, gfr = F
1, 1, global_var_samples[i], cutpoint_grid_size, gfr = F
)

# Sample global variance parameter
Expand Down Expand Up @@ -824,12 +824,12 @@ for (i in 1:num_warmstart) {
forest_model$sample_one_iteration(
forest_dataset, outcome, forest_samples, rng, feature_types,
outcome_model_type, leaf_prior_scale, var_weights,
sigma2, cutpoint_grid_size, gfr = T
1, 1, sigma2, cutpoint_grid_size, gfr = T
)

# Sample global variance parameter
global_var_samples[i+1] <- sample_sigma2_one_iteration(
outcome, rng, nu, lambda
outcome, forest_dataset, rng, nu, lambda
)
}
```
Expand Down Expand Up @@ -862,12 +862,12 @@ for (i in (num_warmstart+1):num_samples) {
forest_model$sample_one_iteration(
forest_dataset, outcome, forest_samples, rng, feature_types,
outcome_model_type, leaf_prior_scale, var_weights,
global_var_samples[i], cutpoint_grid_size, gfr = F
1, 1, global_var_samples[i], cutpoint_grid_size, gfr = F
)

# Sample global variance parameter
global_var_samples[i+1] <- sample_sigma2_one_iteration(
outcome, rng, nu, lambda
outcome, forest_dataset, rng, nu, lambda
)
}
```
Expand Down Expand Up @@ -1150,7 +1150,7 @@ if (num_gfr > 0){
forest_model_mu$sample_one_iteration(
forest_dataset_mu, outcome, forest_samples_mu, rng,
feature_types_mu, 0, current_leaf_scale_mu, variable_weights_mu,
current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T
1, 1, current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T
)

# Sample variance parameters (if requested)
Expand All @@ -1163,7 +1163,7 @@ if (num_gfr > 0){
forest_model_tau$sample_one_iteration(
forest_dataset_tau, outcome, forest_samples_tau, rng,
feature_types_tau, 1, current_leaf_scale_tau, variable_weights_tau,
current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T
1, 1, current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T
)

# Sample adaptive coding parameters
Expand Down Expand Up @@ -1198,7 +1198,7 @@ if (num_burnin + num_mcmc > 0) {
# Sample the prognostic forest
forest_model_mu$sample_one_iteration(
forest_dataset_mu, outcome, forest_samples_mu, rng, feature_types_mu,
0, current_leaf_scale_mu, variable_weights_mu, current_sigma2,
0, current_leaf_scale_mu, variable_weights_mu, 1, 1, current_sigma2,
cutpoint_grid_size, gfr = F, pre_initialized = T
)

Expand All @@ -1209,7 +1209,7 @@ if (num_burnin + num_mcmc > 0) {
# Sample the treatment forest
forest_model_tau$sample_one_iteration(
forest_dataset_tau, outcome, forest_samples_tau, rng, feature_types_tau,
1, current_leaf_scale_tau, variable_weights_tau, current_sigma2,
1, current_leaf_scale_tau, variable_weights_tau, 1, 1, current_sigma2,
cutpoint_grid_size, gfr = F, pre_initialized = T
)

Expand Down
14 changes: 8 additions & 6 deletions vignettes/ModelSerialization.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,14 @@ num_gfr <- 10
num_burnin <- 0
num_mcmc <- 100
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F)
bcf_model <- bcf(
X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train,
group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train,
X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test,
rfx_basis_test = rfx_basis_test,
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
params = bcf_params
)
```

Expand Down Expand Up @@ -189,14 +190,15 @@ num_gfr <- 10
num_burnin <- 0
num_mcmc <- 100
num_samples <- num_gfr + num_burnin + num_mcmc
bart_params <- list(num_trees_mean = 100, num_trees_variance = 50,
alpha_mean = 0.95, beta_mean = 2, min_samples_leaf_mean = 5,
alpha_variance = 0.95, beta_variance = 1.25,
min_samples_leaf_variance = 1,
sample_sigma_global = F, sample_sigma_leaf = F)
bart_model <- stochtree::bart(
X_train = X_train, y_train = y_train, X_test = X_test,
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
num_trees_mean = 0, num_trees_variance = 50,
alpha_mean = 0.95, beta_mean = 2, min_samples_leaf_mean = 5,
alpha_variance = 0.95, beta_variance = 1.25,
min_samples_leaf_variance = 1,
sample_sigma_global = F, sample_sigma_leaf = F
params = bart_params
)
```

Expand Down
Loading