Skip to content

Commit

Permalink
Add persistent flag to Metric.add_state (#4195)
Browse files Browse the repository at this point in the history
* add persistant flag to add_state in metrics

* wrap register_buffer with try catch

* pep8

* use loose version

* test

* pep8
  • Loading branch information
teddykoker committed Oct 16, 2020
1 parent 3fe479f commit 827a557
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
12 changes: 10 additions & 2 deletions pytorch_lightning/metrics/metric.py
Expand Up @@ -17,6 +17,7 @@
from collections.abc import Mapping, Sequence
from collections import namedtuple
from copy import deepcopy
from distutils.version import LooseVersion

import os
import torch
Expand Down Expand Up @@ -78,7 +79,9 @@ def __init__(
self._reductions = {}
self._defaults = {}

def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None):
def add_state(
self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = True
):
"""
Adds metric state variable. Only used by subclasses.
Expand All @@ -90,6 +93,7 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
function in this parameter.
persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.
Note:
Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes.
Expand Down Expand Up @@ -130,7 +134,11 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call
)

if isinstance(default, torch.Tensor):
self.register_buffer(name, default)
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
# persistent keyword is only supported in torch >= 1.6.0
self.register_buffer(name, default, persistent=persistent)
else:
self.register_buffer(name, default)
else:
setattr(self, name, default)

Expand Down
14 changes: 14 additions & 0 deletions tests/metrics/test_metric.py
@@ -1,5 +1,6 @@
import pickle

from distutils.version import LooseVersion
import cloudpickle
import numpy as np
import pytest
Expand Down Expand Up @@ -59,6 +60,19 @@ def custom_fx(x):
assert a._reductions["e"](torch.tensor([1, 1])) == -1


def test_add_state_persistent():
a = Dummy()

a.add_state("a", torch.tensor(0), "sum", persistent=True)
assert "a" in a.state_dict()

a.add_state("b", torch.tensor(0), "sum", persistent=False)

if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
assert "b" not in a.state_dict()



def test_reset():
class A(Dummy):
pass
Expand Down

0 comments on commit 827a557

Please sign in to comment.