From d67078bc9283f90c93842c6dece7a54aedc74d4d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 21 Nov 2025 01:17:41 -0600 Subject: [PATCH 1/2] Added checks for variables already being treated as categorical and also explicitly flagging binary variables --- R/bart.R | 77 ++++++++++++++++++++++++++++++++++++++----------------- R/bcf.R | 78 +++++++++++++++++++++++++++++++++++++++----------------- 2 files changed, 109 insertions(+), 46 deletions(-) diff --git a/R/bart.R b/R/bart.R index d5acd6af..fe90bb7c 100644 --- a/R/bart.R +++ b/R/bart.R @@ -423,37 +423,56 @@ bart <- function( floor(num_values / cutpoint_grid_size), 1 ) + x_is_df <- is.data.frame(X_train) covs_warning_1 <- NULL covs_warning_2 <- NULL covs_warning_3 <- NULL + covs_warning_4 <- NULL for (i in 1:num_cov_orig) { - # Determine the number of unique values - num_unique_values <- length(unique(X_train[, i])) - - # Determine a "name" for the covariate - cov_name <- ifelse( - is.null(colnames(X_train)), - paste0("X", i), - colnames(X_train)[i] - ) - - # Check for a small relative number of unique values - unique_full_ratio <- num_unique_values / num_values - if (unique_full_ratio < 0.2) { - covs_warning_1 <- c(covs_warning_1, cov_name) + # Skip check for variables that are treated as categorical + x_numeric <- T + if (x_is_df) { + if (is.factor(X_train[, i])) { + x_numeric <- F + } } + if (x_numeric) { + # Determine the number of unique values + num_unique_values <- length(unique(X_train[, i])) + + # Determine a "name" for the covariate + cov_name <- ifelse( + is.null(colnames(X_train)), + paste0("X", i), + colnames(X_train)[i] + ) - # Check for a small absolute number of unique values - if (num_values > 100) { - if (num_unique_values < 20) { - covs_warning_2 <- c(covs_warning_2, cov_name) + # Check for a small relative number of unique values + unique_full_ratio <- num_unique_values / num_values + if (unique_full_ratio < 0.2) { + covs_warning_1 <- c(covs_warning_1, cov_name) + } + + # Check for a small absolute number of unique values + if (num_values > 100) { + if (num_unique_values < 20) { + covs_warning_2 <- c(covs_warning_2, cov_name) + } + } + + # Check for a large number of duplicates of any individual value + x_j_hist <- table(X_train[, i]) + if (any(x_j_hist > 2 * max_grid_size)) { + covs_warning_3 <- c(covs_warning_3, cov_name) } - } - # Check for a large number of duplicates of any individual value - x_j_hist <- table(X_train[, i]) - if (any(x_j_hist > 2 * max_grid_size)) { - covs_warning_3 <- c(covs_warning_3, cov_name) + # Check for binary variables + if (num_unique_values == 2) { + already_flagged <- (num_values > 100) && (num_unique_values < 20) + if (!already_flagged) { + covs_warning_4 <- c(covs_warning_4, cov_name) + } + } } } @@ -494,6 +513,18 @@ bart <- function( ) ) } + + if (!is.null(covs_warning_4)) { + warning( + paste0( + "Covariates ", + paste(covs_warning_4, collapse = ", "), + " appear to be binary but are currently treated by stochtree as continuous. ", + "This might present some issues with the grow-from-root (GFR) algorithm. ", + "Consider converting binary variables to ordered factor (i.e. `factor(..., ordered = T)`." + ) + ) + } } # Standardize the keep variable lists to numeric indices diff --git a/R/bcf.R b/R/bcf.R index 5a80d5ec..5d360de3 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -527,37 +527,57 @@ bcf <- function( floor(num_values / cutpoint_grid_size), 1 ) + x_is_df <- is.data.frame(X_train) covs_warning_1 <- NULL covs_warning_2 <- NULL covs_warning_3 <- NULL + covs_warning_4 <- NULL for (i in 1:num_cov_orig) { - # Determine the number of unique values - num_unique_values <- length(unique(X_train[, i])) - - # Determine a "name" for the covariate - cov_name <- ifelse( - is.null(colnames(X_train)), - paste0("X", i), - colnames(X_train)[i] - ) - - # Check for a small relative number of unique values - unique_full_ratio <- num_unique_values / num_values - if (unique_full_ratio < 0.2) { - covs_warning_1 <- c(covs_warning_1, cov_name) + # Skip check for variables that are treated as categorical + x_numeric <- T + if (x_is_df) { + if (is.factor(X_train[, i])) { + x_numeric <- F + } } - # Check for a small absolute number of unique values - if (num_values > 100) { - if (num_unique_values < 20) { - covs_warning_2 <- c(covs_warning_2, cov_name) + if (x_numeric) { + # Determine the number of unique values + num_unique_values <- length(unique(X_train[, i])) + + # Determine a "name" for the covariate + cov_name <- ifelse( + is.null(colnames(X_train)), + paste0("X", i), + colnames(X_train)[i] + ) + + # Check for a small relative number of unique values + unique_full_ratio <- num_unique_values / num_values + if (unique_full_ratio < 0.2) { + covs_warning_1 <- c(covs_warning_1, cov_name) + } + + # Check for a small absolute number of unique values + if (num_values > 100) { + if (num_unique_values < 20) { + covs_warning_2 <- c(covs_warning_2, cov_name) + } + } + + # Check for a large number of duplicates of any individual value + x_j_hist <- table(X_train[, i]) + if (any(x_j_hist > 2 * max_grid_size)) { + covs_warning_3 <- c(covs_warning_3, cov_name) } - } - # Check for a large number of duplicates of any individual value - x_j_hist <- table(X_train[, i]) - if (any(x_j_hist > 2 * max_grid_size)) { - covs_warning_3 <- c(covs_warning_3, cov_name) + # Check for binary variables + if (num_unique_values == 2) { + already_flagged <- (num_values > 100) && (num_unique_values < 20) + if (!already_flagged) { + covs_warning_4 <- c(covs_warning_4, cov_name) + } + } } } @@ -598,6 +618,18 @@ bcf <- function( ) ) } + + if (!is.null(covs_warning_4)) { + warning( + paste0( + "Covariates ", + paste(covs_warning_4, collapse = ", "), + " appear to be binary but are currently treated by stochtree as continuous. ", + "This might present some issues with the grow-from-root (GFR) algorithm. ", + "Consider converting binary variables to ordered factor (i.e. `factor(..., ordered = T)`." + ) + ) + } } # Check delta_max is valid From 3f0130789e79fd008cf4d5465a540b87000b5a76 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 21 Nov 2025 01:39:54 -0600 Subject: [PATCH 2/2] Reflected this change through the Python interface as well --- R/bart.R | 5 +- R/bcf.R | 5 +- stochtree/bart.py | 58 +++++++---- stochtree/bcf.py | 244 ++++++++++++++++++++++++++++------------------ 4 files changed, 189 insertions(+), 123 deletions(-) diff --git a/R/bart.R b/R/bart.R index fe90bb7c..177ed961 100644 --- a/R/bart.R +++ b/R/bart.R @@ -468,10 +468,7 @@ bart <- function( # Check for binary variables if (num_unique_values == 2) { - already_flagged <- (num_values > 100) && (num_unique_values < 20) - if (!already_flagged) { - covs_warning_4 <- c(covs_warning_4, cov_name) - } + covs_warning_4 <- c(covs_warning_4, cov_name) } } } diff --git a/R/bcf.R b/R/bcf.R index 5d360de3..d6e11ad0 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -573,10 +573,7 @@ bcf <- function( # Check for binary variables if (num_unique_values == 2) { - already_flagged <- (num_values > 100) && (num_unique_values < 20) - if (!already_flagged) { - covs_warning_4 <- c(covs_warning_4, cov_name) - } + covs_warning_4 <- c(covs_warning_4, cov_name) } } } diff --git a/stochtree/bart.py b/stochtree/bart.py index 1d5f38ce..b7bf1c88 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -456,32 +456,45 @@ def sample( if (num_gfr > 0) and (num_burnin == 0) and (num_mcmc == 0): num_values, num_cov_orig = X_train.shape max_grid_size = floor(num_values / cutpoint_grid_size) + x_is_df = isinstance(X_train, pd.DataFrame) covs_warning_1 = [] covs_warning_2 = [] covs_warning_3 = [] + covs_warning_4 = [] for i in range(num_cov_orig): - # Determine the number of unique covariate values and a name for the covariate - if isinstance(X_train, np.ndarray): - x_j_hist = np.unique_counts(X_train[:, i]).counts - cov_name = f"X{i + 1}" - else: - x_j_hist = (X_train.iloc[:, i]).value_counts() - cov_name = X_train.columns[i] + # Skip check for variables that are treated as categorical + x_numeric = True + if x_is_df: + if isinstance(X_train.iloc[:,i].dtype, pd.CategoricalDtype): + x_numeric = False + + if x_numeric: + # Determine the number of unique covariate values and a name for the covariate + if isinstance(X_train, np.ndarray): + x_j_hist = np.unique_counts(X_train[:, i]).counts + cov_name = f"X{i + 1}" + else: + x_j_hist = (X_train.iloc[:, i]).value_counts() + cov_name = X_train.columns[i] - # Check for a small relative number of unique values - num_unique_values = len(x_j_hist) - unique_full_ratio = num_unique_values / num_values - if unique_full_ratio < 0.2: - covs_warning_1.append(cov_name) + # Check for a small relative number of unique values + num_unique_values = len(x_j_hist) + unique_full_ratio = num_unique_values / num_values + if unique_full_ratio < 0.2: + covs_warning_1.append(cov_name) - # Check for a small absolute number of unique values - if num_values > 100: - if num_unique_values < 20: - covs_warning_2.append(cov_name) + # Check for a small absolute number of unique values + if num_values > 100: + if num_unique_values < 20: + covs_warning_2.append(cov_name) - # Check for a large number of duplicates of any individual value - if np.any(x_j_hist > 2 * max_grid_size): - covs_warning_3.append(cov_name) + # Check for a large number of duplicates of any individual value + if np.any(x_j_hist > 2 * max_grid_size): + covs_warning_3.append(cov_name) + + # Check for binary variables + if num_unique_values == 2: + covs_warning_4.append(cov_name) if covs_warning_1: warnings.warn( @@ -505,6 +518,13 @@ def sample( "Consider running with `num_mcmc > 0` and `num_burnin > 0` to improve your model's performance." ) + if covs_warning_4: + warnings.warn( + f"Covariates {', '.join(covs_warning_4)} appear to be binary but are currently treated by stochtree as continuous. " + "This might present some issues with the grow-from-root (GFR) algorithm. " + "Consider converting binary variables to ordered categorical (i.e. `pd.Categorical(..., ordered = True)`." + ) + # Variable weight preprocessing (and initialization if necessary) p = X_train.shape[1] if variable_weights is None: diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 2e1b2ad2..ac98fdbb 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -222,7 +222,7 @@ def sample( * `group_parameter_prior_cov`: Prior covariance matrix for the random effects "group parameters." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. * `variance_prior_shape`: Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. * `variance_prior_scale`: Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. - + previous_model_json : str, optional JSON string containing a previous BCF model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Defaults to `None`. previous_model_warmstart_sample_num : int, optional @@ -436,7 +436,7 @@ def sample( # Override keep_gfr if there are no MCMC samples if num_mcmc == 0: keep_gfr = True - + # Check that num_chains >= 1 if not isinstance(num_chains, Integral) or num_chains < 1: raise ValueError("num_chains must be an integer greater than 0") @@ -509,7 +509,9 @@ def sample( else: previous_leaf_var_mu_samples = None if previous_bcf_model.sample_sigma2_leaf_tau: - previous_leaf_var_tau_samples = previous_bcf_model.leaf_scale_tau_samples + previous_leaf_var_tau_samples = ( + previous_bcf_model.leaf_scale_tau_samples + ) else: previous_leaf_var_tau_samples = None if previous_bcf_model.adaptive_coding: @@ -641,38 +643,53 @@ def sample( raise ValueError("X_test and Z_test must have the same number of rows") if X_test is not None and propensity_test is not None: if X_test.shape[0] != propensity_test.shape[0]: - raise ValueError("X_test and propensity_test must have the same number of rows") + raise ValueError( + "X_test and propensity_test must have the same number of rows" + ) # Raise a warning if the data have ties and only GFR is being run if (num_gfr > 0) and (num_burnin == 0) and (num_mcmc == 0): num_values, num_cov_orig = X_train.shape max_grid_size = floor(num_values / cutpoint_grid_size) + x_is_df = isinstance(X_train, pd.DataFrame) covs_warning_1 = [] covs_warning_2 = [] covs_warning_3 = [] + covs_warning_4 = [] for i in range(num_cov_orig): - # Determine the number of unique covariate values and a name for the covariate - if isinstance(X_train, np.ndarray): - x_j_hist = np.unique_counts(X_train[:, i]).counts - cov_name = f"X{i + 1}" - else: - x_j_hist = (X_train.iloc[:, i]).value_counts() - cov_name = X_train.columns[i] + # Skip check for variables that are treated as categorical + x_numeric = True + if x_is_df: + if isinstance(X_train.iloc[:, i].dtype, pd.CategoricalDtype): + x_numeric = False + + if x_numeric: + # Determine the number of unique covariate values and a name for the covariate + if isinstance(X_train, np.ndarray): + x_j_hist = np.unique_counts(X_train[:, i]).counts + cov_name = f"X{i + 1}" + else: + x_j_hist = (X_train.iloc[:, i]).value_counts() + cov_name = X_train.columns[i] + + # Check for a small relative number of unique values + num_unique_values = len(x_j_hist) + unique_full_ratio = num_unique_values / num_values + if unique_full_ratio < 0.2: + covs_warning_1.append(cov_name) - # Check for a small relative number of unique values - num_unique_values = len(x_j_hist) - unique_full_ratio = num_unique_values / num_values - if unique_full_ratio < 0.2: - covs_warning_1.append(cov_name) + # Check for a small absolute number of unique values + if num_values > 100: + if num_unique_values < 20: + covs_warning_2.append(cov_name) - # Check for a small absolute number of unique values - if num_values > 100: - if num_unique_values < 20: - covs_warning_2.append(cov_name) + # Check for a large number of duplicates of any individual value + if np.any(x_j_hist > 2 * max_grid_size): + covs_warning_3.append(cov_name) - # Check for a large number of duplicates of any individual value - if np.any(x_j_hist > 2 * max_grid_size): - covs_warning_3.append(cov_name) + # Check for binary variables + if num_unique_values == 2: + covs_warning_4.append(cov_name) if covs_warning_1: warnings.warn( @@ -696,6 +713,13 @@ def sample( "Consider running with `num_mcmc > 0` and `num_burnin > 0` to improve your model's performance." ) + if covs_warning_4: + warnings.warn( + f"Covariates {', '.join(covs_warning_4)} appear to be binary but are currently treated by stochtree as continuous. " + "This might present some issues with the grow-from-root (GFR) algorithm. " + "Consider converting binary variables to ordered categorical (i.e. `pd.Categorical(..., ordered = True)`." + ) + # Prognostic model details leaf_dimension_mu = 1 leaf_model_mu = 0 @@ -1671,7 +1695,8 @@ def sample( X_test_processed = np.c_[X_test_processed, propensity_test] if propensity_covariate == "prognostic": variable_weights_mu = np.append( - variable_weights_mu, np.repeat(1 / num_cov_orig, propensity_train.shape[1]) + variable_weights_mu, + np.repeat(1 / num_cov_orig, propensity_train.shape[1]), ) variable_weights_tau = np.append( variable_weights_tau, np.repeat(0.0, propensity_train.shape[1]) @@ -1681,14 +1706,17 @@ def sample( variable_weights_mu, np.repeat(0.0, propensity_train.shape[1]) ) variable_weights_tau = np.append( - variable_weights_tau, np.repeat(1 / num_cov_orig, propensity_train.shape[1]) + variable_weights_tau, + np.repeat(1 / num_cov_orig, propensity_train.shape[1]), ) elif propensity_covariate == "both": variable_weights_mu = np.append( - variable_weights_mu, np.repeat(1 / num_cov_orig, propensity_train.shape[1]) + variable_weights_mu, + np.repeat(1 / num_cov_orig, propensity_train.shape[1]), ) variable_weights_tau = np.append( - variable_weights_tau, np.repeat(1 / num_cov_orig, propensity_train.shape[1]) + variable_weights_tau, + np.repeat(1 / num_cov_orig, propensity_train.shape[1]), ) # For now, propensities are not included in the variance forest variable_weights_variance = np.append( @@ -2141,22 +2169,18 @@ def sample( global_model_config.update_global_error_variance(current_sigma2) # Reset mu forest leaf scale if sample_sigma2_leaf_mu: - leaf_scale_double_mu = self.leaf_scale_mu_samples[ - forest_ind - ] - current_leaf_scale_mu[0, 0] = leaf_scale_double_mu - forest_model_config_mu.update_leaf_model_scale( - current_leaf_scale_mu - ) + leaf_scale_double_mu = self.leaf_scale_mu_samples[forest_ind] + current_leaf_scale_mu[0, 0] = leaf_scale_double_mu + forest_model_config_mu.update_leaf_model_scale( + current_leaf_scale_mu + ) # Reset tau forest leaf scale if sample_sigma2_leaf_tau: - leaf_scale_double_tau = self.leaf_scale_tau_samples[ - forest_ind - ] - current_leaf_scale_tau[0, 0] = leaf_scale_double_tau - forest_model_config_tau.update_leaf_model_scale( - current_leaf_scale_tau - ) + leaf_scale_double_tau = self.leaf_scale_tau_samples[forest_ind] + current_leaf_scale_tau[0, 0] = leaf_scale_double_tau + forest_model_config_tau.update_leaf_model_scale( + current_leaf_scale_tau + ) # Reset adaptive coding parameters if self.adaptive_coding: if self.b0_samples is not None: @@ -2181,12 +2205,25 @@ def sample( ) # Reset random effects terms if self.has_rfx: - rfx_model.reset(self.rfx_container, forest_ind, sigma_alpha_init) - rfx_tracker.reset(rfx_model, rfx_dataset_train, residual_train, self.rfx_container) + rfx_model.reset( + self.rfx_container, forest_ind, sigma_alpha_init + ) + rfx_tracker.reset( + rfx_model, + rfx_dataset_train, + residual_train, + self.rfx_container, + ) elif has_prev_model: - warmstart_index = previous_model_warmstart_sample_num - chain_num if previous_model_decrement else previous_model_warmstart_sample_num + warmstart_index = ( + previous_model_warmstart_sample_num - chain_num + if previous_model_decrement + else previous_model_warmstart_sample_num + ) # Reset prognostic forest - active_forest_mu.reset(previous_bcf_model.forest_container_mu, warmstart_index) + active_forest_mu.reset( + previous_bcf_model.forest_container_mu, warmstart_index + ) forest_sampler_mu.reconstitute_from_forest( active_forest_mu, forest_dataset_train, @@ -2194,7 +2231,9 @@ def sample( True, ) # Reset CATE forest - active_forest_tau.reset(previous_bcf_model.forest_container_tau, warmstart_index) + active_forest_tau.reset( + previous_bcf_model.forest_container_tau, warmstart_index + ) forest_sampler_tau.reconstitute_from_forest( active_forest_tau, forest_dataset_train, @@ -2215,12 +2254,13 @@ def sample( ) # Reset global error scale if self.sample_sigma2_global: - current_sigma2 = previous_global_var_samples[ - warmstart_index - ] + current_sigma2 = previous_global_var_samples[warmstart_index] global_model_config.update_global_error_variance(current_sigma2) # Reset mu forest leaf scale - if sample_sigma2_leaf_mu and previous_leaf_var_mu_samples is not None: + if ( + sample_sigma2_leaf_mu + and previous_leaf_var_mu_samples is not None + ): leaf_scale_double_mu = previous_leaf_var_mu_samples[ warmstart_index ] @@ -2229,7 +2269,10 @@ def sample( current_leaf_scale_mu ) # Reset mu forest leaf scale - if sample_sigma2_leaf_tau and previous_leaf_var_tau_samples is not None: + if ( + sample_sigma2_leaf_tau + and previous_leaf_var_tau_samples is not None + ): leaf_scale_double_tau = previous_leaf_var_tau_samples[ warmstart_index ] @@ -2257,19 +2300,24 @@ def sample( ) # Reset random effects terms if self.has_rfx: - rfx_model.reset(previous_bcf_model.rfx_container, warmstart_index, sigma_alpha_init) - rfx_tracker.reset(rfx_model, rfx_dataset_train, residual_train, previous_bcf_model.rfx_container) + rfx_model.reset( + previous_bcf_model.rfx_container, + warmstart_index, + sigma_alpha_init, + ) + rfx_tracker.reset( + rfx_model, + rfx_dataset_train, + residual_train, + previous_bcf_model.rfx_container, + ) else: # Reset prognostic forest active_forest_mu.reset_root() if init_mu.shape[0] == 1: - active_forest_mu.set_root_leaves( - init_mu[0] / num_trees_mu - ) + active_forest_mu.set_root_leaves(init_mu[0] / num_trees_mu) else: - active_forest_mu.set_root_leaves( - init_mu / num_trees_mu - ) + active_forest_mu.set_root_leaves(init_mu / num_trees_mu) forest_sampler_mu.reconstitute_from_forest( active_forest_mu, forest_dataset_train, @@ -2279,13 +2327,9 @@ def sample( # Reset CATE forest active_forest_tau.reset_root() if init_tau.shape[0] == 1: - active_forest_tau.set_root_leaves( - init_tau[0] / num_trees_tau - ) + active_forest_tau.set_root_leaves(init_tau[0] / num_trees_tau) else: - active_forest_tau.set_root_leaves( - init_tau / num_trees_tau - ) + active_forest_tau.set_root_leaves(init_tau / num_trees_tau) forest_sampler_tau.reconstitute_from_forest( active_forest_tau, forest_dataset_train, @@ -2309,13 +2353,19 @@ def sample( current_sigma2 = sigma2_init global_model_config.update_global_error_variance(current_sigma2) # Reset mu forest leaf scale - if sample_sigma2_leaf_mu and previous_leaf_var_mu_samples is not None: + if ( + sample_sigma2_leaf_mu + and previous_leaf_var_mu_samples is not None + ): current_leaf_scale_mu[0, 0] = sigma2_leaf_mu forest_model_config_mu.update_leaf_model_scale( current_leaf_scale_mu ) # Reset mu forest leaf scale - if sample_sigma2_leaf_tau and previous_leaf_var_tau_samples is not None: + if ( + sample_sigma2_leaf_tau + and previous_leaf_var_tau_samples is not None + ): current_leaf_scale_tau[0, 0] = sigma2_leaf_tau forest_model_config_tau.update_leaf_model_scale( current_leaf_scale_tau @@ -2338,8 +2388,20 @@ def sample( ) # Reset random effects terms if self.has_rfx: - rfx_model.root_reset(alpha_init, xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale) - rfx_tracker.root_reset(rfx_model, rfx_dataset_train, residual_train, self.rfx_container) + rfx_model.root_reset( + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale, + ) + rfx_tracker.root_reset( + rfx_model, + rfx_dataset_train, + residual_train, + self.rfx_container, + ) # Sample MCMC and burnin for each chain for i in range(num_gfr, num_temp_samples): is_mcmc = i + 1 > num_gfr + num_burnin @@ -2360,7 +2422,9 @@ def sample( if self.probit_outcome_model: # Sample latent probit variable z | - forest_pred_mu = active_forest_mu.predict(forest_dataset_train) - forest_pred_tau = active_forest_tau.predict(forest_dataset_train) + forest_pred_tau = active_forest_tau.predict( + forest_dataset_train + ) outcome_pred = forest_pred_mu + forest_pred_tau if self.has_rfx: rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) @@ -2466,12 +2530,16 @@ def sample( ) current_b_0 = self.rng.normal( loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)), - scale=np.sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)), + scale=np.sqrt( + current_sigma2 / (s_tt0 + 2 * current_sigma2) + ), size=1, )[0] current_b_1 = self.rng.normal( loc=(s_ty1 / (s_tt1 + 2 * current_sigma2)), - scale=np.sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)), + scale=np.sqrt( + current_sigma2 / (s_tt1 + 2 * current_sigma2) + ), size=1, )[0] tau_basis_train = ( @@ -2583,9 +2651,7 @@ def sample( adaptive_coding_weights = np.expand_dims( self.b1_samples - self.b0_samples, axis=(0, 2) ) - b0_weights = np.expand_dims( - self.b0_samples, axis=(0, 2) - ) + b0_weights = np.expand_dims(self.b0_samples, axis=(0, 2)) control_adj_train = self.tau_hat_train * b0_weights * self.y_std self.tau_hat_train = self.tau_hat_train * adaptive_coding_weights self.mu_hat_train = self.mu_hat_train + np.squeeze(control_adj_train) @@ -2610,9 +2676,7 @@ def sample( adaptive_coding_weights_test = np.expand_dims( self.b1_samples - self.b0_samples, axis=(0, 2) ) - b0_weights = np.expand_dims( - self.b0_samples, axis=(0, 2) - ) + b0_weights = np.expand_dims(self.b0_samples, axis=(0, 2)) control_adj_test = self.tau_hat_test * b0_weights * self.y_std self.tau_hat_test = self.tau_hat_test * adaptive_coding_weights_test self.mu_hat_test = self.mu_hat_test + np.squeeze(control_adj_test) @@ -2921,9 +2985,7 @@ def predict( self.b1_samples - self.b0_samples, axis=(0, 2) ) if predict_mu_forest or predict_mu_forest_intermediate: - b0_weights = np.expand_dims( - self.b0_samples, axis=(0, 2) - ) + b0_weights = np.expand_dims(self.b0_samples, axis=(0, 2)) control_adj = tau_raw * b0_weights * self.y_std mu_x_forest = mu_x_forest + np.squeeze(control_adj) tau_raw = tau_raw * adaptive_coding_weights @@ -3341,9 +3403,7 @@ def compute_posterior_interval( raise ValueError( "'X' must be provided in order to compute the requested intervals" ) - if not isinstance(X, np.ndarray) and not isinstance( - X, pd.DataFrame - ): + if not isinstance(X, np.ndarray) and not isinstance(X, pd.DataFrame): raise ValueError("'X' must be a matrix or data frame") needs_treatment = needs_covariates if needs_treatment: @@ -3354,9 +3414,7 @@ def compute_posterior_interval( if not isinstance(Z, np.ndarray): raise ValueError("'Z' must be a numpy array") if Z.shape[0] != X.shape[0]: - raise ValueError( - "'Z' must have the same number of rows as 'X'" - ) + raise ValueError("'Z' must have the same number of rows as 'X'") uses_propensity = self.propensity_covariate != "none" internal_propensity_model = self.internal_propensity_model needs_propensity = ( @@ -3472,9 +3530,7 @@ def sample_posterior_predictive( raise ValueError( "'X' must be provided in order to compute the requested intervals" ) - if not isinstance(X, np.ndarray) and not isinstance( - X, pd.DataFrame - ): + if not isinstance(X, np.ndarray) and not isinstance(X, pd.DataFrame): raise ValueError("'X' must be a matrix or data frame") needs_treatment = needs_covariates if needs_treatment: @@ -3485,9 +3541,7 @@ def sample_posterior_predictive( if not isinstance(Z, np.ndarray): raise ValueError("'Z' must be a numpy array") if Z.shape[0] != X.shape[0]: - raise ValueError( - "'Z' must have the same number of rows as 'X'" - ) + raise ValueError("'Z' must have the same number of rows as 'X'") uses_propensity = self.propensity_covariate != "none" internal_propensity_model = self.internal_propensity_model needs_propensity = ( @@ -3523,9 +3577,7 @@ def sample_posterior_predictive( if not isinstance(rfx_basis, np.ndarray): raise ValueError("'rfx_basis' must be a numpy array") if rfx_basis.shape[0] != X.shape[0]: - raise ValueError( - "'rfx_basis' must have the same number of rows as 'X'" - ) + raise ValueError("'rfx_basis' must have the same number of rows as 'X'") # Compute posterior predictive samples bcf_preds = self.predict(