Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for SQ & RQ as well as per-class metrics #2381

Merged
merged 40 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3c4388e
Add support for SQ & RQ as well as per-class metrics
ChristophReich1996 Feb 14, 2024
bd62a9f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 14, 2024
967dab8
Fix RQ and SQ
ChristophReich1996 Feb 14, 2024
e1d507f
Merge remote-tracking branch 'origin/master'
ChristophReich1996 Feb 14, 2024
be1a4a0
Change return type and refactor flag name
ChristophReich1996 Feb 15, 2024
a7deffd
Fix typing
ChristophReich1996 Feb 15, 2024
d897c9a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2024
9be7aa4
Merge branch 'master' into master
SkafteNicki Feb 15, 2024
35f15d5
changelog
SkafteNicki Feb 15, 2024
cf8e0f5
input/output docstring
SkafteNicki Feb 15, 2024
e860aed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2024
88e548f
Merge branch 'master' into master
SkafteNicki Feb 15, 2024
4e2316e
Merge branch 'master' into master
SkafteNicki Feb 16, 2024
7f1859a
Merge branch 'master' into master
Borda Feb 27, 2024
677836e
Merge branch 'master' into master
SkafteNicki Mar 6, 2024
9b70dfd
guard against older versions
SkafteNicki Mar 6, 2024
f3d33c8
skip doctests on older versions
SkafteNicki Mar 6, 2024
43ec024
Merge branch 'master' into master
mergify[bot] Mar 7, 2024
ec12562
Merge branch 'master' into master
mergify[bot] Mar 7, 2024
9275430
Merge branch 'master' into master
mergify[bot] Mar 13, 2024
a15b706
Merge branch 'master' into master
Borda Mar 14, 2024
91ad069
Merge branch 'master' into master
mergify[bot] Mar 14, 2024
07524f2
ci/gpu: do not fail if HF cache is not present
Borda Mar 14, 2024
8d60f73
Merge branch 'master' into master
mergify[bot] Mar 14, 2024
ef0cc27
ci/gpu: do not update ref on PR
Borda Mar 14, 2024
9fb488d
build(deps): update fire requirement from <=0.5.0 to <=0.6.0 in /requ…
dependabot[bot] Mar 14, 2024
bf4d523
Merge branch 'master' into master
mergify[bot] Mar 14, 2024
0240110
build(deps): bump pytest-timeout from 2.2.0 to 2.3.1 in /requirements…
dependabot[bot] Mar 14, 2024
fd1eb60
Merge branch 'master' into master
mergify[bot] Mar 14, 2024
090e3ed
ci/mergify: rename label `ready`
Borda Mar 14, 2024
3b29931
Merge branch 'master' into master
mergify[bot] Mar 14, 2024
8c0b0cc
Merge branch 'master' into ChristophReich1996/master
Borda Mar 19, 2024
4b66e41
Merge branch 'master' into master
Borda Mar 22, 2024
0e04dbe
fix skipping on older versions
SkafteNicki Mar 24, 2024
699673d
change order of skipping
SkafteNicki Mar 24, 2024
78cbbf8
Fix return type documentation of inception score (#2467)
furkan-celik Mar 24, 2024
1d650bf
fixes
SkafteNicki Mar 24, 2024
be7b55b
Merge branch 'master' into master
mergify[bot] Mar 24, 2024
bc3c36c
Merge branch 'master' of https://github.com/ChristophReich1996/torchm…
SkafteNicki Mar 24, 2024
63419ce
Merge branch 'master' into master
mergify[bot] Mar 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `QualityWithNoReference` metric ([#2288](https://github.com/Lightning-AI/torchmetrics/pull/2288))


- Added support for calculating segmentation quality and recognition quality in `PanopticQuality` metric ([#2381](https://github.com/Lightning-AI/torchmetrics/pull/2381))


### Changed

- Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424))
Expand Down
9 changes: 9 additions & 0 deletions src/torchmetrics/detection/_deprecated.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from typing import Any, Collection

from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from torchmetrics.utilities.prints import _deprecated_root_import_class

if not _TORCH_GREATER_EQUAL_1_12:
__doctest_skip__ = [
"_PanopticQuality",
"_PanopticQuality.*",
"_ModifiedPanopticQuality",
"_ModifiedPanopticQuality.*",
]


class _ModifiedPanopticQuality(ModifiedPanopticQuality):
"""Wrapper for deprecated import.
Expand Down
83 changes: 80 additions & 3 deletions src/torchmetrics/detection/panoptic_qualities.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@
_validate_inputs,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_GREATER_EQUAL_1_12
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["PanopticQuality.plot", "ModifiedPanopticQuality.plot"]


if not _TORCH_GREATER_EQUAL_1_12:
__doctest_skip__ = ["PanopticQuality", "PanopticQuality.*", "ModifiedPanopticQuality", "ModifiedPanopticQuality.*"]


class PanopticQuality(Metric):
r"""Compute the `Panoptic Quality`_ for panoptic segmentations.

Expand All @@ -47,6 +51,23 @@ class PanopticQuality(Metric):
Points in the target tensor that do not map to a known category ID are automatically ignored in the metric
computation.

As input to ``forward`` and ``update`` the metric accepts the following input:

- ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(B, *spatial_dims, 2)``, where there needs to
be at least one spatial dimension.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(B, *spatial_dims, 2)``, where there needs to
be at least one spatial dimension.

As output to ``forward`` and ``compute`` the metric returns the following output:

- ``quality`` (:class:`~torch.Tensor`): If ``return_sq_and_rq=False`` and ``return_per_class=False`` then a
single scalar tensor is returned with average panoptic quality over all classes. If ``return_sq_and_rq=True``
and ``return_per_class=False`` a tensor of length 3 is returned with panoptic, segmentation and recognition
quality (in that order). If If ``return_sq_and_rq=False`` and ``return_per_class=True`` a tensor of length
equal to the number of classes are returned, with panoptic quality for each class. Finally, if both arguments
are ``True`` a tensor of shape ``(3, C)`` is returned with individual panoptic, segmentation and recognition
quality for each class.

Args:
things:
Set of ``category_id`` for countable things.
Expand All @@ -55,6 +76,10 @@ class PanopticQuality(Metric):
allow_unknown_preds_category:
Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric
computation or raise an exception when found.
return_sq_and_rq:
Boolean flag to specify if Segmentation Quality and Recognition Quality should be also returned.
return_per_class:
Boolean flag to specify if the per-class values should be returned or the class average.


Raises:
Expand All @@ -80,6 +105,40 @@ class PanopticQuality(Metric):
>>> panoptic_quality(preds, target)
tensor(0.5463, dtype=torch.float64)

You can also return the segmentation and recognition quality alognside the PQ
>>> from torch import tensor
>>> from torchmetrics.detection import PanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7}, return_sq_and_rq=True)
>>> panoptic_quality(preds, target)
tensor([0.5463, 0.6111, 0.6667], dtype=torch.float64)

You can also specify to return the per-class metrics
>>> from torch import tensor
>>> from torchmetrics.detection import PanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7}, return_per_class=True)
>>> panoptic_quality(preds, target)
tensor([[0.5185, 0.0000, 0.6667, 1.0000]], dtype=torch.float64)

"""

is_differentiable: bool = False
Expand All @@ -98,16 +157,22 @@ def __init__(
things: Collection[int],
stuffs: Collection[int],
allow_unknown_preds_category: bool = False,
return_sq_and_rq: bool = False,
return_per_class: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if not _TORCH_GREATER_EQUAL_1_12:
raise RuntimeError("Panoptic Quality metric requires PyTorch 1.12 or later")

things, stuffs = _parse_categories(things, stuffs)
self.things = things
self.stuffs = stuffs
self.void_color = _get_void_color(things, stuffs)
self.cat_id_to_continuous_id = _get_category_id_to_continuous_id(things, stuffs)
self.allow_unknown_preds_category = allow_unknown_preds_category
self.return_sq_and_rq = return_sq_and_rq
self.return_per_class = return_per_class

# per category intermediate metrics
num_categories = len(things) + len(stuffs)
Expand Down Expand Up @@ -154,7 +219,16 @@ def update(self, preds: Tensor, target: Tensor) -> None:

def compute(self) -> Tensor:
"""Compute panoptic quality based on inputs passed in to ``update`` previously."""
return _panoptic_quality_compute(self.iou_sum, self.true_positives, self.false_positives, self.false_negatives)
pq, sq, rq, pq_avg, sq_avg, rq_avg = _panoptic_quality_compute(
self.iou_sum, self.true_positives, self.false_positives, self.false_negatives
)
if self.return_per_class:
if self.return_sq_and_rq:
return torch.stack((pq, sq, rq), dim=-1)
return pq.view(1, -1)
if self.return_sq_and_rq:
return torch.stack((pq_avg, sq_avg, rq_avg), dim=0)
return pq_avg

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down Expand Up @@ -337,7 +411,10 @@ def update(self, preds: Tensor, target: Tensor) -> None:

def compute(self) -> Tensor:
"""Compute panoptic quality based on inputs passed in to ``update`` previously."""
return _panoptic_quality_compute(self.iou_sum, self.true_positives, self.false_positives, self.false_negatives)
_, _, _, pq_avg, _, _ = _panoptic_quality_compute(
self.iou_sum, self.true_positives, self.false_positives, self.false_negatives
)
return pq_avg

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/functional/detection/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
from torch import Tensor

from torchmetrics.functional.detection.panoptic_qualities import modified_panoptic_quality, panoptic_quality
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from torchmetrics.utilities.prints import _deprecated_root_import_func

if not _TORCH_GREATER_EQUAL_1_12:
__doctest_skip__ = ["_panoptic_quality", "_modified_panoptic_quality"]


def _modified_panoptic_quality(
preds: Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def _panoptic_quality_compute(
true_positives: Tensor,
false_positives: Tensor,
false_negatives: Tensor,
) -> Tensor:
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Compute the final panoptic quality from interim values.

Args:
Expand All @@ -459,11 +459,17 @@ def _panoptic_quality_compute(
false_negatives: the FN value from the update step

Returns:
Panoptic quality as a tensor containing a single scalar.
A tuple containing the per-class panoptic, segmentation and recognition quality followed by the averages

"""
# per category calculation
denominator = (true_positives + 0.5 * false_positives + 0.5 * false_negatives).double()
panoptic_quality = torch.where(denominator > 0.0, iou_sum / denominator, 0.0)
# Reduce across categories. TODO: is it useful to have the option of returning per class metrics?
return torch.mean(panoptic_quality[denominator > 0])
# compute segmentation and recognition quality (per-class)
sq: Tensor = torch.where(true_positives > 0.0, iou_sum / true_positives, 0.0)
denominator: Tensor = true_positives + 0.5 * false_positives + 0.5 * false_negatives
rq: Tensor = torch.where(denominator > 0.0, true_positives / denominator, 0.0)
# compute per-class panoptic quality
pq: Tensor = sq * rq
# compute averages
pq_avg: Tensor = torch.mean(pq[denominator > 0])
sq_avg: Tensor = torch.mean(sq[denominator > 0])
rq_avg: Tensor = torch.mean(rq[denominator > 0])
return pq, sq, rq, pq_avg, sq_avg, rq_avg
80 changes: 78 additions & 2 deletions src/torchmetrics/functional/detection/panoptic_qualities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Collection

import torch
from torch import Tensor

from torchmetrics.functional.detection._panoptic_quality_common import (
Expand All @@ -24,6 +25,10 @@
_prepocess_inputs,
_validate_inputs,
)
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12

if not _TORCH_GREATER_EQUAL_1_12:
__doctest_skip__ = ["panoptic_quality", "modified_panoptic_quality"]


def panoptic_quality(
Expand All @@ -32,6 +37,8 @@ def panoptic_quality(
things: Collection[int],
stuffs: Collection[int],
allow_unknown_preds_category: bool = False,
return_sq_and_rq: bool = False,
return_per_class: bool = False,
) -> Tensor:
r"""Compute `Panoptic Quality`_ for panoptic segmentations.

Expand Down Expand Up @@ -61,6 +68,10 @@ def panoptic_quality(
allow_unknown_preds_category:
Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric
computation or raise an exception when found.
return_sq_and_rq:
Boolean flag to specify if Segmentation Quality and Recognition Quality should be also returned.
return_per_class:
Boolean flag to specify if the per-class values should be returned or the class average.

Raises:
ValueError:
Expand Down Expand Up @@ -91,7 +102,59 @@ def panoptic_quality(
>>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7})
tensor(0.5463, dtype=torch.float64)

You can also return the segmentation and recognition quality alognside the PQ
>>> from torch import tensor
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7}, return_sq_and_rq=True)
tensor([0.5463, 0.6111, 0.6667], dtype=torch.float64)

You can also specify to return the per-class metrics
>>> from torch import tensor
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7}, return_per_class=True)
tensor([[0.5185, 0.0000, 0.6667, 1.0000]], dtype=torch.float64)

You can also specify to return the per-class metrics and the segmentation and recognition quality
>>> from torch import tensor
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7},
... return_per_class=True, return_sq_and_rq=True)
tensor([[0.5185, 0.7778, 0.6667],
[0.0000, 0.0000, 0.0000],
[0.6667, 0.6667, 1.0000],
[1.0000, 1.0000, 1.0000]], dtype=torch.float64)

"""
if not _TORCH_GREATER_EQUAL_1_12:
raise RuntimeError("Panoptic Quality metric requires PyTorch 1.12 or later")

things, stuffs = _parse_categories(things, stuffs)
_validate_inputs(preds, target)
void_color = _get_void_color(things, stuffs)
Expand All @@ -101,7 +164,19 @@ def panoptic_quality(
iou_sum, true_positives, false_positives, false_negatives = _panoptic_quality_update(
flatten_preds, flatten_target, cat_id_to_continuous_id, void_color
)
return _panoptic_quality_compute(iou_sum, true_positives, false_positives, false_negatives)
pq, sq, rq, pq_avg, sq_avg, rq_avg = _panoptic_quality_compute(
iou_sum,
true_positives,
false_positives,
false_negatives,
)
if return_per_class:
if return_sq_and_rq:
return torch.stack((pq, sq, rq), dim=-1)
return pq.view(1, -1)
if return_sq_and_rq:
return torch.stack((pq_avg, sq_avg, rq_avg), dim=0)
return pq_avg


def modified_panoptic_quality(
Expand Down Expand Up @@ -177,4 +252,5 @@ def modified_panoptic_quality(
void_color,
modified_metric_stuffs=stuffs,
)
return _panoptic_quality_compute(iou_sum, true_positives, false_positives, false_negatives)
_, _, _, pq_avg, _, _ = _panoptic_quality_compute(iou_sum, true_positives, false_positives, false_negatives)
return pq_avg
6 changes: 6 additions & 0 deletions tests/unittests/detection/test_modified_panoptic_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
from torchmetrics.detection import ModifiedPanopticQuality
from torchmetrics.functional.detection import modified_panoptic_quality
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12

from unittests import _Input
from unittests._helpers import seed_all
Expand Down Expand Up @@ -76,6 +77,7 @@ def _reference_fn_1_2(preds, target) -> np.ndarray:
return np.array([23 / 30])


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12")
class TestModifiedPanopticQuality(MetricTester):
"""Test class for `ModifiedPanopticQuality` metric."""

Expand Down Expand Up @@ -111,6 +113,7 @@ def test_panoptic_quality_functional(self):
)


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12")
def test_empty_metric():
"""Test empty metric."""
with pytest.raises(ValueError, match="At least one of `things` and `stuffs` must be non-empty"):
Expand All @@ -120,6 +123,7 @@ def test_empty_metric():
assert torch.isnan(metric.compute())


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12")
def test_error_on_wrong_input():
"""Test class input validation."""
with pytest.raises(TypeError, match="Expected argument `stuffs` to contain `int` categories.*"):
Expand Down Expand Up @@ -162,6 +166,7 @@ def test_error_on_wrong_input():
metric.update(preds, preds)


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12")
def test_extreme_values():
"""Test that the metric returns expected values in trivial cases."""
# Exact match between preds and target => metric is 1
Expand All @@ -170,6 +175,7 @@ def test_extreme_values():
assert modified_panoptic_quality(_INPUTS_0.target[0], _INPUTS_0.target[0] + 1, **_ARGS_0) == 0.0


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12")
@pytest.mark.parametrize(
("inputs", "args", "cat_dim"),
[
Expand Down
Loading
Loading