From 7143321493f42975e2d892a7c82c6191a7373d91 Mon Sep 17 00:00:00 2001 From: Dhananjai Sharma Date: Sun, 27 Nov 2022 20:08:05 +0800 Subject: [PATCH 1/4] Add test for case when there are preds but no GT for a class --- tests/unittests/detection/test_map.py | 47 +++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index bd63e9676ed..c9a2bd8cdae 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -534,6 +534,53 @@ def test_missing_gt(): assert result["map"] < 1, "MAP cannot be 1, as there is an image with no ground truth, but some predictions." +@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") +def test_class_metrics_with_missing_gt(): + """ + Checks MAP for each class when there are 4 detections, each for a + different class. But there are targets for only 2 classes. Hence, MAP + should be lower than 1. MAP for classes with targets should be 1 and 0 for + the others. + """ + # Example source: Issue https://github.com/Lightning-AI/metrics/issues/1184 + preds = [ + dict( + boxes=torch.Tensor( + [ + [0, 0, 20, 20], + [30, 30, 50, 50], + [70, 70, 90, 90], # FP + [100, 100, 120, 120], # FP + ] + ), + scores=torch.Tensor([0.6, 0.6, 0.6, 0.6]), + labels=torch.IntTensor([0, 1, 2, 3]), + ) + ] + + targets = [ + dict( + boxes=torch.Tensor([[0, 0, 20, 20], [30, 30, 50, 50]]), + labels=torch.IntTensor([0, 1]), + ) + ] + + metric = MeanAveragePrecision(class_metrics=True) + metric.update(preds, targets) + result = metric.compute() + + assert result["map"] < 1, "MAP cannot be 1, as for class 2 and 3, there are some predictions, but not targets." + + result_map_per_class = result.get("map_per_class", None) + assert result_map_per_class is not None, "map_per_class must be present in results." + assert isinstance(result_map_per_class, Tensor), "map_per_class must be a tensor" + assert len(result_map_per_class) == 4, "map_per_class must be of length 4, same as the number of classes." + assert result_map_per_class[0].item() == 1.0, "map for class 0 must be 1." + assert result_map_per_class[1].item() == 1.0, "map for class 1 must be 1." + assert result_map_per_class[2].item() == 0.0, "map for class 2 must be 0." + assert result_map_per_class[3].item() == 0.0, "map for class 3 must be 0." + + @pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") def test_segm_iou_empty_gt_mask(): """Test empty ground truths.""" From 16219830abe553d0f73af628ca42d05e9815b350 Mon Sep 17 00:00:00 2001 From: Dhananjai Sharma Date: Sun, 27 Nov 2022 21:36:13 +0800 Subject: [PATCH 2/4] Make Precision 0 when no GT but preds are made; -1 if no preds --- src/torchmetrics/detection/mean_ap.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 69cc0edf574..049e51e9d12 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -827,21 +827,25 @@ def __calculate_recall_precision_scores( return recall, precision, scores det_scores = torch.cat([e["dtScores"][:max_det] for e in img_eval_cls_bbox]) - # different sorting method generates slightly different results. # mergesort is used to be consistent as Matlab implementation. # Sort in PyTorch does not support bool types on CUDA (yet, 1.11.0) dtype = torch.uint8 if det_scores.is_cuda and det_scores.dtype is torch.bool else det_scores.dtype # Explicitly cast to uint8 to avoid error for bool inputs on CUDA to argsort inds = torch.argsort(det_scores.to(dtype), descending=True) - det_scores_sorted = det_scores[inds] - - det_matches = torch.cat([e["dtMatches"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] det_ignore = torch.cat([e["dtIgnore"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] gt_ignore = torch.cat([e["gtIgnore"] for e in img_eval_cls_bbox]) + npig = torch.count_nonzero(gt_ignore == False) # noqa: E712 if npig == 0: + # If there are any predictions, make Precision 0; otherwise, -1. + npreds = torch.count_nonzero(det_ignore == False) # noqa: E712 + if npreds != 0: + precision[:, :, idx_cls, idx_bbox_area, idx_max_det_thrs] = 0.0 return recall, precision, scores + + det_scores_sorted = det_scores[inds] + det_matches = torch.cat([e["dtMatches"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] tps = torch.logical_and(det_matches, torch.logical_not(det_ignore)) fps = torch.logical_and(torch.logical_not(det_matches), torch.logical_not(det_ignore)) From 8a12f5f6868dbf7361ee5f55f4cd0913ad8c8240 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 27 Nov 2022 13:40:06 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/detection/test_map.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index c9a2bd8cdae..82c70640022 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -536,11 +536,10 @@ def test_missing_gt(): @pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") def test_class_metrics_with_missing_gt(): - """ - Checks MAP for each class when there are 4 detections, each for a - different class. But there are targets for only 2 classes. Hence, MAP - should be lower than 1. MAP for classes with targets should be 1 and 0 for - the others. + """Checks MAP for each class when there are 4 detections, each for a different class. + + But there are targets for only 2 classes. Hence, MAP should be lower than 1. MAP for classes with targets should be + 1 and 0 for the others. """ # Example source: Issue https://github.com/Lightning-AI/metrics/issues/1184 preds = [ From d6f538e299e58142443d001c34e3222b78621821 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Feb 2023 14:17:21 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/detection/test_map.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 83b44badcd8..83153f2d47f 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -566,8 +566,8 @@ def test_class_metrics_with_missing_gt(): """ # Example source: Issue https://github.com/Lightning-AI/metrics/issues/1184 preds = [ - dict( - boxes=torch.Tensor( + { + "boxes": torch.Tensor( [ [0, 0, 20, 20], [30, 30, 50, 50], @@ -575,16 +575,16 @@ def test_class_metrics_with_missing_gt(): [100, 100, 120, 120], # FP ] ), - scores=torch.Tensor([0.6, 0.6, 0.6, 0.6]), - labels=torch.IntTensor([0, 1, 2, 3]), - ) + "scores": torch.Tensor([0.6, 0.6, 0.6, 0.6]), + "labels": torch.IntTensor([0, 1, 2, 3]), + } ] targets = [ - dict( - boxes=torch.Tensor([[0, 0, 20, 20], [30, 30, 50, 50]]), - labels=torch.IntTensor([0, 1]), - ) + { + "boxes": torch.Tensor([[0, 0, 20, 20], [30, 30, 50, 50]]), + "labels": torch.IntTensor([0, 1]), + } ] metric = MeanAveragePrecision(class_metrics=True)