-
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(reindex_missing_columns): add
ReindexMissingColumns
and `Repor…
…tNonFinite` (#27) * feat(reindex_missing_columns): add `ReindexMissingColumns` * feat: add `ReportNonFinite` * feat: export modules * docs(readme): update README.md * fix: import annotations from __future__
- Loading branch information
Showing
9 changed files
with
756 additions
and
463 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,3 +140,4 @@ dmypy.json | |
cython_debug/ | ||
.VSCodeCounter/ | ||
catboost_info/ | ||
sklearn_utilities_info/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from __future__ import annotations | ||
|
||
import warnings | ||
from typing import Any, Callable, Literal | ||
|
||
from pandas import DataFrame, Index | ||
from sklearn.base import BaseEstimator, TransformerMixin | ||
from typing_extensions import Self | ||
|
||
from .types import TXPandas | ||
|
||
|
||
class ReindexMissingColumns(BaseEstimator, TransformerMixin): | ||
"""Reindex X to match the columns of the training data to avoid errors.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
if_missing: Literal["warn", "raise"] | ||
| Callable[[Index[Any], Index[Any]], None] = "warn", | ||
reindex_kwargs: dict[ | ||
Literal["method", "copy", "level", "fill_value", "limit", "tolerance"], Any | ||
] = {}, | ||
) -> None: | ||
"""Reindex X to match the columns of the training data to avoid errors. | ||
Parameters | ||
---------- | ||
if_missing : Literal['warn', 'raise'] | Callable[[Index[Any], Index[Any]], None], optional | ||
If callable, the first argument is the expected columns and the | ||
second argument is the actual columns, by default 'warn' | ||
reindex_kwargs : dict[Literal['method', 'copy', 'level', 'fill_value', | ||
'limit', 'tolerance'], Any], optional | ||
Keyword arguments to pass to reindex, by default {} | ||
""" | ||
self.if_missing = if_missing | ||
self.reindex_kwargs = reindex_kwargs | ||
|
||
def fit(self, X: DataFrame, y: Any = None, **fit_params: Any) -> Self: | ||
self.feature_names_in_ = X.columns | ||
return self | ||
|
||
def transform(self, X: TXPandas, y: Any = None, **fit_params: Any) -> TXPandas: | ||
expected_columns = self.feature_names_in_ | ||
actual_columns = X.columns | ||
if not expected_columns.equals(actual_columns): | ||
missing_columns = expected_columns.difference(actual_columns) | ||
if self.if_missing == "warn": | ||
warnings.warn(f"Missing columns: {missing_columns}") | ||
elif self.if_missing == "raise": | ||
raise ValueError(f"Missing columns: {missing_columns}") | ||
elif isinstance(self.if_missing, Callable): # type: ignore | ||
self.if_missing(expected_columns, actual_columns) | ||
else: | ||
raise ValueError(f"Invalid value for if_missing: {self.if_missing}") | ||
|
||
return X.reindex(columns=self.feature_names_in_, **self.reindex_kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
from __future__ import annotations | ||
|
||
from logging import getLogger | ||
from pathlib import Path | ||
from typing import Any, Callable | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from matplotlib.figure import Figure | ||
from pandas import DataFrame, Series, Timestamp | ||
from sklearn.base import BaseEstimator, TransformerMixin | ||
from typing_extensions import Self | ||
|
||
from .types import TXPandas | ||
|
||
LOG = getLogger(__name__) | ||
|
||
|
||
class ReportNonFinite(BaseEstimator, TransformerMixin): | ||
"""Report non-finite values in X or y.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
on_fit: bool = False, | ||
on_fit_y: bool = False, | ||
on_transform: bool = True, | ||
plot: bool = True, | ||
calc_corr: bool = False, | ||
callback: Callable[[dict[str, DataFrame | Series]], None] | None = None, | ||
callback_figure: Callable[[Figure], None] | ||
| None = lambda fig: Path("sklearn_utilities_info/ReportNonFinite").mkdir( # type: ignore | ||
parents=True, exist_ok=True | ||
) | ||
or fig.savefig( | ||
Path("sklearn_utilities_info/ReportNonFinite") | ||
/ f"{Timestamp.now().isoformat().replace(':', '-')}.png" | ||
), | ||
) -> None: | ||
"""Report non-finite values in X or y. | ||
Parameters | ||
---------- | ||
on_fit : bool, optional | ||
Whether to report non-finite values in X during fit, by default False | ||
on_fit_y : bool, optional | ||
Whether to report non-finite values in y during fit, by default False | ||
on_transform : bool, optional | ||
Whether to report non-finite values in X during transform, by default True | ||
plot : bool, optional | ||
Whether to plot the report result, by default True | ||
calc_corr : bool, optional | ||
Whether to calculate the correlation of non-finite values, by default False | ||
callback : Callable[[dict[str, DataFrame | Series]], None] | None, optional | ||
The callback function, by default None | ||
callback_figure : _type_, optional | ||
The callback function for figure, by default | ||
`lambda fig: | ||
Path("sklearn-utilities/ReportNonFinite").mkdir(parents=True, exist_ok=True) | ||
or fig.savefig(Path("sklearn-utilities/ReportNonFinite") / | ||
f"{Timestamp.now().isoformat().replace(':', '-')}.png")` | ||
""" | ||
self.on_fit = on_fit | ||
self.on_fit_y = on_fit_y | ||
self.on_transform = on_transform | ||
self.plot = plot | ||
self.calc_corr = calc_corr | ||
self.callback = callback | ||
self.callback_figure = callback_figure | ||
|
||
def fit(self, X: DataFrame, y: Any = None, **fit_params: Any) -> Self: | ||
if self.on_fit: | ||
try: | ||
self._report(X, "fit") | ||
except Exception as e: | ||
LOG.warning(f"Failed to report non-finite values in X during fit: {e}") | ||
LOG.exception(e) | ||
|
||
if self.on_fit_y: | ||
try: | ||
DataFrame(y) | ||
except Exception as e: | ||
LOG.warning(f"Failed to convert y to DataFrame during fit: {e}") | ||
LOG.exception(e) | ||
|
||
try: | ||
self._report(DataFrame(y), "fit_y") | ||
except Exception as e: | ||
LOG.warning(f"Failed to report non-finite values in y during fit: {e}") | ||
LOG.exception(e) | ||
return self | ||
|
||
def transform(self, X: TXPandas, y: Any = None, **fit_params: Any) -> TXPandas: | ||
if self.on_transform: | ||
try: | ||
self._report(X, "transform") | ||
except Exception as e: | ||
LOG.warning( | ||
f"Failed to report non-finite values in X during transform: {e}" | ||
) | ||
LOG.exception(e) | ||
return X | ||
|
||
def _report(self, X: TXPandas, caller: str = "") -> TXPandas: | ||
"""Report non-finite values in X. | ||
Parameters | ||
---------- | ||
X : TXPandas | ||
Input data. | ||
caller : str, optional | ||
The caller name used in the log message, by default "". | ||
Returns | ||
------- | ||
TXPandas | ||
Input data. | ||
""" | ||
is_na = X.isna() | ||
is_inf = X.isin([np.inf, -np.inf]) | ||
is_non_finite = is_na | is_inf | ||
|
||
d: dict[str, DataFrame | Series] = { | ||
"nan_rate_by_column": is_na.mean(), | ||
"inf_rate_by_column": is_inf.mean(), | ||
"nan_rate_by_row": is_na.mean(axis=1), | ||
"inf_rate_by_row": is_inf.mean(axis=1), | ||
} | ||
d = d | { | ||
"non_finite_rate_by_column": d["nan_rate_by_column"] | ||
+ d["inf_rate_by_column"], | ||
"non_finite_rate_by_row": d["nan_rate_by_row"] + d["inf_rate_by_row"], | ||
} | ||
|
||
if self.calc_corr: | ||
d["nan_rate_corr_by_column"] = is_na.corr() | ||
d["inf_rate_corr_by_column"] = is_inf.corr() | ||
d["non_finite_rate_corr_by_column"] = is_non_finite.corr() | ||
|
||
LOG.info(f"Non-finite values in X during {caller}: {d}") | ||
|
||
if self.plot: | ||
import seaborn as sns | ||
|
||
fig, axes = plt.subplots(3, 3 if self.calc_corr else 2, figsize=(20, 10)) | ||
fig.suptitle(f"Non-finite values in X during {caller}") | ||
d["nan_rate_by_column"].plot( | ||
ax=axes[0, 0], | ||
kind="bar", | ||
title="NaN rate By column", | ||
xlabel="column name", | ||
ylabel="NaN rate", | ||
) | ||
d["inf_rate_by_column"].plot( | ||
ax=axes[1, 0], | ||
kind="bar", | ||
title="Inf rate By column", | ||
xlabel="column name", | ||
ylabel="Inf rate", | ||
) | ||
d["non_finite_rate_by_column"].plot( | ||
ax=axes[2, 0], | ||
kind="bar", | ||
title="Non-finite rate By column", | ||
xlabel="column name", | ||
ylabel="Non-finite rate", | ||
) | ||
d["nan_rate_by_row"].plot( | ||
ax=axes[0, 1], | ||
kind="line", | ||
title="NaN rate By row", | ||
xlabel="row index", | ||
ylabel="NaN rate", | ||
) | ||
d["inf_rate_by_row"].plot( | ||
ax=axes[1, 1], | ||
kind="line", | ||
title="Inf rate By row", | ||
xlabel="row index", | ||
ylabel="Inf rate", | ||
) | ||
d["non_finite_rate_by_row"].plot( | ||
ax=axes[2, 1], | ||
kind="line", | ||
title="Non-finite rate By row", | ||
xlabel="row index", | ||
ylabel="Non-finite rate", | ||
) | ||
if self.calc_corr: | ||
sns.heatmap( | ||
d["nan_rate_corr_by_column"], ax=axes[0, 2], vmin=-1, vmax=1 | ||
) | ||
axes[0, 2].set_title("NaN rate Corr By column") | ||
sns.heatmap( | ||
d["inf_rate_corr_by_column"], ax=axes[1, 2], vmin=-1, vmax=1 | ||
) | ||
axes[1, 2].set_title("Inf rate Corr By column") | ||
sns.heatmap( | ||
d["non_finite_rate_corr_by_column"], ax=axes[2, 2], vmin=-1, vmax=1 | ||
) | ||
axes[2, 2].set_title("Non-finite rate Corr By column") | ||
|
||
# tight layout | ||
plt.tight_layout() | ||
|
||
# callback | ||
if self.callback_figure is not None: | ||
self.callback_figure(fig) | ||
|
||
if self.callback is not None: | ||
self.callback(d) | ||
return X |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
from pandas.testing import assert_frame_equal | ||
|
||
from sklearn_utilities.reindex_missing_columns import ReindexMissingColumns | ||
|
||
|
||
def test_reindex_missing_columns() -> None: | ||
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) | ||
df_missing = df.drop(columns=["a"], inplace=False) | ||
df_expected = pd.DataFrame({"a": [np.nan] * 3, "b": [4, 5, 6]}) | ||
estimator = ReindexMissingColumns().fit(df) | ||
|
||
with pytest.warns(UserWarning): | ||
df_out = estimator.transform(df_missing) | ||
|
||
assert_frame_equal(df_out, df_expected) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import Any | ||
|
||
from pandas import DataFrame | ||
from pandas.testing import assert_frame_equal | ||
|
||
from sklearn_utilities.report_non_finite import ReportNonFinite | ||
|
||
|
||
def test_fit(caplog: Any) -> None: | ||
caplog.set_level("DEBUG") | ||
X = DataFrame( | ||
{ | ||
"clean": [1, 2, 3], | ||
"inf": [4, float("inf"), 6], | ||
"nan": [7, 8, float("nan")], | ||
"both": [float("inf"), float("nan"), 9], | ||
} | ||
) | ||
transformer = ReportNonFinite(on_fit=True, calc_corr=True) | ||
transformer.fit(X) | ||
# Add assertions to check the expected behavior of the fit method | ||
assert "Non-finite values in X during fit" in caplog.text | ||
|
||
|
||
def test_transform(caplog: Any) -> None: | ||
caplog.set_level("DEBUG") | ||
X = DataFrame( | ||
{ | ||
"clean": [1, 2, 3], | ||
"inf": [4, float("inf"), 6], | ||
"nan": [7, 8, float("nan")], | ||
"both": [float("inf"), float("nan"), 9], | ||
} | ||
) | ||
transformer = ReportNonFinite(on_transform=True, calc_corr=True) | ||
X_transformed = transformer.transform(X) | ||
# Add assertions to check the expected behavior of the transform method | ||
assert "Non-finite values in X during transform" in caplog.text | ||
assert_frame_equal(X_transformed, X) |