Skip to content

Commit

Permalink
update code (#995)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Apr 29, 2022
1 parent 3a141ae commit 9011ec9
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -76,6 +76,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed `CalibrationError` to work on logit input ([#985](https://github.com/PyTorchLightning/metrics/pull/985))


- Fixed MAP metric when using custom list of thresholds ([#995](https://github.com/PyTorchLightning/metrics/issues/995))


## [0.8.0] - 2022-04-14

### Added
Expand Down
13 changes: 13 additions & 0 deletions tests/detection/test_map.py
Expand Up @@ -313,6 +313,19 @@ def test_map_gpu(inputs):
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
@pytest.mark.skipif(_gpu_test_condition, reason="test requires CUDA availability")
def test_map_with_custom_thresholds():
"""Test that map works with custom iou thresholds."""
metric = MeanAveragePrecision(iou_thresholds=[0.1, 0.2])
metric = metric.to("cuda")
for preds, targets in zip(_inputs.preds, _inputs.target):
metric.update(_move_to_gpu(preds), _move_to_gpu(targets))
res = metric.compute()
assert res["map_50"].item() == -1
assert res["map_75"].item() == -1


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
def test_empty_metric():
"""Test empty metric."""
Expand Down
14 changes: 10 additions & 4 deletions torchmetrics/detection/mean_ap.py
Expand Up @@ -653,8 +653,14 @@ def _summarize_results(self, precisions: Tensor, recalls: Tensor) -> Tuple[MAPMe
map_metrics = MAPMetricResults()
map_metrics.map = self._summarize(results, True)
last_max_det_thr = self.max_detection_thresholds[-1]
map_metrics.map_50 = self._summarize(results, True, iou_threshold=0.5, max_dets=last_max_det_thr)
map_metrics.map_75 = self._summarize(results, True, iou_threshold=0.75, max_dets=last_max_det_thr)
if 0.5 in self.iou_thresholds:
map_metrics.map_50 = self._summarize(results, True, iou_threshold=0.5, max_dets=last_max_det_thr)
else:
map_metrics.map_50 = torch.tensor([-1])
if 0.75 in self.iou_thresholds:
map_metrics.map_75 = self._summarize(results, True, iou_threshold=0.75, max_dets=last_max_det_thr)
else:
map_metrics.map_75 = torch.tensor([-1])
map_metrics.map_small = self._summarize(results, True, area_range="small", max_dets=last_max_det_thr)
map_metrics.map_medium = self._summarize(results, True, area_range="medium", max_dets=last_max_det_thr)
map_metrics.map_large = self._summarize(results, True, area_range="large", max_dets=last_max_det_thr)
Expand Down Expand Up @@ -750,8 +756,6 @@ def compute(self) -> dict:
dict containing
- map: ``torch.Tensor``
- map_50: ``torch.Tensor``
- map_75: ``torch.Tensor``
- map_small: ``torch.Tensor``
- map_medium: ``torch.Tensor``
- map_large: ``torch.Tensor``
Expand All @@ -761,6 +765,8 @@ def compute(self) -> dict:
- mar_small: ``torch.Tensor``
- mar_medium: ``torch.Tensor``
- mar_large: ``torch.Tensor``
- map_50: ``torch.Tensor`` (-1 if 0.5 not in the list of iou thresholds)
- map_75: ``torch.Tensor`` (-1 if 0.75 not in the list of iou thresholds)
- map_per_class: ``torch.Tensor`` (-1 if class metrics are disabled)
- mar_100_per_class: ``torch.Tensor`` (-1 if class metrics are disabled)
"""
Expand Down

0 comments on commit 9011ec9

Please sign in to comment.