Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compatibility between ClasswiseWrapper and prefix/postfix arg in MetricCollection #843

Merged
merged 2 commits into from Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -86,6 +86,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Improved testing speed ([#820](https://github.com/PyTorchLightning/metrics/pull/820))


- Fixed compatibility of `ClasswiseWrapper` with the `prefix` argument of `MetricCollection` ([#843](https://github.com/PyTorchLightning/metrics/pull/843))


## [0.7.2] - 2022-02-10

### Fixed
Expand Down
20 changes: 16 additions & 4 deletions tests/wrappers/test_classwise.py
Expand Up @@ -38,20 +38,32 @@ def test_output_with_labels():
assert f"accuracy_{lab}" in val


def test_using_metriccollection():
@pytest.mark.parametrize("prefix", [None, "pre_"])
@pytest.mark.parametrize("postfix", [None, "_post"])
def test_using_metriccollection(prefix, postfix):
"""Test wrapper in combination with metric collection."""
labels = ["horse", "fish", "cat"]
metric = MetricCollection(
{
"accuracy": ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels),
"recall": ClasswiseWrapper(Recall(num_classes=3, average=None), labels=labels),
}
},
prefix=prefix,
postfix=postfix,
)
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
val = metric(preds, target)
assert isinstance(val, dict)
assert len(val) == 6

def _get_correct_name(base):
name = base if prefix is None else prefix + base
name = name if postfix is None else name + postfix
return name

for lab in labels:
assert f"accuracy_{lab}" in val
assert f"recall_{lab}" in val
name = _get_correct_name(f"accuracy_{lab}")
assert name in val
name = _get_correct_name(f"recall_{lab}")
assert name in val
9 changes: 6 additions & 3 deletions torchmetrics/collections.py
Expand Up @@ -131,7 +131,9 @@ def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
return _flatten_dict({k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()})
res = {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items(keep_base=True)}
res = _flatten_dict(res)
return {self._set_name(k): v for k, v in res.items()}

def update(self, *args: Any, **kwargs: Any) -> None:
"""Iteratively call update for each metric.
Expand Down Expand Up @@ -219,8 +221,9 @@ def compute(self) -> Dict[str, Any]:
mi = getattr(self, cg[i])
for state in m0._defaults:
setattr(mi, state, getattr(m0, state))
res = {k: m.compute() for k, m in self.items()}
return _flatten_dict(res)
res = {k: m.compute() for k, m in self.items(keep_base=True)}
res = _flatten_dict(res)
return {self._set_name(k): v for k, v in res.items()}

def reset(self) -> None:
"""Iteratively call reset for each metric."""
Expand Down