Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions stochtree/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,19 +765,8 @@ def sample(
)
previous_bart_model = BARTModel()
previous_bart_model.from_json(previous_model_json)
previous_y_bar = previous_bart_model.y_bar
previous_y_scale = previous_bart_model.y_std
previous_model_num_samples = previous_bart_model.num_samples
if previous_bart_model.include_mean_forest:
previous_forest_samples_mean = previous_bart_model.forest_container_mean
else:
previous_forest_samples_mean = None
if previous_bart_model.include_variance_forest:
previous_forest_samples_variance = (
previous_bart_model.forest_container_variance
)
else:
previous_forest_samples_variance = None
if previous_bart_model.sample_sigma2_global:
previous_global_var_samples = previous_bart_model.global_var_samples / (
previous_y_scale * previous_y_scale
Expand All @@ -788,22 +777,14 @@ def sample(
previous_leaf_var_samples = previous_bart_model.leaf_scale_samples
else:
previous_leaf_var_samples = None
if previous_bart_model.has_rfx:
previous_rfx_samples = previous_bart_model.rfx_container
else:
previous_rfx_samples = None
if previous_model_warmstart_sample_num + 1 > previous_model_num_samples:
raise ValueError(
"`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`"
)
else:
previous_y_bar = None
previous_y_scale = None
previous_global_var_samples = None
previous_leaf_var_samples = None
previous_rfx_samples = None
previous_forest_samples_mean = None
previous_forest_samples_variance = None
previous_model_num_samples = 0

# Update variable weights if the covariates have been resized (by e.g. one-hot encoding)
Expand Down Expand Up @@ -1772,7 +1753,6 @@ def predict(
rfx_intercept = rfx_model_spec == "intercept_only"
if not isinstance(terms, str) and not isinstance(terms, list):
raise ValueError("type must be a string or list of strings")
num_terms = 1 if isinstance(terms, str) else len(terms)
has_mean_forest = self.include_mean_forest
has_variance_forest = self.include_variance_forest
has_rfx = self.has_rfx
Expand Down
2 changes: 0 additions & 2 deletions stochtree/bcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,7 +1656,6 @@ def sample(
if sample_sigma2_leaf_tau:
self.leaf_scale_tau_samples = np.empty(self.num_samples, dtype=np.float64)
muhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64)
tauhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64)
if self.include_variance_forest:
sigma2_x_train_raw = np.empty(
(self.n_train, self.num_samples), dtype=np.float64
Expand Down Expand Up @@ -2442,7 +2441,6 @@ def predict(
raise ValueError(
f"term '{term}' was requested. Valid terms are 'y_hat', 'prognostic_function', 'mu', 'cate', 'tau', 'rfx', 'variance_forest', and 'all'"
)
num_terms = 1 if isinstance(terms, str) else len(terms)
has_mu_forest = True
has_tau_forest = True
has_variance_forest = self.include_variance_forest
Expand Down
12 changes: 6 additions & 6 deletions stochtree/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,9 @@ def _fit_numpy(self, covariates: np.array) -> None:
self._onehot_feature_index = np.array(
[-1 for i in range(self._num_original_features)], dtype=int
)
self._original_feature_types = np.array(
["float" for i in range(self._num_original_features)]
)
self._original_feature_types = np.array([
"float" for i in range(self._num_original_features)
])

# Check whether the array is numeric
cov_dtype = covariates.dtype
Expand Down Expand Up @@ -443,9 +443,9 @@ def _transform_numpy(self, covariates: np.array) -> np.array:
raise ValueError(
"Attempting to call transform from a CovariateTransformer that was fit on a dataset with different dimensionality"
)
self._original_feature_indices = np.array(
[i for i in range(covariates.shape[1])]
)
self._original_feature_indices = np.array([
i for i in range(covariates.shape[1])
])
return covariates

def _transform(self, covariates: Union[pd.DataFrame, np.array]) -> np.array:
Expand Down
18 changes: 9 additions & 9 deletions stochtree/random_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,20 +367,20 @@ def predict(self, group_labels: np.array, basis: np.array) -> np.ndarray:
return self.rfx_container_cpp.Predict(
rfx_dataset.rfx_dataset_cpp, self.rfx_label_mapper_cpp
)

def extract_parameter_samples(self) -> dict[str, np.ndarray]:
"""
Extract the random effects parameters sampled. With the "redundant parameterization" of Gelman et al (2008),
this includes four parameters: alpha (the "working parameter" shared across every group), xi
(the "group parameter" sampled separately for each group), beta (the product of alpha and xi,
which corresponds to the overall group-level random effects), and sigma (group-independent prior
Extract the random effects parameters sampled. With the "redundant parameterization" of Gelman et al (2008),
this includes four parameters: alpha (the "working parameter" shared across every group), xi
(the "group parameter" sampled separately for each group), beta (the product of alpha and xi,
which corresponds to the overall group-level random effects), and sigma (group-independent prior
variance for each component of xi).

Returns
-------
dict[str, np.ndarray]
dict of arrays. The alpha array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`.
The xi and beta arrays have dimension (`num_components`, `num_groups`, `num_samples`) and are simply matrices if `num_components = 1`.
dict of arrays. The alpha array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`.
The xi and beta arrays have dimension (`num_components`, `num_groups`, `num_samples`) and are simply matrices if `num_components = 1`.
The sigma array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`.
"""
# num_samples = self.rfx_container_cpp.NumSamples()
Expand All @@ -391,10 +391,10 @@ def extract_parameter_samples(self) -> dict[str, np.ndarray]:
alpha_samples = np.squeeze(self.rfx_container_cpp.GetAlpha())
sigma_samples = np.squeeze(self.rfx_container_cpp.GetSigma())
output = {
"beta_samples": beta_samples,
"beta_samples": beta_samples,
"xi_samples": xi_samples,
"alpha_samples": alpha_samples,
"sigma_samples": sigma_samples
"sigma_samples": sigma_samples,
}
return output

Expand Down
16 changes: 13 additions & 3 deletions stochtree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,21 +334,31 @@ def _expand_dims_2d_diag(
)
return output

def _posterior_predictive_heuristic_multiplier(num_samples: int, num_observations: int) -> int:

def _posterior_predictive_heuristic_multiplier(
num_samples: int, num_observations: int
) -> int:
if num_samples >= 1000:
return 1
else:
return math.ceil(1000 / num_samples)

def _summarize_interval(array: np.ndarray, sample_dim: int = 2, level: float = 0.95) -> dict:

def _summarize_interval(
array: np.ndarray, sample_dim: int = 2, level: float = 0.95
) -> dict:
# Check that the array is numeric and at least 2 dimensional
if not isinstance(array, np.ndarray):
raise ValueError("`array` must be a numpy array")
if not _check_array_numeric(array):
raise ValueError("`array` must be a numeric numpy array")
if not len(array.shape) >= 2:
raise ValueError("`array` must be at least a 2-dimensional numpy array")
if not _check_is_int(sample_dim) or (sample_dim < 0) or (sample_dim >= len(array.shape)):
if (
not _check_is_int(sample_dim)
or (sample_dim < 0)
or (sample_dim >= len(array.shape))
):
raise ValueError(
"`sample_dim` must be an integer between 0 and the number of dimensions of `array` - 1"
)
Expand Down
Loading