diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a35b5b78a0..1eee142a1a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,7 +36,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- Fix empty predictions in MAP metric ([#594](https://github.com/PyTorchLightning/metrics/pull/594)) +- Fix empty predictions in MAP metric ([#594](https://github.com/PyTorchLightning/metrics/pull/594), [#624](https://github.com/PyTorchLightning/metrics/pull/624)) - Fix edge case of AUROC with `average=weighted` on GPU ([#606](https://github.com/PyTorchLightning/metrics/pull/606)) diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index 0607385a397..0e6eb22bf5a 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -181,7 +181,6 @@ class TestMAP(MetricTester): @pytest.mark.parametrize("ddp", [False, True]) def test_map(self, ddp): """Test modular implementation for correctness.""" - self.run_class_metric_test( ddp=ddp, preds=_inputs.preds, @@ -198,7 +197,6 @@ def test_map(self, ddp): @pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") def test_error_on_wrong_init(): """Test class raises the expected errors.""" - MAP() # no error with pytest.raises(ValueError, match="Expected argument `class_metrics` to be a boolean"): @@ -208,7 +206,6 @@ def test_error_on_wrong_init(): @pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") def test_empty_preds(): """Test empty predictions.""" - metric = MAP() metric.update( @@ -219,13 +216,28 @@ def test_empty_preds(): dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=torch.IntTensor([4])), ], ) + + metric.update( + [ + dict(boxes=torch.Tensor([]), scores=torch.Tensor([]), labels=torch.IntTensor([])), + ], + [ + dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=torch.IntTensor([4])), + ], + ) + metric.compute() + + +@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") +def test_empty_metric(): + """Test empty metric.""" + metric = MAP() metric.compute() @pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") def test_error_on_wrong_input(): """Test class input validation.""" - metric = MAP() metric.update([], []) # no error diff --git a/torchmetrics/detection/map.py b/torchmetrics/detection/map.py index eb4786d5eb7..24c1ba56a81 100644 --- a/torchmetrics/detection/map.py +++ b/torchmetrics/detection/map.py @@ -97,7 +97,6 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore def _input_validator(preds: List[Dict[str, torch.Tensor]], targets: List[Dict[str, torch.Tensor]]) -> None: """Ensure the correct input format of `preds` and `targets`""" - if not isinstance(preds, Sequence): raise ValueError("Expected argument `preds` to be of type List") if not isinstance(targets, Sequence): @@ -139,6 +138,13 @@ def _input_validator(preds: List[Dict[str, torch.Tensor]], targets: List[Dict[st ) +def _fix_empty_tensors(boxes: torch.Tensor) -> torch.Tensor: + """Empty tensors can cause problems in DDP mode, this methods corrects them.""" + if boxes.numel() == 0 and boxes.ndim == 1: + return boxes.unsqueeze(0) + return boxes + + class MAP(Metric): r""" Computes the `Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR)\ @@ -273,12 +279,12 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] _input_validator(preds, target) for item in preds: - self.detection_boxes.append(item["boxes"]) + self.detection_boxes.append(_fix_empty_tensors(item["boxes"])) self.detection_scores.append(item["scores"]) self.detection_labels.append(item["labels"]) for item in target: - self.groundtruth_boxes.append(item["boxes"]) + self.groundtruth_boxes.append(_fix_empty_tensors(item["boxes"])) self.groundtruth_labels.append(item["labels"]) def compute(self) -> dict: @@ -325,7 +331,7 @@ def compute(self) -> dict: if self.class_metrics: map_per_class_list = [] mar_100_per_class_list = [] - for class_id in torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist(): + for class_id in self._get_classes(): coco_eval.params.catIds = [class_id] with _hide_prints(): coco_eval.evaluate() @@ -363,12 +369,14 @@ def _get_coco_format( Format is defined at https://cocodataset.org/#format-data """ - images = [] annotations = [] annotation_id = 1 # has to start with 1, otherwise COCOEval results are wrong - boxes = [box_convert(box, in_fmt="xyxy", out_fmt="xywh") if box.size(1) == 4 else box for box in boxes] + boxes = [ + box_convert(box, in_fmt="xyxy", out_fmt="xywh") if box.ndim > 1 and box.size(1) == 4 else box + for box in boxes + ] for image_id, (image_boxes, image_labels) in enumerate(zip(boxes, labels)): image_boxes = image_boxes.cpu().tolist() image_labels = image_labels.cpu().tolist() @@ -405,8 +413,11 @@ def _get_coco_format( annotations.append(annotation) annotation_id += 1 - classes = [ - {"id": i, "name": str(i)} - for i in torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist() - ] + classes = [{"id": i, "name": str(i)} for i in self._get_classes()] return {"images": images, "annotations": annotations, "categories": classes} + + def _get_classes(self) -> list: + """Get list of unique classes depending on groundtruth_labels and detection_labels.""" + if len(self.detection_labels) > 0 or len(self.groundtruth_labels) > 0: + return torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist() + return []