-
Notifications
You must be signed in to change notification settings - Fork 387
/
precision_recall_curve.py
331 lines (285 loc) · 13.4 KB
/
precision_recall_curve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor, tensor
from torch.nn import functional as F
from torchmetrics.utilities import rank_zero_warn
def _binary_clf_curve(
preds: Tensor,
target: Tensor,
sample_weights: Optional[Sequence] = None,
pos_label: int = 1,
) -> Tuple[Tensor, Tensor, Tensor]:
"""adapted from https://github.com/scikit-learn/scikit- learn/blob/master/sklearn/metrics/_ranking.py."""
if sample_weights is not None and not isinstance(sample_weights, Tensor):
sample_weights = tensor(sample_weights, device=preds.device, dtype=torch.float)
# remove class dimension if necessary
if preds.ndim > target.ndim:
preds = preds[:, 0]
desc_score_indices = torch.argsort(preds, descending=True)
preds = preds[desc_score_indices]
target = target[desc_score_indices]
if sample_weights is not None:
weight = sample_weights[desc_score_indices]
else:
weight = 1.0
# pred typically has many tied values. Here we extract
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0]
threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1)
target = (target == pos_label).to(torch.long)
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
if sample_weights is not None:
# express fps as a cumsum to ensure fps is increasing even in
# the presence of floating point errors
fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs]
else:
fps = 1 + threshold_idxs - tps
return fps, tps, preds[threshold_idxs]
def _precision_recall_curve_update(
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
) -> Tuple[Tensor, Tensor, int, Optional[int]]:
"""Updates and returns variables required to compute the precision-recall pairs for different thresholds.
Args:
preds: Predicted tensor
target: Ground truth tensor
num_classes: integer with number of classes for multi-label and multiclass problems.
Should be set to ``None`` for binary problems.
pos_label: integer determining the positive class. Default is ``None``
which for binary problem is translated to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
"""
if len(preds.shape) == len(target.shape):
if pos_label is None:
pos_label = 1
if num_classes is not None and num_classes != 1:
# multilabel problem
if num_classes != preds.shape[1]:
raise ValueError(
f"Argument `num_classes` was set to {num_classes} in"
f" metric `precision_recall_curve` but detected {preds.shape[1]}"
" number of classes from predictions"
)
preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1)
target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1)
else:
# binary problem
preds = preds.flatten()
target = target.flatten()
num_classes = 1
# multi class problem
elif len(preds.shape) == len(target.shape) + 1:
if pos_label is not None:
rank_zero_warn(
"Argument `pos_label` should be `None` when running"
f" multiclass precision recall curve. Got {pos_label}"
)
if num_classes != preds.shape[1]:
raise ValueError(
f"Argument `num_classes` was set to {num_classes} in"
f" metric `precision_recall_curve` but detected {preds.shape[1]}"
" number of classes from predictions"
)
preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1)
target = target.flatten()
else:
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")
return preds, target, num_classes, pos_label
def _precision_recall_curve_compute_single_class(
preds: Tensor,
target: Tensor,
pos_label: int,
sample_weights: Optional[Sequence] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Computes precision-recall pairs for single class inputs.
Args:
preds: Predicted tensor
target: Ground truth tensor
pos_label: integer determining the positive class.
sample_weights: sample weights for each data point
"""
fps, tps, thresholds = _binary_clf_curve(
preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label
)
precision = tps / (tps + fps)
recall = tps / tps[-1]
# stop when full recall attained and reverse the outputs so recall is decreasing
last_ind = torch.where(tps == tps[-1])[0][0]
sl = slice(0, last_ind.item() + 1)
# need to call reversed explicitly, since including that to slice would
# introduce negative strides that are not yet supported in pytorch
precision = torch.cat([reversed(precision[sl]), torch.ones(1, dtype=precision.dtype, device=precision.device)])
recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)])
thresholds = reversed(thresholds[sl]).detach().clone() # type: ignore
return precision, recall, thresholds
def _precision_recall_curve_compute_multi_class(
preds: Tensor,
target: Tensor,
num_classes: int,
sample_weights: Optional[Sequence] = None,
) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
"""Computes precision-recall pairs for multi class inputs.
Args:
preds: Predicted tensor
target: Ground truth tensor
num_classes: integer with number of classes for multi-label and multiclass problems.
Should be set to ``None`` for binary problems.
sample_weights: sample weights for each data point
"""
# Recursively call per class
precision, recall, thresholds = [], [], []
for cls in range(num_classes):
preds_cls = preds[:, cls]
prc_args = dict(
preds=preds_cls,
target=target,
num_classes=1,
pos_label=cls,
sample_weights=sample_weights,
)
if target.ndim > 1:
prc_args.update(
dict(
target=target[:, cls],
pos_label=1,
)
)
res = precision_recall_curve(**prc_args)
precision.append(res[0])
recall.append(res[1])
thresholds.append(res[2])
return precision, recall, thresholds
def _precision_recall_curve_compute(
preds: Tensor,
target: Tensor,
num_classes: int,
pos_label: Optional[int] = None,
sample_weights: Optional[Sequence] = None,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Computes precision-recall pairs based on the number of classes.
Args:
preds: Predicted tensor
target: Ground truth tensor
num_classes: integer with number of classes for multi-label and multiclass problems.
Should be set to ``None`` for binary problems.
pos_label: integer determining the positive class. Default is ``None``
which for binary problem is translated to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range ``[0,num_classes-1]``
sample_weights: sample weights for each data point
Example:
>>> # binary case
>>> preds = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 0])
>>> pos_label = 1
>>> preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, pos_label=pos_label)
>>> precision, recall, thresholds = _precision_recall_curve_compute(preds, target, num_classes, pos_label)
>>> precision
tensor([0.6667, 0.5000, 0.0000, 1.0000])
>>> recall
tensor([1.0000, 0.5000, 0.0000, 0.0000])
>>> thresholds
tensor([1, 2, 3])
>>> # multiclass case
>>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> num_classes = 5
>>> preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes)
>>> precision, recall, thresholds = _precision_recall_curve_compute(preds, target, num_classes)
>>> precision
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
>>> recall
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
>>> thresholds
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
"""
with torch.no_grad():
if num_classes == 1:
if pos_label is None:
pos_label = 1
return _precision_recall_curve_compute_single_class(preds, target, pos_label, sample_weights)
return _precision_recall_curve_compute_multi_class(preds, target, num_classes, sample_weights)
def precision_recall_curve(
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
sample_weights: Optional[Sequence] = None,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Computes precision-recall pairs for different thresholds.
Args:
preds: predictions from model (probabilities)
target: ground truth labels
num_classes: integer with number of classes for multi-label and multiclass problems.
Should be set to ``None`` for binary problems.
pos_label: integer determining the positive class. Default is ``None`` which for binary problem is translated
to 1. For multiclass problems this argument should not be set as we iteratively change it in the
range ``[0, num_classes-1]``
sample_weights: sample weights for each data point
Returns:
3-element tuple containing
precision:
tensor where element ``i`` is the precision of predictions with
``score >= thresholds[i]`` and the last element is 1.
If multiclass, this is a list of such tensors, one for each class.
recall:
tensor where element ``i`` is the recall of predictions with
``score >= thresholds[i]`` and the last element is 0.
If multiclass, this is a list of such tensors, one for each class.
thresholds:
Thresholds used for computing precision/recall scores
Raises:
ValueError:
If ``preds`` and ``target`` don't have the same number of dimensions,
or one additional dimension for ``preds``.
ValueError:
If the number of classes deduced from ``preds`` is not the same as the ``num_classes`` provided.
Example (binary case):
>>> from torchmetrics.functional import precision_recall_curve
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 0])
>>> precision, recall, thresholds = precision_recall_curve(pred, target, pos_label=1)
>>> precision
tensor([0.6667, 0.5000, 0.0000, 1.0000])
>>> recall
tensor([1.0000, 0.5000, 0.0000, 0.0000])
>>> thresholds
tensor([1, 2, 3])
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> precision, recall, thresholds = precision_recall_curve(pred, target, num_classes=5)
>>> precision
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
>>> recall
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
>>> thresholds
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
"""
preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label)
return _precision_recall_curve_compute(preds, target, num_classes, pos_label, sample_weights)