Skip to content

Commit

Permalink
[Segmentation] Add mean IoU (#1236)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
6 people committed Apr 23, 2024
1 parent ec2c246 commit af32fd0
Show file tree
Hide file tree
Showing 8 changed files with 391 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for calculating segmentation quality and recognition quality in `PanopticQuality` metric ([#2381](https://github.com/Lightning-AI/torchmetrics/pull/2381))


- Added a new segmentation metric `MeanIoU` ([#1236](https://github.com/PyTorchLightning/metrics/pull/1236))


- Added `pretty-errors` for improving error prints ([#2431](https://github.com/Lightning-AI/torchmetrics/pull/2431))


Expand Down
19 changes: 19 additions & 0 deletions docs/source/segmentation/mean_iou.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
.. customcarditem::
:header: Mean Intersection over Union (mIoU)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/object_detection.svg
:tags: segmentation

###################################
Mean Intersection over Union (mIoU)
###################################

Module Interface
________________

.. autoclass:: torchmetrics.segmentation.MeanIoU
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.segmentation.mean_iou
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score
from torchmetrics.functional.segmentation.mean_iou import mean_iou

__all__ = ["generalized_dice_score"]
__all__ = ["generalized_dice_score", "mean_iou"]
109 changes: 109 additions & 0 deletions src/torchmetrics/functional/segmentation/mean_iou.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch
from torch import Tensor

from torchmetrics.functional.segmentation.utils import _ignore_background
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import _safe_divide


def _mean_iou_validate_args(
num_classes: int,
include_background: bool,
per_class: bool,
) -> None:
"""Validate the arguments of the metric."""
if num_classes <= 0:
raise ValueError(f"Expected argument `num_classes` must be a positive integer, but got {num_classes}.")
if not isinstance(include_background, bool):
raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.")
if not isinstance(per_class, bool):
raise ValueError(f"Expected argument `per_class` must be a boolean, but got {per_class}.")


def _mean_iou_update(
preds: Tensor,
target: Tensor,
num_classes: int,
include_background: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Update the intersection and union counts for the mean IoU computation."""
_check_same_shape(preds, target)

if (preds.bool() != preds).any(): # preds is an index tensor
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
if (target.bool() != target).any(): # target is an index tensor
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)

if not include_background:
preds, target = _ignore_background(preds, target)

reduce_axis = list(range(2, preds.ndim))
intersection = torch.sum(preds & target, dim=reduce_axis)
target_sum = torch.sum(target, dim=reduce_axis)
pred_sum = torch.sum(preds, dim=reduce_axis)
union = target_sum + pred_sum - intersection
return intersection, union


def _mean_iou_compute(
intersection: Tensor,
union: Tensor,
per_class: bool = False,
) -> Tensor:
"""Compute the mean IoU metric."""
val = _safe_divide(intersection, union)
return val if per_class else torch.mean(val, 1)


def mean_iou(
preds: Tensor,
target: Tensor,
num_classes: int,
include_background: bool = True,
per_class: bool = False,
) -> Tensor:
"""Calculates the mean Intersection over Union (mIoU) for semantic segmentation.
Args:
preds: Predictions from model
target: Ground truth values
num_classes: Number of classes
include_background: Whether to include the background class in the computation
per_class: Whether to compute the IoU for each class separately, else average over all classes
Returns:
The mean IoU score
Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.functional.segmentation import mean_iou
>>> preds = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction
>>> target = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target
>>> mean_iou(preds, target, num_classes=5)
tensor([0.3193, 0.3305, 0.3382, 0.3246])
>>> mean_iou(preds, target, num_classes=5, per_class=True)
tensor([[0.3093, 0.3500, 0.3081, 0.3389, 0.2903],
[0.2963, 0.3316, 0.3505, 0.2804, 0.3936],
[0.3724, 0.3249, 0.3660, 0.3184, 0.3093],
[0.3085, 0.3267, 0.3155, 0.3575, 0.3147]])
"""
_mean_iou_validate_args(num_classes, include_background, per_class)
intersection, union = _mean_iou_update(preds, target, num_classes, include_background)
return _mean_iou_compute(intersection, union, per_class=per_class)
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


def _ignore_background(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""Ignore the background class in the computation."""
"""Ignore the background class in the computation assuming it is the first, index 0."""
preds = preds[:, 1:] if preds.shape[1] > 1 else preds
target = target[:, 1:] if target.shape[1] > 1 else target
return preds, target
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright The PyTorch Lightning team.
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore
from torchmetrics.segmentation.mean_iou import MeanIoU

__all__ = ["GeneralizedDiceScore"]
__all__ = ["GeneralizedDiceScore", "MeanIoU"]
157 changes: 157 additions & 0 deletions src/torchmetrics/segmentation/mean_iou.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor

from torchmetrics.functional.segmentation.mean_iou import _mean_iou_compute, _mean_iou_update, _mean_iou_validate_args
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["MeanIoU.plot"]


class MeanIoU(Metric):
"""Computes Mean Intersection over Union (mIoU) for semantic segmentation.
The metric is defined by the overlap between the predicted segmentation and the ground truth, divided by the
total area covered by the union of the two. The metric can be computed for each class separately or for all
classes at once. The metric is optimal at a value of 1 and worst at a value of 0.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being
the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)``
can be provided, where the integer values correspond to the class index. That format will be automatically
converted to a one-hot tensor.
- ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being
the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)``
can be provided, where the integer values correspond to the class index. That format will be automatically
converted to a one-hot tensor.
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``miou`` (:class:`~torch.Tensor`): The mean Intersection over Union (mIoU) score. If ``per_class`` is set to
``True``, the output will be a tensor of shape ``(C,)`` with the IoU score for each class. If ``per_class`` is
set to ``False``, the output will be a scalar tensor.
Args:
num_classes: The number of classes in the segmentation problem.
include_background: Whether to include the background class in the computation
per_class: Whether to compute the IoU for each class separately. If set to ``False``, the metric will
compute the mean IoU over all classes.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``num_classes`` is not a positive integer
ValueError:
If ``include_background`` is not a boolean
ValueError:
If ``per_class`` is not a boolean
Example:
>>> import torch
>>> _ = torch.manual_seed(0)
>>> from torchmetrics.segmentation import MeanIoU
>>> miou = MeanIoU(num_classes=3)
>>> preds = torch.randint(0, 2, (10, 3, 128, 128))
>>> target = torch.randint(0, 2, (10, 3, 128, 128))
>>> miou(preds, target)
tensor(0.3318)
>>> miou = MeanIoU(num_classes=3, per_class=True)
>>> miou(preds, target)
tensor([0.3322, 0.3303, 0.3329])
>>> miou = MeanIoU(num_classes=3, per_class=True, include_background=False)
>>> miou(preds, target)
tensor([0.3303, 0.3329])
"""

score: Tensor
num_batches: Tensor
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def __init__(
self,
num_classes: int,
include_background: bool = True,
per_class: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
_mean_iou_validate_args(num_classes, include_background, per_class)
self.num_classes = num_classes
self.include_background = include_background
self.per_class = per_class

num_classes = num_classes - 1 if not include_background else num_classes
self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="mean")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update the state with the new data."""
intersection, union = _mean_iou_update(preds, target, self.num_classes, self.include_background)
score = _mean_iou_compute(intersection, union, per_class=self.per_class)
self.score += score.mean(0) if self.per_class else score.mean()

def compute(self) -> Tensor:
"""Update the state with the new data."""
return self.score # / self.num_batches

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality
>>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb')
>>> metric.update(torch.rand(8000), torch.rand(8000))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality
>>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb')
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(8000), torch.rand(8000)))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
Loading

0 comments on commit af32fd0

Please sign in to comment.