From 1c42f6643f9241089e55a4d899f14da480021e16 Mon Sep 17 00:00:00 2001 From: Tobias Kupek Date: Fri, 29 Oct 2021 12:42:36 +0200 Subject: [PATCH] MAP metric - fix empty predictions (#594) * Apply suggestions from code review Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec --- CHANGELOG.md | 8 +++++--- tests/detection/test_map.py | 17 +++++++++++++++++ torchmetrics/detection/map.py | 2 +- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 47801a831fa..5707c02edb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,16 +23,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fix empty predictions in MAP metric ([#594](https://github.com/PyTorchLightning/metrics/pull/594)) + ## [0.6.0] - 2021-10-28 ### Added - Added audio metrics: - - Perceptual Evaluation of Speech Quality (PESQ) ([#353](https://github.com/PyTorchLightning/metrics/issues/353)) - - Short Term Objective Intelligibility (STOI) ([#353](https://github.com/PyTorchLightning/metrics/issues/353)) + - Perceptual Evaluation of Speech Quality (PESQ) ([#353](https://github.com/PyTorchLightning/metrics/pull/353)) + - Short Term Objective Intelligibility (STOI) ([#353](https://github.com/PyTorchLightning/metrics/pull/353)) - Added Information retrieval metrics: - - `RetrievalRPrecision` ([#577](https://github.com/PyTorchLightning/metrics/pull/577/)) + - `RetrievalRPrecision` ([#577](https://github.com/PyTorchLightning/metrics/pull/577)) - `RetrievalHitRate` ([#576](https://github.com/PyTorchLightning/metrics/pull/576)) - Added NLP metrics: - `SacreBLEUScore` ([#546](https://github.com/PyTorchLightning/metrics/pull/546)) diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index 00fbc19848c..610a275c63e 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -206,6 +206,23 @@ def test_error_on_wrong_init(): MAP(class_metrics=0) +@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") +def test_empty_preds(): + """Test empty predictions.""" + + metric = MAP() + + metric.update( + [ + dict(boxes=torch.Tensor([[]]), scores=torch.Tensor([]), labels=torch.IntTensor([])), + ], + [ + dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=torch.IntTensor([4])), + ], + ) + metric.compute() + + @pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") def test_error_on_wrong_input(): """Test class input validation.""" diff --git a/torchmetrics/detection/map.py b/torchmetrics/detection/map.py index 3e3b9cbb154..a5d2397de50 100644 --- a/torchmetrics/detection/map.py +++ b/torchmetrics/detection/map.py @@ -365,7 +365,7 @@ def _get_coco_format( annotations = [] annotation_id = 1 # has to start with 1, otherwise COCOEval results are wrong - boxes = [box_convert(box, in_fmt="xyxy", out_fmt="xywh") for box in boxes] + boxes = [box_convert(box, in_fmt="xyxy", out_fmt="xywh") if boxes[0].size(1) == 4 else box for box in boxes] for image_id, (image_boxes, image_labels) in enumerate(zip(boxes, labels)): image_boxes = image_boxes.cpu().tolist() image_labels = image_labels.cpu().tolist()