Skip to content

Commit

Permalink
Merge pull request #60 from StochasticTree/feature_subset_bcf_hotfix
Browse files Browse the repository at this point in the history
Updated feature subset code in python BCF interface
  • Loading branch information
andrewherren authored Jun 25, 2024
2 parents 0f61602 + ce517d0 commit 1d8762d
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions stochtree/bcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,13 +368,13 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
if all(isinstance(i, str) for i in drop_vars_mu):
if not np.all(np.isin(drop_vars_mu, X_train.columns)):
raise ValueError("drop_vars_mu includes some variable names that are not in X_train")
variable_subset_mu = [i for i in X_train.shape[1] if drop_vars_mu.count(X_train.columns.array[i]) == 0]
variable_subset_mu = [i for i in range(X_train.shape[1]) if drop_vars_mu.count(X_train.columns.array[i]) == 0]
elif all(isinstance(i, int) for i in drop_vars_mu):
if any(i >= X_train.shape[1] for i in drop_vars_mu):
raise ValueError("drop_vars_mu includes some variable indices that exceed the number of columns in X_train")
if any(i < 0 for i in drop_vars_mu):
raise ValueError("drop_vars_mu includes some negative variable indices")
variable_subset_mu = [i for i in X_train.shape[1] if drop_vars_mu.count(i) == 0]
variable_subset_mu = [i for i in range(X_train.shape[1]) if drop_vars_mu.count(i) == 0]
else:
raise ValueError("drop_vars_mu must be a list of variable names (str) or column indices (int)")
elif isinstance(drop_vars_mu, np.ndarray):
Expand All @@ -399,7 +399,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
if all(isinstance(i, str) for i in keep_vars_tau):
if not np.all(np.isin(keep_vars_tau, X_train.columns)):
raise ValueError("keep_vars_tau includes some variable names that are not in X_train")
variable_subset_tau = [i for i in X_train.shape[1] if keep_vars_tau.count(X_train.columns.array[i]) > 0]
variable_subset_tau = [i for i in range(X_train.shape[1]) if keep_vars_tau.count(X_train.columns.array[i]) > 0]
elif all(isinstance(i, int) for i in keep_vars_tau):
if any(i >= X_train.shape[1] for i in keep_vars_tau):
raise ValueError("keep_vars_tau includes some variable indices that exceed the number of columns in X_train")
Expand All @@ -412,7 +412,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
if keep_vars_tau.dtype == np.str_:
if not np.all(np.isin(keep_vars_tau, X_train.columns)):
raise ValueError("keep_vars_tau includes some variable names that are not in X_train")
variable_subset_tau = [i for i in X_train.shape[1] if keep_vars_tau.count(X_train.columns.array[i]) > 0]
variable_subset_tau = [i for i in range(X_train.shape[1]) if keep_vars_tau.count(X_train.columns.array[i]) > 0]
else:
if np.any(keep_vars_tau >= X_train.shape[1]):
raise ValueError("keep_vars_tau includes some variable indices that exceed the number of columns in X_train")
Expand All @@ -426,13 +426,13 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
if all(isinstance(i, str) for i in drop_vars_tau):
if not np.all(np.isin(drop_vars_tau, X_train.columns)):
raise ValueError("drop_vars_tau includes some variable names that are not in X_train")
variable_subset_tau = [i for i in X_train.shape[1] if drop_vars_tau.count(X_train.columns.array[i]) == 0]
variable_subset_tau = [i for i in range(X_train.shape[1]) if drop_vars_tau.count(X_train.columns.array[i]) == 0]
elif all(isinstance(i, int) for i in drop_vars_tau):
if any(i >= X_train.shape[1] for i in drop_vars_tau):
raise ValueError("drop_vars_tau includes some variable indices that exceed the number of columns in X_train")
if any(i < 0 for i in drop_vars_tau):
raise ValueError("drop_vars_tau includes some negative variable indices")
variable_subset_tau = [i for i in X_train.shape[1] if drop_vars_tau.count(i) == 0]
variable_subset_tau = [i for i in range(X_train.shape[1]) if drop_vars_tau.count(i) == 0]
else:
raise ValueError("drop_vars_tau must be a list of variable names (str) or column indices (int)")
elif isinstance(drop_vars_tau, np.ndarray):
Expand Down

0 comments on commit 1d8762d

Please sign in to comment.