-
Notifications
You must be signed in to change notification settings - Fork 388
/
precision_recall.py
322 lines (268 loc) · 14.8 KB
/
precision_recall.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
import torch
from torch import Tensor
from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute
class Precision(StatScores):
r"""
Computes `Precision`_:
.. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}
Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and
false positives respecitively. With the use of ``top_k`` parameter, this metric can
generalize to Precision@K.
The reduction method (how the precision scores are aggregated) is controlled by the
``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`references/modules:input types`.
Args:
num_classes:
Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
threshold:
Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case
of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
average:
Defines the reduction that is applied. Should be one of the following:
- ``'micro'`` [default]: Calculate the metric globally, across all samples and classes.
- ``'macro'``: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.
- ``'samples'``: Calculate the metric for each sample, and average the metrics
across samples (with equal weights for each sample).
.. note:: What is considered a sample in the multi-dimensional multi-class case
depends on the value of ``mdmc_average``.
mdmc_average:
Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
``average`` parameter). Should be one of the following:
- ``None`` [default]: Should be left unchanged if your data is not multi-dimensional
multi-class.
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`references/modules:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`references/modules:input types`)
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
or ``'none'``, the score for the ignored class will be returned as ``nan``.
top_k:
Number of highest probability or logit score predictions considered to find the correct label,
relevant only for (multi-dimensional) multi-class inputs. The
default value (``None``) will be interpreted as 1 for these inputs.
Should be left at default (``None``) for all other types of inputs.
multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <references/modules:using the multiclass parameter>`
for a more detailed explanation and examples.
compute_on_step:
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
process_group:
Specify the process group on which synchronization is called.
default: ``None`` (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather.
Raises:
ValueError:
If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``.
Example:
>>> from torchmetrics import Precision
>>> preds = torch.tensor([2, 0, 2, 1])
>>> target = torch.tensor([1, 1, 2, 0])
>>> precision = Precision(average='macro', num_classes=3)
>>> precision(preds, target)
tensor(0.1667)
>>> precision = Precision(average='micro')
>>> precision(preds, target)
tensor(0.2500)
"""
is_differentiable = False
higher_is_better = True
def __init__(
self,
num_classes: Optional[int] = None,
threshold: float = 0.5,
average: str = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
) -> None:
allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")
super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
multiclass=multiclass,
ignore_index=ignore_index,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.average = average
def compute(self) -> Tensor:
"""Computes the precision score based on inputs passed in to ``update`` previously.
Return:
The shape of the returned tensor depends on the ``average`` parameter
- If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned
- If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number
of classes
"""
tp, fp, _, fn = self._get_final_stats()
return _precision_compute(tp, fp, fn, self.average, self.mdmc_reduce)
class Recall(StatScores):
r"""
Computes `Recall`_:
.. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}
Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and
false negatives respecitively. With the use of ``top_k`` parameter, this metric can
generalize to Recall@K.
The reduction method (how the recall scores are aggregated) is controlled by the
``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`references/modules:input types`.
Args:
num_classes:
Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
threshold:
Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case
of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
average:
Defines the reduction that is applied. Should be one of the following:
- ``'micro'`` [default]: Calculate the metric globally, across all samples and classes.
- ``'macro'``: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.
- ``'samples'``: Calculate the metric for each sample, and average the metrics
across samples (with equal weights for each sample).
.. note:: What is considered a sample in the multi-dimensional multi-class case
depends on the value of ``mdmc_average``.
mdmc_average:
Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
``average`` parameter). Should be one of the following:
- ``None`` [default]: Should be left unchanged if your data is not multi-dimensional
multi-class.
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`references/modules:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`references/modules:input types`)
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
or ``'none'``, the score for the ignored class will be returned as ``nan``.
top_k:
Number of highest probability or logit score predictions considered to find the correct label,
relevant only for (multi-dimensional) multi-class. The
default value (``None``) will be interpreted as 1 for these inputs.
Should be left at default (``None``) for all other types of inputs.
multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <references/modules:using the multiclass parameter>`
for a more detailed explanation and examples.
compute_on_step:
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
process_group:
Specify the process group on which synchronization is called.
default: ``None`` (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather.
Raises:
ValueError:
If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``.
Example:
>>> from torchmetrics import Recall
>>> preds = torch.tensor([2, 0, 2, 1])
>>> target = torch.tensor([1, 1, 2, 0])
>>> recall = Recall(average='macro', num_classes=3)
>>> recall(preds, target)
tensor(0.3333)
>>> recall = Recall(average='micro')
>>> recall(preds, target)
tensor(0.2500)
"""
is_differentiable = False
higher_is_better = True
def __init__(
self,
num_classes: Optional[int] = None,
threshold: float = 0.5,
average: str = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
) -> None:
allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")
super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
multiclass=multiclass,
ignore_index=ignore_index,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.average = average
def compute(self) -> Tensor:
"""Computes the recall score based on inputs passed in to ``update`` previously.
Return:
The shape of the returned tensor depends on the ``average`` parameter
- If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned
- If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number
of classes
"""
tp, fp, _, fn = self._get_final_stats()
return _recall_compute(tp, fp, fn, self.average, self.mdmc_reduce)