This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
categorical_accuracy.py
130 lines (111 loc) · 5.05 KB
/
categorical_accuracy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from typing import Optional
from overrides import overrides
import torch
import torch.distributed as dist
from allennlp.common.util import is_distributed
from allennlp.common.checks import ConfigurationError
from allennlp.training.metrics.metric import Metric
@Metric.register("categorical_accuracy")
class CategoricalAccuracy(Metric):
"""
Categorical Top-K accuracy. Assumes integer labels, with
each item to be classified having a single correct class.
Tie break enables equal distribution of scores among the
classes with same maximum predicted scores.
"""
supports_distributed = True
def __init__(self, top_k: int = 1, tie_break: bool = False) -> None:
if top_k > 1 and tie_break:
raise ConfigurationError(
"Tie break in Categorical Accuracy can be done only for maximum (top_k = 1)"
)
if top_k <= 0:
raise ConfigurationError("top_k passed to Categorical Accuracy must be > 0")
self._top_k = top_k
self._tie_break = tie_break
self.correct_count = 0.0
self.total_count = 0.0
def __call__(
self,
predictions: torch.Tensor,
gold_labels: torch.Tensor,
mask: Optional[torch.BoolTensor] = None,
):
"""
# Parameters
predictions : `torch.Tensor`, required.
A tensor of predictions of shape (batch_size, ..., num_classes).
gold_labels : `torch.Tensor`, required.
A tensor of integer class label of shape (batch_size, ...). It must be the same
shape as the `predictions` tensor without the `num_classes` dimension.
mask : `torch.BoolTensor`, optional (default = `None`).
A masking tensor the same size as `gold_labels`.
"""
predictions, gold_labels, mask = self.detach_tensors(predictions, gold_labels, mask)
# Some sanity checks.
num_classes = predictions.size(-1)
if gold_labels.dim() != predictions.dim() - 1:
raise ConfigurationError(
"gold_labels must have dimension == predictions.size() - 1 but "
"found tensor of shape: {}".format(predictions.size())
)
if (gold_labels >= num_classes).any():
raise ConfigurationError(
"A gold label passed to Categorical Accuracy contains an id >= {}, "
"the number of classes.".format(num_classes)
)
predictions = predictions.view((-1, num_classes))
gold_labels = gold_labels.view(-1).long()
if not self._tie_break:
# Top K indexes of the predictions (or fewer, if there aren't K of them).
# Special case topk == 1, because it's common and .max() is much faster than .topk().
if self._top_k == 1:
top_k = predictions.max(-1)[1].unsqueeze(-1)
else:
_, sorted_indices = predictions.sort(dim=-1, descending=True)
top_k = sorted_indices[..., : min(self._top_k, predictions.shape[-1])]
# This is of shape (batch_size, ..., top_k).
correct = top_k.eq(gold_labels.unsqueeze(-1)).float()
else:
# prediction is correct if gold label falls on any of the max scores. distribute score by tie_counts
max_predictions = predictions.max(-1)[0]
max_predictions_mask = predictions.eq(max_predictions.unsqueeze(-1))
# max_predictions_mask is (rows X num_classes) and gold_labels is (batch_size)
# ith entry in gold_labels points to index (0-num_classes) for ith row in max_predictions
# For each row check if index pointed by gold_label is was 1 or not (among max scored classes)
correct = max_predictions_mask[
torch.arange(gold_labels.numel(), device=gold_labels.device).long(), gold_labels
].float()
tie_counts = max_predictions_mask.sum(-1)
correct /= tie_counts.float()
correct.unsqueeze_(-1)
if mask is not None:
correct *= mask.view(-1, 1)
_total_count = mask.sum()
else:
_total_count = torch.tensor(gold_labels.numel())
_correct_count = correct.sum()
if is_distributed():
device = torch.device("cuda" if dist.get_backend() == "nccl" else "cpu")
_correct_count = _correct_count.to(device)
_total_count = _total_count.to(device)
dist.all_reduce(_correct_count, op=dist.ReduceOp.SUM)
dist.all_reduce(_total_count, op=dist.ReduceOp.SUM)
self.correct_count += _correct_count.item()
self.total_count += _total_count.item()
def get_metric(self, reset: bool = False):
"""
# Returns
The accumulated accuracy.
"""
if self.total_count > 1e-12:
accuracy = float(self.correct_count) / float(self.total_count)
else:
accuracy = 0.0
if reset:
self.reset()
return accuracy
@overrides
def reset(self):
self.correct_count = 0.0
self.total_count = 0.0