From 46c64d01e4e96f73b6d63661e755bf94342a49e1 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 5 Dec 2025 13:57:45 -0600 Subject: [PATCH 01/12] Bump version number to patch release --- CHANGELOG.md | 10 +--------- DESCRIPTION | 2 +- Doxyfile | 2 +- NEWS.md | 10 +--------- configure | 18 +++++++++--------- configure.ac | 2 +- pyproject.toml | 2 +- 7 files changed, 15 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e3be9f5..bf6c0ecd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,19 +1,11 @@ # Changelog -# stochtree (development version) - -## New Features - -## Computational Improvements +# stochtree 0.2.1 ## Bug Fixes * Predict random effects correctly in R for univariate random effects models ([#248](https://github.com/StochasticTree/stochtree/pull/248)) -## Documentation Improvements - -## Other Changes - # stochtree 0.2.0 ## New Features diff --git a/DESCRIPTION b/DESCRIPTION index aa7ae91a..02cc3bc6 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: stochtree Title: Stochastic Tree Ensembles (XBART and BART) for Supervised Learning and Causal Inference -Version: 0.2.0.9000 +Version: 0.2.1 Authors@R: c( person("Drew", "Herren", email = "drewherrenopensource@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-4109-6611")), diff --git a/Doxyfile b/Doxyfile index 4c8168ae..4230f847 100644 --- a/Doxyfile +++ b/Doxyfile @@ -48,7 +48,7 @@ PROJECT_NAME = "StochTree" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = 0.2.0.9000 +PROJECT_NUMBER = 0.2.1 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/NEWS.md b/NEWS.md index 8844d3c1..676ed749 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,17 +1,9 @@ -# stochtree (development version) - -## New Features - -## Computational Improvements +# stochtree 0.2.1 ## Bug Fixes * Predict random effects correctly in R for univariate random effects models ([#248](https://github.com/StochasticTree/stochtree/pull/248)) -## Documentation Improvements - -## Other Changes - # stochtree 0.2.0 ## New Features diff --git a/configure b/configure index d862d747..54bfbeb1 100755 --- a/configure +++ b/configure @@ -1,6 +1,6 @@ #! /bin/sh # Guess values for system-dependent variables and create Makefiles. -# Generated by GNU Autoconf 2.72 for stochtree 0.2.0.9000. +# Generated by GNU Autoconf 2.72 for stochtree 0.2.1. # # # Copyright (C) 1992-1996, 1998-2017, 2020-2023 Free Software Foundation, @@ -600,8 +600,8 @@ MAKEFLAGS= # Identity of this package. PACKAGE_NAME='stochtree' PACKAGE_TARNAME='stochtree' -PACKAGE_VERSION='0.2.0.9000' -PACKAGE_STRING='stochtree 0.2.0.9000' +PACKAGE_VERSION='0.2.1' +PACKAGE_STRING='stochtree 0.2.1' PACKAGE_BUGREPORT='' PACKAGE_URL='' @@ -1204,7 +1204,7 @@ if test "$ac_init_help" = "long"; then # Omit some internal or obsolete options to make the list less imposing. # This message is too long to be a string in the A/UX 3.1 sh. cat <<_ACEOF -'configure' configures stochtree 0.2.0.9000 to adapt to many kinds of systems. +'configure' configures stochtree 0.2.1 to adapt to many kinds of systems. Usage: $0 [OPTION]... [VAR=VALUE]... @@ -1266,7 +1266,7 @@ fi if test -n "$ac_init_help"; then case $ac_init_help in - short | recursive ) echo "Configuration of stochtree 0.2.0.9000:";; + short | recursive ) echo "Configuration of stochtree 0.2.1:";; esac cat <<\_ACEOF @@ -1334,7 +1334,7 @@ fi test -n "$ac_init_help" && exit $ac_status if $ac_init_version; then cat <<\_ACEOF -stochtree configure 0.2.0.9000 +stochtree configure 0.2.1 generated by GNU Autoconf 2.72 Copyright (C) 2023 Free Software Foundation, Inc. @@ -1371,7 +1371,7 @@ cat >config.log <<_ACEOF This file contains any messages produced by compilers while running configure, to aid debugging if configure makes a mistake. -It was created by stochtree $as_me 0.2.0.9000, which was +It was created by stochtree $as_me 0.2.1, which was generated by GNU Autoconf 2.72. Invocation command line was $ $0$ac_configure_args_raw @@ -2380,7 +2380,7 @@ cat >>$CONFIG_STATUS <<\_ACEOF || ac_write_fail=1 # report actual input values of CONFIG_FILES etc. instead of their # values after options handling. ac_log=" -This file was extended by stochtree $as_me 0.2.0.9000, which was +This file was extended by stochtree $as_me 0.2.1, which was generated by GNU Autoconf 2.72. Invocation command line was CONFIG_FILES = $CONFIG_FILES @@ -2435,7 +2435,7 @@ ac_cs_config_escaped=`printf "%s\n" "$ac_cs_config" | sed "s/^ //; s/'/'\\\\\\\\ cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1 ac_cs_config='$ac_cs_config_escaped' ac_cs_version="\\ -stochtree config.status 0.2.0.9000 +stochtree config.status 0.2.1 configured by $0, generated by GNU Autoconf 2.72, with options \\"\$ac_cs_config\\" diff --git a/configure.ac b/configure.ac index 3d1143ba..33505e2c 100644 --- a/configure.ac +++ b/configure.ac @@ -3,7 +3,7 @@ # https://github.com/microsoft/LightGBM/blob/master/R-package/configure.ac AC_PREREQ(2.69) -AC_INIT([stochtree], [0.2.0.9000], [], [stochtree], []) +AC_INIT([stochtree], [0.2.1], [], [stochtree], []) # Note: consider making version number dynamic as in # https://github.com/microsoft/LightGBM/blob/195c26fc7b00eb0fec252dfe841e2e66d6833954/build-cran-package.sh diff --git a/pyproject.toml b/pyproject.toml index 0fe8a12a..2e992666 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta" [project] name = "stochtree" -version = "0.2.0-dev" +version = "0.2.1" dynamic = ["readme", "optional-dependencies", "license"] description = "Stochastic Tree Ensembles for Machine Learning and Causal Inference" requires-python = ">=3.8.0" From 7488c029325a7d3d723be02c196dc491b717cdfa Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 9 Dec 2025 12:00:40 -0500 Subject: [PATCH 02/12] Added simple functionality check for BART interface options --- test/R/testthat/test-bart-integration.R | 209 ++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 test/R/testthat/test-bart-integration.R diff --git a/test/R/testthat/test-bart-integration.R b/test/R/testthat/test-bart-integration.R new file mode 100644 index 00000000..c7bf2f2e --- /dev/null +++ b/test/R/testthat/test-bart-integration.R @@ -0,0 +1,209 @@ +run_bart_factorial <- function( + bart_data, + leaf_reg = "none", + variance_forest = FALSE, + random_effects = "none", + sampling_global_error_scale = FALSE, + sampling_leaf_scale = FALSE, + outcome_type = "continuous", + num_chains = 1 +) { + if ((leaf_reg == "multivariate") && (sampling_leaf_scale)) { + stop( + "Leaf error scale cannot be stochastic for multivariate leaf regression" + ) + } + + # Unpack BART data + y <- bart_data[["y"]] + X <- bart_data[["X"]] + if (leaf_reg != "none") { + leaf_basis <- bart_data[["leaf_basis"]] + } else { + leaf_basis <- NULL + } + if (random_effects != "none") { + rfx_group_ids <- bart_data[["rfx_group_ids"]] + } else { + rfx_group_ids <- NULL + } + if (random_effects == "custom") { + rfx_basis <- bart_data[["rfx_basis"]] + } else { + rfx_basis <- NULL + } + + # Run and return the bart model + general_params <- list( + num_chains = num_chains, + sample_sigma2_global = sampling_global_error_scale, + probit_outcome_model = outcome_type == "binary" + ) + mean_forest_params <- list( + sample_sigma2_leaf = sampling_leaf_scale + ) + variance_forest_params <- list( + num_trees = ifelse(variance_forest, 20, 0) + ) + rfx_params <- list( + model_spec = ifelse(random_effects == "none", "custom", random_effects) + ) + # cat("X = ", X) + # cat("y = ", y) + # cat("leaf_basis = ", leaf_basis) + # cat("rfx_group_ids = ", rfx_group_ids) + # cat("rfx_basis = ", rfx_basis) + return(stochtree::bart( + X_train = X, + y_train = y, + leaf_basis_train = leaf_basis, + rfx_group_ids_train = rfx_group_ids, + rfx_basis_train = rfx_basis, + general_params = general_params, + mean_forest_params = mean_forest_params, + variance_forest_params = variance_forest_params, + random_effects_params = rfx_params + )) +} + +test_that("Quick check of interactions between components of BART functionality", { + skip_on_cran() + + # Overall, we have seven components of a BART sampler which can be on / off or set to different levels: + # 1. Leaf regression: none, univariate, multivariate + # 2. Variance forest: no, yes + # 3. Random effects: no, custom basis, `intercept_only` + # 4. Sampling global error scale: no, yes + # 5. Sampling leaf scale on mean forest: no, yes (only available for constant leaf or univariate leaf regression) + # 6. Outcome type: continuous (identity link), binary (probit link) + # 7. Number of chains: 1, >1 + # + # For each of the possible models this implies, + # we'd like to be sure that stochtree functions that operate on BART models + # will run without error. Since there are so many possible models implied by the + # options above, this test is designed to be quick (small sample size, low dimensional data) + # and we are only interested in ensuring no errors are triggered. + + # Generate data with random effects + n <- 50 + p <- 3 + num_basis <- 2 + num_rfx_groups <- 3 + num_rfx_basis <- 2 + X <- matrix(runif(n * p), ncol = p) + leaf_basis <- matrix(runif(n * num_basis), ncol = num_basis) + leaf_coefs <- runif(num_basis) + group_ids <- sample(1:num_rfx_groups, n, replace = T) + rfx_basis <- matrix(runif(n * num_rfx_basis), ncol = num_rfx_basis) + rfx_coefs <- matrix( + runif(num_rfx_groups * num_rfx_basis), + ncol = num_rfx_basis + ) + mean_term <- sin(X[, 1]) * rowSums(leaf_basis * leaf_coefs) + rfx_term <- rowSums(rfx_coefs[group_ids, ] * rfx_basis) + E_y <- sin(X[, 1]) + rfx_term + E_y <- E_y - mean(E_y) + epsilon <- rnorm(n, 0, 1) + y_continuous <- E_y + epsilon + y_binary <- 1 * (y_continuous > 0) + + # Split into test and train sets + test_set_pct <- 0.5 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + leaf_basis_test <- leaf_basis[test_inds, ] + leaf_basis_train <- leaf_basis[train_inds, ] + rfx_basis_test <- rfx_basis[test_inds, ] + rfx_basis_train <- rfx_basis[train_inds, ] + group_ids_test <- group_ids[test_inds] + group_ids_train <- group_ids[train_inds] + y_continuous_test <- y_continuous[test_inds] + y_continuous_train <- y_continuous[train_inds] + y_binary_test <- y_binary[test_inds] + y_binary_train <- y_binary[train_inds] + + # Run the power set of models + leaf_reg_options <- c("none", "univariate", "multivariate") + variance_forest_options <- c(FALSE, TRUE) + random_effects_options <- c("none", "custom", "intercept_only") + sampling_global_error_scale_options <- c(FALSE, TRUE) + sampling_leaf_scale_options <- c(FALSE, TRUE) + outcome_type_options <- c("continuous", "binary") + num_chains_options <- c(1, 3) + model_options_df <- expand.grid( + leaf_reg = leaf_reg_options, + variance_forest = variance_forest_options, + random_effects = random_effects_options, + sampling_global_error_scale = sampling_global_error_scale_options, + sampling_leaf_scale = sampling_leaf_scale_options, + outcome_type = outcome_type_options, + num_chains = num_chains_options, + stringsAsFactors = FALSE + ) + for (i in 1:nrow(model_options_df)) { + error_cond_1 <- (model_options_df$sampling_leaf_scale[i]) && + (model_options_df$leaf_reg[i] == "multivariate") + error_cond_2 <- (model_options_df$variance_forest[i]) && + (model_options_df$outcome_type[i] == "binary") + error_cond <- error_cond_1 || error_cond_2 + warning_cond_1 <- (model_options_df$sampling_leaf_scale[i]) && + (model_options_df$leaf_reg[i] == "multivariate") + warning_cond_2 <- (model_options_df$sampling_global_error_scale[i]) && + (model_options_df$outcome_type[i] == "binary") + warning_cond <- warning_cond_1 || warning_cond_2 + if (error_cond && warning_cond) { + test_fun <- function(x) expect_error(expect_warning(x)) + } else if (error_cond && !warning_cond) { + test_fun <- expect_error + } else if (!error_cond && warning_cond) { + test_fun <- expect_warning + } else { + test_fun <- expect_no_error + } + test_fun({ + bart_data <- list(X = X_train) + if (model_options_df$outcome_type[i] == "continuous") { + bart_data[["y"]] <- y_continuous_train + } else { + bart_data[["y"]] <- y_binary_train + } + if (model_options_df$leaf_reg[i] != "none") { + if (model_options_df$leaf_reg[i] == "univariate") { + bart_data[["leaf_basis"]] <- leaf_basis_train[, 1] + } else { + bart_data[["leaf_basis"]] <- leaf_basis_train + } + } else { + bart_data[["leaf_basis"]] <- NULL + } + if (model_options_df$random_effects[i] != "none") { + bart_data[["rfx_group_ids"]] <- group_ids_train + } else { + bart_data[["rfx_group_ids"]] <- NULL + } + if (model_options_df$random_effects[i] == "custom") { + bart_data[["rfx_basis"]] <- rfx_basis_train + } else { + bart_data[["rfx_basis"]] <- NULL + } + run_bart_factorial( + bart_data = bart_data, + leaf_reg = model_options_df$leaf_reg[i], + variance_forest = model_options_df$variance_forest[i], + random_effects = model_options_df$random_effects[i], + sampling_global_error_scale = model_options_df$sampling_global_error_scale[ + i + ], + sampling_leaf_scale = model_options_df$sampling_leaf_scale[ + i + ], + outcome_type = model_options_df$outcome_type[i], + num_chains = model_options_df$num_chains[i] + ) + }) + } +}) From 81d84bd4ea53ac35ef53200ffa2895b05bdb6dd1 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 9 Dec 2025 17:58:39 -0500 Subject: [PATCH 03/12] Updated BART integration tests in R --- test/R/testthat/test-bart-integration.R | 181 ++++++++++++++++++------ 1 file changed, 141 insertions(+), 40 deletions(-) diff --git a/test/R/testthat/test-bart-integration.R b/test/R/testthat/test-bart-integration.R index c7bf2f2e..b34d36c7 100644 --- a/test/R/testthat/test-bart-integration.R +++ b/test/R/testthat/test-bart-integration.R @@ -1,5 +1,6 @@ run_bart_factorial <- function( - bart_data, + bart_data_train, + bart_data_test, leaf_reg = "none", variance_forest = FALSE, random_effects = "none", @@ -8,32 +9,26 @@ run_bart_factorial <- function( outcome_type = "continuous", num_chains = 1 ) { - if ((leaf_reg == "multivariate") && (sampling_leaf_scale)) { - stop( - "Leaf error scale cannot be stochastic for multivariate leaf regression" - ) - } - - # Unpack BART data - y <- bart_data[["y"]] - X <- bart_data[["X"]] + # Unpack BART training data + y <- bart_data_train[["y"]] + X <- bart_data_train[["X"]] if (leaf_reg != "none") { - leaf_basis <- bart_data[["leaf_basis"]] + leaf_basis <- bart_data_train[["leaf_basis"]] } else { leaf_basis <- NULL } if (random_effects != "none") { - rfx_group_ids <- bart_data[["rfx_group_ids"]] + rfx_group_ids <- bart_data_train[["rfx_group_ids"]] } else { rfx_group_ids <- NULL } if (random_effects == "custom") { - rfx_basis <- bart_data[["rfx_basis"]] + rfx_basis <- bart_data_train[["rfx_basis"]] } else { rfx_basis <- NULL } - # Run and return the bart model + # Set BART model parameters general_params <- list( num_chains = num_chains, sample_sigma2_global = sampling_global_error_scale, @@ -46,14 +41,11 @@ run_bart_factorial <- function( num_trees = ifelse(variance_forest, 20, 0) ) rfx_params <- list( - model_spec = ifelse(random_effects == "none", "custom", random_effects) + model_spec = ifelse(random_effects == "custom", "custom", random_effects) ) - # cat("X = ", X) - # cat("y = ", y) - # cat("leaf_basis = ", leaf_basis) - # cat("rfx_group_ids = ", rfx_group_ids) - # cat("rfx_basis = ", rfx_basis) - return(stochtree::bart( + + # Sample BART model + bart_model <- stochtree::bart( X_train = X, y_train = y, leaf_basis_train = leaf_basis, @@ -63,7 +55,70 @@ run_bart_factorial <- function( mean_forest_params = mean_forest_params, variance_forest_params = variance_forest_params, random_effects_params = rfx_params - )) + ) + + # Unpack test set data + y_test <- bart_data_test[["y"]] + X_test <- bart_data_test[["X"]] + if (leaf_reg != "none") { + leaf_basis_test <- bart_data_test[["leaf_basis"]] + } else { + leaf_basis_test <- NULL + } + if (random_effects != "none") { + rfx_group_ids_test <- bart_data_test[["rfx_group_ids"]] + } else { + rfx_group_ids_test <- NULL + } + if (random_effects == "custom") { + rfx_basis_test <- bart_data_test[["rfx_basis"]] + } else { + rfx_basis_test <- NULL + } + + # Predict on test set + mean_preds <- predict( + bart_model, + X = X_test, + leaf_basis = leaf_basis_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + type = "mean", + terms = "all", + scale = ifelse(outcome_type == "binary", "probability", "linear") + ) + posterior_preds <- predict( + bart_model, + X = X_test, + leaf_basis = leaf_basis_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "all", + scale = ifelse(outcome_type == "binary", "probability", "linear") + ) + + # Compute intervals + posterior_interval <- compute_bart_posterior_interval( + bart_model, + terms = "all", + level = 0.95, + scale = ifelse(outcome_type == "binary", "probability", "linear"), + X = X_test, + leaf_basis = leaf_basis_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test + ) + + # Sample posterior predictive + posterior_predictive_draws <- sample_bart_posterior_predictive( + bart_model, + X = X_test, + leaf_basis = leaf_basis_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + num_draws_per_sample = 5 + ) } test_that("Quick check of interactions between components of BART functionality", { @@ -101,7 +156,7 @@ test_that("Quick check of interactions between components of BART functionality" ) mean_term <- sin(X[, 1]) * rowSums(leaf_basis * leaf_coefs) rfx_term <- rowSums(rfx_coefs[group_ids, ] * rfx_basis) - E_y <- sin(X[, 1]) + rfx_term + E_y <- mean_term + rfx_term E_y <- E_y - mean(E_y) epsilon <- rnorm(n, 0, 1) y_continuous <- E_y + epsilon @@ -145,53 +200,99 @@ test_that("Quick check of interactions between components of BART functionality" stringsAsFactors = FALSE ) for (i in 1:nrow(model_options_df)) { - error_cond_1 <- (model_options_df$sampling_leaf_scale[i]) && - (model_options_df$leaf_reg[i] == "multivariate") - error_cond_2 <- (model_options_df$variance_forest[i]) && + error_cond <- (model_options_df$variance_forest[i]) && (model_options_df$outcome_type[i] == "binary") - error_cond <- error_cond_1 || error_cond_2 warning_cond_1 <- (model_options_df$sampling_leaf_scale[i]) && (model_options_df$leaf_reg[i] == "multivariate") + warning_message_1 <- "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model." warning_cond_2 <- (model_options_df$sampling_global_error_scale[i]) && (model_options_df$outcome_type[i] == "binary") + warning_message_2 <- "Global error variance will not be sampled with a probit link as it is fixed at 1" warning_cond <- warning_cond_1 || warning_cond_2 if (error_cond && warning_cond) { - test_fun <- function(x) expect_error(expect_warning(x)) + if (warning_cond_1 && warning_cond_2) { + test_fun <- function(x) { + expect_error( + expect_warning( + expect_warning(x, warning_message_1), + warning_message_2 + ) + ) + } + } else if (warning_cond_1) { + test_fun <- function(x) { + expect_error( + expect_warning(x, warning_message_1) + ) + } + } else { + test_fun <- function(x) { + expect_error( + expect_warning(x, warning_message_2) + ) + } + } } else if (error_cond && !warning_cond) { test_fun <- expect_error } else if (!error_cond && warning_cond) { - test_fun <- expect_warning + if (warning_cond_1 && warning_cond_2) { + test_fun <- function(x) { + expect_warning( + expect_warning(x, warning_message_1), + warning_message_2 + ) + } + } else if (warning_cond_1) { + test_fun <- function(x) { + expect_warning(x, warning_message_1) + } + } else { + test_fun <- function(x) { + expect_warning(x, warning_message_2) + } + } } else { test_fun <- expect_no_error } test_fun({ - bart_data <- list(X = X_train) + bart_data_train <- list(X = X_train) + bart_data_test <- list(X = X_test) if (model_options_df$outcome_type[i] == "continuous") { - bart_data[["y"]] <- y_continuous_train + bart_data_train[["y"]] <- y_continuous_train + bart_data_test[["y"]] <- y_continuous_test } else { - bart_data[["y"]] <- y_binary_train + bart_data_train[["y"]] <- y_binary_train + bart_data_test[["y"]] <- y_binary_test } if (model_options_df$leaf_reg[i] != "none") { if (model_options_df$leaf_reg[i] == "univariate") { - bart_data[["leaf_basis"]] <- leaf_basis_train[, 1] + bart_data_train[["leaf_basis"]] <- leaf_basis_train[, 1, drop = FALSE] + bart_data_test[["leaf_basis"]] <- leaf_basis_test[, 1, drop = FALSE] } else { - bart_data[["leaf_basis"]] <- leaf_basis_train + bart_data_train[["leaf_basis"]] <- leaf_basis_train + bart_data_test[["leaf_basis"]] <- leaf_basis_test } } else { - bart_data[["leaf_basis"]] <- NULL + bart_data_train[["leaf_basis"]] <- NULL + bart_data_test[["leaf_basis"]] <- NULL } if (model_options_df$random_effects[i] != "none") { - bart_data[["rfx_group_ids"]] <- group_ids_train + bart_data_train[["rfx_group_ids"]] <- group_ids_train + bart_data_test[["rfx_group_ids"]] <- group_ids_test } else { - bart_data[["rfx_group_ids"]] <- NULL + bart_data_train[["rfx_group_ids"]] <- NULL + bart_data_test[["rfx_group_ids"]] <- NULL } if (model_options_df$random_effects[i] == "custom") { - bart_data[["rfx_basis"]] <- rfx_basis_train + bart_data_train[["rfx_basis"]] <- rfx_basis_train + bart_data_test[["rfx_basis"]] <- rfx_basis_test } else { - bart_data[["rfx_basis"]] <- NULL + bart_data_train[["rfx_basis"]] <- NULL + bart_data_test[["rfx_basis"]] <- NULL } run_bart_factorial( - bart_data = bart_data, + bart_data_train = bart_data_train, + bart_data_test = bart_data_test, leaf_reg = model_options_df$leaf_reg[i], variance_forest = model_options_df$variance_forest[i], random_effects = model_options_df$random_effects[i], From 69d0de26d33166e4b633f24697edec7e9ee5ab7b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 9 Dec 2025 17:59:14 -0500 Subject: [PATCH 04/12] Renamed API combination test file and added python tests --- R/bart.R | 2 - R/posterior_transformation.R | 82 +++-- stochtree/bart.py | 103 ++++--- ...-integration.R => test-api-combinations.R} | 0 test/python/test_api_combinations.py | 289 ++++++++++++++++++ 5 files changed, 394 insertions(+), 82 deletions(-) rename test/R/testthat/{test-bart-integration.R => test-api-combinations.R} (100%) create mode 100644 test/python/test_api_combinations.py diff --git a/R/bart.R b/R/bart.R index f5068f96..82d6dab6 100644 --- a/R/bart.R +++ b/R/bart.R @@ -2124,7 +2124,6 @@ predict.bartmodel <- function( X <- preprocessPredictionData(X, train_set_metadata) # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) - has_rfx <- FALSE if (predict_rfx) { if (!is.null(rfx_group_ids)) { rfx_unique_group_ids <- object$rfx_unique_group_ids @@ -2135,7 +2134,6 @@ predict.bartmodel <- function( ) } rfx_group_ids <- as.integer(group_ids_factor) - has_rfx <- TRUE } } diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index dca34be3..010bd6ff 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -409,19 +409,31 @@ compute_contrast_bart_model <- function( "rfx_group_ids_0 and rfx_group_ids_1 must be provided for this model" ) } - if ((has_rfx) && (is.null(rfx_basis_0) || is.null(rfx_basis_1))) { - stop( - "rfx_basis_0 and rfx_basis_1 must be provided for this model" - ) - } - if ( - (object$model_params$num_rfx_basis > 0) && - ((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) || - (ncol(rfx_basis_1) != object$model_params$num_rfx_basis)) - ) { - stop( - "rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model" - ) + if (has_rfx) { + if (object$model_params$rfx_model_spec == "custom") { + if ((is.null(rfx_basis_0) || is.null(rfx_basis_1))) { + stop( + "A user-provided basis (`rfx_basis_0` and `rfx_basis_1`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + } + if (!is.matrix(rfx_basis_0) || !is.matrix(rfx_basis_1)) { + stop("'rfx_basis_0' and 'rfx_basis_1' must be matrices") + } + if ((nrow(rfx_basis_0) != nrow(X)) || (nrow(rfx_basis_1) != nrow(X))) { + stop( + "'rfx_basis_0' and 'rfx_basis_1' must have the same number of rows as 'X'" + ) + } + if ( + (object$model_params$num_rfx_basis > 0) && + ((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) || + (ncol(rfx_basis_1) != object$model_params$num_rfx_basis)) + ) { + stop( + "rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model" + ) + } + } } # Predict for the control arm @@ -735,16 +747,18 @@ sample_bart_posterior_predictive <- function( "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) } - if (is.null(rfx_basis)) { - stop( - "'rfx_basis' must be provided in order to compute the requested intervals" - ) - } - if (!is.matrix(rfx_basis)) { - stop("'rfx_basis' must be a matrix") - } - if (nrow(rfx_basis) != nrow(X)) { - stop("'rfx_basis' must have the same number of rows as 'X'") + if (model_object$model_params$rfx_model_spec == "custom") { + if (is.null(rfx_basis)) { + stop( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + } + if (!is.matrix(rfx_basis)) { + stop("'rfx_basis' must be a matrix") + } + if (nrow(rfx_basis) != nrow(X)) { + stop("'rfx_basis' must have the same number of rows as 'X'") + } } } @@ -1172,16 +1186,18 @@ compute_bart_posterior_interval <- function( "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) } - if (is.null(rfx_basis)) { - stop( - "'rfx_basis' must be provided in order to compute the requested intervals" - ) - } - if (!is.matrix(rfx_basis)) { - stop("'rfx_basis' must be a matrix") - } - if (nrow(rfx_basis) != nrow(X)) { - stop("'rfx_basis' must have the same number of rows as 'X'") + if (model_object$model_params$rfx_model_spec == "custom") { + if (is.null(rfx_basis)) { + stop( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + } + if (!is.matrix(rfx_basis)) { + stop("'rfx_basis' must be a matrix") + } + if (nrow(rfx_basis) != nrow(X)) { + stop("'rfx_basis' must have the same number of rows as 'X'") + } } } diff --git a/stochtree/bart.py b/stochtree/bart.py index b7bf1c88..5d21c48e 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -70,15 +70,15 @@ def __init__(self) -> None: def sample( self, - X_train: Union[np.array, pd.DataFrame], - y_train: np.array, - leaf_basis_train: np.array = None, - rfx_group_ids_train: np.array = None, - rfx_basis_train: np.array = None, - X_test: Union[np.array, pd.DataFrame] = None, - leaf_basis_test: np.array = None, - rfx_group_ids_test: np.array = None, - rfx_basis_test: np.array = None, + X_train: Union[np.ndarray, pd.DataFrame], + y_train: np.ndarray, + leaf_basis_train: Optional[np.ndarray] = None, + rfx_group_ids_train: Optional[np.ndarray] = None, + rfx_basis_train: Optional[np.ndarray] = None, + X_test: Optional[Union[np.ndarray, pd.DataFrame]] = None, + leaf_basis_test: Optional[np.ndarray] = None, + rfx_group_ids_test: Optional[np.ndarray] = None, + rfx_basis_test: Optional[np.ndarray] = None, num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, @@ -859,6 +859,13 @@ def sample( if num_features_subsample_variance is None: num_features_subsample_variance = X_train.shape[1] + # Runtime check for multivariate leaf regression + if sample_sigma2_leaf and self.num_basis > 1: + warnings.warn( + "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model." + ) + sample_sigma2_leaf = False + # Preliminary runtime checks for probit link if not self.include_mean_forest: self.probit_outcome_model = False @@ -872,15 +879,15 @@ def sample( raise ValueError( "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1" ) - if self.include_variance_forest: - raise ValueError( - "We do not support heteroskedasticity with a probit link" - ) if sample_sigma2_global: warnings.warn( "Global error variance will not be sampled with a probit link as it is fixed at 1" ) sample_sigma2_global = False + if self.include_variance_forest: + raise ValueError( + "We do not support heteroskedasticity with a probit link" + ) # Handle standardization, prior calibration, and initialization of forest # differently for binary and continuous outcomes @@ -1217,7 +1224,7 @@ def sample( else: leaf_model_mean_forest = 2 leaf_dimension_mean = self.num_basis - + # Sampling data structures global_model_config = GlobalModelConfig(global_error_variance=current_sigma2) if self.include_mean_forest: @@ -1900,6 +1907,9 @@ def predict( if leaf_basis is not None: if leaf_basis.ndim == 1: leaf_basis = np.expand_dims(leaf_basis, 1) + if rfx_basis is not None: + if rfx_basis.ndim == 1: + rfx_basis = np.expand_dims(rfx_basis, 1) # Covariate preprocessing if not self._covariate_preprocessor._check_is_fitted(): @@ -1958,21 +1968,18 @@ def predict( mean_forest_predictions = mean_pred_raw * self.y_std + self.y_bar # Random effects data checks - if has_rfx: - if rfx_group_ids is None: - raise ValueError( - "rfx_group_ids must be provided if rfx_basis is provided" - ) - if rfx_basis is not None: - if rfx_basis.ndim == 1: - rfx_basis = np.expand_dims(rfx_basis, 1) - if rfx_basis.shape[0] != X.shape[0]: - raise ValueError("X and rfx_basis must have the same number of rows") + if predict_rfx and rfx_group_ids is None: + raise ValueError( + "Random effect group labels (rfx_group_ids) must be provided for this model" + ) + if predict_rfx and rfx_basis is None and not rfx_intercept: + raise ValueError("Random effects basis (rfx_basis) must be provided for this model") + if self.num_rfx_basis > 0 and not rfx_intercept: if rfx_basis.shape[1] != self.num_rfx_basis: raise ValueError( - "rfx_basis must have the same number of columns as the random effects basis used to sample this model" + "Random effects basis has a different dimension than the basis used to train this model" ) - + # Random effects predictions if predict_rfx or predict_rfx_intermediate: if rfx_basis is not None: @@ -1983,7 +1990,7 @@ def predict( # Sanity check -- this branch should only occur if rfx_model_spec == "intercept_only" if not rfx_intercept: raise ValueError( - "rfx_basis must be provided for random effects models with random slopes" + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" ) # Extract the raw RFX samples and scale by train set outcome standard deviation @@ -2321,16 +2328,17 @@ def compute_posterior_interval( raise ValueError( "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) - if rfx_basis is None: - raise ValueError( - "'rfx_basis' must be provided in order to compute the requested intervals" - ) - if not isinstance(rfx_basis, np.ndarray): - raise ValueError("'rfx_basis' must be a numpy array") - if rfx_basis.shape[0] != X.shape[0]: - raise ValueError( - "'rfx_basis' must have the same number of rows as 'X'" - ) + if self.rfx_model_spec == "custom": + if rfx_basis is None: + raise ValueError( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + if not isinstance(rfx_basis, np.ndarray): + raise ValueError("'rfx_basis' must be a numpy array") + if rfx_basis.shape[0] != X.shape[0]: + raise ValueError( + "'rfx_basis' must have the same number of rows as 'X'" + ) # Compute posterior matrices for the requested model terms predictions = self.predict( @@ -2427,16 +2435,17 @@ def sample_posterior_predictive( raise ValueError( "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) - if rfx_basis is None: - raise ValueError( - "'rfx_basis' must be provided in order to compute the requested intervals" - ) - if not isinstance(rfx_basis, np.ndarray): - raise ValueError("'rfx_basis' must be a numpy array") - if rfx_basis.shape[0] != X.shape[0]: - raise ValueError( - "'rfx_basis' must have the same number of rows as 'X'" - ) + if self.rfx_model_spec == "custom": + if rfx_basis is None: + raise ValueError( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + if not isinstance(rfx_basis, np.ndarray): + raise ValueError("'rfx_basis' must be a numpy array") + if rfx_basis.shape[0] != X.shape[0]: + raise ValueError( + "'rfx_basis' must have the same number of rows as 'X'" + ) # Compute posterior predictive samples bart_preds = self.predict( diff --git a/test/R/testthat/test-bart-integration.R b/test/R/testthat/test-api-combinations.R similarity index 100% rename from test/R/testthat/test-bart-integration.R rename to test/R/testthat/test-api-combinations.R diff --git a/test/python/test_api_combinations.py b/test/python/test_api_combinations.py new file mode 100644 index 00000000..edbb0dd6 --- /dev/null +++ b/test/python/test_api_combinations.py @@ -0,0 +1,289 @@ +import itertools +import pytest +import numpy as np +from sklearn.model_selection import train_test_split + +from stochtree import BARTModel + + +def run_bart_factorial( + bart_data_train, + bart_data_test, + leaf_reg="none", + variance_forest=False, + random_effects="none", + sampling_global_error_scale=False, + sampling_leaf_scale=False, + outcome_type="continuous", + num_chains=1, +): + # Unpack BART training data + y = bart_data_train["y"] + X = bart_data_train["X"] + if leaf_reg != "none": + leaf_basis = bart_data_train["leaf_basis"] + else: + leaf_basis = None + if random_effects != "none": + rfx_group_ids = bart_data_train["rfx_group_ids"] + else: + rfx_group_ids = None + if random_effects == "custom": + rfx_basis = bart_data_train["rfx_basis"] + else: + rfx_basis = None + + # Set BART model parameters + general_params = { + "num_chains": num_chains, + "sample_sigma2_global": sampling_global_error_scale, + "probit_outcome_model": outcome_type == "binary", + } + mean_forest_params = {"sample_sigma2_leaf": sampling_leaf_scale} + variance_forest_params = {"num_trees": 20 if variance_forest else 0} + rfx_params = { + "model_spec": "custom" if random_effects == "none" else random_effects + } + + # Sample BART model + bart_model = BARTModel() + bart_model.sample( + X_train=X, + y_train=y, + leaf_basis_train=leaf_basis, + rfx_group_ids_train=rfx_group_ids, + rfx_basis_train=rfx_basis, + general_params=general_params, + mean_forest_params=mean_forest_params, + variance_forest_params=variance_forest_params, + random_effects_params=rfx_params, + ) + + # Unpack test set data + y_test = bart_data_test["y"] + X_test = bart_data_test["X"] + if leaf_reg != "none": + leaf_basis_test = bart_data_test["leaf_basis"] + else: + leaf_basis_test = None + if random_effects != "none": + rfx_group_ids_test = bart_data_test["rfx_group_ids"] + else: + rfx_group_ids_test = None + if random_effects == "custom": + rfx_basis_test = bart_data_test["rfx_basis"] + else: + rfx_basis_test = None + + # Predict on test set + mean_preds = bart_model.predict( + X=X_test, + leaf_basis=leaf_basis_test, + rfx_group_ids=rfx_group_ids_test, + rfx_basis=rfx_basis_test, + type="mean", + terms="all", + scale="probability" if outcome_type == "binary" else "linear", + ) + posterior_preds = bart_model.predict( + X=X_test, + leaf_basis=leaf_basis_test, + rfx_group_ids=rfx_group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="all", + scale="probability" if outcome_type == "binary" else "linear", + ) + + # Compute intervals + posterior_interval = bart_model.compute_posterior_interval( + terms="all", + level=0.95, + scale="probability" if outcome_type == "binary" else "linear", + X=X_test, + leaf_basis=leaf_basis_test, + rfx_group_ids=rfx_group_ids_test, + rfx_basis=rfx_basis_test, + ) + + # Sample posterior predictive + posterior_predictive_draws = bart_model.sample_posterior_predictive( + X=X_test, + leaf_basis=leaf_basis_test, + rfx_group_ids=rfx_group_ids_test, + rfx_basis=rfx_basis_test, + num_draws_per_sample=5, + ) + + +class TestAPICombinations: + def test_bart_api_combinations(self): + # RNG + random_seed = 101 + rng = np.random.default_rng(random_seed) + + # Overall, we have seven components of a BART sampler which can be on / off or set to different levels: + # 1. Leaf regression: none, univariate, multivariate + # 2. Variance forest: no, yes + # 3. Random effects: no, custom basis, `intercept_only` + # 4. Sampling global error scale: no, yes + # 5. Sampling leaf scale on mean forest: no, yes (only available for constant leaf or univariate leaf regression) + # 6. Outcome type: continuous (identity link), binary (probit link) + # 7. Number of chains: 1, >1 + # + # For each of the possible models this implies, + # we'd like to be sure that stochtree functions that operate on BART models + # will run without error. Since there are so many possible models implied by the + # options above, this test is designed to be quick (small sample size, low dimensional data) + # and we are only interested in ensuring no errors are triggered. + + # Generate data with random effects + n = 50 + p = 3 + num_basis = 2 + num_rfx_groups = 3 + num_rfx_basis = 2 + X = rng.uniform(0, 1, (n, p)) + leaf_basis = rng.uniform(0, 1, (n, num_basis)) + leaf_coefs = rng.uniform(0, 1, num_basis) + group_ids = rng.choice(num_rfx_groups, size=n) + rfx_basis = rng.uniform(0, 1, (n, num_rfx_basis)) + rfx_coefs = rng.uniform(0, 1, (num_rfx_groups, num_rfx_basis)) + mean_term = np.sin(X[:, 0]) * np.sum(leaf_basis * leaf_coefs, axis=1) + rfx_term = np.sum(rfx_coefs[group_ids - 1, :] * rfx_basis, axis=1) + E_y = mean_term + rfx_term + E_y = E_y - np.mean(E_y) + epsilon = rng.normal(0, 1, n) + y_continuous = E_y + epsilon + y_binary = (y_continuous > 0).astype(int) + + # Split into test and train sets + test_set_pct = 0.5 + sample_inds = np.arange(n) + train_inds, test_inds = train_test_split(sample_inds, test_size=0.5) + X_train = X[train_inds, :] + X_test = X[test_inds, :] + leaf_basis_train = leaf_basis[train_inds, :] + leaf_basis_test = leaf_basis[test_inds, :] + rfx_basis_train = rfx_basis[train_inds, :] + rfx_basis_test = rfx_basis[test_inds, :] + group_ids_train = group_ids[train_inds] + group_ids_test = group_ids[test_inds] + y_continuous_train = y_continuous[train_inds] + y_continuous_test = y_continuous[test_inds] + y_binary_train = y_binary[train_inds] + y_binary_test = y_binary[test_inds] + + # Run the power set of models + leaf_reg_options = ["none", "univariate", "multivariate"] + variance_forest_options = [False, True] + random_effects_options = ["none", "custom", "intercept_only"] + sampling_global_error_scale_options = [False, True] + sampling_leaf_scale_options = [False, True] + outcome_type_options = ["continuous", "binary"] + num_chains_options = [1, 3] + model_options_iter = itertools.product( + leaf_reg_options, + variance_forest_options, + random_effects_options, + sampling_global_error_scale_options, + sampling_leaf_scale_options, + outcome_type_options, + num_chains_options, + ) + for i, options in enumerate(model_options_iter): + print(f"i = {i}, options = {options}") + # Unpack BART train and test data + bart_data_train = {} + bart_data_test = {} + bart_data_train["X"] = X_train + bart_data_test["X"] = X_test + if options[5] == "continuous": + bart_data_train["y"] = y_continuous_train + bart_data_test["y"] = y_continuous_test + else: + bart_data_train["y"] = y_binary_train + bart_data_test["y"] = y_binary_test + if options[0] != "none": + if options[0] == "univariate": + bart_data_train["leaf_basis"] = leaf_basis_train[:, 0] + bart_data_test["leaf_basis"] = leaf_basis_test[:, 0] + else: + bart_data_train["leaf_basis"] = leaf_basis_train + bart_data_test["leaf_basis"] = leaf_basis_test + else: + bart_data_train["leaf_basis"] = None + bart_data_test["leaf_basis"] = None + if options[2] != "none": + bart_data_train["rfx_group_ids"] = group_ids_train + bart_data_test["rfx_group_ids"] = group_ids_test + else: + bart_data_train["rfx_group_ids"] = None + bart_data_test["rfx_group_ids"] = None + if options[2] == "custom": + bart_data_train["rfx_basis"] = rfx_basis_train + bart_data_test["rfx_basis"] = rfx_basis_test + else: + bart_data_train["rfx_basis"] = None + bart_data_test["rfx_basis"] = None + + # Determine whether this combination should throw an error, raise a warning, or run as intended + error_cond = (options[1]) and (options[5] == "binary") + warning_cond_1 = (options[4]) and (options[0] == "multivariate") + warning_message_1 = "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model." + warning_cond_2 = (options[3]) and (options[5] == "binary") + warning_message_2 = "Global error variance will not be sampled with a probit link as it is fixed at 1" + warning_cond = warning_cond_1 or warning_cond_2 + print(f"error_cond = {error_cond}, warning_cond = {warning_cond}") + if error_cond and warning_cond: + with pytest.raises(ValueError) as excinfo: + with pytest.warns(UserWarning) as warninfo: + run_bart_factorial( + bart_data_train=bart_data_train, + bart_data_test=bart_data_test, + leaf_reg=options[0], + variance_forest=options[1], + random_effects=options[2], + sampling_global_error_scale=options[3], + sampling_leaf_scale=options[4], + outcome_type=options[5], + num_chains=options[6], + ) + elif error_cond and not warning_cond: + with pytest.raises(ValueError) as excinfo: + run_bart_factorial( + bart_data_train=bart_data_train, + bart_data_test=bart_data_test, + leaf_reg=options[0], + variance_forest=options[1], + random_effects=options[2], + sampling_global_error_scale=options[3], + sampling_leaf_scale=options[4], + outcome_type=options[5], + num_chains=options[6], + ) + elif not error_cond and warning_cond: + with pytest.warns(UserWarning) as warninfo: + run_bart_factorial( + bart_data_train=bart_data_train, + bart_data_test=bart_data_test, + leaf_reg=options[0], + variance_forest=options[1], + random_effects=options[2], + sampling_global_error_scale=options[3], + sampling_leaf_scale=options[4], + outcome_type=options[5], + num_chains=options[6], + ) + else: + run_bart_factorial( + bart_data_train=bart_data_train, + bart_data_test=bart_data_test, + leaf_reg=options[0], + variance_forest=options[1], + random_effects=options[2], + sampling_global_error_scale=options[3], + sampling_leaf_scale=options[4], + outcome_type=options[5], + num_chains=options[6], + ) From 47dbf187ca1acb0195c34e5d9d482af19e1a0f8c Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 10 Dec 2025 01:47:30 -0500 Subject: [PATCH 05/12] Updated R package and API combination test --- R/bart.R | 10 + R/bcf.R | 25 +- R/posterior_transformation.R | 24 +- test/R/testthat/test-api-combinations.R | 742 +++++++++++++++++++++--- 4 files changed, 699 insertions(+), 102 deletions(-) diff --git a/R/bart.R b/R/bart.R index 82d6dab6..23fee012 100644 --- a/R/bart.R +++ b/R/bart.R @@ -835,6 +835,16 @@ bart <- function( } } + # Runtime checks for variance forest + if (include_variance_forest) { + if (sample_sigma2_global) { + warning( + "Global error variance will not be sampled with a heteroskedasticity forest" + ) + sample_sigma2_global <- F + } + } + # Handle standardization, prior calibration, and initialization of forest # differently for binary and continuous outcomes if (probit_outcome_model) { diff --git a/R/bcf.R b/R/bcf.R index 765347cb..94393656 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -897,14 +897,7 @@ bcf <- function( # Handle multivariate treatment has_multivariate_treatment <- ncol(Z_train) > 1 if (has_multivariate_treatment) { - # Disable adaptive coding, internal propensity model, and - # leaf scale sampling if treatment is multivariate - if (adaptive_coding) { - warning( - "Adaptive coding is incompatible with multivariate treatment and will be ignored" - ) - adaptive_coding <- FALSE - } + # Disable internal propensity model and leaf scale sampling if treatment is multivariate if (is.null(propensity_train)) { if (propensity_covariate != "none") { warning( @@ -1021,15 +1014,21 @@ bcf <- function( y_train <- as.matrix(y_train) } - # Check whether treatment is binary (specifically 0-1 binary) - binary_treatment <- length(unique(Z_train)) == 2 - if (binary_treatment) { - unique_treatments <- sort(unique(Z_train)) - if (!(all(unique_treatments == c(0, 1)))) binary_treatment <- FALSE + # Check whether treatment is binary and univariate (specifically 0-1 binary) + binary_treatment <- FALSE + if (!has_multivariate_treatment) { + binary_treatment <- length(unique(Z_train)) == 2 + if (binary_treatment) { + unique_treatments <- sort(unique(Z_train)) + if (!(all(unique_treatments == c(0, 1)))) binary_treatment <- FALSE + } } # Adaptive coding will be ignored for continuous / ordered categorical treatments if ((!binary_treatment) && (adaptive_coding)) { + warning( + "Adaptive coding is only compatible with binary (univariate) treatment and, as a result, will be ignored in sampling this model" + ) adaptive_coding <- FALSE } diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index 010bd6ff..be883198 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -586,16 +586,22 @@ sample_bcf_posterior_predictive <- function( "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) } - if (is.null(rfx_basis)) { - stop( - "'rfx_basis' must be provided in order to compute the requested intervals" - ) - } - if (!is.matrix(rfx_basis)) { - stop("'rfx_basis' must be a matrix") + + if (model_object$model_params$rfx_model_spec == "custom") { + if (is.null(rfx_basis)) { + stop( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + } } - if (nrow(rfx_basis) != nrow(X)) { - stop("'rfx_basis' must have the same number of rows as 'X'") + + if (!is.null(rfx_basis)) { + if (!is.matrix(rfx_basis)) { + stop("'rfx_basis' must be a matrix") + } + if (nrow(rfx_basis) != nrow(X)) { + stop("'rfx_basis' must have the same number of rows as 'X'") + } } } diff --git a/test/R/testthat/test-api-combinations.R b/test/R/testthat/test-api-combinations.R index b34d36c7..2fe27468 100644 --- a/test/R/testthat/test-api-combinations.R +++ b/test/R/testthat/test-api-combinations.R @@ -41,7 +41,7 @@ run_bart_factorial <- function( num_trees = ifelse(variance_forest, 20, 0) ) rfx_params <- list( - model_spec = ifelse(random_effects == "custom", "custom", random_effects) + model_spec = ifelse(random_effects == "none", "custom", random_effects) ) # Sample BART model @@ -121,8 +121,248 @@ run_bart_factorial <- function( ) } +run_bcf_factorial <- function( + bcf_data_train, + bcf_data_test, + treatment_type = "binary", + variance_forest = FALSE, + random_effects = "none", + sampling_global_error_scale = FALSE, + sampling_mu_leaf_scale = FALSE, + sampling_tau_leaf_scale = FALSE, + outcome_type = "continuous", + num_chains = 1, + adaptive_coding = TRUE, + include_propensity = TRUE +) { + # Unpack BART training data + y <- bcf_data_train[["y"]] + X <- bcf_data_train[["X"]] + Z <- bcf_data_train[["Z"]] + if (include_propensity) { + propensity_train <- bcf_data_train[["propensity"]] + } else { + propensity_train <- NULL + } + if (random_effects != "none") { + rfx_group_ids <- bcf_data_train[["rfx_group_ids"]] + } else { + rfx_group_ids <- NULL + } + if (random_effects == "custom") { + rfx_basis <- bcf_data_train[["rfx_basis"]] + } else { + rfx_basis <- NULL + } + + # Set BART model parameters + general_params <- list( + num_chains = num_chains, + sample_sigma2_global = sampling_global_error_scale, + probit_outcome_model = outcome_type == "binary", + adaptive_coding = adaptive_coding + ) + mu_forest_params <- list( + sample_sigma2_leaf = sampling_mu_leaf_scale + ) + tau_forest_params <- list( + sample_sigma2_leaf = sampling_tau_leaf_scale + ) + variance_forest_params <- list( + num_trees = ifelse(variance_forest, 20, 0) + ) + rfx_params <- list( + model_spec = ifelse(random_effects == "none", "custom", random_effects) + ) + + # Sample BART model + bcf_model <- stochtree::bcf( + X_train = X, + y_train = y, + Z_train = Z, + propensity_train = propensity_train, + rfx_group_ids_train = rfx_group_ids, + rfx_basis_train = rfx_basis, + general_params = general_params, + prognostic_forest_params = mu_forest_params, + treatment_effect_forest_params = tau_forest_params, + variance_forest_params = variance_forest_params, + random_effects_params = rfx_params + ) + + # Unpack test set data + y_test <- bcf_data_test[["y"]] + X_test <- bcf_data_test[["X"]] + Z_test <- bcf_data_test[["Z"]] + if (include_propensity) { + propensity_test <- bcf_data_test[["propensity"]] + } else { + propensity_test <- NULL + } + if (random_effects != "none") { + rfx_group_ids_test <- bcf_data_test[["rfx_group_ids"]] + } else { + rfx_group_ids_test <- NULL + } + if (random_effects == "custom") { + rfx_basis_test <- bcf_data_test[["rfx_basis"]] + } else { + rfx_basis_test <- NULL + } + + # Predict on test set + mean_preds <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = propensity_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + type = "mean", + terms = "all", + scale = ifelse(outcome_type == "binary", "probability", "linear") + ) + posterior_preds <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = propensity_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "all", + scale = ifelse(outcome_type == "binary", "probability", "linear") + ) + + # Compute intervals + posterior_interval <- compute_bcf_posterior_interval( + bcf_model, + terms = "all", + level = 0.95, + scale = ifelse(outcome_type == "binary", "probability", "linear"), + X = X_test, + Z = Z_test, + propensity = propensity_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test + ) + + # Sample posterior predictive + posterior_predictive_draws <- sample_bcf_posterior_predictive( + bcf_model, + X = X_test, + Z = Z_test, + propensity = propensity_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + num_draws_per_sample = 5 + ) +} + +# Construct chained expectations without writing out every combination of function calls +construct_chained_expectation_bart <- function( + error_cond, + warning_cond_1, + warning_cond_2, + warning_cond_3 +) { + # Build the chain from innermost to outermost + function_text <- "x" + if (warning_cond_1) { + function_text <- paste0( + "warning_fun_1(", + function_text, + ")" + ) + } + if (warning_cond_2) { + function_text <- paste0( + "warning_fun_2(", + function_text, + ")" + ) + } + if (warning_cond_3) { + function_text <- paste0( + "warning_fun_3(", + function_text, + ")" + ) + } + if (error_cond) { + function_text <- paste0( + "expect_error(", + function_text, + ")" + ) + } + return(as.function( + c(alist(x = ), parse(text = function_text)[[1]]), + envir = parent.frame() + )) +} + +construct_chained_expectation_bcf <- function( + error_cond, + warning_cond_1, + warning_cond_2, + warning_cond_3, + warning_cond_4, + warning_cond_5 +) { + # Build the chain from innermost to outermost + function_text <- "x" + if (warning_cond_1) { + function_text <- paste0( + "warning_fun_1(", + function_text, + ")" + ) + } + if (warning_cond_2) { + function_text <- paste0( + "warning_fun_2(", + function_text, + ")" + ) + } + if (warning_cond_3) { + function_text <- paste0( + "warning_fun_3(", + function_text, + ")" + ) + } + if (warning_cond_4) { + function_text <- paste0( + "warning_fun_4(", + function_text, + ")" + ) + } + if (warning_cond_5) { + function_text <- paste0( + "warning_fun_5(", + function_text, + ")" + ) + } + if (error_cond) { + function_text <- paste0( + "expect_error(", + function_text, + ")" + ) + } + return(as.function( + c(alist(x = ), parse(text = function_text)[[1]]), + envir = parent.frame() + )) +} + test_that("Quick check of interactions between components of BART functionality", { skip_on_cran() + skip_on_ci() # Overall, we have seven components of a BART sampler which can be on / off or set to different levels: # 1. Leaf regression: none, univariate, multivariate @@ -200,96 +440,194 @@ test_that("Quick check of interactions between components of BART functionality" stringsAsFactors = FALSE ) for (i in 1:nrow(model_options_df)) { + # Determine which errors and warnings should be triggered error_cond <- (model_options_df$variance_forest[i]) && (model_options_df$outcome_type[i] == "binary") warning_cond_1 <- (model_options_df$sampling_leaf_scale[i]) && (model_options_df$leaf_reg[i] == "multivariate") - warning_message_1 <- "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model." + warning_fun_1 <- function(x) { + expect_warning( + x, + "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model." + ) + } warning_cond_2 <- (model_options_df$sampling_global_error_scale[i]) && (model_options_df$outcome_type[i] == "binary") - warning_message_2 <- "Global error variance will not be sampled with a probit link as it is fixed at 1" - warning_cond <- warning_cond_1 || warning_cond_2 - if (error_cond && warning_cond) { - if (warning_cond_1 && warning_cond_2) { - test_fun <- function(x) { - expect_error( - expect_warning( - expect_warning(x, warning_message_1), - warning_message_2 - ) - ) - } - } else if (warning_cond_1) { - test_fun <- function(x) { - expect_error( - expect_warning(x, warning_message_1) - ) - } - } else { - test_fun <- function(x) { - expect_error( - expect_warning(x, warning_message_2) - ) - } - } - } else if (error_cond && !warning_cond) { - test_fun <- expect_error - } else if (!error_cond && warning_cond) { - if (warning_cond_1 && warning_cond_2) { - test_fun <- function(x) { - expect_warning( - expect_warning(x, warning_message_1), - warning_message_2 - ) - } - } else if (warning_cond_1) { - test_fun <- function(x) { - expect_warning(x, warning_message_1) - } - } else { - test_fun <- function(x) { - expect_warning(x, warning_message_2) - } - } + warning_fun_2 <- function(x) { + expect_warning( + x, + "Global error variance will not be sampled with a probit link as it is fixed at 1" + ) + } + warning_cond_3 <- (model_options_df$sampling_global_error_scale[i]) && + (model_options_df$variance_forest[i]) + warning_fun_3 <- function(x) { + expect_warning( + x, + "Global error variance will not be sampled with a heteroskedasticity" + ) + } + warning_cond <- warning_cond_1 || warning_cond_2 || warning_cond_3 + + if (error_cond || warning_cond) { + # test_fun <- create_test_function( + # error_cond = FALSE, + # warning_conditions = c(warning_cond_1, warning_cond_2, warning_cond_3), + # warning_functions = c(warning_fun_1, warning_fun_2, warning_fun_3) + # ) + # test_fun <- function(x) { + # result <- x + # if (warning_cond_1) { + # result <- warning_fun_3(result) + # } + # if (warning_cond_2) { + # result <- warning_fun_2(result) + # } + # if (warning_cond_3) { + # result <- warning_fun_1(result) + # } + # if (error_cond) { + # result <- expect_error(result) + # } + # } + # if (error_cond && warning_cond_1 && warning_cond_2 && warning_cond_3) { + # test_fun <- function(x) { + # expect_error(warning_fun_1(warning_fun_2(warning_fun_3(x)))) + # } + # } else if ( + # error_cond && warning_cond_1 && warning_cond_2 && !warning_cond_3 + # ) { + # test_fun <- function(x) { + # expect_error(warning_fun_1(warning_fun_2(x))) + # } + # } else if ( + # error_cond && warning_cond_1 && !warning_cond_2 && warning_cond_3 + # ) { + # test_fun <- function(x) { + # expect_error(warning_fun_1(warning_fun_3(x))) + # } + # } else if ( + # error_cond && !warning_cond_1 && warning_cond_2 && warning_cond_3 + # ) { + # test_fun <- function(x) { + # expect_error(warning_fun_2(warning_fun_3(x))) + # } + # } else if ( + # error_cond && !warning_cond_1 && !warning_cond_2 && warning_cond_3 + # ) { + # test_fun <- function(x) { + # expect_error(warning_fun_3(x)) + # } + # } else if ( + # error_cond && !warning_cond_1 && warning_cond_2 && !warning_cond_3 + # ) { + # test_fun <- function(x) { + # expect_error(warning_fun_2(x)) + # } + # } else if ( + # error_cond && warning_cond_1 && !warning_cond_2 && !warning_cond_3 + # ) { + # test_fun <- function(x) { + # expect_error(warning_fun_1(x)) + # } + # } else if ( + # error_cond && !warning_cond_1 && !warning_cond_2 && !warning_cond_3 + # ) { + # test_fun <- function(x) { + # expect_error(x) + # } + # } else if ( + # !error_cond && warning_cond_1 && warning_cond_2 && warning_cond_3 + # ) { + # test_fun <- function(x) { + # warning_fun_1(warning_fun_2(warning_fun_3(x))) + # } + # } else if ( + # !error_cond && warning_cond_1 && warning_cond_2 && !warning_cond_3 + # ) { + # test_fun <- function(x) { + # warning_fun_1(warning_fun_2(x)) + # } + # } else if ( + # !error_cond && warning_cond_1 && !warning_cond_2 && warning_cond_3 + # ) { + # test_fun <- function(x) { + # warning_fun_1(warning_fun_3(x)) + # } + # } else if ( + # !error_cond && !warning_cond_1 && warning_cond_2 && warning_cond_3 + # ) { + # test_fun <- function(x) { + # warning_fun_2(warning_fun_3(x)) + # } + # } else if ( + # !error_cond && !warning_cond_1 && !warning_cond_2 && warning_cond_3 + # ) { + # test_fun <- function(x) { + # warning_fun_3(x) + # } + # } else if ( + # !error_cond && !warning_cond_1 && warning_cond_2 && !warning_cond_3 + # ) { + # test_fun <- function(x) { + # warning_fun_2(x) + # } + # } else if ( + # !error_cond && warning_cond_1 && !warning_cond_2 && !warning_cond_3 + # ) { + # test_fun <- function(x) { + # warning_fun_1(x) + # } + # } + test_fun <- construct_chained_expectation_bart( + error_cond = error_cond, + warning_cond_1 = warning_cond_1, + warning_cond_2 = warning_cond_2, + warning_cond_3 = warning_cond_3 + ) } else { test_fun <- expect_no_error } - test_fun({ - bart_data_train <- list(X = X_train) - bart_data_test <- list(X = X_test) - if (model_options_df$outcome_type[i] == "continuous") { - bart_data_train[["y"]] <- y_continuous_train - bart_data_test[["y"]] <- y_continuous_test - } else { - bart_data_train[["y"]] <- y_binary_train - bart_data_test[["y"]] <- y_binary_test - } - if (model_options_df$leaf_reg[i] != "none") { - if (model_options_df$leaf_reg[i] == "univariate") { - bart_data_train[["leaf_basis"]] <- leaf_basis_train[, 1, drop = FALSE] - bart_data_test[["leaf_basis"]] <- leaf_basis_test[, 1, drop = FALSE] - } else { - bart_data_train[["leaf_basis"]] <- leaf_basis_train - bart_data_test[["leaf_basis"]] <- leaf_basis_test - } - } else { - bart_data_train[["leaf_basis"]] <- NULL - bart_data_test[["leaf_basis"]] <- NULL - } - if (model_options_df$random_effects[i] != "none") { - bart_data_train[["rfx_group_ids"]] <- group_ids_train - bart_data_test[["rfx_group_ids"]] <- group_ids_test - } else { - bart_data_train[["rfx_group_ids"]] <- NULL - bart_data_test[["rfx_group_ids"]] <- NULL - } - if (model_options_df$random_effects[i] == "custom") { - bart_data_train[["rfx_basis"]] <- rfx_basis_train - bart_data_test[["rfx_basis"]] <- rfx_basis_test + + # Prepare test function arguments + bart_data_train <- list(X = X_train) + bart_data_test <- list(X = X_test) + if (model_options_df$outcome_type[i] == "continuous") { + bart_data_train[["y"]] <- y_continuous_train + bart_data_test[["y"]] <- y_continuous_test + } else { + bart_data_train[["y"]] <- y_binary_train + bart_data_test[["y"]] <- y_binary_test + } + if (model_options_df$leaf_reg[i] != "none") { + if (model_options_df$leaf_reg[i] == "univariate") { + bart_data_train[["leaf_basis"]] <- leaf_basis_train[, 1, drop = FALSE] + bart_data_test[["leaf_basis"]] <- leaf_basis_test[, 1, drop = FALSE] } else { - bart_data_train[["rfx_basis"]] <- NULL - bart_data_test[["rfx_basis"]] <- NULL + bart_data_train[["leaf_basis"]] <- leaf_basis_train + bart_data_test[["leaf_basis"]] <- leaf_basis_test } + } else { + bart_data_train[["leaf_basis"]] <- NULL + bart_data_test[["leaf_basis"]] <- NULL + } + if (model_options_df$random_effects[i] != "none") { + bart_data_train[["rfx_group_ids"]] <- group_ids_train + bart_data_test[["rfx_group_ids"]] <- group_ids_test + } else { + bart_data_train[["rfx_group_ids"]] <- NULL + bart_data_test[["rfx_group_ids"]] <- NULL + } + if (model_options_df$random_effects[i] == "custom") { + bart_data_train[["rfx_basis"]] <- rfx_basis_train + bart_data_test[["rfx_basis"]] <- rfx_basis_test + } else { + bart_data_train[["rfx_basis"]] <- NULL + bart_data_test[["rfx_basis"]] <- NULL + } + + # Apply testthat expectation(s) + test_fun({ run_bart_factorial( bart_data_train = bart_data_train, bart_data_test = bart_data_test, @@ -308,3 +646,247 @@ test_that("Quick check of interactions between components of BART functionality" }) } }) + +test_that("Quick check of interactions between components of BCF functionality", { + skip_on_cran() + skip_on_ci() + + # Overall, we have nine components of a BCF sampler which can be on / off or set to different levels: + # 1. treatment: binary, univariate continuous, multivariate + # 2. Variance forest: no, yes + # 3. Random effects: no, custom basis, `intercept_only`, `intercept_plus_treatment` + # 4. Sampling global error scale: no, yes + # 5. Sampling leaf scale on prognostic forest: no, yes + # 6. Sampling leaf scale on treatment forest: no, yes (only available for univariate treatment) + # 7. Outcome type: continuous (identity link), binary (probit link) + # 8. Number of chains: 1, >1 + # 9. Adaptive coding: no, yes + # + # For each of the possible models this implies, + # we'd like to be sure that stochtree functions that operate on BART models + # will run without error. Since there are so many possible models implied by the + # options above, this test is designed to be quick (small sample size, low dimensional data) + # and we are only interested in ensuring no errors are triggered. + + # Generate data with random effects + n <- 50 + p <- 3 + num_rfx_groups <- 3 + num_rfx_basis <- 2 + X <- matrix(runif(n * p), ncol = p) + binary_treatment <- rbinom(n, 1, 0.5) + continuous_treatment <- runif(n, 0, 1) + multivariate_treatment <- cbind( + binary_treatment, + continuous_treatment + ) + group_ids <- sample(1:num_rfx_groups, n, replace = T) + rfx_basis <- matrix(runif(n * num_rfx_basis), ncol = num_rfx_basis) + rfx_coefs <- matrix( + runif(num_rfx_groups * num_rfx_basis), + ncol = num_rfx_basis + ) + propensity <- runif(n) + prognostic_term <- sin(X[, 1]) + binary_treatment_effect <- X[, 2] + continuous_treatment_effect <- X[, 3] + rfx_term <- rowSums(rfx_coefs[group_ids, ] * rfx_basis) + E_y <- prognostic_term + + binary_treatment_effect * binary_treatment + + continuous_treatment_effect * continuous_treatment + + rfx_term + E_y <- E_y - mean(E_y) + epsilon <- rnorm(n, 0, 1) + y_continuous <- E_y + epsilon + y_binary <- 1 * (y_continuous > 0) + + # Split into test and train sets + test_set_pct <- 0.5 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + binary_treatment_test <- binary_treatment[test_inds] + binary_treatment_train <- binary_treatment[train_inds] + propensity_test <- propensity[test_inds] + propensity_train <- propensity[train_inds] + continuous_treatment_test <- continuous_treatment[test_inds] + continuous_treatment_train <- continuous_treatment[train_inds] + multivariate_treatment_test <- multivariate_treatment[test_inds, ] + multivariate_treatment_train <- multivariate_treatment[train_inds, ] + rfx_basis_test <- rfx_basis[test_inds, ] + rfx_basis_train <- rfx_basis[train_inds, ] + group_ids_test <- group_ids[test_inds] + group_ids_train <- group_ids[train_inds] + y_continuous_test <- y_continuous[test_inds] + y_continuous_train <- y_continuous[train_inds] + y_binary_test <- y_binary[test_inds] + y_binary_train <- y_binary[train_inds] + + # Run the power set of models + treatment_options <- c("binary", "univariate_continuous", "multivariate") + variance_forest_options <- c(FALSE, TRUE) + random_effects_options <- c( + "none", + "custom", + "intercept_only", + "intercept_plus_treatment" + ) + sampling_global_error_scale_options <- c(FALSE, TRUE) + sampling_mu_leaf_scale_options <- c(FALSE, TRUE) + sampling_tau_leaf_scale_options <- c(FALSE, TRUE) + outcome_type_options <- c("continuous", "binary") + num_chains_options <- c(1, 3) + adaptive_coding_options <- c(FALSE, TRUE) + include_propensity_options <- c(FALSE, TRUE) + model_options_df <- expand.grid( + treatment_type = treatment_options, + variance_forest = variance_forest_options, + random_effects = random_effects_options, + sampling_global_error_scale = sampling_global_error_scale_options, + sampling_mu_leaf_scale = sampling_mu_leaf_scale_options, + sampling_tau_leaf_scale = sampling_tau_leaf_scale_options, + outcome_type = outcome_type_options, + num_chains = num_chains_options, + adaptive_coding = adaptive_coding_options, + include_propensity = include_propensity_options, + stringsAsFactors = FALSE + ) + for (i in 1:nrow(model_options_df)) { + # Determine which errors and warnings should be triggered + error_cond <- (model_options_df$variance_forest[i]) && + (model_options_df$outcome_type[i] == "binary") + warning_cond_1 <- (model_options_df$sampling_tau_leaf_scale[i]) && + (model_options_df$treatment_type[i] == "multivariate") + warning_fun_1 <- function(x) { + expect_warning( + x, + "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model.", + fixed = TRUE + ) + } + warning_cond_2 <- (!model_options_df$include_propensity[i]) && + (model_options_df$treatment_type[i] == "multivariate") + warning_fun_2 <- function(x) { + expect_warning( + x, + "No propensities were provided for the multivariate treatment; an internal propensity model will not be fitted to the multivariate treatment and propensity_covariate will be set to 'none'", + fixed = TRUE + ) + } + warning_cond_3 <- (model_options_df$adaptive_coding[i]) && + (model_options_df$treatment_type[i] != "binary") + warning_fun_3 <- function(x) { + expect_warning( + x, + "Adaptive coding is only compatible with binary (univariate) treatment and, as a result, will be ignored in sampling this model", + fixed = TRUE + ) + } + warning_cond_4 <- (model_options_df$sampling_global_error_scale[i]) && + (model_options_df$outcome_type[i] == "binary") + warning_fun_4 <- function(x) { + expect_warning( + x, + "Global error variance will not be sampled with a probit link as it is fixed at 1", + fixed = TRUE + ) + } + warning_cond_5 <- (model_options_df$sampling_global_error_scale[i]) && + (model_options_df$variance_forest[i]) + warning_fun_5 <- function(x) { + expect_warning( + x, + "Global error variance will not be sampled with a heteroskedasticity", + fixed = TRUE + ) + } + warning_cond <- (warning_cond_1 || + warning_cond_2 || + warning_cond_3 || + warning_cond_4 || + warning_cond_5) + + # Generate something like the below code but for all five warnings + if (error_cond || warning_cond) { + test_fun <- construct_chained_expectation_bcf( + error_cond = error_cond, + warning_cond_1 = warning_cond_1, + warning_cond_2 = warning_cond_2, + warning_cond_3 = warning_cond_3, + warning_cond_4 = warning_cond_4, + warning_cond_5 = warning_cond_5 + ) + } else { + test_fun <- expect_no_error + } + + # Prepare test function arguments + bcf_data_train <- list(X = X_train) + bcf_data_test <- list(X = X_test) + if (model_options_df$outcome_type[i] == "continuous") { + bcf_data_train[["y"]] <- y_continuous_train + bcf_data_test[["y"]] <- y_continuous_test + } else { + bcf_data_train[["y"]] <- y_binary_train + bcf_data_test[["y"]] <- y_binary_test + } + if (model_options_df$include_propensity[i]) { + bcf_data_train[["propensity"]] <- propensity_train + bcf_data_test[["propensity"]] <- propensity_test + } else { + bcf_data_train[["propensity"]] <- NULL + bcf_data_test[["propensity"]] <- NULL + } + if (model_options_df$treatment_type[i] == "binary") { + bcf_data_train[["Z"]] <- binary_treatment_train + bcf_data_test[["Z"]] <- binary_treatment_test + } else if (model_options_df$treatment_type[i] == "univariate_continuous") { + bcf_data_train[["Z"]] <- continuous_treatment_train + bcf_data_test[["Z"]] <- continuous_treatment_test + } else { + bcf_data_train[["Z"]] <- multivariate_treatment_train + bcf_data_test[["Z"]] <- multivariate_treatment_test + } + if (model_options_df$random_effects[i] != "none") { + bcf_data_train[["rfx_group_ids"]] <- group_ids_train + bcf_data_test[["rfx_group_ids"]] <- group_ids_test + } else { + bcf_data_train[["rfx_group_ids"]] <- NULL + bcf_data_test[["rfx_group_ids"]] <- NULL + } + if (model_options_df$random_effects[i] == "custom") { + bcf_data_train[["rfx_basis"]] <- rfx_basis_train + bcf_data_test[["rfx_basis"]] <- rfx_basis_test + } else { + bcf_data_train[["rfx_basis"]] <- NULL + bcf_data_test[["rfx_basis"]] <- NULL + } + + # Apply testthat expectation(s) + test_fun({ + run_bcf_factorial( + bcf_data_train = bcf_data_train, + bcf_data_test = bcf_data_test, + treatment_type = model_options_df$treatment_type[i], + variance_forest = model_options_df$variance_forest[i], + random_effects = model_options_df$random_effects[i], + sampling_global_error_scale = model_options_df$sampling_global_error_scale[ + i + ], + sampling_mu_leaf_scale = model_options_df$sampling_mu_leaf_scale[ + i + ], + sampling_tau_leaf_scale = model_options_df$sampling_tau_leaf_scale[ + i + ], + outcome_type = model_options_df$outcome_type[i], + num_chains = model_options_df$num_chains[i], + adaptive_coding = model_options_df$adaptive_coding[i], + include_propensity = model_options_df$include_propensity[i] + ) + }) + } +}) From 33ad242db75b0cc3adf848423e25daf70f814736 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 10 Dec 2025 01:47:48 -0500 Subject: [PATCH 06/12] Removed old code --- test/R/testthat/test-api-combinations.R | 109 ------------------------ 1 file changed, 109 deletions(-) diff --git a/test/R/testthat/test-api-combinations.R b/test/R/testthat/test-api-combinations.R index 2fe27468..e74d87d9 100644 --- a/test/R/testthat/test-api-combinations.R +++ b/test/R/testthat/test-api-combinations.R @@ -470,115 +470,6 @@ test_that("Quick check of interactions between components of BART functionality" warning_cond <- warning_cond_1 || warning_cond_2 || warning_cond_3 if (error_cond || warning_cond) { - # test_fun <- create_test_function( - # error_cond = FALSE, - # warning_conditions = c(warning_cond_1, warning_cond_2, warning_cond_3), - # warning_functions = c(warning_fun_1, warning_fun_2, warning_fun_3) - # ) - # test_fun <- function(x) { - # result <- x - # if (warning_cond_1) { - # result <- warning_fun_3(result) - # } - # if (warning_cond_2) { - # result <- warning_fun_2(result) - # } - # if (warning_cond_3) { - # result <- warning_fun_1(result) - # } - # if (error_cond) { - # result <- expect_error(result) - # } - # } - # if (error_cond && warning_cond_1 && warning_cond_2 && warning_cond_3) { - # test_fun <- function(x) { - # expect_error(warning_fun_1(warning_fun_2(warning_fun_3(x)))) - # } - # } else if ( - # error_cond && warning_cond_1 && warning_cond_2 && !warning_cond_3 - # ) { - # test_fun <- function(x) { - # expect_error(warning_fun_1(warning_fun_2(x))) - # } - # } else if ( - # error_cond && warning_cond_1 && !warning_cond_2 && warning_cond_3 - # ) { - # test_fun <- function(x) { - # expect_error(warning_fun_1(warning_fun_3(x))) - # } - # } else if ( - # error_cond && !warning_cond_1 && warning_cond_2 && warning_cond_3 - # ) { - # test_fun <- function(x) { - # expect_error(warning_fun_2(warning_fun_3(x))) - # } - # } else if ( - # error_cond && !warning_cond_1 && !warning_cond_2 && warning_cond_3 - # ) { - # test_fun <- function(x) { - # expect_error(warning_fun_3(x)) - # } - # } else if ( - # error_cond && !warning_cond_1 && warning_cond_2 && !warning_cond_3 - # ) { - # test_fun <- function(x) { - # expect_error(warning_fun_2(x)) - # } - # } else if ( - # error_cond && warning_cond_1 && !warning_cond_2 && !warning_cond_3 - # ) { - # test_fun <- function(x) { - # expect_error(warning_fun_1(x)) - # } - # } else if ( - # error_cond && !warning_cond_1 && !warning_cond_2 && !warning_cond_3 - # ) { - # test_fun <- function(x) { - # expect_error(x) - # } - # } else if ( - # !error_cond && warning_cond_1 && warning_cond_2 && warning_cond_3 - # ) { - # test_fun <- function(x) { - # warning_fun_1(warning_fun_2(warning_fun_3(x))) - # } - # } else if ( - # !error_cond && warning_cond_1 && warning_cond_2 && !warning_cond_3 - # ) { - # test_fun <- function(x) { - # warning_fun_1(warning_fun_2(x)) - # } - # } else if ( - # !error_cond && warning_cond_1 && !warning_cond_2 && warning_cond_3 - # ) { - # test_fun <- function(x) { - # warning_fun_1(warning_fun_3(x)) - # } - # } else if ( - # !error_cond && !warning_cond_1 && warning_cond_2 && warning_cond_3 - # ) { - # test_fun <- function(x) { - # warning_fun_2(warning_fun_3(x)) - # } - # } else if ( - # !error_cond && !warning_cond_1 && !warning_cond_2 && warning_cond_3 - # ) { - # test_fun <- function(x) { - # warning_fun_3(x) - # } - # } else if ( - # !error_cond && !warning_cond_1 && warning_cond_2 && !warning_cond_3 - # ) { - # test_fun <- function(x) { - # warning_fun_2(x) - # } - # } else if ( - # !error_cond && warning_cond_1 && !warning_cond_2 && !warning_cond_3 - # ) { - # test_fun <- function(x) { - # warning_fun_1(x) - # } - # } test_fun <- construct_chained_expectation_bart( error_cond = error_cond, warning_cond_1 = warning_cond_1, From 5e01ae632909ed6faddef94e1678e672a704d361 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 10 Dec 2025 16:02:50 -0500 Subject: [PATCH 07/12] Updated grid of model checks and R / Python errors and warnings --- R/bcf.R | 38 ++- stochtree/bart.py | 10 + stochtree/bcf.py | 163 ++++++----- test/R/testthat/test-api-combinations.R | 27 +- test/python/conftest.py | 21 ++ test/python/test_api_combinations.py | 343 +++++++++++++++++++++++- 6 files changed, 511 insertions(+), 91 deletions(-) create mode 100644 test/python/conftest.py diff --git a/R/bcf.R b/R/bcf.R index 94393656..c634295b 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -942,21 +942,31 @@ bcf <- function( } has_basis_rfx <- TRUE num_basis_rfx <- ncol(rfx_basis_train) - } else if (rfx_model_spec == "intercept_only") { - rfx_basis_train <- matrix( - rep(1, nrow(X_train)), - nrow = nrow(X_train), - ncol = 1 - ) - has_basis_rfx <- TRUE - num_basis_rfx <- 1 } else if (rfx_model_spec == "intercept_plus_treatment") { - rfx_basis_train <- cbind( - rep(1, nrow(X_train)), - Z_train - ) - has_basis_rfx <- TRUE - num_basis_rfx <- 1 + ncol(Z_train) + if (has_multivariate_treatment) { + warning( + "Random effects `intercept_plus_treatment` specification is not currently implemented for multivariate treatments. This model will be fit under the `intercept_only` specification instead. Please provide a custom `rfx_basis_train` if you wish to have random slopes on multivariate treatment variables." + ) + rfx_model_spec <- "intercept_only" + } + } + if (is.null(rfx_basis_train)) { + if (rfx_model_spec == "intercept_only") { + rfx_basis_train <- matrix( + rep(1, nrow(X_train)), + nrow = nrow(X_train), + ncol = 1 + ) + has_basis_rfx <- TRUE + num_basis_rfx <- 1 + } else { + rfx_basis_train <- cbind( + rep(1, nrow(X_train)), + Z_train + ) + has_basis_rfx <- TRUE + num_basis_rfx <- 1 + ncol(Z_train) + } } num_rfx_groups <- length(unique(rfx_group_ids_train)) num_rfx_components <- ncol(rfx_basis_train) diff --git a/stochtree/bart.py b/stochtree/bart.py index 5d21c48e..832d9d00 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -888,6 +888,14 @@ def sample( raise ValueError( "We do not support heteroskedasticity with a probit link" ) + + # Runtime checks for variance forest + if self.include_variance_forest: + if sample_sigma2_global: + warnings.warn( + "Sampling global error variance not yet supported for models with variance forests, so the global error variance parameter will not be sampled in this model." + ) + sample_sigma2_global = False # Handle standardization, prior calibration, and initialization of forest # differently for binary and continuous outcomes @@ -2333,6 +2341,7 @@ def compute_posterior_interval( raise ValueError( "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" ) + if rfx_basis is not None: if not isinstance(rfx_basis, np.ndarray): raise ValueError("'rfx_basis' must be a numpy array") if rfx_basis.shape[0] != X.shape[0]: @@ -2440,6 +2449,7 @@ def sample_posterior_predictive( raise ValueError( "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" ) + if rfx_basis is not None: if not isinstance(rfx_basis, np.ndarray): raise ValueError("'rfx_basis' must be a numpy array") if rfx_basis.shape[0] != X.shape[0]: diff --git a/stochtree/bcf.py b/stochtree/bcf.py index ac98fdbb..dfe0610e 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -1320,25 +1320,38 @@ def sample( self.p_x = X_train_processed.shape[1] # Check whether treatment is binary - self.binary_treatment = np.unique(Z_train).size == 2 + self.binary_treatment = False + if not self.multivariate_treatment: + self.binary_treatment = np.unique(Z_train).size == 2 + if self.binary_treatment: + unique_treatments = np.squeeze(np.unique(Z_train)).tolist() + if not all(i in [0,1] for i in unique_treatments): + self.binary_treatment = False # Adaptive coding will be ignored for continuous / ordered categorical treatments self.adaptive_coding = adaptive_coding if adaptive_coding and not self.binary_treatment: - self.adaptive_coding = False - if adaptive_coding and self.multivariate_treatment: + warnings.warn( + "Adaptive coding is only compatible with binary (univariate) treatment and, as a result, will be ignored in sampling this model" + ) self.adaptive_coding = False # Sampling sigma2_leaf_tau will be ignored for multivariate treatments if sample_sigma2_leaf_tau and self.multivariate_treatment: + warnings.warn( + "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model." + ) sample_sigma2_leaf_tau = False # Check if user has provided propensities that are needed in the model if propensity_train is None and propensity_covariate != "none": + # Disable internal propensity model if treatment is multivariate if self.multivariate_treatment: - raise ValueError( - "Propensities must be provided (via propensity_train and / or propensity_test parameters) or omitted by setting propensity_covariate = 'none' for multivariate treatments" + warnings.warn( + "No propensities were provided for the multivariate treatment; an internal propensity model will not be fitted to the multivariate treatment and propensity_covariate will be set to 'none'" ) + propensity_covariate = "none" + self.internal_propensity_model = True else: self.bart_propensity_model = BARTModel() num_gfr_propensity = 10 @@ -1373,6 +1386,64 @@ def sample( self.internal_propensity_model = True else: self.internal_propensity_model = False + + # Runtime checks on RFX group ids + self.has_rfx = False + has_rfx_test = False + if rfx_group_ids_train is not None: + self.has_rfx = True + if rfx_group_ids_test is not None: + has_rfx_test = True + if not np.all(np.isin(rfx_group_ids_test, rfx_group_ids_train)): + raise ValueError( + "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" + ) + + # Handle the rfx basis matrices + self.has_rfx_basis = False + self.num_rfx_basis = 0 + if self.has_rfx: + if self.rfx_model_spec == "custom": + if rfx_basis_train is None: + raise ValueError( + "rfx_basis_train must be provided when rfx_model_spec = 'custom'" + ) + elif self.rfx_model_spec == "intercept_plus_treatment": + if self.multivariate_treatment: + warnings.warn( + "Random effects `intercept_plus_treatment` specification is not currently implemented for multivariate treatments. This model will be fit under the `intercept_only` specification instead. Please provide a custom `rfx_basis_train` if you wish to have random slopes on multivariate treatment variables." + ) + self.rfx_model_spec = "intercept_only" + if rfx_basis_train is None: + if self.rfx_model_spec == "intercept_only": + rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1)) + else: + rfx_basis_train = np.concatenate( + (np.ones((rfx_group_ids_train.shape[0], 1)), Z_train), axis=1 + ) + + self.has_rfx_basis = True + self.num_rfx_basis = rfx_basis_train.shape[1] + num_rfx_groups = np.unique(rfx_group_ids_train).shape[0] + num_rfx_components = rfx_basis_train.shape[1] + if num_rfx_groups == 1: + warnings.warn( + "Only one group was provided for random effect sampling, so the random effects model is likely overkill" + ) + if has_rfx_test: + if self.rfx_model_spec == "custom": + if rfx_basis_test is None: + raise ValueError( + "rfx_basis_test must be provided when rfx_model_spec = 'custom' and a test set is provided" + ) + elif self.rfx_model_spec == "intercept_only": + if rfx_basis_test is None: + rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) + elif self.rfx_model_spec == "intercept_plus_treatment": + if rfx_basis_test is None: + rfx_basis_test = np.concatenate( + (np.ones((rfx_group_ids_test.shape[0], 1)), Z_test), axis=1 + ) # Preliminary runtime checks for probit link if self.probit_outcome_model: @@ -1385,13 +1456,21 @@ def sample( raise ValueError( "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1" ) + if sample_sigma2_global: + warnings.warn( + "Global error variance will not be sampled with a probit link as it is fixed at 1" + ) + sample_sigma2_global = False if self.include_variance_forest: raise ValueError( "We do not support heteroskedasticity with a probit link" ) + + # Runtime checks for variance forest + if self.include_variance_forest: if sample_sigma2_global: warnings.warn( - "Global error variance will not be sampled with a probit link as it is fixed at 1" + "Sampling global error variance not yet supported for models with variance forests, so the global error variance parameter will not be sampled in this model." ) sample_sigma2_global = False @@ -1550,58 +1629,6 @@ def sample( if not b_forest: b_forest = 1.0 - # Runtime checks on RFX group ids - self.has_rfx = False - has_rfx_test = False - if rfx_group_ids_train is not None: - self.has_rfx = True - if rfx_group_ids_test is not None: - has_rfx_test = True - if not np.all(np.isin(rfx_group_ids_test, rfx_group_ids_train)): - raise ValueError( - "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" - ) - - # Handle the rfx basis matrices - self.has_rfx_basis = False - self.num_rfx_basis = 0 - if self.has_rfx: - if self.rfx_model_spec == "custom": - if rfx_basis_train is None: - raise ValueError( - "rfx_basis_train must be provided when rfx_model_spec = 'custom'" - ) - elif self.rfx_model_spec == "intercept_only": - if rfx_basis_train is None: - rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1)) - elif self.rfx_model_spec == "intercept_plus_treatment": - if rfx_basis_train is None: - rfx_basis_train = np.concatenate( - (np.ones((rfx_group_ids_train.shape[0], 1)), Z_train), axis=1 - ) - self.has_rfx_basis = True - self.num_rfx_basis = rfx_basis_train.shape[1] - num_rfx_groups = np.unique(rfx_group_ids_train).shape[0] - num_rfx_components = rfx_basis_train.shape[1] - if num_rfx_groups == 1: - warnings.warn( - "Only one group was provided for random effect sampling, so the random effects model is likely overkill" - ) - if has_rfx_test: - if self.rfx_model_spec == "custom": - if rfx_basis_test is None: - raise ValueError( - "rfx_basis_test must be provided when rfx_model_spec = 'custom' and a test set is provided" - ) - elif self.rfx_model_spec == "intercept_only": - if rfx_basis_test is None: - rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) - elif self.rfx_model_spec == "intercept_plus_treatment": - if rfx_basis_test is None: - rfx_basis_test = np.concatenate( - (np.ones((rfx_group_ids_test.shape[0], 1)), Z_test), axis=1 - ) - # Set up random effects structures if self.has_rfx: # Prior parameters @@ -3570,14 +3597,18 @@ def sample_posterior_predictive( raise ValueError( "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) - if rfx_basis is None: - raise ValueError( - "'rfx_basis' must be provided in order to compute the requested intervals" - ) - if not isinstance(rfx_basis, np.ndarray): - raise ValueError("'rfx_basis' must be a numpy array") - if rfx_basis.shape[0] != X.shape[0]: - raise ValueError("'rfx_basis' must have the same number of rows as 'X'") + if self.rfx_model_spec == "custom": + if rfx_basis is None: + raise ValueError( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + if rfx_basis is not None: + if not isinstance(rfx_basis, np.ndarray): + raise ValueError("'rfx_basis' must be a numpy array") + if rfx_basis.shape[0] != X.shape[0]: + raise ValueError( + "'rfx_basis' must have the same number of rows as 'X'" + ) # Compute posterior predictive samples bcf_preds = self.predict( diff --git a/test/R/testthat/test-api-combinations.R b/test/R/testthat/test-api-combinations.R index e74d87d9..67ab2ec3 100644 --- a/test/R/testthat/test-api-combinations.R +++ b/test/R/testthat/test-api-combinations.R @@ -308,7 +308,8 @@ construct_chained_expectation_bcf <- function( warning_cond_2, warning_cond_3, warning_cond_4, - warning_cond_5 + warning_cond_5, + warning_cond_6 ) { # Build the chain from innermost to outermost function_text <- "x" @@ -347,6 +348,13 @@ construct_chained_expectation_bcf <- function( ")" ) } + if (warning_cond_6) { + function_text <- paste0( + "warning_fun_6(", + function_text, + ")" + ) + } if (error_cond) { function_text <- paste0( "expect_error(", @@ -554,7 +562,7 @@ test_that("Quick check of interactions between components of BCF functionality", # 9. Adaptive coding: no, yes # # For each of the possible models this implies, - # we'd like to be sure that stochtree functions that operate on BART models + # we'd like to be sure that stochtree functions that operate on BCF models # will run without error. Since there are so many possible models implied by the # options above, this test is designed to be quick (small sample size, low dimensional data) # and we are only interested in ensuring no errors are triggered. @@ -694,11 +702,21 @@ test_that("Quick check of interactions between components of BCF functionality", fixed = TRUE ) } + warning_cond_6 <- (model_options_df$treatment_type[i] == "multivariate") && + (model_options_df$random_effects[i] == "intercept_plus_treatment") + warning_fun_6 <- function(x) { + expect_warning( + x, + "Random effects `intercept_plus_treatment` specification is not currently implemented for multivariate treatments. This model will be fit under the `intercept_only` specification instead. Please provide a custom `rfx_basis_train` if you wish to have random slopes on multivariate treatment variables.", + fixed = TRUE + ) + } warning_cond <- (warning_cond_1 || warning_cond_2 || warning_cond_3 || warning_cond_4 || - warning_cond_5) + warning_cond_5 || + warning_cond_6) # Generate something like the below code but for all five warnings if (error_cond || warning_cond) { @@ -708,7 +726,8 @@ test_that("Quick check of interactions between components of BCF functionality", warning_cond_2 = warning_cond_2, warning_cond_3 = warning_cond_3, warning_cond_4 = warning_cond_4, - warning_cond_5 = warning_cond_5 + warning_cond_5 = warning_cond_5, + warning_cond_6 = warning_cond_6 ) } else { test_fun <- expect_no_error diff --git a/test/python/conftest.py b/test/python/conftest.py new file mode 100644 index 00000000..e446d0a1 --- /dev/null +++ b/test/python/conftest.py @@ -0,0 +1,21 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--runslow", action="store_true", default=False, help="run slow tests" + ) + + +def pytest_configure(config): + config.addinivalue_line("markers", "slow: mark test as slow to run") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--runslow"): + # --runslow given in cli: do not skip slow tests + return + skip_slow = pytest.mark.skip(reason="need --runslow option to run") + for item in items: + if "slow" in item.keywords: + item.add_marker(skip_slow) diff --git a/test/python/test_api_combinations.py b/test/python/test_api_combinations.py index edbb0dd6..36b1ef5d 100644 --- a/test/python/test_api_combinations.py +++ b/test/python/test_api_combinations.py @@ -3,7 +3,7 @@ import numpy as np from sklearn.model_selection import train_test_split -from stochtree import BARTModel +from stochtree import BARTModel, BCFModel def run_bart_factorial( @@ -116,7 +116,131 @@ def run_bart_factorial( ) +def run_bcf_factorial( + bcf_data_train, + bcf_data_test, + treatment_type="binary", + variance_forest=False, + random_effects="none", + sampling_global_error_scale=False, + sampling_mu_leaf_scale=False, + sampling_tau_leaf_scale=False, + outcome_type="continuous", + num_chains=1, + adaptive_coding=False, + include_propensity=False, +): + # Unpack BART training data + y = bcf_data_train["y"] + X = bcf_data_train["X"] + Z = bcf_data_train["Z"] + if include_propensity: + propensity = bcf_data_train["propensity"] + else: + propensity = None + if random_effects != "none": + rfx_group_ids = bcf_data_train["rfx_group_ids"] + else: + rfx_group_ids = None + if random_effects == "custom": + rfx_basis = bcf_data_train["rfx_basis"] + else: + rfx_basis = None + + # Set BART model parameters + general_params = { + "num_chains": num_chains, + "sample_sigma2_global": sampling_global_error_scale, + "probit_outcome_model": outcome_type == "binary", + "adaptive_coding": adaptive_coding, + } + mu_forest_params = {"sample_sigma2_leaf": sampling_mu_leaf_scale} + tau_forest_params = {"sample_sigma2_leaf": sampling_tau_leaf_scale} + variance_forest_params = {"num_trees": 20 if variance_forest else 0} + rfx_params = { + "model_spec": "custom" if random_effects == "none" else random_effects + } + + # Sample BART model + bcf_model = BCFModel() + bcf_model.sample( + X_train=X, + y_train=y, + Z_train=Z, + propensity_train=propensity, + rfx_group_ids_train=rfx_group_ids, + rfx_basis_train=rfx_basis, + general_params=general_params, + prognostic_forest_params=mu_forest_params, + treatment_effect_forest_params=tau_forest_params, + variance_forest_params=variance_forest_params, + random_effects_params=rfx_params, + ) + + # Unpack test set data + y_test = bcf_data_test["y"] + X_test = bcf_data_test["X"] + Z_test = bcf_data_test["Z"] + if include_propensity: + propensity_test = bcf_data_test["propensity"] + else: + propensity_test = None + if random_effects != "none": + rfx_group_ids_test = bcf_data_test["rfx_group_ids"] + else: + rfx_group_ids_test = None + if random_effects == "custom": + rfx_basis_test = bcf_data_test["rfx_basis"] + else: + rfx_basis_test = None + + # Predict on test set + mean_preds = bcf_model.predict( + X=X_test, + Z=Z_test, + propensity=propensity_test, + rfx_group_ids=rfx_group_ids_test, + rfx_basis=rfx_basis_test, + type="mean", + terms="all", + scale="probability" if outcome_type == "binary" else "linear", + ) + posterior_preds = bcf_model.predict( + X=X_test, + Z=Z_test, + propensity=propensity_test, + rfx_group_ids=rfx_group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="all", + scale="probability" if outcome_type == "binary" else "linear", + ) + + # Compute intervals + posterior_interval = bcf_model.compute_posterior_interval( + terms="all", + level=0.95, + scale="probability" if outcome_type == "binary" else "linear", + X=X_test, + Z=Z_test, + propensity=propensity_test, + rfx_group_ids=rfx_group_ids_test, + rfx_basis=rfx_basis_test, + ) + + # Sample posterior predictive + posterior_predictive_draws = bcf_model.sample_posterior_predictive( + X=X_test, + Z=Z_test, + propensity=propensity_test, + rfx_group_ids=rfx_group_ids_test, + rfx_basis=rfx_basis_test, + num_draws_per_sample=5, + ) + + class TestAPICombinations: + @pytest.mark.slow def test_bart_api_combinations(self): # RNG random_seed = 101 @@ -160,7 +284,7 @@ def test_bart_api_combinations(self): # Split into test and train sets test_set_pct = 0.5 sample_inds = np.arange(n) - train_inds, test_inds = train_test_split(sample_inds, test_size=0.5) + train_inds, test_inds = train_test_split(sample_inds, test_size=test_set_pct) X_train = X[train_inds, :] X_test = X[test_inds, :] leaf_basis_train = leaf_basis[train_inds, :] @@ -192,7 +316,6 @@ def test_bart_api_combinations(self): num_chains_options, ) for i, options in enumerate(model_options_iter): - print(f"i = {i}, options = {options}") # Unpack BART train and test data bart_data_train = {} bart_data_test = {} @@ -230,11 +353,9 @@ def test_bart_api_combinations(self): # Determine whether this combination should throw an error, raise a warning, or run as intended error_cond = (options[1]) and (options[5] == "binary") warning_cond_1 = (options[4]) and (options[0] == "multivariate") - warning_message_1 = "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model." warning_cond_2 = (options[3]) and (options[5] == "binary") - warning_message_2 = "Global error variance will not be sampled with a probit link as it is fixed at 1" - warning_cond = warning_cond_1 or warning_cond_2 - print(f"error_cond = {error_cond}, warning_cond = {warning_cond}") + warning_cond_3 = (options[3]) and (options[1]) + warning_cond = warning_cond_1 or warning_cond_2 or warning_cond_3 if error_cond and warning_cond: with pytest.raises(ValueError) as excinfo: with pytest.warns(UserWarning) as warninfo: @@ -287,3 +408,211 @@ def test_bart_api_combinations(self): outcome_type=options[5], num_chains=options[6], ) + + @pytest.mark.slow + def test_bcf_api_combinations(self): + # RNG + random_seed = 101 + rng = np.random.default_rng(random_seed) + + # Overall, we have nine components of a BCF sampler which can be on / off or set to different levels: + # 1. treatment: binary, univariate continuous, multivariate + # 2. Variance forest: no, yes + # 3. Random effects: no, custom basis, `intercept_only`, `intercept_plus_treatment` + # 4. Sampling global error scale: no, yes + # 5. Sampling leaf scale on prognostic forest: no, yes + # 6. Sampling leaf scale on treatment forest: no, yes (only available for univariate treatment) + # 7. Outcome type: continuous (identity link), binary (probit link) + # 8. Number of chains: 1, >1 + # 9. Adaptive coding: no, yes + # + # For each of the possible models this implies, + # we'd like to be sure that stochtree functions that operate on BCF models + # will run without error. Since there are so many possible models implied by the + # options above, this test is designed to be quick (small sample size, low dimensional data) + # and we are only interested in ensuring no errors are triggered. + + # Generate data with random effects + n = 50 + p = 3 + num_rfx_groups = 3 + num_rfx_basis = 2 + X = rng.uniform(0, 1, (n, p)) + binary_treatment = rng.binomial(1, 0.5, n) + continuous_treatment = rng.uniform(0, 1, n) + multivariate_treatment = np.column_stack( + (binary_treatment, continuous_treatment) + ) + propensity = rng.uniform(0, 1, n) + group_ids = rng.choice(num_rfx_groups, size=n) + rfx_basis = rng.uniform(0, 1, (n, num_rfx_basis)) + rfx_coefs = rng.uniform(0, 1, (num_rfx_groups, num_rfx_basis)) + prognostic_term = np.sin(X[:, 0]) + binary_treatment_effect = X[:, 1] + continuous_treatment_effect = X[:, 2] + rfx_term = np.sum(rfx_coefs[group_ids - 1, :] * rfx_basis, axis=1) + E_y = (prognostic_term + + binary_treatment_effect * binary_treatment + + continuous_treatment_effect * continuous_treatment + + rfx_term) + E_y = E_y - np.mean(E_y) + epsilon = rng.normal(0, 1, n) + y_continuous = E_y + epsilon + y_binary = (y_continuous > 0).astype(int) + + # Split into test and train sets + test_set_pct = 0.5 + sample_inds = np.arange(n) + train_inds, test_inds = train_test_split(sample_inds, test_size=test_set_pct) + X_train = X[train_inds, :] + X_test = X[test_inds, :] + binary_treatment_test = binary_treatment[test_inds] + binary_treatment_train = binary_treatment[train_inds] + propensity_test = propensity[test_inds] + propensity_train = propensity[train_inds] + continuous_treatment_test = continuous_treatment[test_inds] + continuous_treatment_train = continuous_treatment[train_inds] + multivariate_treatment_test = multivariate_treatment[test_inds, ] + multivariate_treatment_train = multivariate_treatment[train_inds, ] + rfx_basis_train = rfx_basis[train_inds, :] + rfx_basis_test = rfx_basis[test_inds, :] + group_ids_train = group_ids[train_inds] + group_ids_test = group_ids[test_inds] + y_continuous_train = y_continuous[train_inds] + y_continuous_test = y_continuous[test_inds] + y_binary_train = y_binary[train_inds] + y_binary_test = y_binary[test_inds] + + # Run the power set of models + treatment_options = ["binary", "univariate_continuous", "multivariate"] + variance_forest_options = [False, True] + random_effects_options = ["none", "custom", "intercept_only", "intercept_plus_treatment"] + sampling_global_error_scale_options = [False, True] + sampling_mu_leaf_scale_options = [False, True] + sampling_tau_leaf_scale_options = [False, True] + outcome_type_options = ["continuous", "binary"] + num_chains_options = [1, 3] + adaptive_coding_options = [False, True] + include_propensity_options = [False, True] + model_options_iter = itertools.product( + treatment_options, + variance_forest_options, + random_effects_options, + sampling_global_error_scale_options, + sampling_mu_leaf_scale_options, + sampling_tau_leaf_scale_options, + outcome_type_options, + num_chains_options, + adaptive_coding_options, + include_propensity_options + ) + for i, options in enumerate(model_options_iter): + # Unpack BCF train and test data + bcf_data_train = {} + bcf_data_test = {} + bcf_data_train["X"] = X_train + bcf_data_test["X"] = X_test + bcf_data_train["propensity"] = propensity_train + bcf_data_test["propensity"] = propensity_test + if options[5] == "continuous": + bcf_data_train["y"] = y_continuous_train + bcf_data_test["y"] = y_continuous_test + else: + bcf_data_train["y"] = y_binary_train + bcf_data_test["y"] = y_binary_test + if options[0] == "binary": + bcf_data_train["Z"] = binary_treatment_train + bcf_data_test["Z"] = binary_treatment_test + elif options[0] == "univariate_continuous": + bcf_data_train["Z"] = continuous_treatment_train + bcf_data_test["Z"] = continuous_treatment_test + else: + bcf_data_train["Z"] = multivariate_treatment_train + bcf_data_test["Z"] = multivariate_treatment_test + if options[2] != "none": + bcf_data_train["rfx_group_ids"] = group_ids_train + bcf_data_test["rfx_group_ids"] = group_ids_test + else: + bcf_data_train["rfx_group_ids"] = None + bcf_data_test["rfx_group_ids"] = None + if options[2] == "custom": + bcf_data_train["rfx_basis"] = rfx_basis_train + bcf_data_test["rfx_basis"] = rfx_basis_test + else: + bcf_data_train["rfx_basis"] = None + bcf_data_test["rfx_basis"] = None + + # Determine whether this combination should throw an error, raise a warning, or run as intended + error_cond = (options[1]) and (options[6] == "binary") + warning_cond_1 = (options[5]) and (options[0] == "multivariate") + warning_cond_2 = (options[3]) and (options[6] == "binary") + warning_cond_3 = (options[3]) and (options[1]) + warning_cond_4 = (options[8]) and (options[0] != "binary") + warning_cond_5 = (not options[9]) and (options[0] == "multivariate") + warning_cond_6 = (options[2] == "intercept_plus_treatment") and (options[0] == "multivariate") + warning_cond = warning_cond_1 or warning_cond_2 or warning_cond_3 or warning_cond_4 or warning_cond_5 or warning_cond_6 + print(f"error_cond: {error_cond}, warning_cond_1: {warning_cond_1}, warning_cond_2: {warning_cond_2}, warning_cond_3: {warning_cond_3}, warning_cond_4: {warning_cond_4}, warning_cond_5: {warning_cond_5}, warning_cond_6: {warning_cond_6}") + if error_cond and warning_cond: + with pytest.raises(ValueError) as excinfo: + with pytest.warns(UserWarning) as warninfo: + run_bcf_factorial( + bcf_data_train=bcf_data_train, + bcf_data_test=bcf_data_test, + treatment_type=options[0], + variance_forest=options[1], + random_effects=options[2], + sampling_global_error_scale=options[3], + sampling_mu_leaf_scale=options[4], + sampling_tau_leaf_scale=options[5], + outcome_type=options[6], + num_chains=options[7], + adaptive_coding=options[8], + include_propensity=options[9], + ) + elif error_cond and not warning_cond: + with pytest.raises(ValueError) as excinfo: + run_bcf_factorial( + bcf_data_train=bcf_data_train, + bcf_data_test=bcf_data_test, + treatment_type=options[0], + variance_forest=options[1], + random_effects=options[2], + sampling_global_error_scale=options[3], + sampling_mu_leaf_scale=options[4], + sampling_tau_leaf_scale=options[5], + outcome_type=options[6], + num_chains=options[7], + adaptive_coding=options[8], + include_propensity=options[9], + ) + elif not error_cond and warning_cond: + with pytest.warns(UserWarning) as warninfo: + run_bcf_factorial( + bcf_data_train=bcf_data_train, + bcf_data_test=bcf_data_test, + treatment_type=options[0], + variance_forest=options[1], + random_effects=options[2], + sampling_global_error_scale=options[3], + sampling_mu_leaf_scale=options[4], + sampling_tau_leaf_scale=options[5], + outcome_type=options[6], + num_chains=options[7], + adaptive_coding=options[8], + include_propensity=options[9], + ) + else: + run_bcf_factorial( + bcf_data_train=bcf_data_train, + bcf_data_test=bcf_data_test, + treatment_type=options[0], + variance_forest=options[1], + random_effects=options[2], + sampling_global_error_scale=options[3], + sampling_mu_leaf_scale=options[4], + sampling_tau_leaf_scale=options[5], + outcome_type=options[6], + num_chains=options[7], + adaptive_coding=options[8], + include_propensity=options[9], + ) From 855bb0eb751d0079f932513413fbeb065fa1ac56 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 10 Dec 2025 16:12:23 -0500 Subject: [PATCH 08/12] Update R API combination tests so they can be enabled from GHA with a special environment variable --- test/R/testthat/test-api-combinations.R | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/test/R/testthat/test-api-combinations.R b/test/R/testthat/test-api-combinations.R index 67ab2ec3..d2c2f2ce 100644 --- a/test/R/testthat/test-api-combinations.R +++ b/test/R/testthat/test-api-combinations.R @@ -370,7 +370,11 @@ construct_chained_expectation_bcf <- function( test_that("Quick check of interactions between components of BART functionality", { skip_on_cran() - skip_on_ci() + # Code from: https://github.com/r-lib/testthat/blob/main/R/skip.R + skip_if( + isTRUE(as.logical(Sys.getenv("RUN_SLOW_TESTS", "false"))), + "skipping slow tests" + ) # Overall, we have seven components of a BART sampler which can be on / off or set to different levels: # 1. Leaf regression: none, univariate, multivariate @@ -548,7 +552,11 @@ test_that("Quick check of interactions between components of BART functionality" test_that("Quick check of interactions between components of BCF functionality", { skip_on_cran() - skip_on_ci() + # Code from: https://github.com/r-lib/testthat/blob/main/R/skip.R + skip_if( + isTRUE(as.logical(Sys.getenv("RUN_SLOW_TESTS", "false"))), + "skipping slow tests" + ) # Overall, we have nine components of a BCF sampler which can be on / off or set to different levels: # 1. treatment: binary, univariate continuous, multivariate From e56b7963ef7a310632b318b961c937630b3cfb1b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 10 Dec 2025 16:22:48 -0500 Subject: [PATCH 09/12] Updated python tests --- test/python/test_bart.py | 70 ++++++++++++++++++++-------------------- test/python/test_bcf.py | 4 +-- 2 files changed, 37 insertions(+), 37 deletions(-) diff --git a/test/python/test_bart.py b/test/python/test_bart.py index b182524b..7f49f5b4 100644 --- a/test/python/test_bart.py +++ b/test/python/test_bart.py @@ -429,13 +429,13 @@ def conditional_stddev(X): sigma2_x_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.sigma2_x_train, ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples - ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], - bart_model_2.global_var_samples, - ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + # ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + # bart_model_2.global_var_samples, + # ) def test_bart_univariate_leaf_regression_heteroskedastic(self): # RNG @@ -554,13 +554,13 @@ def conditional_stddev(X): np.testing.assert_allclose( y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples - ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], - bart_model_2.global_var_samples, - ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + # ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + # bart_model_2.global_var_samples, + # ) def test_bart_multivariate_leaf_regression_heteroskedastic(self): # RNG @@ -679,13 +679,13 @@ def conditional_stddev(X): np.testing.assert_allclose( y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples - ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], - bart_model_2.global_var_samples, - ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + # ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + # bart_model_2.global_var_samples, + # ) def test_bart_constant_leaf_heteroskedastic_rfx(self): # RNG @@ -836,13 +836,13 @@ def rfx_term(group_labels, basis): np.testing.assert_allclose( y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples - ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], - bart_model_2.global_var_samples, - ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + # ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + # bart_model_2.global_var_samples, + # ) np.testing.assert_allclose(rfx_preds_train_3[:, 0:num_mcmc], rfx_preds_train) np.testing.assert_allclose( rfx_preds_train_3[:, num_mcmc : (2 * num_mcmc)], rfx_preds_train_2 @@ -1010,13 +1010,13 @@ def conditional_stddev(X): np.testing.assert_allclose( y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples - ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], - bart_model_2.global_var_samples, - ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + # ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + # bart_model_2.global_var_samples, + # ) np.testing.assert_allclose(rfx_preds_train_3[:, 0:num_mcmc], rfx_preds_train) np.testing.assert_allclose( rfx_preds_train_3[:, num_mcmc : (2 * num_mcmc)], rfx_preds_train_2 diff --git a/test/python/test_bcf.py b/test/python/test_bcf.py index c5a1446f..bbfd55d5 100644 --- a/test/python/test_bcf.py +++ b/test/python/test_bcf.py @@ -645,7 +645,7 @@ def test_multivariate_bcf(self): assert tau_hat.shape == (n_test, num_mcmc, treatment_dim) # Run BCF with test set and without propensity score - with pytest.raises(ValueError): + with pytest.warns(UserWarning): bcf_model = BCFModel() variance_forest_params = {"num_trees": 0} bcf_model.sample( @@ -661,7 +661,7 @@ def test_multivariate_bcf(self): ) # Run BCF without test set and without propensity score - with pytest.raises(ValueError): + with pytest.warns(UserWarning): bcf_model = BCFModel() variance_forest_params = {"num_trees": 0} bcf_model.sample( From 87fb950e188825ec3eec9258bd572ad9ec7f42c6 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 10 Dec 2025 22:18:04 -0500 Subject: [PATCH 10/12] Added manually-dispatched workflow for slow-running unit tests --- .github/workflows/python-test.yml | 4 +- .github/workflows/slow-api-test.yml | 72 +++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/slow-api-test.yml diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index 719160a6..6c9fff91 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -25,10 +25,10 @@ jobs: with: submodules: 'recursive' - - name: Setup Python 3.10 + - name: Setup Python 3.12 uses: actions/setup-python@v5 with: - python-version: "3.10" + python-version: "3.12" cache: "pip" - name: Set up openmp (macos) diff --git a/.github/workflows/slow-api-test.yml b/.github/workflows/slow-api-test.yml new file mode 100644 index 00000000..1cb686fa --- /dev/null +++ b/.github/workflows/slow-api-test.yml @@ -0,0 +1,72 @@ +name: Unit Tests and Slow Running API Integration Tests for R and Python + +on: + workflow_dispatch: + +jobs: + testing: + name: test-slow-api-combinations + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + + steps: + - name: Prevent conversion of line endings on Windows + if: startsWith(matrix.os, 'windows') + shell: pwsh + run: git config --global core.autocrlf false + + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: 'recursive' + + - name: Setup Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: "pip" + + - name: Set up openmp (macos) + # Set up openMP on MacOS since it doesn't ship with the apple clang compiler suite + if: matrix.os == 'macos-latest' + run: | + brew install libomp + + - name: Install Package with Relevant Dependencies + run: | + pip install --upgrade pip + pip install -r requirements.txt + pip install . + + - name: Run Pytest with Slow Running API Tests Enabled + run: | + pytest --runslow test/python + + - name: Setup Pandoc for R + uses: r-lib/actions/setup-pandoc@v2 + + - name: Setup R + uses: r-lib/actions/setup-r@v2 + with: + use-public-rspm: true + + - name: Setup R Package Dependencies + uses: r-lib/actions/setup-r-dependencies@v2 + with: + extra-packages: any::testthat, any::decor, any::rcmdcheck + needs: check + + - name: Create a CRAN-ready version of the R package + run: | + Rscript cran-bootstrap.R 0 0 1 + + - name: Run CRAN Checks with Slow Running API Tests Enabled + uses: r-lib/actions/check-r-package@v2 + env: + RUN_SLOW_TESTS: true + with: + working-directory: 'stochtree_cran' From fb088f988fd85cd307e65d2f35af0d80b36fdeb9 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 10 Dec 2025 22:20:31 -0500 Subject: [PATCH 11/12] Decrement python version --- .github/workflows/python-test.yml | 4 ++-- .github/workflows/slow-api-test.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index 6c9fff91..719160a6 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -25,10 +25,10 @@ jobs: with: submodules: 'recursive' - - name: Setup Python 3.12 + - name: Setup Python 3.10 uses: actions/setup-python@v5 with: - python-version: "3.12" + python-version: "3.10" cache: "pip" - name: Set up openmp (macos) diff --git a/.github/workflows/slow-api-test.yml b/.github/workflows/slow-api-test.yml index 1cb686fa..308a5bf8 100644 --- a/.github/workflows/slow-api-test.yml +++ b/.github/workflows/slow-api-test.yml @@ -24,10 +24,10 @@ jobs: with: submodules: 'recursive' - - name: Setup Python 3.12 + - name: Setup Python 3.10 uses: actions/setup-python@v5 with: - python-version: "3.12" + python-version: "3.10" cache: "pip" - name: Set up openmp (macos) From fe5d1a3dec73fc24181c3e4ea7e87365fe8880ab Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 11 Dec 2025 10:36:41 -0500 Subject: [PATCH 12/12] Update changelog --- CHANGELOG.md | 6 +++++- NEWS.md | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bf6c0ecd..af065869 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,11 @@ ## Bug Fixes -* Predict random effects correctly in R for univariate random effects models ([#248](https://github.com/StochasticTree/stochtree/pull/248)) +* Fix prediction bug for univariate random effects models in R ([#248](https://github.com/StochasticTree/stochtree/pull/248)) + +## Other Changes + +* Encode expectations about which combinations of BART / BCF features work together and ensure warning ([#250](https://github.com/StochasticTree/stochtree/pull/250)) # stochtree 0.2.0 diff --git a/NEWS.md b/NEWS.md index 676ed749..644e6f72 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,7 +2,11 @@ ## Bug Fixes -* Predict random effects correctly in R for univariate random effects models ([#248](https://github.com/StochasticTree/stochtree/pull/248)) +* Fix prediction bug for univariate random effects models in R ([#248](https://github.com/StochasticTree/stochtree/pull/248)) + +## Other Changes + +* Encode expectations about which combinations of BART / BCF features work together and ensure warning ([#250](https://github.com/StochasticTree/stochtree/pull/250)) # stochtree 0.2.0