Skip to content

Commit

Permalink
Merge pull request #172 from IFCA/fix-msprt-fit
Browse files Browse the repository at this point in the history
Add missing on_fit_end in mSPRT
  • Loading branch information
jaime-cespedes-sisniega committed Apr 28, 2023
2 parents bdc4fa0 + f90c517 commit 18e9c67
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 7 deletions.
6 changes: 3 additions & 3 deletions frouros/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def set_detector(self, detector) -> None:
# )
# self._detector = value

def on_fit_start(self) -> None:
def on_fit_start(self, **kwargs) -> None:
"""On fit start method."""

def on_fit_end(self) -> None:
def on_fit_end(self, **kwargs) -> None:
"""On fit end method."""

def on_drift_detected(self) -> None:
def on_drift_detected(self, **kwargs) -> None:
"""On drift detected method."""

@abc.abstractmethod
Expand Down
6 changes: 5 additions & 1 deletion frouros/callbacks/streaming/msprt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class mSPRT(StreamingCallback): # noqa: N801 # pylint: disable=invalid-name

def __init__(
self,
alpha: float = 0.05,
alpha: float,
sigma: float = 1.0,
tau: Optional[float] = None,
truncation: int = 1,
Expand Down Expand Up @@ -97,6 +97,10 @@ def tau(self, value: Optional[float]) -> None:
raise TypeError("tau must be a float or None")
self._tau = value

def on_fit_end(self, **kwargs) -> None:
"""On fit end method."""
self.incremental_mean.num_values = len(kwargs["X"])

def on_update_end(self, value: Union[int, float], **kwargs) -> None:
"""On update end method.
Expand Down
2 changes: 1 addition & 1 deletion frouros/detectors/data_drift/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def fit(self, X: np.ndarray, **kwargs) -> Dict[str, Any]: # noqa: N803
callback.on_fit_start()
self._fit(X=X, **kwargs)
for callback in self.callbacks: # type: ignore
callback.on_fit_end()
callback.on_fit_end(X=X, **kwargs)

logs = self._get_callbacks_logs()
return logs
Expand Down
2 changes: 1 addition & 1 deletion frouros/tests/integration/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def test_streaming_warning_samples_buffer_on_concept_drift(
" expected_p_value,"
" expected_likelihood",
[
(MMDStreaming, 70, 0.3622854, 0.0324733, 30.79452443),
(MMDStreaming, 40, 0.08821576, 0.00494882, 202.06836342),
],
)
def test_streaming_msprt_multivariate_different_distribution(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "frouros"
version = "0.2.4"
version = "0.2.5"
description = "A Python library for drift detection in Machine Learning problems"
authors = [
{name = "Jaime Céspedes Sisniega", email = "cespedes@ifca.unican.es"}
Expand Down

0 comments on commit 18e9c67

Please sign in to comment.