/
auroc.py
269 lines (234 loc) · 11.9 KB
/
auroc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Optional, Sequence, Tuple
import torch
from torch import Tensor, tensor
from torchmetrics.functional.classification.auc import _auc_compute_without_check
from torchmetrics.functional.classification.roc import roc
from torchmetrics.utilities.checks import _input_format_classification
from torchmetrics.utilities.data import _bincount
from torchmetrics.utilities.enums import AverageMethod, DataType
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6
def _auroc_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, DataType]:
"""Updates and returns variables required to compute Area Under the Receiver Operating Characteristic Curve.
Validates the inputs and returns the mode of the inputs.
Args:
preds: Predicted tensor
target: Ground truth tensor
"""
# use _input_format_classification for validating the input and get the mode of data
_, _, mode = _input_format_classification(preds, target)
if mode == "multi class multi dim":
n_classes = preds.shape[1]
preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)
target = target.flatten()
if mode == "multi-label" and preds.ndim > 2:
n_classes = preds.shape[1]
preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)
target = target.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)
return preds, target, mode
def _auroc_compute(
preds: Tensor,
target: Tensor,
mode: DataType,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = "macro",
max_fpr: Optional[float] = None,
sample_weights: Optional[Sequence] = None,
) -> Tensor:
"""Computes Area Under the Receiver Operating Characteristic Curve.
Args:
preds: predictions from model (logits or probabilities)
target: Ground truth labels
mode: 'multi class multi dim' or 'multi-label' or 'binary'
num_classes: integer with number of classes for multi-label and multiclass problems.
Should be set to ``None`` for binary problems
pos_label: integer determining the positive class.
Should be set to ``None`` for binary problems
average: Defines the reduction that is applied to the output:
max_fpr: If not ``None``, calculates standardized partial AUC over the
range ``[0, max_fpr]``. Should be a float between 0 and 1.
sample_weights: sample weights for each data point
Example:
>>> # binary case
>>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
>>> target = torch.tensor([0, 0, 1, 1, 1])
>>> preds, target, mode = _auroc_update(preds, target)
>>> _auroc_compute(preds, target, mode, pos_label=1)
tensor(0.5000)
>>> # multiclass case
>>> preds = torch.tensor([[0.90, 0.05, 0.05],
... [0.05, 0.90, 0.05],
... [0.05, 0.05, 0.90],
... [0.85, 0.05, 0.10],
... [0.10, 0.10, 0.80]])
>>> target = torch.tensor([0, 1, 1, 2, 2])
>>> preds, target, mode = _auroc_update(preds, target)
>>> _auroc_compute(preds, target, mode, num_classes=3)
tensor(0.7778)
"""
# binary mode override num_classes
if mode == DataType.BINARY:
num_classes = 1
# check max_fpr parameter
if max_fpr is not None:
if not isinstance(max_fpr, float) and 0 < max_fpr <= 1:
raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")
if _TORCH_LOWER_1_6:
raise RuntimeError(
"`max_fpr` argument requires `torch.bucketize` which" " is not available below PyTorch version 1.6"
)
# max_fpr parameter is only support for binary
if mode != DataType.BINARY:
raise ValueError(
"Partial AUC computation not available in multilabel/multiclass setting,"
f" 'max_fpr' must be set to `None`, received `{max_fpr}`."
)
# calculate fpr, tpr
if mode == DataType.MULTILABEL:
if average == AverageMethod.MICRO:
fpr, tpr, _ = roc(preds.flatten(), target.flatten(), 1, pos_label, sample_weights)
elif num_classes:
# for multilabel we iteratively evaluate roc in a binary fashion
output = [
roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights)
for i in range(num_classes)
]
fpr = [o[0] for o in output]
tpr = [o[1] for o in output]
else:
raise ValueError("Detected input to be `multilabel` but you did not provide `num_classes` argument")
else:
if mode != DataType.BINARY:
if num_classes is None:
raise ValueError("Detected input to `multiclass` but you did not provide `num_classes` argument")
if average == AverageMethod.WEIGHTED and len(torch.unique(target)) < num_classes:
# If one or more classes has 0 observations, we should exclude them, as its weight will be 0
target_bool_mat = torch.zeros((len(target), num_classes), dtype=bool, device=target.device)
target_bool_mat[torch.arange(len(target)), target.long()] = 1
class_observed = target_bool_mat.sum(axis=0) > 0
for c in range(num_classes):
if not class_observed[c]:
warnings.warn(f"Class {c} had 0 observations, omitted from AUROC calculation", UserWarning)
preds = preds[:, class_observed]
target = target_bool_mat[:, class_observed]
target = torch.where(target)[1]
num_classes = class_observed.sum()
if num_classes == 1:
raise ValueError("Found 1 non-empty class in `multiclass` AUROC calculation")
fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights)
# calculate standard roc auc score
if max_fpr is None or max_fpr == 1:
if mode == DataType.MULTILABEL and average == AverageMethod.MICRO:
pass
elif num_classes != 1:
# calculate auc scores per class
auc_scores = [_auc_compute_without_check(x, y, 1.0) for x, y in zip(fpr, tpr)]
# calculate average
if average == AverageMethod.NONE:
return tensor(auc_scores)
if average == AverageMethod.MACRO:
return torch.mean(torch.stack(auc_scores))
if average == AverageMethod.WEIGHTED:
if mode == DataType.MULTILABEL:
support = torch.sum(target, dim=0)
else:
support = _bincount(target.flatten(), minlength=num_classes)
return torch.sum(torch.stack(auc_scores) * support / support.sum())
allowed_average = (AverageMethod.NONE.value, AverageMethod.MACRO.value, AverageMethod.WEIGHTED.value)
raise ValueError(
f"Argument `average` expected to be one of the following: {allowed_average} but got {average}"
)
return _auc_compute_without_check(fpr, tpr, 1.0)
_device = fpr.device if isinstance(fpr, Tensor) else fpr[0].device
max_area: Tensor = tensor(max_fpr, device=_device)
# Add a single point at max_fpr and interpolate its tpr value
stop = torch.bucketize(max_area, fpr, out_int32=True, right=True)
weight = (max_area - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1])
interp_tpr: Tensor = torch.lerp(tpr[stop - 1], tpr[stop], weight)
tpr = torch.cat([tpr[:stop], interp_tpr.view(1)])
fpr = torch.cat([fpr[:stop], max_area.view(1)])
# Compute partial AUC
partial_auc = _auc_compute_without_check(fpr, tpr, 1.0)
# McClish correction: standardize result to be 0.5 if non-discriminant and 1 if maximal
min_area: Tensor = 0.5 * max_area**2
return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))
def auroc(
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = "macro",
max_fpr: Optional[float] = None,
sample_weights: Optional[Sequence] = None,
) -> Tensor:
"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_)
For non-binary input, if the ``preds`` and ``target`` tensor have the same
size the input will be interpretated as multilabel and if ``preds`` have one
dimension more than the ``target`` tensor the input will be interpretated as
multiclass.
.. note::
If either the positive class or negative class is completly missing in the target tensor,
the auroc score is meaningless in this case and a score of 0 will be returned together
with a warning.
Args:
preds: predictions from model (logits or probabilities)
target: Ground truth labels
num_classes: integer with number of classes for multi-label and multiclass problems.
Should be set to ``None`` for binary problems
pos_label: integer determining the positive class. Default is ``None``
which for binary problem is translate to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
average:
- ``'micro'`` computes metric globally. Only works for multilabel problems
- ``'macro'`` computes metric for each class and uniformly averages them
- ``'weighted'`` computes metric for each class and does a weighted-average,
where each class is weighted by their support (accounts for class imbalance)
- ``None`` computes and returns the metric per class
max_fpr:
If not ``None``, calculates standardized partial AUC over the
range ``[0, max_fpr]``. Should be a float between 0 and 1.
sample_weights: sample weights for each data point
Raises:
ValueError:
If ``max_fpr`` is not a ``float`` in the range ``(0, 1]``.
RuntimeError:
If ``PyTorch version`` is below 1.6 since max_fpr requires ``torch.bucketize``
which is not available below 1.6.
ValueError:
If ``max_fpr`` is not set to ``None`` and the mode is ``not binary``
since partial AUC computation is not available in multilabel/multiclass.
ValueError:
If ``average`` is none of ``None``, ``"macro"`` or ``"weighted"``.
Example (binary case):
>>> from torchmetrics.functional import auroc
>>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
>>> target = torch.tensor([0, 0, 1, 1, 1])
>>> auroc(preds, target, pos_label=1)
tensor(0.5000)
Example (multiclass case):
>>> preds = torch.tensor([[0.90, 0.05, 0.05],
... [0.05, 0.90, 0.05],
... [0.05, 0.05, 0.90],
... [0.85, 0.05, 0.10],
... [0.10, 0.10, 0.80]])
>>> target = torch.tensor([0, 1, 1, 2, 2])
>>> auroc(preds, target, num_classes=3)
tensor(0.7778)
"""
preds, target, mode = _auroc_update(preds, target)
return _auroc_compute(preds, target, mode, num_classes, pos_label, average, max_fpr, sample_weights)