From 06c803a3c3b9ccb8b90a39d1b7bb05d15d207bee Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sun, 16 Mar 2025 18:57:14 -0500 Subject: [PATCH 1/2] Override keep_gfr if no MCMC samples --- stochtree/bart.py | 4 ++++ stochtree/bcf.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/stochtree/bart.py b/stochtree/bart.py index 0159fc92..0d0601ba 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -226,6 +226,10 @@ def sample(self, X_train: Union[np.array, pd.DataFrame], y_train: np.array, basi keep_vars_variance = variance_forest_params_updated['keep_vars'] drop_vars_variance = variance_forest_params_updated['drop_vars'] + # Override keep_gfr if there are no MCMC samples + if num_mcmc == 0: + keep_gfr = True + # Check that num_chains >= 1 if not isinstance(num_chains, Integral) or num_chains < 1: raise ValueError("num_chains must be an integer greater than 0") diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 4f24234b..2c518590 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -299,6 +299,10 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr keep_vars_variance = variance_forest_params_updated['keep_vars'] drop_vars_variance = variance_forest_params_updated['drop_vars'] + # Override keep_gfr if there are no MCMC samples + if num_mcmc == 0: + keep_gfr = True + # Variable weight preprocessing (and initialization if necessary) if variable_weights is None: if X_train.ndim > 1: From fc2404052c3efff6f1121aaf6f476a8dbfb644af Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sun, 16 Mar 2025 18:59:29 -0500 Subject: [PATCH 2/2] Fixed indexing issue in post-sampler cleanup in R and Python --- R/bart.R | 6 +++--- R/bcf.R | 8 ++++---- stochtree/bart.py | 4 ++-- stochtree/bcf.py | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/R/bart.R b/R/bart.R index 64422438..893b3bbb 100644 --- a/R/bart.R +++ b/R/bart.R @@ -826,13 +826,13 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if ((!keep_gfr) && (num_gfr > 0)) { for (i in 1:num_gfr) { if (include_mean_forest) { - forest_samples_mean$delete_sample(i-1) + forest_samples_mean$delete_sample(0) } if (include_variance_forest) { - forest_samples_variance$delete_sample(i-1) + forest_samples_variance$delete_sample(0) } if (has_rfx) { - rfx_samples$delete_sample(i-1) + rfx_samples$delete_sample(0) } } if (sample_sigma_global) { diff --git a/R/bcf.R b/R/bcf.R index 60821e79..fcbd1347 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -1214,13 +1214,13 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Remove GFR samples if they are not to be retained if ((!keep_gfr) && (num_gfr > 0)) { for (i in 1:num_gfr) { - forest_samples_mu$delete_sample(i-1) - forest_samples_tau$delete_sample(i-1) + forest_samples_mu$delete_sample(0) + forest_samples_tau$delete_sample(0) if (include_variance_forest) { - forest_samples_variance$delete_sample(i-1) + forest_samples_variance$delete_sample(0) } if (has_rfx) { - rfx_samples$delete_sample(i-1) + rfx_samples$delete_sample(0) } } if (sample_sigma_global) { diff --git a/stochtree/bart.py b/stochtree/bart.py index 0d0601ba..dfb8749c 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -684,9 +684,9 @@ def sample(self, X_train: Union[np.array, pd.DataFrame], y_train: np.array, basi if not keep_gfr and num_gfr > 0: for i in range(num_gfr): if self.include_mean_forest: - self.forest_container_mean.delete_sample(i) + self.forest_container_mean.delete_sample(0) if self.include_variance_forest: - self.forest_container_variance.delete_sample(i) + self.forest_container_variance.delete_sample(0) if self.sample_sigma_global: self.global_var_samples = self.global_var_samples[num_gfr:] if self.sample_sigma_leaf: diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 2c518590..d35999f9 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -1055,10 +1055,10 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr # Remove GFR samples if they are not to be retained if not keep_gfr and num_gfr > 0: for i in range(num_gfr): - self.forest_container_mu.delete_sample(i) - self.forest_container_tau.delete_sample(i) + self.forest_container_mu.delete_sample(0) + self.forest_container_tau.delete_sample(0) if self.include_variance_forest: - self.forest_container_variance.delete_sample(i) + self.forest_container_variance.delete_sample(0) if self.adaptive_coding: self.b1_samples = self.b1_samples[num_gfr:] self.b0_samples = self.b0_samples[num_gfr:]