This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
bleu.py
166 lines (133 loc) · 6.46 KB
/
bleu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from collections import Counter
import math
from typing import Iterable, Tuple, Dict, Set, Optional
import torch
import torch.distributed as dist
from allennlp.common.util import is_distributed
from allennlp.training.metrics.metric import Metric
from allennlp.nn.util import dist_reduce_sum
@Metric.register("bleu")
class BLEU(Metric):
"""
Bilingual Evaluation Understudy (BLEU).
BLEU is a common metric used for evaluating the quality of machine translations
against a set of reference translations. See
[Papineni et. al., "BLEU: a method for automatic evaluation of machine translation", 2002][1].
# Parameters
ngram_weights : `Iterable[float]`, optional (default = `(0.25, 0.25, 0.25, 0.25)`)
Weights to assign to scores for each ngram size.
exclude_indices : `Set[int]`, optional (default = `None`)
Indices to exclude when calculating ngrams. This should usually include
the indices of the start, end, and pad tokens.
# Notes
We chose to implement this from scratch instead of wrapping an existing implementation
(such as `nltk.translate.bleu_score`) for a two reasons. First, so that we could
pass tensors directly to this metric instead of first converting the tensors to lists of strings.
And second, because functions like `nltk.translate.bleu_score.corpus_bleu()` are
meant to be called once over the entire corpus, whereas it is more efficient
in our use case to update the running precision counts every batch.
This implementation only considers a reference set of size 1, i.e. a single
gold target sequence for each predicted sequence.
[1]: https://www.semanticscholar.org/paper/8ff93cfd37dced279134c9d642337a2085b31f59/
"""
def __init__(
self,
ngram_weights: Iterable[float] = (0.25, 0.25, 0.25, 0.25),
exclude_indices: Set[int] = None,
) -> None:
self._ngram_weights = ngram_weights
self._exclude_indices = exclude_indices or set()
self._precision_matches: Dict[int, int] = Counter()
self._precision_totals: Dict[int, int] = Counter()
self._prediction_lengths = 0
self._reference_lengths = 0
def reset(self) -> None:
self._precision_matches = Counter()
self._precision_totals = Counter()
self._prediction_lengths = 0
self._reference_lengths = 0
def _get_modified_precision_counts(
self,
predicted_tokens: torch.LongTensor,
reference_tokens: torch.LongTensor,
ngram_size: int,
) -> Tuple[int, int]:
"""
Compare the predicted tokens to the reference (gold) tokens at the desired
ngram size and calculate the numerator and denominator for a modified
form of precision.
The numerator is the number of ngrams in the predicted sentences that match
with an ngram in the corresponding reference sentence, clipped by the total
count of that ngram in the reference sentence. The denominator is just
the total count of predicted ngrams.
"""
clipped_matches = 0
total_predicted = 0
from allennlp.training.util import ngrams
for predicted_row, reference_row in zip(predicted_tokens, reference_tokens):
predicted_ngram_counts = ngrams(predicted_row, ngram_size, self._exclude_indices)
reference_ngram_counts = ngrams(reference_row, ngram_size, self._exclude_indices)
for ngram, count in predicted_ngram_counts.items():
clipped_matches += min(count, reference_ngram_counts[ngram])
total_predicted += count
return clipped_matches, total_predicted
def _get_brevity_penalty(self) -> float:
if self._prediction_lengths > self._reference_lengths:
return 1.0
if self._reference_lengths == 0 or self._prediction_lengths == 0:
return 0.0
return math.exp(1.0 - self._reference_lengths / self._prediction_lengths)
def __call__(
self, # type: ignore
predictions: torch.LongTensor,
gold_targets: torch.LongTensor,
mask: Optional[torch.BoolTensor] = None,
) -> None:
"""
Update precision counts.
# Parameters
predictions : `torch.LongTensor`, required
Batched predicted tokens of shape `(batch_size, max_sequence_length)`.
references : `torch.LongTensor`, required
Batched reference (gold) translations with shape `(batch_size, max_gold_sequence_length)`.
# Returns
None
"""
if mask is not None:
raise NotImplementedError("This metric does not support a mask.")
predictions, gold_targets = self.detach_tensors(predictions, gold_targets)
if is_distributed():
world_size = dist.get_world_size()
else:
world_size = 1
for ngram_size, _ in enumerate(self._ngram_weights, start=1):
precision_matches, precision_totals = self._get_modified_precision_counts(
predictions, gold_targets, ngram_size
)
self._precision_matches[ngram_size] += dist_reduce_sum(precision_matches) / world_size
self._precision_totals[ngram_size] += dist_reduce_sum(precision_totals) / world_size
if not self._exclude_indices:
_prediction_lengths = predictions.size(0) * predictions.size(1)
_reference_lengths = gold_targets.size(0) * gold_targets.size(1)
else:
from allennlp.training.util import get_valid_tokens_mask
valid_predictions_mask = get_valid_tokens_mask(predictions, self._exclude_indices)
valid_gold_targets_mask = get_valid_tokens_mask(gold_targets, self._exclude_indices)
_prediction_lengths = valid_predictions_mask.sum().item()
_reference_lengths = valid_gold_targets_mask.sum().item()
self._prediction_lengths += dist_reduce_sum(_prediction_lengths)
self._reference_lengths += dist_reduce_sum(_reference_lengths)
def get_metric(self, reset: bool = False) -> Dict[str, float]:
brevity_penalty = self._get_brevity_penalty()
ngram_scores = (
weight
* (
math.log(self._precision_matches[n] + 1e-13)
- math.log(self._precision_totals[n] + 1e-13)
)
for n, weight in enumerate(self._ngram_weights, start=1)
)
bleu = brevity_penalty * math.exp(sum(ngram_scores))
if reset:
self.reset()
return {"BLEU": bleu}