-
Notifications
You must be signed in to change notification settings - Fork 387
/
avg_precision.py
147 lines (126 loc) · 6.13 KB
/
avg_precision.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Union
import torch
from torch import Tensor
from torchmetrics.functional.classification.average_precision import (
_average_precision_compute,
_average_precision_update,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat
class AveragePrecision(Metric):
"""Computes the average precision score, which summarises the precision recall curve into one number. Works for
both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one-
vs-the-rest approach.
Forward accepts
- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor
with probabilities, where C is the number of classes.
- ``target`` (long tensor): ``(N, ...)`` with integer labels
Args:
num_classes: integer with number of classes. Not nessesary to provide
for binary problems.
pos_label: integer determining the positive class. Default is ``None``
which for binary problem is translate to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
average:
defines the reduction that is applied in the case of multiclass and multilabel input.
Should be one of the following:
- ``'macro'`` [default]: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'micro'``: Calculate the metric globally, across all samples and classes. Cannot be
used with multiclass input.
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support.
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Example (binary case):
>>> from torchmetrics import AveragePrecision
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
>>> average_precision = AveragePrecision(pos_label=1)
>>> average_precision(pred, target)
tensor(1.)
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision = AveragePrecision(num_classes=5, average=None)
>>> average_precision(pred, target)
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]
"""
is_differentiable = False
preds: List[Tensor]
target: List[Tensor]
def __init__(
self,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = "macro",
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)
self.num_classes = num_classes
self.pos_label = pos_label
allowed_average = ("micro", "macro", "weighted", None)
if average not in allowed_average:
raise ValueError(f"Expected argument `average` to be one of {allowed_average}" f" but got {average}")
self.average = average
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
rank_zero_warn(
"Metric `AveragePrecision` will save all targets and predictions in buffer."
" For large datasets this may lead to large memory footprint."
)
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
preds, target, num_classes, pos_label = _average_precision_update(
preds, target, self.num_classes, self.pos_label, self.average
)
self.preds.append(preds)
self.target.append(target)
self.num_classes = num_classes
self.pos_label = pos_label
def compute(self) -> Union[Tensor, List[Tensor]]:
"""Compute the average precision score.
Returns:
tensor with average precision. If multiclass will return list
of such tensors, one for each class
"""
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
if not self.num_classes:
raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}")
return _average_precision_compute(preds, target, self.num_classes, self.pos_label, self.average)