Skip to content

Commit

Permalink
Clear list states (i.e. delete their contents), not reassign the de…
Browse files Browse the repository at this point in the history
…fault `[]` (#2493)

* Clear (i.e. delete) list state items, not simply overwrite. Previous behaviour produced memory leak from list[Tensor] states
* Added test to check list states elements are deleted (even when referenced, and hence not automatically garbage collected). Fixed failing test (want to check list state, but assigned Tensor)
* Updated documentation - highlighted reset clears list states, and that care must be taken when referencing them
* Add missing method (sphinx) role
* changelog
* Remove failing testcode example (fixing introduces too much complexity)
* Linting - Line break docstring
* copy internal states in forward
* Detach Tensor | list[Tensor] state values before copying.
* Use 'typing' type hints
* DO not clone (when caching) Tensor states, but retain references to avoid memory leakage
* Revert "DO not clone (when caching) Tensor states, but retain references to avoid memory leakage" (This reverts commit ef27215.)
* Added mypy type-hinting requirement/recommendation
* Moved update from test checking .__init__ memory leakage. Added test checking .reset clears memory allocated during update (memory should be allowed to grow, as long as discarded safely)
* Fix unused loop control variable for pre-commit

---------

Co-authored-by: dominicgkerr <dominicgkerr1@gmail.co>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: stancld <daniel.stancl@gmail.com>
  • Loading branch information
7 people committed Apr 16, 2024
1 parent 581c444 commit 5259c22
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462))


- Fixed memory leak in metrics using list states ([#2492](https://github.com/Lightning-AI/torchmetrics/pull/2492))


- Fixed bug in computation of `ERGAS` metric ([#2498](https://github.com/Lightning-AI/torchmetrics/pull/2498))


Expand Down
4 changes: 4 additions & 0 deletions docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ A few important things to note for this example:
``dim_zero_cat`` helper function which will standardize the list states to be a single concatenate tensor regardless
of the mode.

* Calling the ``reset`` method will clear the list state, deleting any values inserted into it. For this reason, care
must be taken when referencing list states. If you require the values after your metric is reset, you must first
copy the attribute to another object (e.g. using `deepcopy.copy`).

*****************
Metric attributes
*****************
Expand Down
29 changes: 25 additions & 4 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ def add_state(
When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow
the format discussed in the above note.
Note:
The values inserted into a list state are deleted whenever :meth:`~Metric.reset` is called. This allows
device memory to be automatically reallocated, but may produce unexpected effects when referencing list
states. To retain such values after :meth:`~Metric.reset` is called, you must first copy them to another
object.
Raises:
ValueError:
If ``default`` is not a ``tensor`` or an ``empty list``.
Expand Down Expand Up @@ -325,7 +331,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any:
self.compute_on_cpu = False

# save context before switch
cache = {attr: getattr(self, attr) for attr in self._defaults}
cache = self._copy_state_dict()

# call reset, update, compute, on single batch
self._enable_grad = True # allow grads for batch computation
Expand Down Expand Up @@ -358,7 +364,7 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:
"""
# store global state and reset to default
global_state = {attr: getattr(self, attr) for attr in self._defaults}
global_state = self._copy_state_dict()
_update_count = self._update_count
self.reset()

Expand Down Expand Up @@ -525,7 +531,7 @@ def sync(
dist_sync_fn = gather_all_tensors

# cache prior to syncing
self._cache = {attr: getattr(self, attr) for attr in self._defaults}
self._cache = self._copy_state_dict()

# sync
self._sync_dist(dist_sync_fn, process_group=process_group)
Expand Down Expand Up @@ -681,7 +687,7 @@ def reset(self) -> None:
if isinstance(default, Tensor):
setattr(self, attr, default.detach().clone().to(current_val.device))
else:
setattr(self, attr, [])
getattr(self, attr).clear() # delete/free list items

# reset internal states
self._cache = None
Expand Down Expand Up @@ -870,6 +876,21 @@ def state_dict( # type: ignore[override] # todo
destination[prefix + key] = deepcopy(current_val)
return destination

def _copy_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]:
"""Copy the current state values."""
cache: Dict[str, Union[Tensor, List[Any]]] = {}
for attr in self._defaults:
current_value = getattr(self, attr)

if isinstance(current_value, Tensor):
cache[attr] = current_value.detach().clone().to(current_value.device)
else:
cache[attr] = [ # safely copy (non-graph leaf) Tensor elements
_.detach().clone().to(_.device) if isinstance(_, Tensor) else deepcopy(_) for _ in current_value
]

return cache

def _load_from_state_dict(
self,
state_dict: dict,
Expand Down
32 changes: 27 additions & 5 deletions tests/unittests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,17 @@ class B(DummyListMetric):
metric = B()
assert isinstance(metric.x, list)
assert len(metric.x) == 0
metric.x = tensor(5)
metric.x = [tensor(5)]
metric.reset()
assert isinstance(metric.x, list)
assert len(metric.x) == 0

metric = B()
metric.x = [1, 2, 3]
reference = metric.x # prevents garbage collection
metric.reset()
assert len(reference) == 0 # check list state is freed


def test_reset_compute():
"""Test that `reset`+`compute` methods works as expected."""
Expand Down Expand Up @@ -474,18 +480,34 @@ def test_constant_memory_on_repeat_init():
def mem():
return torch.cuda.memory_allocated() / 1024**2

x = torch.randn(10000).cuda()

for i in range(100):
m = DummyListMetric(compute_with_cache=False).cuda()
m(x)
_ = DummyListMetric(compute_with_cache=False).cuda()
if i == 0:
after_one_iter = mem()

# allow for 5% flucturation due to measuring
assert after_one_iter * 1.05 >= mem(), "memory increased too much above base level"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
def test_freed_memory_on_reset():
"""Test that resetting a metric frees all the memory allocated when updating it."""

def mem():
return torch.cuda.memory_allocated() / 1024**2

m = DummyListMetric().cuda()
after_init = mem()

for _ in range(100):
m(x=torch.randn(10000).cuda())

m.reset()

# allow for 5% flucturation due to measuring
assert after_init * 1.05 >= mem(), "memory increased too much above base level"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu")
def test_specific_error_on_wrong_device():
"""Test that a specific error is raised if we detect input and metric are on different devices."""
Expand Down

0 comments on commit 5259c22

Please sign in to comment.