diff --git a/CHANGELOG.md b/CHANGELOG.md index de5223d60d5..f4c5308ae76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142)) - Added property `is_differentiable` to metrics and test for differentiability ([#154](https://github.com/PyTorchLightning/metrics/pull/154)) - Added support for `average`, `ignore_index` and `mdmc_average` in `Accuracy` metric ([#166](https://github.com/PyTorchLightning/metrics/pull/166)) +- Added `postfix` arg to `MetricCollection` ([#188](https://github.com/PyTorchLightning/metrics/pull/188)) ### Changed diff --git a/tests/bases/test_collections.py b/tests/bases/test_collections.py index 37c8cfbac50..96a04019389 100644 --- a/tests/bases/test_collections.py +++ b/tests/bases/test_collections.py @@ -133,30 +133,47 @@ def test_metric_collection_args_kwargs(tmpdir): assert metric_collection['DummyMetricDiff'].x == -20 -def test_metric_collection_prefix_arg(tmpdir): +@pytest.mark.parametrize( + "prefix, postfix", [ + [None, None], + ['prefix_', None], + [None, '_postfix'], + ['prefix_', '_postfix'], + ] +) +def test_metric_collection_prefix_postfix_args(prefix, postfix): """ Test that the prefix arg alters the keywords in the output""" m1 = DummyMetricSum() m2 = DummyMetricDiff() names = ['DummyMetricSum', 'DummyMetricDiff'] + names = [prefix + n if prefix is not None else n for n in names] + names = [n + postfix if postfix is not None else n for n in names] - metric_collection = MetricCollection([m1, m2], prefix='prefix_') + metric_collection = MetricCollection([m1, m2], prefix=prefix, postfix=postfix) # test forward out = metric_collection(5) for name in names: - assert f"prefix_{name}" in out, 'prefix argument not working as intended with forward method' + assert name in out, 'prefix or postfix argument not working as intended with forward method' # test compute out = metric_collection.compute() for name in names: - assert f"prefix_{name}" in out, 'prefix argument not working as intended with compute method' + assert name in out, 'prefix or postfix argument not working as intended with compute method' # test clone new_metric_collection = metric_collection.clone(prefix='new_prefix_') out = new_metric_collection(5) + names = [n[len(prefix):] if prefix is not None else n for n in names] # strip away old prefix for name in names: assert f"new_prefix_{name}" in out, 'prefix argument not working as intended with clone method' + new_metric_collection = new_metric_collection.clone(postfix='_new_postfix') + out = new_metric_collection(5) + names = [n[:-len(postfix)] if postfix is not None else n for n in names] # strip away old postfix + for name in names: + assert f"new_prefix_{name}_new_postfix" in out, 'postfix argument not working as intended with clone method' + def test_metric_collection_same_order(): m1 = DummyMetricSum() diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index 80234738706..b84b1d59a8a 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -40,6 +40,8 @@ class name as key for the output dict. prefix: a string to append in front of the keys of the output dict + postfix: a string to append after the keys of the output dict + Raises: ValueError: If one of the elements of ``metrics`` is not an instance of ``pl.metrics.Metric``. @@ -48,7 +50,11 @@ class name as key for the output dict. ValueError: If ``metrics`` is not a ``list``, ``tuple`` or a ``dict``. ValueError: - If ``metrics`` is is ``dict`` and passed any additional_metrics. + If ``metrics`` is ``dict`` and additional_metrics are passed in. + ValueError: + If ``prefix`` is set and it is not a string. + ValueError: + If ``postfix`` is set and it is not a string. Example (input as list): >>> import torch @@ -85,6 +91,7 @@ def __init__( metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], *additional_metrics: Metric, prefix: Optional[str] = None, + postfix: Optional[str] = None ): super().__init__() if isinstance(metrics, Metric): @@ -128,7 +135,8 @@ def __init__( else: raise ValueError("Unknown input to MetricCollection.") - self.prefix = self._check_prefix_arg(prefix) + self.prefix = self._check_arg(prefix, 'prefix') + self.postfix = self._check_arg(postfix, 'postfix') def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202 """ @@ -136,7 +144,7 @@ def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202 be passed to every metric in the collection, while keyword arguments (kwargs) will be filtered based on the signature of the individual metric. """ - return {self._set_prefix(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()} + return {self._set_name(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()} def update(self, *args, **kwargs): # pylint: disable=E0202 """ @@ -149,20 +157,25 @@ def update(self, *args, **kwargs): # pylint: disable=E0202 m.update(*args, **m_kwargs) def compute(self) -> Dict[str, Any]: - return {self._set_prefix(k): m.compute() for k, m in self.items()} + return {self._set_name(k): m.compute() for k, m in self.items()} def reset(self) -> None: """ Iteratively call reset for each metric """ for _, m in self.items(): m.reset() - def clone(self, prefix: Optional[str] = None) -> 'MetricCollection': + def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> 'MetricCollection': """ Make a copy of the metric collection Args: prefix: a string to append in front of the metric keys + postfix: a string to append after the keys of the output dict + """ mc = deepcopy(self) - mc.prefix = self._check_prefix_arg(prefix) + if prefix: + mc.prefix = self._check_arg(prefix, 'prefix') + if postfix: + mc.postfix = self._check_arg(postfix, 'postfix') return mc def persistent(self, mode: bool = True) -> None: @@ -172,14 +185,15 @@ def persistent(self, mode: bool = True) -> None: for _, m in self.items(): m.persistent(mode) - def _set_prefix(self, k: str) -> str: - return k if self.prefix is None else self.prefix + k + def _set_name(self, base: str) -> str: + name = base if self.prefix is None else self.prefix + base + name = name if self.postfix is None else name + self.postfix + return name @staticmethod - def _check_prefix_arg(prefix: str) -> Optional[str]: - if prefix is not None: - if isinstance(prefix, str): - return prefix - else: - raise ValueError('Expected input `prefix` to be a string') + def _check_arg(arg: str, name: str) -> Optional[str]: + if arg is not None: + if isinstance(arg, str): + return arg + raise ValueError(f'Expected input {name} to be a string') return None