From 745c471357875384ada1f059101798f75f65c553 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Louren=C3=A7o=20Silva?= Date: Tue, 23 Apr 2024 12:53:34 +0100 Subject: [PATCH] [Segmentation] Added generalized dice score metric (#1090) * Adding generalized dice score metric * Apply suggestions from code review --------- Co-authored-by: Nicki Skafte Detlefsen Co-authored-by: Jirka Borovec Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Jirka Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- CHANGELOG.md | 4 +- README.md | 1 + docs/source/index.rst | 8 + docs/source/links.rst | 1 + docs/source/segmentation/generalized_dice.rst | 22 +++ .../functional/classification/__init__.py | 1 + .../functional/segmentation/__init__.py | 4 + .../segmentation/generalized_dice.py | 138 +++++++++++++ .../functional/segmentation/utils.py | 7 + src/torchmetrics/segmentation/__init__.py | 17 ++ .../segmentation/generalized_dice.py | 184 ++++++++++++++++++ .../test_generalized_dice_score.py | 89 +++++++++ 12 files changed, 475 insertions(+), 1 deletion(-) create mode 100644 docs/source/segmentation/generalized_dice.rst create mode 100644 src/torchmetrics/functional/segmentation/generalized_dice.py create mode 100644 src/torchmetrics/segmentation/__init__.py create mode 100644 src/torchmetrics/segmentation/generalized_dice.py create mode 100644 tests/unittests/segmentation/test_generalized_dice_score.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ceaa5cbc95..66ebcc8803f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `GeneralizedDiceScore` to segmentation package ([#1090](https://github.com/Lightning-AI/metrics/pull/1090)) + + - Added `SensitivityAtSpecificity` metric to classification subpackage ([#2217](https://github.com/Lightning-AI/torchmetrics/pull/2217)) @@ -34,7 +37,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Deprecated -- ### Fixed diff --git a/README.md b/README.md index 2144d89e010..82e370c91f5 100644 --- a/README.md +++ b/README.md @@ -288,6 +288,7 @@ covers the following domains: - Multimodal (Image-Text) - Nominal - Regression +- Segmentation - Text Each domain may require some additional dependencies which can be installed with `pip install torchmetrics[audio]`, diff --git a/docs/source/index.rst b/docs/source/index.rst index a51a1184a78..880a6a2657e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -222,6 +222,14 @@ Or directly from conda retrieval/* +.. toctree:: + :maxdepth: 2 + :name: segmentation + :caption: Segmentation + :glob: + + segmentation/* + .. toctree:: :maxdepth: 2 :name: text diff --git a/docs/source/links.rst b/docs/source/links.rst index 7034f764d65..04b53797c61 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -170,3 +170,4 @@ .. _FLORES-200: https://arxiv.org/abs/2207.04672 .. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html .. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013 +.. _Generalized Dice Score: https://arxiv.org/abs/1707.03237 diff --git a/docs/source/segmentation/generalized_dice.rst b/docs/source/segmentation/generalized_dice.rst new file mode 100644 index 00000000000..5c48fc670d1 --- /dev/null +++ b/docs/source/segmentation/generalized_dice.rst @@ -0,0 +1,22 @@ +.. customcarditem:: + :header: Generalized Dice Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Classification + +.. include:: ../links.rst + +###################### +Generalized Dice Score +###################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.segmentation.GeneralizedDiceScore + :noindex: + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.segmentation.generalized_dice_score + :noindex: diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index e1f74af896d..faf523844bc 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -159,6 +159,7 @@ "confusion_matrix", "multiclass_confusion_matrix", "multilabel_confusion_matrix", + "generalized_dice_score", "dice", "exact_match", "multiclass_exact_match", diff --git a/src/torchmetrics/functional/segmentation/__init__.py b/src/torchmetrics/functional/segmentation/__init__.py index 94f1dec4a9f..eec2e4dfcf3 100644 --- a/src/torchmetrics/functional/segmentation/__init__.py +++ b/src/torchmetrics/functional/segmentation/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score + +__all__ = ["generalized_dice_score"] diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py new file mode 100644 index 00000000000..6b740bcea53 --- /dev/null +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -0,0 +1,138 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.segmentation.utils import _ignore_background +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.compute import _safe_divide + + +def _generalized_dice_validate_args( + num_classes: int, + include_background: bool, + per_class: bool, + weight_type: Literal["square", "simple", "linear"], +) -> None: + """Validate the arguments of the metric.""" + if num_classes <= 0: + raise ValueError(f"Expected argument `num_classes` must be a positive integer, but got {num_classes}.") + if not isinstance(include_background, bool): + raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.") + if not isinstance(per_class, bool): + raise ValueError(f"Expected argument `per_class` must be a boolean, but got {per_class}.") + if weight_type not in ["square", "simple", "linear"]: + raise ValueError( + f"Expected argument `weight_type` to be one of 'square', 'simple', 'linear', but got {weight_type}." + ) + + +def _generalized_dice_update( + preds: Tensor, + target: Tensor, + num_classes: int, + include_background: bool, + weight_type: Literal["square", "simple", "linear"] = "square", +) -> Tensor: + """Update the state with the current prediction and target.""" + _check_same_shape(preds, target) + if preds.ndim < 3: + raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.") + + if (preds.bool() != preds).any(): # preds is an index tensor + preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) + if (target.bool() != target).any(): # target is an index tensor + target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) + + if not include_background: + preds, target = _ignore_background(preds, target) + + reduce_axis = list(range(2, target.ndim)) + intersection = torch.sum(preds * target, dim=reduce_axis) + target_sum = torch.sum(target, dim=reduce_axis) + pred_sum = torch.sum(preds, dim=reduce_axis) + cardinality = target_sum + pred_sum + + if weight_type == "simple": + weights = 1.0 / target_sum + elif weight_type == "linear": + weights = torch.ones_like(target_sum) + elif weight_type == "square": + weights = 1.0 / (target_sum**2) + else: + raise ValueError( + f"Expected argument `weight_type` to be one of 'simple', 'linear', 'square', but got {weight_type}." + ) + + w_shape = weights.shape + weights_flatten = weights.flatten() + infs = torch.isinf(weights_flatten) + weights_flatten[infs] = 0 + w_max = torch.max(weights, 0).values.repeat(w_shape[0], 1).T.flatten() + weights_flatten[infs] = w_max[infs] + weights = weights_flatten.reshape(w_shape) + + numerator = 2.0 * intersection * weights + denominator = cardinality * weights + return numerator, denominator # type:ignore[return-value] + + +def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class: bool = True) -> Tensor: + """Compute the generalized dice score.""" + if not per_class: + numerator = torch.sum(numerator, 1) + denominator = torch.sum(denominator, 1) + return _safe_divide(numerator, denominator) + + +def generalized_dice_score( + preds: Tensor, + target: Tensor, + num_classes: int, + include_background: bool = True, + per_class: bool = False, + weight_type: Literal["square", "simple", "linear"] = "square", +) -> Tensor: + """Compute the Generalized Dice Score for semantic segmentation. + + Args: + preds: Predictions from model + target: Ground truth values + num_classes: Number of classes + include_background: Whether to include the background class in the computation + per_class: Whether to compute the IoU for each class separately, else average over all classes + weight_type: Type of weight factor to apply to the classes. One of ``"square"``, ``"simple"``, or ``"linear"`` + + Returns: + The Generalized Dice Score + + Example: + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics.functional.segmentation import generalized_dice_score + >>> preds = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> generalized_dice_score(preds, target, num_classes=5) + tensor([0.4830, 0.4935, 0.5044, 0.4880]) + >>> generalized_dice_score(preds, target, num_classes=5, per_class=True) + tensor([[0.4724, 0.5185, 0.4710, 0.5062, 0.4500], + [0.4571, 0.4980, 0.5191, 0.4380, 0.5649], + [0.5428, 0.4904, 0.5358, 0.4830, 0.4724], + [0.4715, 0.4925, 0.4797, 0.5267, 0.4788]]) + + """ + _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type) + numerator, denominator = _generalized_dice_update(preds, target, num_classes, include_background, weight_type) + return _generalized_dice_compute(numerator, denominator, per_class) diff --git a/src/torchmetrics/functional/segmentation/utils.py b/src/torchmetrics/functional/segmentation/utils.py index bbf5c48ded3..e8427a69326 100644 --- a/src/torchmetrics/functional/segmentation/utils.py +++ b/src/torchmetrics/functional/segmentation/utils.py @@ -24,6 +24,13 @@ from torchmetrics.utilities.imports import _SCIPY_AVAILABLE +def _ignore_background(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """Ignore the background class in the computation.""" + preds = preds[:, 1:] if preds.shape[1] > 1 else preds + target = target[:, 1:] if target.shape[1] > 1 else target + return preds, target + + def check_if_binarized(x: Tensor) -> None: """Check if the input is binarized. diff --git a/src/torchmetrics/segmentation/__init__.py b/src/torchmetrics/segmentation/__init__.py new file mode 100644 index 00000000000..24275594e4c --- /dev/null +++ b/src/torchmetrics/segmentation/__init__.py @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore + +__all__ = ["GeneralizedDiceScore"] diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py new file mode 100644 index 00000000000..646ba63fbcf --- /dev/null +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -0,0 +1,184 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Sequence, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.segmentation.generalized_dice import ( + _generalized_dice_compute, + _generalized_dice_update, + _generalized_dice_validate_args, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["GeneralizedDiceScore.plot"] + + +class GeneralizedDiceScore(Metric): + r"""Compute `Generalized Dice Score`_. + + The metric can be used to evaluate the performance of image segmentation models. The Generalized Dice Score is + defined as: + + .. math:: + GDS = \frac{2 \\sum_{i=1}^{N} w_i \\sum_{j} t_{ij} p_{ij}}{ + \\sum_{i=1}^{N} w_i \\sum_{j} t_{ij} + \\sum_{i=1}^{N} w_i \\sum_{j} p_{ij}} + + where :math:`N` is the number of classes, :math:`t_{ij}` is the target tensor, :math:`p_{ij}` is the prediction + tensor, and :math:`w_i` is the weight for class :math:`i`. The weight can be computed in three different ways: + + - `square`: :math:`w_i = 1 / (\\sum_{j} t_{ij})^2` + - `simple`: :math:`w_i = 1 / \\sum_{j} t_{ij}` + - `linear`: :math:`w_i = 1` + + Note that the generalized dice loss can be computed as one minus the generalized dice score. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` + can be provided, where the integer values correspond to the class index. That format will be automatically + converted to a one-hot tensor. + - ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` + can be provided, where the integer values correspond to the class index. That format will be automatically + converted to a one-hot tensor. + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``gds`` (:class:`~torch.Tensor`): The generalized dice score. If ``per_class`` is set to ``True``, the output + will be a tensor of shape ``(C,)`` with the generalized dice score for each class. If ``per_class`` is + set to ``False``, the output will be a scalar tensor. + + Args: + num_classes: The number of classes in the segmentation problem. + include_background: Whether to include the background class in the computation + per_class: Whether to compute the metric for each class separately. + weight_type: The type of weight to apply to each class. Can be one of ``"square"``, ``"simple"``, or + ``"linear"``. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ValueError: + If ``num_classes`` is not a positive integer + ValueError: + If ``include_background`` is not a boolean + ValueError: + If ``per_class`` is not a boolean + ValueError: + If ``weight_type`` is not one of ``"square"``, ``"simple"``, or ``"linear"`` + + Example: + >>> import torch + >>> _ = torch.manual_seed(0) + >>> from torchmetrics.segmentation import GeneralizedDiceScore + >>> gds = GeneralizedDiceScore(num_classes=3) + >>> preds = torch.randint(0, 2, (10, 3, 128, 128)) + >>> target = torch.randint(0, 2, (10, 3, 128, 128)) + >>> gds(preds, target) + tensor(0.4983) + >>> gds = GeneralizedDiceScore(num_classes=3, per_class=True) + >>> gds(preds, target) + tensor([0.4987, 0.4966, 0.4995]) + >>> gds = GeneralizedDiceScore(num_classes=3, per_class=True, include_background=False) + >>> gds(preds, target) + tensor([0.4966, 0.4995]) + + """ + + score: Tensor + samples: Tensor + full_state_update: bool = False + is_differentiable: bool = False + higher_is_better: bool = True + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + def __init__( + self, + num_classes: int, + include_background: bool = True, + per_class: bool = False, + weight_type: Literal["square", "simple", "linear"] = "square", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type) + self.num_classes = num_classes + self.include_background = include_background + self.per_class = per_class + self.weight_type = weight_type + + num_classes = num_classes - 1 if not include_background else num_classes + self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum") + self.add_state("samples", default=torch.zeros(1), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update the state with new data.""" + numerator, denominator = _generalized_dice_update( + preds, target, self.num_classes, self.include_background, self.weight_type + ) + self.score += _generalized_dice_compute(numerator, denominator, self.per_class).sum(dim=0) + self.samples += preds.shape[0] + + def compute(self) -> Tensor: + """Compute the final generalized dice score.""" + return self.score / self.samples + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.segmentation import GeneralizedDiceScore + >>> metric = GeneralizedDiceScore(num_classes=3) + >>> metric.update(torch.randint(0, 2, (10, 3, 128, 128)), torch.randint(0, 2, (10, 3, 128, 128))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.segmentation import GeneralizedDiceScore + >>> metric = GeneralizedDiceScore(num_classes=3) + >>> values = [ ] + >>> for _ in range(10): + ... values.append( + ... metric(torch.randint(0, 2, (10, 3, 128, 128)), torch.randint(0, 2, (10, 3, 128, 128))) + ... ) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py new file mode 100644 index 00000000000..ed80e6fd6d7 --- /dev/null +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -0,0 +1,89 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import pytest +import torch +from monai.metrics.generalized_dice import compute_generalized_dice +from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score +from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore + +from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, _Input +from unittests._helpers.testers import MetricTester + +_inputs1 = _Input( + preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), + target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), +) +_inputs2 = _Input( + preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), + target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), +) +_inputs3 = _Input( + preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), + target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), +) + + +def _reference_generalized_dice( + preds: torch.Tensor, + target: torch.Tensor, + include_background: bool = True, + reduce: bool = True, +): + """Calculate reference metric for `MeanIoU`.""" + if (preds.bool() != preds).any(): # preds is an index tensor + preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) + if (target.bool() != target).any(): # target is an index tensor + target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) + val = compute_generalized_dice(preds, target, include_background=include_background) + if reduce: + val = val.mean() + return val + + +@pytest.mark.parametrize( + "preds, target", + [ + (_inputs1.preds, _inputs1.target), + (_inputs2.preds, _inputs2.target), + (_inputs3.preds, _inputs3.target), + ], +) +@pytest.mark.parametrize("include_background", [True, False]) +class TestMeanIoU(MetricTester): + """Test class for `MeanIoU` metric.""" + + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_mean_iou_class(self, preds, target, include_background, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=GeneralizedDiceScore, + reference_metric=partial(_reference_generalized_dice, include_background=include_background, reduce=True), + metric_args={"num_classes": NUM_CLASSES, "include_background": include_background}, + ) + + def test_mean_iou_functional(self, preds, target, include_background): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=generalized_dice_score, + reference_metric=partial(_reference_generalized_dice, include_background=include_background, reduce=False), + metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": False}, + )