Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -613,11 +613,13 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
}

# Estimate if pre-estimated propensity score is not provided
internal_propensity_model <- F
if ((is.null(pi_train)) && (propensity_covariate != "none")) {
internal_propensity_model <- T
# Estimate using the last of several iterations of GFR BART
num_burnin <- 10
num_total <- 50
bart_model_propensity <- bart(X_train = X_train_raw, y_train = as.numeric(Z_train), X_test = X_test_raw,
bart_model_propensity <- bart(X_train = X_train, y_train = as.numeric(Z_train), X_test = X_test_raw,
num_gfr = num_total, num_burnin = 0, num_mcmc = 0)
pi_train <- rowMeans(bart_model_propensity$y_hat_train[,(num_burnin+1):num_total])
if ((is.null(dim(pi_train))) && (!is.null(pi_train))) {
Expand Down Expand Up @@ -1233,6 +1235,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
"propensity_covariate" = propensity_covariate,
"binary_treatment" = binary_treatment,
"adaptive_coding" = adaptive_coding,
"internal_propensity_model" = internal_propensity_model,
"num_samples" = num_retained_samples,
"num_gfr" = num_gfr,
"num_burnin" = num_burnin,
Expand Down Expand Up @@ -1277,6 +1280,9 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
result[["rfx_unique_group_ids"]] = levels(group_ids_factor)
}
if ((has_rfx_test) && (has_test)) result[["rfx_preds_test"]] = rfx_preds_test
if (internal_propensity_model) {
result[["bart_propensity_model"]] = bart_model_propensity
}
class(result) <- "bcf"

return(result)
Expand Down Expand Up @@ -1366,7 +1372,11 @@ predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL, group_ids_test = NU

# Data checks
if ((bcf$model_params$propensity_covariate != "none") && (is.null(pi_test))) {
stop("pi_test must be provided for this model")
if (!bcf$model_params$internal_propensity_model) {
stop("pi_test must be provided for this model")
}
# Compute propensity score using the internal bart model
pi_test <- rowMeans(predict(bcf$bart_propensity_model, X_test)$y_hat)
}
if (nrow(X_test) != nrow(Z_test)) {
stop("X_test and Z_test must have the same number of rows")
Expand Down Expand Up @@ -1662,6 +1672,7 @@ convertBCFModelToJson <- function(object){
jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis)
jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis)
jsonobj$add_boolean("adaptive_coding", object$model_params$adaptive_coding)
jsonobj$add_boolean("internal_propensity_model", object$model_params$internal_propensity_model)
jsonobj$add_scalar("num_gfr", object$model_params$num_gfr)
jsonobj$add_scalar("num_burnin", object$model_params$num_burnin)
jsonobj$add_scalar("num_mcmc", object$model_params$num_mcmc)
Expand Down Expand Up @@ -1689,6 +1700,14 @@ convertBCFModelToJson <- function(object){
jsonobj$add_string_vector("rfx_unique_group_ids", object$rfx_unique_group_ids)
}

# Add propensity model (if it exists)
if (object$model_params$internal_propensity_model) {
bart_propensity_string <- saveBARTModelToJsonString(
object$bart_propensity_model
)
jsonobj$add_string("bart_propensity_model", bart_propensity_string)
}

return(jsonobj)
}

Expand Down Expand Up @@ -1962,6 +1981,7 @@ createBCFModelFromJson <- function(json_object){
model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis")
model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis")
model_params[["adaptive_coding"]] <- json_object$get_boolean("adaptive_coding")
model_params[["internal_propensity_model"]] <- json_object$get_boolean("internal_propensity_model")
model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr")
model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin")
model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc")
Expand Down Expand Up @@ -1990,6 +2010,14 @@ createBCFModelFromJson <- function(json_object){
output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0)
}

# Unpack propensity model (if it exists)
if (model_params[["internal_propensity_model"]]) {
bart_propensity_string <- json_object$get_string("bart_propensity_model")
output[["bart_propensity_model"]] <- createBARTModelFromJsonString(
bart_propensity_string
)
}

class(output) <- "bcf"
return(output)
}
Expand Down Expand Up @@ -2229,6 +2257,12 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){
for (i in 1:length(json_string_list)) {
json_string <- json_string_list[[i]]
json_object_list[[i]] <- createCppJsonString(json_string)
# Add runtime check for separately serialized propensity models
# We don't support merging BCF models with independent propensity models
# this way at the moment
if (json_object_list[[i]]$get_boolean("internal_propensity_model")) {
stop("Combining separate BCF models with cached internal propensity models is currently unsupported. To make this work, please first train a propensity model and then pass the propensities as data to the separate BCF models before sampling.")
}
}

# For scalar / preprocessing details which aren't sample-dependent,
Expand Down Expand Up @@ -2279,6 +2313,7 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){
model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains")
model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every")
model_params[["adaptive_coding"]] <- json_object_default$get_boolean("adaptive_coding")
model_params[["internal_propensity_model"]] <- json_object_default$get_boolean("internal_propensity_model")

# Combine values that are sample-specific
for (i in 1:length(json_object_list)) {
Expand Down
13 changes: 13 additions & 0 deletions stochtree/bcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,7 @@ def to_json(self) -> str:
bcf_json.add_scalar("num_samples", self.num_samples)
bcf_json.add_boolean("adaptive_coding", self.adaptive_coding)
bcf_json.add_string("propensity_covariate", self.propensity_covariate)
bcf_json.add_boolean("internal_propensity_model", self.internal_propensity_model)

# Add parameter samples
if self.sample_sigma_global:
Expand All @@ -1414,6 +1415,11 @@ def to_json(self) -> str:
bcf_json.add_numeric_vector("b0_samples", self.b0_samples, "parameters")
bcf_json.add_numeric_vector("b1_samples", self.b1_samples, "parameters")

# Add propensity model (if it exists)
if self.internal_propensity_model:
bart_propensity_string = self.bart_propensity_model.to_json()
bcf_json.add_string("bart_propensity_model", bart_propensity_string)

return bcf_json.return_json_string()

def from_json(self, json_string: str) -> None:
Expand Down Expand Up @@ -1457,6 +1463,7 @@ def from_json(self, json_string: str) -> None:
self.num_samples = int(bcf_json.get_scalar("num_samples"))
self.adaptive_coding = bcf_json.get_boolean("adaptive_coding")
self.propensity_covariate = bcf_json.get_string("propensity_covariate")
self.internal_propensity_model = bcf_json.get_boolean("internal_propensity_model")

# Unpack parameter samples
if self.sample_sigma_global:
Expand All @@ -1469,6 +1476,12 @@ def from_json(self, json_string: str) -> None:
self.b1_samples = bcf_json.get_numeric_vector("b1_samples", "parameters")
self.b0_samples = bcf_json.get_numeric_vector("b0_samples", "parameters")

# Unpack internal propensity model
if self.internal_propensity_model:
bart_propensity_string = bcf_json.get_string("bart_propensity_model")
self.bart_propensity_model = BARTModel()
self.bart_propensity_model.from_json(bart_propensity_string)

# Mark the deserialized model as "sampled"
self.sampled = True

Expand Down
162 changes: 162 additions & 0 deletions test/R/testthat/test-serialization.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
test_that("BART Serialization", {
skip_on_cran()

# Generate simulated data
n <- 100
p <- 5
X <- matrix(runif(n*p), ncol = p)
f_XW <- (
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
)
noise_sd <- 1
y <- f_XW + rnorm(n, 0, noise_sd)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
y_test <- y[test_inds]
y_train <- y[train_inds]

# Sample a BART model
general_param_list <- list(num_chains = 1, keep_every = 1)
bart_model <- bart(X_train = X_train, y_train = y_train,
num_gfr = 0, num_burnin = 10, num_mcmc = 10,
general_params = general_param_list)
y_hat_orig <- rowMeans(predict(bart_model, X_test)$y_hat)

# Save to JSON
bart_json_string <- saveBARTModelToJsonString(bart_model)

# Reload as a BART model
bart_model_roundtrip <- createBARTModelFromJsonString(bart_json_string)

# Predict from the roundtrip BART model
y_hat_reloaded <- rowMeans(predict(bart_model_roundtrip, X_test)$y_hat)

# Assertion
expect_equal(y_hat_orig, y_hat_reloaded)
})

test_that("BCF Serialization", {
skip_on_cran()

n <- 500
x1 <- runif(n)
x2 <- runif(n)
x3 <- runif(n)
x4 <- runif(n)
x5 <- runif(n)
X <- cbind(x1,x2,x3,x4,x5)
p <- ncol(X)
pi_x <- 0.25 + 0.5*X[,1]
mu_x <- pi_x * 5
tau_x <- X[,2] * 2
Z <- rbinom(n,1,pi_x)
E_XZ <- mu_x + Z*tau_x
y <- E_XZ + rnorm(n, 0, 1)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds]
tau_train <- tau_x[train_inds]

# Sample a BCF model
bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
pi_train = pi_train, num_gfr = 100, num_burnin = 0, num_mcmc = 100)
bcf_preds_orig <- predict(bcf_model, X_test, Z_test, pi_test)
mu_hat_orig <- rowMeans(bcf_preds_orig[["mu_hat"]])
tau_hat_orig <- rowMeans(bcf_preds_orig[["tau_hat"]])
y_hat_orig <- rowMeans(bcf_preds_orig[["y_hat"]])

# Save to JSON
bcf_json_string <- saveBCFModelToJsonString(bcf_model)

# Reload as a BCF model
bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string)

# Predict from the roundtrip BCF model
bcf_preds_reloaded <- predict(bcf_model_roundtrip, X_test, Z_test, pi_test)
mu_hat_reloaded <- rowMeans(bcf_preds_reloaded[["mu_hat"]])
tau_hat_reloaded <- rowMeans(bcf_preds_reloaded[["tau_hat"]])
y_hat_reloaded <- rowMeans(bcf_preds_reloaded[["y_hat"]])

# Assertion
expect_equal(y_hat_orig, y_hat_reloaded)
})

test_that("BCF Serialization (no propensity)", {
skip_on_cran()

n <- 500
x1 <- runif(n)
x2 <- runif(n)
x3 <- runif(n)
x4 <- runif(n)
x5 <- runif(n)
X <- cbind(x1,x2,x3,x4,x5)
p <- ncol(X)
pi_x <- 0.25 + 0.5*X[,1]
mu_x <- pi_x * 5
tau_x <- X[,2] * 2
Z <- rbinom(n,1,pi_x)
E_XZ <- mu_x + Z*tau_x
y <- E_XZ + rnorm(n, 0, 1)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds]
tau_train <- tau_x[train_inds]

# Sample a BCF model
bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
num_gfr = 100, num_burnin = 0, num_mcmc = 100)
bcf_preds_orig <- predict(bcf_model, X_test, Z_test)
mu_hat_orig <- rowMeans(bcf_preds_orig[["mu_hat"]])
tau_hat_orig <- rowMeans(bcf_preds_orig[["tau_hat"]])
y_hat_orig <- rowMeans(bcf_preds_orig[["y_hat"]])

# Save to JSON
bcf_json_string <- saveBCFModelToJsonString(bcf_model)

# Reload as a BCF model
bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string)

# Predict from the roundtrip BCF model
bcf_preds_reloaded <- predict(bcf_model_roundtrip, X_test, Z_test)
mu_hat_reloaded <- rowMeans(bcf_preds_reloaded[["mu_hat"]])
tau_hat_reloaded <- rowMeans(bcf_preds_reloaded[["tau_hat"]])
y_hat_reloaded <- rowMeans(bcf_preds_reloaded[["y_hat"]])

# Assertion
expect_equal(y_hat_orig, y_hat_reloaded)
})
36 changes: 36 additions & 0 deletions test/python/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,39 @@ def test_bcf_string(self):
np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded)
np.testing.assert_almost_equal(tau_hat_orig, tau_hat_reloaded)
np.testing.assert_almost_equal(mu_hat_orig, mu_hat_reloaded)

def test_bcf_propensity_string(self):
# RNG
random_seed = 1234
rng = np.random.default_rng(random_seed)

# Generate covariates and basis
n = 100
p_X = 5
X = rng.uniform(0, 1, (n, p_X))
pi_X = 0.25 + 0.5*X[:,0]
Z = rng.binomial(1, pi_X, n).astype(float)

# Define the outcome mean functions (prognostic and treatment effects)
mu_X = pi_X*5
tau_X = X[:,1]*2

# Generate outcome
epsilon = rng.normal(0, 1, n)
y = mu_X + tau_X*Z + epsilon

# Run BCF without passing propensity scores (so an internal propensity model must be constructed)
bcf_orig = BCFModel()
bcf_orig.sample(X_train=X, Z_train=Z, y_train=y, num_gfr=10, num_mcmc=10)

# Extract predictions from the sampler
mu_hat_orig, tau_hat_orig, y_hat_orig = bcf_orig.predict(X, Z, pi_X)

# "Round-trip" the model to JSON string and back and check that the predictions agree
bcf_json_string = bcf_orig.to_json()
bcf_reloaded = BCFModel()
bcf_reloaded.from_json(bcf_json_string)
mu_hat_reloaded, tau_hat_reloaded, y_hat_reloaded = bcf_reloaded.predict(X, Z, pi_X)
np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded)
np.testing.assert_almost_equal(tau_hat_orig, tau_hat_reloaded)
np.testing.assert_almost_equal(mu_hat_orig, mu_hat_reloaded)
Loading