Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAP metric - fix empty predictions #594

Merged
merged 8 commits into from Oct 29, 2021
8 changes: 5 additions & 3 deletions CHANGELOG.md
Expand Up @@ -23,16 +23,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fix empty predictions in MAP metric ([#594](https://github.com/PyTorchLightning/metrics/pull/594))


## [0.6.0] - 2021-10-28

### Added

- Added audio metrics:
- Perceptual Evaluation of Speech Quality (PESQ) ([#353](https://github.com/PyTorchLightning/metrics/issues/353))
- Short Term Objective Intelligibility (STOI) ([#353](https://github.com/PyTorchLightning/metrics/issues/353))
- Perceptual Evaluation of Speech Quality (PESQ) ([#353](https://github.com/PyTorchLightning/metrics/pull/353))
- Short Term Objective Intelligibility (STOI) ([#353](https://github.com/PyTorchLightning/metrics/pull/353))
- Added Information retrieval metrics:
- `RetrievalRPrecision` ([#577](https://github.com/PyTorchLightning/metrics/pull/577/))
- `RetrievalRPrecision` ([#577](https://github.com/PyTorchLightning/metrics/pull/577))
- `RetrievalHitRate` ([#576](https://github.com/PyTorchLightning/metrics/pull/576))
- Added NLP metrics:
- `SacreBLEUScore` ([#546](https://github.com/PyTorchLightning/metrics/pull/546))
Expand Down
17 changes: 17 additions & 0 deletions tests/detection/test_map.py
Expand Up @@ -206,6 +206,23 @@ def test_error_on_wrong_init():
MAP(class_metrics=0)


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
def test_empty_preds():
"""Test empty predictions."""

metric = MAP()

metric.update(
[
dict(boxes=torch.Tensor([[]]), scores=torch.Tensor([]), labels=torch.IntTensor([])),
],
[
dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=torch.IntTensor([4])),
],
)
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
def test_error_on_wrong_input():
"""Test class input validation."""
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/detection/map.py
Expand Up @@ -365,7 +365,7 @@ def _get_coco_format(
annotations = []
annotation_id = 1 # has to start with 1, otherwise COCOEval results are wrong

boxes = [box_convert(box, in_fmt="xyxy", out_fmt="xywh") for box in boxes]
boxes = [box_convert(box, in_fmt="xyxy", out_fmt="xywh") if boxes[0].size(1) == 4 else box for box in boxes]
for image_id, (image_boxes, image_labels) in enumerate(zip(boxes, labels)):
image_boxes = image_boxes.cpu().tolist()
image_labels = image_labels.cpu().tolist()
Expand Down