Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add postfix arg to MetricCollection #188

Merged
merged 12 commits into from Apr 20, 2021
17 changes: 17 additions & 0 deletions CHANGELOG.md
Expand Up @@ -4,6 +4,23 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.4.0] - ????-??-??

### Added
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

- Added `postfix` arg to `MetricCollection` ([#188](https://github.com/PyTorchLightning/metrics/pull/188))

### Changed


### Deprecated


### Removed


### Fixed


## [0.3.0] - 2021-04-20

Expand Down
23 changes: 19 additions & 4 deletions tests/bases/test_collections.py
Expand Up @@ -133,30 +133,45 @@ def test_metric_collection_args_kwargs(tmpdir):
assert metric_collection['DummyMetricDiff'].x == -20


def test_metric_collection_prefix_arg(tmpdir):
@pytest.mark.parametrize("prefix, postfix", [
[None, None],
['prefix_', None],
[None, '_postfix'],
['prefix_', '_postfix']
])
def test_metric_collection_prefix_postfix_args(prefix, postfix):
""" Test that the prefix arg alters the keywords in the output"""
m1 = DummyMetricSum()
m2 = DummyMetricDiff()
names = ['DummyMetricSum', 'DummyMetricDiff']
names = [prefix + n if prefix is not None else n for n in names]
names = [n + postfix if postfix is not None else n for n in names]

metric_collection = MetricCollection([m1, m2], prefix='prefix_')
metric_collection = MetricCollection([m1, m2], prefix=prefix, postfix=postfix)

# test forward
out = metric_collection(5)
for name in names:
assert f"prefix_{name}" in out, 'prefix argument not working as intended with forward method'
assert name in out, 'prefix or postfix argument not working as intended with forward method'

# test compute
out = metric_collection.compute()
for name in names:
assert f"prefix_{name}" in out, 'prefix argument not working as intended with compute method'
assert name in out, 'prefix or postfix argument not working as intended with compute method'

# test clone
new_metric_collection = metric_collection.clone(prefix='new_prefix_')
out = new_metric_collection(5)
names = [n[7:] if prefix is not None else n for n in names] # strip away old prefix
Borda marked this conversation as resolved.
Show resolved Hide resolved
for name in names:
assert f"new_prefix_{name}" in out, 'prefix argument not working as intended with clone method'

new_metric_collection = new_metric_collection.clone(postfix='_new_postfix')
out = new_metric_collection(5)
names = [n[:-8] if postfix is not None else n for n in names] # strip away old postfix
Borda marked this conversation as resolved.
Show resolved Hide resolved
for name in names:
assert f"new_prefix_{name}_new_postfix" in out, 'postfix argument not working as intended with clone method'


def test_metric_collection_same_order():
m1 = DummyMetricSum()
Expand Down
41 changes: 28 additions & 13 deletions torchmetrics/collections.py
Expand Up @@ -40,6 +40,8 @@ class name as key for the output dict.

prefix: a string to append in front of the keys of the output dict

postfix: a string to append after the keys of the output dict

Raises:
ValueError:
If one of the elements of ``metrics`` is not an instance of ``pl.metrics.Metric``.
Expand All @@ -48,7 +50,11 @@ class name as key for the output dict.
ValueError:
If ``metrics`` is not a ``list``, ``tuple`` or a ``dict``.
ValueError:
If ``metrics`` is is ``dict`` and passed any additional_metrics.
If ``metrics`` is ``dict`` and additional_metrics are passed in.
ValueError:
If ``prefix`` is set and it is not a string.
ValueError:
If ``postfix`` is set and it is not a string.

Example (input as list):
>>> import torch
Expand Down Expand Up @@ -85,6 +91,7 @@ def __init__(
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]],
*additional_metrics: Metric,
prefix: Optional[str] = None,
postfix: Optional[str] = None
):
super().__init__()
if isinstance(metrics, Metric):
Expand Down Expand Up @@ -128,15 +135,16 @@ def __init__(
else:
raise ValueError("Unknown input to MetricCollection.")

self.prefix = self._check_prefix_arg(prefix)
self.prefix = self._check_arg(prefix, 'prefix')
self.postfix = self._check_arg(postfix, 'postfix')

def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202
"""
Iteratively call forward for each metric. Positional arguments (args) will
be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
return {self._set_prefix(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}
return {self._set_name(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}

def update(self, *args, **kwargs): # pylint: disable=E0202
"""
Expand All @@ -149,20 +157,25 @@ def update(self, *args, **kwargs): # pylint: disable=E0202
m.update(*args, **m_kwargs)

def compute(self) -> Dict[str, Any]:
return {self._set_prefix(k): m.compute() for k, m in self.items()}
return {self._set_name(k): m.compute() for k, m in self.items()}

def reset(self) -> None:
""" Iteratively call reset for each metric """
for _, m in self.items():
m.reset()

def clone(self, prefix: Optional[str] = None) -> 'MetricCollection':
def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> 'MetricCollection':
""" Make a copy of the metric collection
Args:
prefix: a string to append in front of the metric keys
postfix: a string to append after the keys of the output dic
Borda marked this conversation as resolved.
Show resolved Hide resolved

"""
mc = deepcopy(self)
mc.prefix = self._check_prefix_arg(prefix)
if prefix is not None:
mc.prefix = self._check_arg(prefix, 'prefix')
if postfix is not None:
mc.postfix = self._check_arg(postfix, 'postfix')
Borda marked this conversation as resolved.
Show resolved Hide resolved
return mc

def persistent(self, mode: bool = True) -> None:
Expand All @@ -172,14 +185,16 @@ def persistent(self, mode: bool = True) -> None:
for _, m in self.items():
m.persistent(mode)

def _set_prefix(self, k: str) -> str:
return k if self.prefix is None else self.prefix + k
def _set_name(self, k: str) -> str:
out = k if self.prefix is None else self.prefix + k
out = out if self.postfix is None else out + self.postfix
return out
Borda marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def _check_prefix_arg(prefix: str) -> Optional[str]:
if prefix is not None:
if isinstance(prefix, str):
return prefix
def _check_arg(arg: str, name: str) -> Optional[str]:
if arg is not None:
if isinstance(arg, str):
return arg
else:
Borda marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError('Expected input `prefix` to be a string')
raise ValueError(f'Expected input {name} to be a string')
Borda marked this conversation as resolved.
Show resolved Hide resolved
return None