-
Notifications
You must be signed in to change notification settings - Fork 387
/
binned_precision_recall.py
324 lines (273 loc) · 13.9 KB
/
binned_precision_recall.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torchmetrics.functional.classification.average_precision import _average_precision_compute_with_precision_recall
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import METRIC_EPS, to_onehot
def _recall_at_precision(
precision: Tensor,
recall: Tensor,
thresholds: Tensor,
min_precision: float,
) -> Tuple[Tensor, Tensor]:
try:
max_recall, _, best_threshold = max(
(r, p, t) for p, r, t in zip(precision, recall, thresholds) if p >= min_precision
)
except ValueError:
max_recall = torch.tensor(0.0, device=recall.device, dtype=recall.dtype)
best_threshold = torch.tensor(0)
if max_recall == 0.0:
best_threshold = torch.tensor(1e6, device=thresholds.device, dtype=thresholds.dtype)
return max_recall, best_threshold
class BinnedPrecisionRecallCurve(Metric):
"""Computes precision-recall pairs for different thresholds. Works for both binary and multiclass problems. In
the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
Computation is performed in constant-memory by computing precision and recall
for ``thresholds`` buckets/thresholds (evenly distributed between 0 and 1).
Forward accepts
- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor
with probabilities, where C is the number of classes.
- ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels
Args:
num_classes: integer with number of classes. For binary, set to 1.
thresholds: list or tensor with specific thresholds or a number of bins from linear sampling.
It is used for computation will lead to more detailed curve and accurate estimates,
but will be slower and consume more memory.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
Raises:
ValueError:
If ``thresholds`` is not a int, list or tensor
Example (binary case):
>>> from torchmetrics import BinnedPrecisionRecallCurve
>>> pred = torch.tensor([0, 0.1, 0.8, 0.4])
>>> target = torch.tensor([0, 1, 1, 0])
>>> pr_curve = BinnedPrecisionRecallCurve(num_classes=1, thresholds=5)
>>> precision, recall, thresholds = pr_curve(pred, target)
>>> precision
tensor([0.5000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000])
>>> recall
tensor([1.0000, 0.5000, 0.5000, 0.5000, 0.0000, 0.0000])
>>> thresholds
tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> pr_curve = BinnedPrecisionRecallCurve(num_classes=5, thresholds=3)
>>> precision, recall, thresholds = pr_curve(pred, target)
>>> precision
[tensor([0.2500, 1.0000, 1.0000, 1.0000]),
tensor([0.2500, 1.0000, 1.0000, 1.0000]),
tensor([2.5000e-01, 1.0000e-06, 1.0000e+00, 1.0000e+00]),
tensor([2.5000e-01, 1.0000e-06, 1.0000e+00, 1.0000e+00]),
tensor([2.5000e-07, 1.0000e+00, 1.0000e+00, 1.0000e+00])]
>>> recall
[tensor([1.0000, 1.0000, 0.0000, 0.0000]),
tensor([1.0000, 1.0000, 0.0000, 0.0000]),
tensor([1.0000, 0.0000, 0.0000, 0.0000]),
tensor([1.0000, 0.0000, 0.0000, 0.0000]),
tensor([0., 0., 0., 0.])]
>>> thresholds
[tensor([0.0000, 0.5000, 1.0000]),
tensor([0.0000, 0.5000, 1.0000]),
tensor([0.0000, 0.5000, 1.0000]),
tensor([0.0000, 0.5000, 1.0000]),
tensor([0.0000, 0.5000, 1.0000])]
"""
TPs: Tensor
FPs: Tensor
FNs: Tensor
def __init__(
self,
num_classes: int,
thresholds: Union[int, Tensor, List[float], None] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)
self.num_classes = num_classes
if isinstance(thresholds, int):
self.num_thresholds = thresholds
thresholds = torch.linspace(0, 1.0, thresholds)
self.register_buffer("thresholds", thresholds)
elif thresholds is not None:
if not isinstance(thresholds, (list, Tensor)):
raise ValueError("Expected argument `thresholds` to either be an integer, list of floats or a tensor")
thresholds = torch.tensor(thresholds) if isinstance(thresholds, list) else thresholds
self.num_thresholds = thresholds.numel()
self.register_buffer("thresholds", thresholds)
for name in ("TPs", "FPs", "FNs"):
self.add_state(
name=name,
default=torch.zeros(num_classes, self.num_thresholds, dtype=torch.float32),
dist_reduce_fx="sum",
)
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Args
preds: (n_samples, n_classes) tensor
target: (n_samples, n_classes) tensor
"""
# binary case
if len(preds.shape) == len(target.shape) == 1:
preds = preds.reshape(-1, 1)
target = target.reshape(-1, 1)
if len(preds.shape) == len(target.shape) + 1:
target = to_onehot(target, num_classes=self.num_classes)
target = target == 1
# Iterate one threshold at a time to conserve memory
for i in range(self.num_thresholds):
predictions = preds >= self.thresholds[i]
self.TPs[:, i] += (target & predictions).sum(dim=0)
self.FPs[:, i] += ((~target) & (predictions)).sum(dim=0)
self.FNs[:, i] += ((target) & (~predictions)).sum(dim=0)
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Returns float tensor of size n_classes."""
precisions = (self.TPs + METRIC_EPS) / (self.TPs + self.FPs + METRIC_EPS)
recalls = self.TPs / (self.TPs + self.FNs + METRIC_EPS)
# Need to guarantee that last precision=1 and recall=0, similar to precision_recall_curve
t_ones = torch.ones(self.num_classes, 1, dtype=precisions.dtype, device=precisions.device)
precisions = torch.cat([precisions, t_ones], dim=1)
t_zeros = torch.zeros(self.num_classes, 1, dtype=recalls.dtype, device=recalls.device)
recalls = torch.cat([recalls, t_zeros], dim=1)
if self.num_classes == 1:
return precisions[0, :], recalls[0, :], self.thresholds
return list(precisions), list(recalls), [self.thresholds for _ in range(self.num_classes)]
class BinnedAveragePrecision(BinnedPrecisionRecallCurve):
"""Computes the average precision score, which summarises the precision recall curve into one number. Works for
both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one-
vs-the-rest approach.
Computation is performed in constant-memory by computing precision and recall
for ``thresholds`` buckets/thresholds (evenly distributed between 0 and 1).
Forward accepts
- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor
with probabilities, where C is the number of classes.
- ``target`` (long tensor): ``(N, ...)`` with integer labels
Args:
num_classes: integer with number of classes. Not nessesary to provide
for binary problems.
thresholds: list or tensor with specific thresholds or a number of bins from linear sampling.
It is used for computation will lead to more detailed curve and accurate estimates,
but will be slower and consume more memory
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False.
process_group:
Specify the process group on which synchronization is called.
Raises:
ValueError:
If ``thresholds`` is not a list or tensor
Example (binary case):
>>> from torchmetrics import BinnedAveragePrecision
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
>>> average_precision = BinnedAveragePrecision(num_classes=1, thresholds=10)
>>> average_precision(pred, target)
tensor(1.0000)
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision = BinnedAveragePrecision(num_classes=5, thresholds=10)
>>> average_precision(pred, target)
[tensor(1.0000), tensor(1.0000), tensor(0.2500), tensor(0.2500), tensor(-0.)]
"""
def compute(self) -> Union[List[Tensor], Tensor]: # type: ignore
precisions, recalls, _ = super().compute()
return _average_precision_compute_with_precision_recall(precisions, recalls, self.num_classes, average=None)
class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve):
"""Computes the higest possible recall value given the minimum precision thresholds provided.
Computation is performed in constant-memory by computing precision and recall
for ``thresholds`` buckets/thresholds (evenly distributed between 0 and 1).
Forward accepts
- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor
with probabilities, where C is the number of classes.
- ``target`` (long tensor): ``(N, ...)`` with integer labels
Args:
num_classes: integer with number of classes. Provide 1 for for binary problems.
min_precision: float value specifying minimum precision threshold.
thresholds: list or tensor with specific thresholds or a number of bins from linear sampling.
It is used for computation will lead to more detailed curve and accurate estimates,
but will be slower and consume more memory
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False.
process_group:
Specify the process group on which synchronization is called.
Raises:
ValueError:
If ``thresholds`` is not a list or tensor
Example (binary case):
>>> from torchmetrics import BinnedRecallAtFixedPrecision
>>> pred = torch.tensor([0, 0.2, 0.5, 0.8])
>>> target = torch.tensor([0, 1, 1, 0])
>>> average_precision = BinnedRecallAtFixedPrecision(num_classes=1, thresholds=10, min_precision=0.5)
>>> average_precision(pred, target)
(tensor(1.0000), tensor(0.1111))
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision = BinnedRecallAtFixedPrecision(num_classes=5, thresholds=10, min_precision=0.5)
>>> average_precision(pred, target)
(tensor([1.0000, 1.0000, 0.0000, 0.0000, 0.0000]),
tensor([6.6667e-01, 6.6667e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06]))
"""
def __init__(
self,
num_classes: int,
min_precision: float,
thresholds: Union[int, Tensor, List[float], None] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
) -> None:
super().__init__(
num_classes=num_classes,
thresholds=thresholds,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)
self.min_precision = min_precision
def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore
"""Returns float tensor of size n_classes."""
precisions, recalls, thresholds = super().compute()
if self.num_classes == 1:
return _recall_at_precision(precisions, recalls, thresholds, self.min_precision)
recalls_at_p = torch.zeros(self.num_classes, device=recalls[0].device, dtype=recalls[0].dtype)
thresholds_at_p = torch.zeros(self.num_classes, device=thresholds[0].device, dtype=thresholds[0].dtype)
for i in range(self.num_classes):
recalls_at_p[i], thresholds_at_p[i] = _recall_at_precision(
precisions[i], recalls[i], thresholds[i], self.min_precision
)
return recalls_at_p, thresholds_at_p