Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loads metric states from state_dict #202

Merged
merged 6 commits into from Apr 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed auc calculation and add tests ([#197](https://github.com/PyTorchLightning/metrics/pull/197))


- Fixed loading persisted metric states using `load_state_dict()` ([#202](https://github.com/PyTorchLightning/metrics/pull/202))


## [0.3.1] - 2021-04-21

- Cleaning remaining inconsistency and fix PL develop integration (
Expand Down
10 changes: 10 additions & 0 deletions tests/bases/test_metric.py
Expand Up @@ -227,6 +227,16 @@ def test_state_dict(tmpdir):
assert metric.state_dict() == OrderedDict()


def test_load_state_dict(tmpdir):
""" test that metric states can be loaded with state dict """
metric = DummyMetricSum()
metric.persistent(True)
metric.update(5)
loaded_metric = DummyMetricSum()
loaded_metric.load_state_dict(metric.state_dict())
assert metric.compute() == 5


def test_child_metric_state_dict():
""" test that child metric states will be added to parent state dict """

Expand Down
21 changes: 20 additions & 1 deletion torchmetrics/metric.py
Expand Up @@ -17,7 +17,7 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from copy import deepcopy
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, List, Optional, Union

import torch
from torch import Tensor, nn
Expand Down Expand Up @@ -324,6 +324,25 @@ def state_dict(self, destination=None, prefix="", keep_vars=False):
destination[prefix + key] = current_val
return destination

def _load_from_state_dict(
self,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
""" Loads metric states from state_dict """
for key in self._defaults.keys():
name = prefix + key
if name in state_dict:
setattr(self, key, state_dict.pop(name))
super()._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
)

def _filter_kwargs(self, **kwargs):
""" filter kwargs such that they match the update signature of the metric """

Expand Down