Skip to content

Commit

Permalink
Fix issue #470 (#477)
Browse files Browse the repository at this point in the history
* renamed function with reference to qc.matrix

* added changelog
  • Loading branch information
MarcAntoineSchmidtQC committed Oct 26, 2021
1 parent 87f6b01 commit 5b40ae1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Unreleased
**Bug fix:**

- Fixed the sign of the log likelihood of the Gaussian distribution (not used for fitting coefficients).

- Renamed functions checking for qc.matrix compliance to refer to tabmat.

2.0.1 - 2021-10-11
------------------
Expand Down
16 changes: 8 additions & 8 deletions src/glum/_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@
]


def check_array_matrix_compliant(mat: ArrayLike, **kwargs):
def check_array_tabmat_compliant(mat: ArrayLike, **kwargs):
to_copy = kwargs.get("copy", False)

if isinstance(mat, pd.DataFrame) and any(mat.dtypes == "category"):
mat = tm.from_pandas(mat)

if isinstance(mat, tm.SplitMatrix):
kwargs.update({"ensure_min_features": 0})
new_matrices = [check_array_matrix_compliant(m, **kwargs) for m in mat.matrices]
new_matrices = [check_array_tabmat_compliant(m, **kwargs) for m in mat.matrices]
new_indices = [elt.copy() for elt in mat.indices] if to_copy else mat.indices
return tm.SplitMatrix(new_matrices, new_indices)

Expand All @@ -104,7 +104,7 @@ def check_array_matrix_compliant(mat: ArrayLike, **kwargs):

if isinstance(mat, tm.StandardizedMatrix):
return tm.StandardizedMatrix(
check_array_matrix_compliant(mat.mat, **kwargs),
check_array_tabmat_compliant(mat.mat, **kwargs),
check_array(mat.shift, **kwargs),
)

Expand All @@ -117,7 +117,7 @@ def check_array_matrix_compliant(mat: ArrayLike, **kwargs):
return res


def check_X_y_matrix_compliant(
def check_X_y_tabmat_compliant(
X: ArrayLike, y: Union[VectorLike, sparse.spmatrix], **kwargs
) -> Tuple[Union[tm.MatrixBase, sparse.spmatrix, np.ndarray], np.ndarray]:
"""
Expand All @@ -142,7 +142,7 @@ def check_X_y_matrix_compliant(
y = y.astype(np.float64)

check_consistent_length(X, y)
X = check_array_matrix_compliant(X, **kwargs)
X = check_array_tabmat_compliant(X, **kwargs)

return X, y

Expand Down Expand Up @@ -1186,7 +1186,7 @@ def linear_predictor(
elif alpha is not None:
alpha_index = [self._find_alpha_index(a) for a in alpha] # type: ignore

X = check_array_matrix_compliant(
X = check_array_tabmat_compliant(
X,
accept_sparse=["csr", "csc", "coo"],
dtype="numeric",
Expand Down Expand Up @@ -1258,7 +1258,7 @@ def predict(
if isinstance(X, pd.DataFrame) and hasattr(self, "feature_dtypes_"):
X = _align_df_categories(X, self.feature_dtypes_)

X = check_array_matrix_compliant(
X = check_array_tabmat_compliant(
X,
accept_sparse=["csr", "csc", "coo"],
dtype="numeric",
Expand Down Expand Up @@ -1770,7 +1770,7 @@ def _expand_categorical_penalties(penalty, X):
X = X.astype(np.float64) # type: ignore

if isinstance(X, tm.MatrixBase):
X, y = check_X_y_matrix_compliant(
X, y = check_X_y_tabmat_compliant(
X,
y,
accept_sparse=_stype,
Expand Down

0 comments on commit 5b40ae1

Please sign in to comment.