Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MAP metric for empty cases #624

Merged
merged 39 commits into from
Nov 25, 2021
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
13de2f7
add detection map code example
Oct 29, 2021
6940092
update setup
Borda Nov 1, 2021
fe4f1dc
simplify named tuple
Nov 1, 2021
276de8d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2021
6b8d711
Update tm_examples/detection_map.py
tkupek Nov 1, 2021
3b2f486
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2021
47dac62
Update tm_examples/detection_map.py
tkupek Nov 1, 2021
79a06bf
Update tm_examples/detection_map.py
tkupek Nov 1, 2021
ae201a0
Update tm_examples/detection_map.py
tkupek Nov 1, 2021
297c286
Update tm_examples/detection_map.py
tkupek Nov 1, 2021
473a2ee
Update tm_examples/detection_map.py
tkupek Nov 1, 2021
7865aa0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2021
e118db8
add some more comments
Nov 1, 2021
8e290b1
Merge branch 'master' into master
tkupek Nov 1, 2021
edec44f
add some more comments
Nov 1, 2021
ce90b25
add example hint in metric docstring
Nov 1, 2021
0cc8be9
Merge branch 'master' into master
mergify[bot] Nov 1, 2021
e628890
Merge branch 'PyTorchLightning:master' into master
tkupek Nov 8, 2021
c5b48cd
Merge branch 'PyTorchLightning:master' into master
tkupek Nov 10, 2021
0508f6c
Merge branch 'PyTorchLightning:master' into master
tkupek Nov 15, 2021
f9fe2c3
fix evaluation for empty metric
Nov 15, 2021
df2dfc2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 15, 2021
a42551a
Merge branch 'master' into master
tkupek Nov 18, 2021
9e49c58
Update torchmetrics/detection/map.py
tkupek Nov 18, 2021
608d258
Merge branch 'master' into master
Borda Nov 18, 2021
33f227c
Merge branch 'master' into master
Borda Nov 23, 2021
5c18985
Merge branch 'master' into master
SkafteNicki Nov 24, 2021
3f13aee
Merge branch 'master' into master
Borda Nov 24, 2021
5d74b51
fix deepsource stuff
Nov 24, 2021
89a9dd4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2021
ea2f5e5
Merge branch 'master' into master
tkupek Nov 24, 2021
fb75aff
update changelog
Nov 24, 2021
70c71c9
fix ddp issue in multi GPU setup for empty boxes
Nov 24, 2021
78d84ef
update doc
Nov 24, 2021
7660c6d
Merge branch 'master' into master
Borda Nov 24, 2021
c9c6b9a
Merge pull request #2 from tkupek/map-ddp-issue
tkupek Nov 24, 2021
053f487
simplify empty tensors fix
Nov 24, 2021
ee16652
fix mypy
Nov 24, 2021
ff78b44
fix failing unittests
Nov 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fix empty predictions in MAP metric ([#594](https://github.com/PyTorchLightning/metrics/pull/594))
- Fix empty predictions in MAP metric ([#594](https://github.com/PyTorchLightning/metrics/pull/594), [#624](https://github.com/PyTorchLightning/metrics/pull/624))


- Fix edge case of AUROC with `average=weighted` on GPU ([#606](https://github.com/PyTorchLightning/metrics/pull/606))
Expand Down
20 changes: 16 additions & 4 deletions tests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ class TestMAP(MetricTester):
@pytest.mark.parametrize("ddp", [False, True])
def test_map(self, ddp):
"""Test modular implementation for correctness."""

self.run_class_metric_test(
ddp=ddp,
preds=_inputs.preds,
Expand All @@ -198,7 +197,6 @@ def test_map(self, ddp):
@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
def test_error_on_wrong_init():
"""Test class raises the expected errors."""

MAP() # no error

with pytest.raises(ValueError, match="Expected argument `class_metrics` to be a boolean"):
Expand All @@ -208,7 +206,6 @@ def test_error_on_wrong_init():
@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
def test_empty_preds():
"""Test empty predictions."""

metric = MAP()

metric.update(
Expand All @@ -219,13 +216,28 @@ def test_empty_preds():
dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=torch.IntTensor([4])),
],
)

metric.update(
[
dict(boxes=torch.Tensor([]), scores=torch.Tensor([]), labels=torch.IntTensor([])),
],
[
dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=torch.IntTensor([4])),
],
)
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
def test_empty_metric():
"""Test empty metric."""
metric = MAP()
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
def test_error_on_wrong_input():
"""Test class input validation."""

metric = MAP()

metric.update([], []) # no error
Expand Down
20 changes: 12 additions & 8 deletions torchmetrics/detection/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore

def _input_validator(preds: List[Dict[str, torch.Tensor]], targets: List[Dict[str, torch.Tensor]]) -> None:
"""Ensure the correct input format of `preds` and `targets`"""

if not isinstance(preds, Sequence):
raise ValueError("Expected argument `preds` to be of type List")
if not isinstance(targets, Sequence):
Expand Down Expand Up @@ -325,7 +324,7 @@ def compute(self) -> dict:
if self.class_metrics:
map_per_class_list = []
mar_100_per_class_list = []
for class_id in torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist():
for class_id in self._get_classes():
coco_eval.params.catIds = [class_id]
with _hide_prints():
coco_eval.evaluate()
Expand Down Expand Up @@ -363,12 +362,14 @@ def _get_coco_format(

Format is defined at https://cocodataset.org/#format-data
"""

images = []
annotations = []
annotation_id = 1 # has to start with 1, otherwise COCOEval results are wrong

boxes = [box_convert(box, in_fmt="xyxy", out_fmt="xywh") if box.size(1) == 4 else box for box in boxes]
boxes = [
box_convert(box, in_fmt="xyxy", out_fmt="xywh") if box.size() == torch.Size([1, 4]) else box
for box in boxes
]
for image_id, (image_boxes, image_labels) in enumerate(zip(boxes, labels)):
image_boxes = image_boxes.cpu().tolist()
image_labels = image_labels.cpu().tolist()
Expand Down Expand Up @@ -405,8 +406,11 @@ def _get_coco_format(
annotations.append(annotation)
annotation_id += 1

classes = [
{"id": i, "name": str(i)}
for i in torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist()
]
classes = [{"id": i, "name": str(i)} for i in self._get_classes()]
return {"images": images, "annotations": annotations, "categories": classes}

def _get_classes(self) -> list:
"""Get list of unique classes depending on groundtruth_labels and detection_labels."""
if len(self.detection_labels) > 0 or len(self.groundtruth_labels) > 0:
return torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist()
return []