Skip to content

Commit

Permalink
fix quantile detector when low/high threshold are the same (unit8co#1553
Browse files Browse the repository at this point in the history
)

* fix quantile detector when low/high threshold are the same

* add test

* Update darts/ad/detectors/threshold_detector.py

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>

* change syntax

* pre commit

* pre commit

---------

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
Co-authored-by: madtoinou <32447896+madtoinou@users.noreply.github.com>
Co-authored-by: Julien Herzen <julien@unit8.co>
  • Loading branch information
4 people authored and alexcolpitts96 committed May 31, 2023
1 parent 107fdc4 commit efd1616
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
5 changes: 3 additions & 2 deletions darts/ad/detectors/quantile_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,13 @@ def _prep_quantile(q):
raise_if_not(
all(
[
l < h
l <= h
for (l, h) in zip(self.low_quantile, self.high_quantile)
if ((l is not None) and (h is not None))
]
),
"all values in `low_quantile` must be lower than their corresponding value in `high_quantile`.",
"all values in `low_quantile` must be lower than or equal"
+ "to their corresponding value in `high_quantile`.",
)

def _fit_core(self, list_series: Sequence[TimeSeries]) -> None:
Expand Down
5 changes: 3 additions & 2 deletions darts/ad/detectors/threshold_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,13 @@ def _prep_thresholds(q):
raise_if_not(
all(
[
l < h
l <= h
for (l, h) in zip(self.low_threshold, self.high_threshold)
if ((l is not None) and (h is not None))
]
),
"all values in `low_threshold` must be lower than their corresponding value in `high_threshold`.",
"all values in `low_threshold` must be lower than or equal"
+ "to their corresponding value in `high_threshold`.",
)

def _detect_core(self, series: TimeSeries) -> TimeSeries:
Expand Down
9 changes: 7 additions & 2 deletions darts/tests/ad/test_detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_QuantileDetector(self):
with self.assertRaises(ValueError):
QuantileDetector(low_quantile=[-0.2, 0.3])

# Parameter high must be higher than parameter low
# Parameter high must be higher or equal than parameter low
with self.assertRaises(ValueError):
QuantileDetector(low_quantile=0.7, high_quantile=0.2)
with self.assertRaises(ValueError):
Expand All @@ -217,6 +217,11 @@ def test_QuantileDetector(self):
with self.assertRaises(ValueError):
QuantileDetector(low_quantile=[None], high_quantile=[None, None, None])

# check that low_threshold and high_threshold are the same and no errors are raised
detector = QuantileDetector(low_quantile=0.5, high_quantile=0.5)
detector.fit(self.train)
self.assertEqual(detector.low_threshold, detector.high_threshold)

# widths of series used for fitting must match the number of values given for high or/and low,
# if high and low have a length higher than 1

Expand Down Expand Up @@ -541,7 +546,7 @@ def test_ThresholdDetector(self):
with self.assertRaises(ValueError):
ThresholdDetector(low_threshold=[0.2, 0.1, 0.7], high_threshold=[0.95, 0.8])

# Parameter high must be higher than parameter low
# Parameter high must be higher or equal than parameter low
with self.assertRaises(ValueError):
ThresholdDetector(low_threshold=0.7, high_threshold=0.2)
with self.assertRaises(ValueError):
Expand Down

0 comments on commit efd1616

Please sign in to comment.