Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -2474,6 +2474,8 @@ bcf <- function(
)
tau_hat_train <- t(t(tau_hat_train_raw) * (b_1_samples - b_0_samples)) *
y_std_train
control_adj_train <- t(t(tau_hat_train_raw) * b_0_samples) * y_std_train
mu_hat_train <- mu_hat_train + control_adj_train
} else {
tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_train) *
y_std_train
Expand Down Expand Up @@ -2508,6 +2510,8 @@ bcf <- function(
t(tau_hat_test_raw) * (b_1_samples - b_0_samples)
) *
y_std_train
control_adj_test <- t(t(tau_hat_test_raw) * b_0_samples) * y_std_train
mu_hat_test <- mu_hat_test + control_adj_test
} else {
tau_hat_test <- forest_samples_tau$predict_raw(
forest_dataset_test
Expand Down Expand Up @@ -2849,10 +2853,11 @@ predict.bcfmodel <- function(
"all"
))
) {
stop(paste0(
warning(paste0(
"Term '",
term,
"' was requested. Valid terms are 'y_hat', 'prognostic_function', 'mu', 'cate', 'tau', 'rfx', 'variance_forest', and 'all'."
"' was requested. Valid terms are 'y_hat', 'prognostic_function', 'mu', 'cate', 'tau', 'rfx', 'variance_forest', and 'all'.",
" This term will be ignored and prediction will only proceed if other requested terms are available in the model."
))
}
}
Expand Down Expand Up @@ -3056,6 +3061,8 @@ predict.bcfmodel <- function(
t(tau_hat_raw) * (object$b_1_samples - object$b_0_samples)
) *
y_std
control_adj <- t(t(tau_hat_raw) * object$b_0_samples) * y_std
mu_hat_forest <- mu_hat_forest + control_adj
} else {
tau_hat_forest <- object$forests_tau$predict_raw(forest_dataset_pred) *
y_std
Expand Down
68 changes: 68 additions & 0 deletions demo/debug/bcf_pred_rmse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Load libraries
from stochtree import BCFModel
import numpy as np
from sklearn.model_selection import train_test_split
from scipy.stats import norm

# Simulation parameters
n = 250
p = 50
n_sim = 100
test_set_pct = 0.2
rng = np.random.default_rng()

# Simulation containers
rmses_cached = np.empty(n_sim)
rmses_pred = np.empty(n_sim)

# Run the simulation
for i in range(n_sim):
# Generate data
X = rng.normal(loc=0.0, scale=1.0, size=(n, p))
mu_X = X[:, 0]
tau_X = 0.25 * X[:, 1]
pi_X = norm.cdf(0.5 * X[:, 1])
Z = rng.binomial(n=1, p=pi_X, size=(n,))
E_XZ = mu_X + tau_X * Z
snr = 2.0
noise_sd = np.std(E_XZ) / snr
y = E_XZ + rng.normal(loc=0.0, scale=noise_sd, size=(n,))

# Train-test split
sample_inds = np.arange(n)
train_inds, test_inds = train_test_split(sample_inds, test_size=test_set_pct)
X_train = X[train_inds, :]
X_test = X[test_inds, :]
Z_train = Z[train_inds]
Z_test = Z[test_inds]
pi_train = pi_X[train_inds]
pi_test = pi_X[test_inds]
tau_train = tau_X[train_inds]
tau_test = tau_X[test_inds]
mu_train = mu_X[train_inds]
mu_test = mu_X[test_inds]
y_train = y[train_inds]
y_test = y[test_inds]
E_XZ_train = E_XZ[train_inds]
E_XZ_test = E_XZ[test_inds]

# Fit simple BCF model
bcf_model = BCFModel()
bcf_model.sample(
X_train=X_train,
Z_train=Z_train,
pi_train=pi_train,
y_train=y_train,
X_test=X_test,
Z_test=Z_test,
pi_test=pi_test,
)

# Predict out of sample
y_hat_test = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test, type="mean", terms = "y_hat")

# Compute RMSE using both cached predictions and those returned by predict()
rmses_cached[i] = np.sqrt(np.mean(np.power(np.mean(bcf_model.y_hat_test, axis = 1) - E_XZ_test, 2.0)))
rmses_pred[i] = np.sqrt(np.mean(np.power(y_hat_test - E_XZ_test, 2.0)))

print(f"Average RMSE, cached: {np.mean(rmses_cached):.4f}, out-of-sample pred: {np.mean(rmses_pred):.4f}")
15 changes: 15 additions & 0 deletions stochtree/bcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2267,7 +2267,12 @@ def sample(
adaptive_coding_weights = np.expand_dims(
self.b1_samples - self.b0_samples, axis=(0, 2)
)
b0_weights = np.expand_dims(
self.b0_samples, axis=(0, 2)
)
control_adj_train = self.tau_hat_train * b0_weights * self.y_std
self.tau_hat_train = self.tau_hat_train * adaptive_coding_weights
self.mu_hat_train = self.mu_hat_train + np.squeeze(control_adj_train)
self.tau_hat_train = np.squeeze(self.tau_hat_train * self.y_std)
if self.multivariate_treatment:
treatment_term_train = np.multiply(
Expand All @@ -2289,7 +2294,12 @@ def sample(
adaptive_coding_weights_test = np.expand_dims(
self.b1_samples - self.b0_samples, axis=(0, 2)
)
b0_weights = np.expand_dims(
self.b0_samples, axis=(0, 2)
)
control_adj_test = self.tau_hat_test * b0_weights * self.y_std
self.tau_hat_test = self.tau_hat_test * adaptive_coding_weights_test
self.mu_hat_test = self.mu_hat_test + np.squeeze(control_adj_test)
self.tau_hat_test = np.squeeze(self.tau_hat_test * self.y_std)
if self.multivariate_treatment:
treatment_term_test = np.multiply(
Expand Down Expand Up @@ -2594,7 +2604,12 @@ def predict(
adaptive_coding_weights = np.expand_dims(
self.b1_samples - self.b0_samples, axis=(0, 2)
)
b0_weights = np.expand_dims(
self.b0_samples, axis=(0, 2)
)
control_adj = tau_raw * b0_weights * self.y_std
tau_raw = tau_raw * adaptive_coding_weights
mu_x_forest = mu_x_forest + np.squeeze(control_adj)
tau_x_forest = np.squeeze(tau_raw * self.y_std)
if Z.shape[1] > 1:
treatment_term = np.multiply(
Expand Down
74 changes: 74 additions & 0 deletions tools/simulations/bcf-pred-rmse.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Load library
library(stochtree)

# Simulation parameters
n <- 250
p <- 50
n_sim <- 100
test_set_pct <- 0.2

# Simulation containers
rmses_cached <- rep(NA_real_, n_sim)
rmses_pred <- rep(NA_real_, n_sim)

# Run the simulation
for (i in 1:n_sim) {
# Generate data
X <- matrix(rnorm(n * p), ncol = p)
mu_x <- X[, 1]
tau_x <- 0.25 * X[, 2]
pi_x <- pnorm(0.5 * X[, 1])
Z <- rbinom(n, 1, pi_x)
E_XZ <- mu_x + Z * tau_x
snr <- 2
y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr)

# Train-test split
n_test <- round(test_set_pct * n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds, ]
X_train <- X[train_inds, ]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds]
tau_train <- tau_x[train_inds]
E_XZ_test <- E_XZ[test_inds]
E_XZ_train <- E_XZ[train_inds]

# Fit a simple BCF model
bcf_model <- bcf(
X_train = X_train,
Z_train = Z_train,
y_train = y_train,
propensity_train = pi_train,
X_test = X_test,
Z_test = Z_test,
propensity_test = pi_test
)

# Predict out of sample
y_hat_test <- predict(
bcf_model,
X = X_test,
Z = Z_test,
propensity = pi_test,
type = "mean",
terms = "y_hat"
)

# Compute RMSE using both cached predictions and those returned by predict()
rmses_cached[i] <- sqrt(mean((rowMeans(bcf_model$y_hat_test) - E_XZ_test)^2))
rmses_pred[i] <- sqrt(mean((y_hat_test - E_XZ_test)^2))
}

# Inspect results
mean(rmses_cached)
mean(rmses_pred)
Loading