Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add persistent flag to Metric.add_state #4195

Merged
merged 6 commits into from Oct 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 10 additions & 2 deletions pytorch_lightning/metrics/metric.py
Expand Up @@ -17,6 +17,7 @@
from collections.abc import Mapping, Sequence
from collections import namedtuple
from copy import deepcopy
from distutils.version import LooseVersion

import os
import torch
Expand Down Expand Up @@ -78,7 +79,9 @@ def __init__(
self._reductions = {}
self._defaults = {}

def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None):
def add_state(
self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = True
):
"""
Adds metric state variable. Only used by subclasses.

Expand All @@ -90,6 +93,7 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
function in this parameter.
persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.

Note:
Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes.
Expand Down Expand Up @@ -130,7 +134,11 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call
)

if isinstance(default, torch.Tensor):
self.register_buffer(name, default)
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
# persistent keyword is only supported in torch >= 1.6.0
self.register_buffer(name, default, persistent=persistent)
else:
self.register_buffer(name, default)
else:
setattr(self, name, default)

Expand Down
14 changes: 14 additions & 0 deletions tests/metrics/test_metric.py
@@ -1,5 +1,6 @@
import pickle

from distutils.version import LooseVersion
import cloudpickle
import numpy as np
import pytest
Expand Down Expand Up @@ -59,6 +60,19 @@ def custom_fx(x):
assert a._reductions["e"](torch.tensor([1, 1])) == -1


def test_add_state_persistent():
a = Dummy()

a.add_state("a", torch.tensor(0), "sum", persistent=True)
assert "a" in a.state_dict()

a.add_state("b", torch.tensor(0), "sum", persistent=False)

if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
assert "b" not in a.state_dict()



def test_reset():
class A(Dummy):
pass
Expand Down