-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
metric.py
236 lines (193 loc) · 9.19 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import functools
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Union
from collections.abc import Mapping, Sequence
from collections import namedtuple
from copy import deepcopy
import os
import torch
from torch import nn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import gather_all_tensors_if_available
from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum
class Metric(nn.Module, ABC):
"""
Base class for all metrics present in the Metrics API.
Implements ``add_state()``, ``forward()``, ``reset()`` and a few other things to
handle distributed synchronization and per-step metric computation.
Override ``update()`` and ``compute()`` functions to implement your own metric. Use
``add_state()`` to register metric state variables which keep track of state on each
call of ``update()`` and are synchronized across processes when ``compute()`` is called.
Note:
Metric state variables can either be ``torch.Tensors`` or an empty list which can we used
to store `torch.Tensors``.
Note:
Different metrics only override ``update()`` and not ``forward()``. A call to ``update()``
is valid, but it won't return the metric value at the current step. A call to ``forward()``
automatically calls ``update()`` and also returns the metric value at the current step.
Args:
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False. default: True
ddp_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
"""
def __init__(
self,
compute_on_step: bool = True,
ddp_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__()
self.ddp_sync_on_step = ddp_sync_on_step
self.compute_on_step = compute_on_step
self.process_group = process_group
self._to_sync = True
self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute)
self._computed = None
self._forward_cache = None
# initialize state
self._reductions = {}
self._defaults = {}
def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None):
"""
Adds metric state variable. Only used by subclasses.
Args:
name: The name of the state variable. The variable will then be accessible at ``self.name``.
default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be
reset to this value when ``self.reset()`` is called.
dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode.
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
function in this parameter.
Note:
Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes.
However, there won't be any reduction function applied to the synchronized metric state.
The metric states would be synced as follows
- If the metric state is ``torch.Tensor``, the synced value will be a stacked ``torch.Tensor`` across
the process dimension if the metric state was a ``torch.Tensor``. The original ``torch.Tensor`` metric
state retains dimension and hence the synchronized output will be of shape ``(num_process, ...)``.
- If the metric state is a ``list``, the synced value will be a ``list`` containing the
combined elements from all processes.
Note:
When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow
the format discussed in the above note.
"""
if (
not isinstance(default, torch.Tensor)
and not isinstance(default, list) # noqa: W503
or (isinstance(default, list) and len(default) != 0) # noqa: W503
):
raise ValueError(
"state variable must be a tensor or any empty list (where you can append tensors)"
)
if dist_reduce_fx == "sum":
dist_reduce_fx = dim_zero_sum
elif dist_reduce_fx == "mean":
dist_reduce_fx = dim_zero_mean
elif dist_reduce_fx == "cat":
dist_reduce_fx = dim_zero_cat
elif dist_reduce_fx is not None and not isinstance(dist_reduce_fx, Callable):
raise ValueError(
"`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]"
)
if isinstance(default, torch.Tensor):
self.register_buffer(name, default)
else:
setattr(self, name, default)
self._defaults[name] = deepcopy(default)
self._reductions[name] = dist_reduce_fx
def forward(self, *args, **kwargs):
"""
Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True.
"""
# add current step
self.update(*args, **kwargs)
self._forward_cache = None
if self.compute_on_step:
self._to_sync = self.ddp_sync_on_step
# save context before switch
self._cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}
# call reset, update, compute, on single batch
self.reset()
self.update(*args, **kwargs)
self._forward_cache = self.compute()
# restore context
for attr, val in self._cache.items():
setattr(self, attr, val)
self._to_sync = True
self._computed = None
return self._forward_cache
def _sync_dist(self):
input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()}
output_dict = apply_to_collection(
input_dict,
torch.Tensor,
gather_all_tensors_if_available,
group=self.process_group,
)
for attr, reduction_fn in self._reductions.items():
# pre-processing ops (stack or flatten for inputs)
if isinstance(output_dict[attr][0], torch.Tensor):
output_dict[attr] = torch.stack(output_dict[attr])
elif isinstance(output_dict[attr][0], list):
output_dict[attr] = _flatten(output_dict[attr])
assert isinstance(reduction_fn, (Callable)) or reduction_fn is None
reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr]
setattr(self, attr, reduced)
def _wrap_update(self, update):
@functools.wraps(update)
def wrapped_func(*args, **kwargs):
self._computed = None
return update(*args, **kwargs)
return wrapped_func
def _wrap_compute(self, compute):
@functools.wraps(compute)
def wrapped_func(*args, **kwargs):
# return cached value
if self._computed is not None:
return self._computed
if (
self._to_sync
and torch.distributed.is_available() # noqa: W503
and torch.distributed.is_initialized() # noqa: W503
):
self._sync_dist()
self._computed = compute(*args, **kwargs)
self.reset()
return self._computed
return wrapped_func
@abstractmethod
def update(self) -> None: # pylint: disable=E0202
"""
Override this method to update the state variables of your metric class.
"""
pass
@abstractmethod
def compute(self): # pylint: disable=E0202
"""
Override this method to compute the final metric value from state variables
synchronized across the distributed backend.
"""
pass
def reset(self):
"""
This method automatically resets the metric state variables to their default value.
"""
for attr, default in self._defaults.items():
current_val = getattr(self, attr)
if isinstance(current_val, torch.Tensor):
setattr(self, attr, deepcopy(default).to(current_val.device))
else:
setattr(self, attr, deepcopy(default))
def __getstate__(self):
# ignore update and compute functions for pickling
return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]}
def __setstate__(self, state):
# manually restore update and compute functions for pickling
self.__dict__.update(state)
self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute)