Skip to content

Commit

Permalink
Improve backwards compatibility (#717)
Browse files Browse the repository at this point in the history
  • Loading branch information
lbittarello committed Oct 25, 2023
1 parent 6a89d52 commit 511b2df
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions src/glum/_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,7 @@ def linear_predictor(
copy=True,
ensure_2d=True,
allow_nd=False,
drop_first=self.drop_first,
drop_first=getattr(self, "drop_first", False),
)

if alpha_index is None:
Expand Down Expand Up @@ -1321,7 +1321,7 @@ def predict(
copy=self._should_copy_X(),
ensure_2d=True,
allow_nd=False,
drop_first=self.drop_first,
drop_first=getattr(self, "drop_first", False),
)
eta = self.linear_predictor(
X, offset=offset, alpha_index=alpha_index, alpha=alpha
Expand Down Expand Up @@ -1889,12 +1889,12 @@ def covariance_matrix(
self.covariance_matrix_: Union[np.ndarray, None]

if robust is None:
_robust = self.robust
_robust = getattr(self, "robust", True)
else:
_robust = robust

if expected_information is None:
_expected_information = self.expected_information
_expected_information = getattr(self, "expected_information", False)
else:
_expected_information = expected_information

Expand Down Expand Up @@ -1960,7 +1960,7 @@ def covariance_matrix(
copy=self._should_copy_X(),
ensure_2d=True,
allow_nd=False,
drop_first=self.drop_first,
drop_first=getattr(self, "drop_first", False),
)

if isinstance(X, np.ndarray):
Expand Down Expand Up @@ -2255,7 +2255,7 @@ def _set_up_and_check_fit_args(
self.feature_names_ = list(
chain.from_iterable(
_name_categorical_variables(
dtype.categories, column, self.drop_first
dtype.categories, column, getattr(self, "drop_first", False)
)
if isinstance(dtype, pd.CategoricalDtype)
else [column]
Expand Down Expand Up @@ -2295,10 +2295,14 @@ def _expand_categorical_penalties(penalty, X, drop_first):
else:
return penalty

P1 = _expand_categorical_penalties(self.P1, X, self.drop_first)
P2 = _expand_categorical_penalties(self.P2, X, self.drop_first)
P1 = _expand_categorical_penalties(
self.P1, X, getattr(self, "drop_first", False)
)
P2 = _expand_categorical_penalties(
self.P2, X, getattr(self, "drop_first", False)
)

X = tm.from_pandas(X, drop_first=self.drop_first)
X = tm.from_pandas(X, drop_first=getattr(self, "drop_first", False))
else:
self.feature_names_ = X.columns
X = tm.from_pandas(X)
Expand Down Expand Up @@ -2336,7 +2340,7 @@ def _expand_categorical_penalties(penalty, X, drop_first):
dtype=_dtype,
copy=copy_X,
force_all_finite=force_all_finite,
drop_first=self.drop_first,
drop_first=getattr(self, "drop_first", False),
)
self._check_n_features(X, reset=True)
else:
Expand Down Expand Up @@ -3074,9 +3078,9 @@ def fit(
y=y,
offset=offset,
sample_weight=sample_weight * weights_sum,
robust=self.robust,
robust=getattr(self, "robust", True),
clusters=clusters,
expected_information=self.expected_information,
expected_information=getattr(self, "expected_information", False),
store_covariance_matrix=True,
skip_checks=True,
)
Expand Down

0 comments on commit 511b2df

Please sign in to comment.