-
Notifications
You must be signed in to change notification settings - Fork 388
/
average_precision.py
169 lines (146 loc) · 6.92 KB
/
average_precision.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor
from torchmetrics.functional.classification.precision_recall_curve import (
_precision_recall_curve_compute,
_precision_recall_curve_update,
)
def _average_precision_update(
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
) -> Tuple[Tensor, Tensor, int, Optional[int]]:
return _precision_recall_curve_update(preds, target, num_classes, pos_label)
def _average_precision_compute(
preds: Tensor,
target: Tensor,
num_classes: int,
pos_label: Optional[int] = None,
sample_weights: Optional[Sequence] = None,
) -> Union[List[Tensor], Tensor]:
"""Computes the average precision score.
Args:
preds: predictions from model (logits or probabilities)
target: ground truth values
num_classes: integer with number of classes.
pos_label: integer determining the positive class. Default is ``None``
which for binary problem is translate to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
sample_weights: sample weights for each data point
Example:
>>> # binary case
>>> preds = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
>>> pos_label = 1
>>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, pos_label=pos_label)
>>> _average_precision_compute(preds, target, num_classes, pos_label)
tensor(1.)
>>> # multiclass case
>>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> num_classes = 5
>>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes)
>>> _average_precision_compute(preds, target, num_classes)
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]
"""
# todo: `sample_weights` is unused
precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label)
return _average_precision_compute_with_precision_recall(precision, recall, num_classes)
def _average_precision_compute_with_precision_recall(
precision: Tensor,
recall: Tensor,
num_classes: int,
) -> Union[List[Tensor], Tensor]:
"""Computes the average precision score from precision and recall.
Args:
precision: precision values
recall: recall values
num_classes: integer with number of classes. Not nessesary to provide
for binary problems.
Example:
>>> # binary case
>>> preds = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
>>> pos_label = 1
>>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, pos_label=pos_label)
>>> precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label)
>>> _average_precision_compute_with_precision_recall(precision, recall, num_classes)
tensor(1.)
>>> # multiclass case
>>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> num_classes = 5
>>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes)
>>> precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes)
>>> _average_precision_compute_with_precision_recall(precision, recall, num_classes)
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]
"""
# Return the step function integral
# The following works because the last entry of precision is
# guaranteed to be 1, as returned by precision_recall_curve
if num_classes == 1:
return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1])
res = []
for p, r in zip(precision, recall):
res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1]))
return res
def average_precision(
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
sample_weights: Optional[Sequence] = None,
) -> Union[List[Tensor], Tensor]:
"""Computes the average precision score.
Args:
preds: predictions from model (logits or probabilities)
target: ground truth values
num_classes: integer with number of classes. Not nessesary to provide
for binary problems.
pos_label: integer determining the positive class. Default is ``None``
which for binary problem is translate to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
sample_weights: sample weights for each data point
Returns:
tensor with average precision. If multiclass will return list
of such tensors, one for each class
Example (binary case):
>>> from torchmetrics.functional import average_precision
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
>>> average_precision(pred, target, pos_label=1)
tensor(1.)
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision(pred, target, num_classes=5)
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]
"""
# fixme: `sample_weights` is unused
preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes, pos_label)
return _average_precision_compute(preds, target, num_classes, pos_label, sample_weights)