Skip to content

Commit

Permalink
Supress warnings in MeanAveragePrecision when intended (#2501)
Browse files Browse the repository at this point in the history
* suppress warnings when needed

* changelog
  • Loading branch information
SkafteNicki committed Apr 15, 2024
1 parent 9e83421 commit 5892316
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462))


- Fixed warnings being suppressed in `MeanAveragePrecision` when requested ([#2501](https://github.com/Lightning-AI/torchmetrics/pull/2501))


## [1.3.2] - 2024-03-18

### Fixed
Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,8 +827,9 @@ def _get_safe_item_values(
rle = self.mask_utils.encode(np.asfortranarray(i))
masks.append((tuple(rle["size"]), rle["counts"]))
output[1] = tuple(masks) # type: ignore[call-overload]
if (output[0] is not None and len(output[0]) > self.max_detection_thresholds[-1]) or (
output[1] is not None and len(output[1]) > self.max_detection_thresholds[-1]
if warn and (
(output[0] is not None and len(output[0]) > self.max_detection_thresholds[-1])
or (output[1] is not None and len(output[1]) > self.max_detection_thresholds[-1])
):
_warning_on_too_many_detections(self.max_detection_thresholds[-1])
return output # type: ignore[return-value]
Expand Down
12 changes: 9 additions & 3 deletions tests/unittests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,8 @@ def test_for_box_format(self, box_format, iou_val_expected, map_val_expected, ba
assert round(float(result["ious"][(0, 0)]), 3) == iou_val_expected

@pytest.mark.parametrize("iou_type", ["bbox", "segm"])
def test_warning_on_many_detections(self, iou_type, backend):
@pytest.mark.parametrize("warn_on_many_detections", [False, True])
def test_warning_on_many_detections(self, iou_type, warn_on_many_detections, backend, recwarn):
"""Test that a warning is raised when there are many detections."""
if iou_type == "bbox":
preds = [
Expand All @@ -727,8 +728,13 @@ def test_warning_on_many_detections(self, iou_type, backend):
preds, targets = _generate_random_segm_input("cpu", 1, 101, 10, False)

metric = MeanAveragePrecision(iou_type=iou_type, backend=backend)
with pytest.warns(UserWarning, match="Encountered more than 100 detections in a single image.*"):
metric.update(preds, targets)
metric.warn_on_many_detections = warn_on_many_detections

if warn_on_many_detections:
with pytest.warns(UserWarning, match="Encountered more than 100 detections in a single image.*"):
metric.update(preds, targets)
else:
assert len(recwarn) == 0

@pytest.mark.parametrize(
("preds", "target", "expected_iou_len", "iou_keys", "precision_shape", "recall_shape", "scores_shape"),
Expand Down

0 comments on commit 5892316

Please sign in to comment.