Skip to content

Commit

Permalink
Add postfix arg to MetricCollection (#188)
Browse files Browse the repository at this point in the history
* postfix

* chglog

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
SkafteNicki and Borda committed Apr 20, 2021
1 parent c1ae5ac commit 1aab672
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142))
- Added property `is_differentiable` to metrics and test for differentiability ([#154](https://github.com/PyTorchLightning/metrics/pull/154))
- Added support for `average`, `ignore_index` and `mdmc_average` in `Accuracy` metric ([#166](https://github.com/PyTorchLightning/metrics/pull/166))
- Added `postfix` arg to `MetricCollection` ([#188](https://github.com/PyTorchLightning/metrics/pull/188))

### Changed

Expand Down
25 changes: 21 additions & 4 deletions tests/bases/test_collections.py
Expand Up @@ -133,30 +133,47 @@ def test_metric_collection_args_kwargs(tmpdir):
assert metric_collection['DummyMetricDiff'].x == -20


def test_metric_collection_prefix_arg(tmpdir):
@pytest.mark.parametrize(
"prefix, postfix", [
[None, None],
['prefix_', None],
[None, '_postfix'],
['prefix_', '_postfix'],
]
)
def test_metric_collection_prefix_postfix_args(prefix, postfix):
""" Test that the prefix arg alters the keywords in the output"""
m1 = DummyMetricSum()
m2 = DummyMetricDiff()
names = ['DummyMetricSum', 'DummyMetricDiff']
names = [prefix + n if prefix is not None else n for n in names]
names = [n + postfix if postfix is not None else n for n in names]

metric_collection = MetricCollection([m1, m2], prefix='prefix_')
metric_collection = MetricCollection([m1, m2], prefix=prefix, postfix=postfix)

# test forward
out = metric_collection(5)
for name in names:
assert f"prefix_{name}" in out, 'prefix argument not working as intended with forward method'
assert name in out, 'prefix or postfix argument not working as intended with forward method'

# test compute
out = metric_collection.compute()
for name in names:
assert f"prefix_{name}" in out, 'prefix argument not working as intended with compute method'
assert name in out, 'prefix or postfix argument not working as intended with compute method'

# test clone
new_metric_collection = metric_collection.clone(prefix='new_prefix_')
out = new_metric_collection(5)
names = [n[len(prefix):] if prefix is not None else n for n in names] # strip away old prefix
for name in names:
assert f"new_prefix_{name}" in out, 'prefix argument not working as intended with clone method'

new_metric_collection = new_metric_collection.clone(postfix='_new_postfix')
out = new_metric_collection(5)
names = [n[:-len(postfix)] if postfix is not None else n for n in names] # strip away old postfix
for name in names:
assert f"new_prefix_{name}_new_postfix" in out, 'postfix argument not working as intended with clone method'


def test_metric_collection_same_order():
m1 = DummyMetricSum()
Expand Down
42 changes: 28 additions & 14 deletions torchmetrics/collections.py
Expand Up @@ -40,6 +40,8 @@ class name as key for the output dict.
prefix: a string to append in front of the keys of the output dict
postfix: a string to append after the keys of the output dict
Raises:
ValueError:
If one of the elements of ``metrics`` is not an instance of ``pl.metrics.Metric``.
Expand All @@ -48,7 +50,11 @@ class name as key for the output dict.
ValueError:
If ``metrics`` is not a ``list``, ``tuple`` or a ``dict``.
ValueError:
If ``metrics`` is is ``dict`` and passed any additional_metrics.
If ``metrics`` is ``dict`` and additional_metrics are passed in.
ValueError:
If ``prefix`` is set and it is not a string.
ValueError:
If ``postfix`` is set and it is not a string.
Example (input as list):
>>> import torch
Expand Down Expand Up @@ -85,6 +91,7 @@ def __init__(
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]],
*additional_metrics: Metric,
prefix: Optional[str] = None,
postfix: Optional[str] = None
):
super().__init__()
if isinstance(metrics, Metric):
Expand Down Expand Up @@ -128,15 +135,16 @@ def __init__(
else:
raise ValueError("Unknown input to MetricCollection.")

self.prefix = self._check_prefix_arg(prefix)
self.prefix = self._check_arg(prefix, 'prefix')
self.postfix = self._check_arg(postfix, 'postfix')

def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202
"""
Iteratively call forward for each metric. Positional arguments (args) will
be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
return {self._set_prefix(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}
return {self._set_name(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}

def update(self, *args, **kwargs): # pylint: disable=E0202
"""
Expand All @@ -149,20 +157,25 @@ def update(self, *args, **kwargs): # pylint: disable=E0202
m.update(*args, **m_kwargs)

def compute(self) -> Dict[str, Any]:
return {self._set_prefix(k): m.compute() for k, m in self.items()}
return {self._set_name(k): m.compute() for k, m in self.items()}

def reset(self) -> None:
""" Iteratively call reset for each metric """
for _, m in self.items():
m.reset()

def clone(self, prefix: Optional[str] = None) -> 'MetricCollection':
def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> 'MetricCollection':
""" Make a copy of the metric collection
Args:
prefix: a string to append in front of the metric keys
postfix: a string to append after the keys of the output dict
"""
mc = deepcopy(self)
mc.prefix = self._check_prefix_arg(prefix)
if prefix:
mc.prefix = self._check_arg(prefix, 'prefix')
if postfix:
mc.postfix = self._check_arg(postfix, 'postfix')
return mc

def persistent(self, mode: bool = True) -> None:
Expand All @@ -172,14 +185,15 @@ def persistent(self, mode: bool = True) -> None:
for _, m in self.items():
m.persistent(mode)

def _set_prefix(self, k: str) -> str:
return k if self.prefix is None else self.prefix + k
def _set_name(self, base: str) -> str:
name = base if self.prefix is None else self.prefix + base
name = name if self.postfix is None else name + self.postfix
return name

@staticmethod
def _check_prefix_arg(prefix: str) -> Optional[str]:
if prefix is not None:
if isinstance(prefix, str):
return prefix
else:
raise ValueError('Expected input `prefix` to be a string')
def _check_arg(arg: str, name: str) -> Optional[str]:
if arg is not None:
if isinstance(arg, str):
return arg
raise ValueError(f'Expected input {name} to be a string')
return None

0 comments on commit 1aab672

Please sign in to comment.