Skip to content

Commit 0a65826

Browse files
Jeff YangBorda
Jeff Yang
andauthored
metrics: add BLEU (Lightning-AI#2535)
* metrics: added bleu score and test bleu * metrics: fixed type hints in bleu * bleu score moved to metrics/functional/nlp.py * refactor with torch.Tensor * Update test_sequence.py * refactor as Borda requests and nltk==3.2 * locked nltk==3.3 * nltk>=3.3, parametrized smooth argument for test * fix bleu_score example * added class BLEUScore metrics and test * added class BLEUScore metrics and test * update CHANGELOG * refactor with torchtext * torchtext changed to optional import * fix E501 line too long * add else: in optional import * remove pragma: no-cover * constants changed to CAPITALS * remove class in tests * List -> Sequence, conda -> pip, cast with tensor * add torchtext in test.txt * remove torchtext from test.txt * bump torchtext to 0.5.0 * bump torchtext to 0.5.0 * Apply suggestions from code review * ignore bleu score in doctest, renamed to nlp.py * back to implementation with torch * remove --ignore in CI test, proper reference format * apply justus comment Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
1 parent 5025be7 commit 0a65826

File tree

11 files changed

+287
-36
lines changed

11 files changed

+287
-36
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535))
1213

1314
### Changed
1415

docs/source/metrics.rst

+23-11
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Example::
2828

2929
.. warning::
3030
The metrics package is still in development! If we're missing a metric or you find a mistake, please send a PR!
31-
to a few metrics. Please feel free to create an issue/PR if you have a proposed metric or have found a bug.
31+
to a few metrics. Please feel free to create an issue/PR if you have a proposed metric or have found a bug.
3232

3333
----------------
3434

@@ -73,7 +73,7 @@ Here's an example showing how to implement a NumpyMetric
7373
class RMSE(NumpyMetric):
7474
def forward(self, x, y):
7575
return np.sqrt(np.mean(np.power(x-y, 2.0)))
76-
76+
7777

7878
.. autoclass:: pytorch_lightning.metrics.metric.NumpyMetric
7979
:noindex:
@@ -138,6 +138,12 @@ AUROC
138138
.. autoclass:: pytorch_lightning.metrics.classification.AUROC
139139
:noindex:
140140

141+
BLEUScore
142+
^^^^^^^^^
143+
144+
.. autoclass:: pytorch_lightning.metrics.nlp.BLEUScore
145+
:noindex:
146+
141147
ConfusionMatrix
142148
^^^^^^^^^^^^^^^
143149

@@ -283,6 +289,12 @@ average_precision (F)
283289
.. autofunction:: pytorch_lightning.metrics.functional.average_precision
284290
:noindex:
285291

292+
bleu_score (F)
293+
^^^^^^^^^^^^^^
294+
295+
.. autofunction:: pytorch_lightning.metrics.functional.bleu_score
296+
:noindex:
297+
286298
confusion_matrix (F)
287299
^^^^^^^^^^^^^^^^^^^^
288300

@@ -418,22 +430,22 @@ to_onehot (F)
418430

419431
Sklearn interface
420432
-----------------
421-
422-
Lightning supports `sklearns metrics module <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_
423-
as a backend for calculating metrics. Sklearns metrics are well tested and robust,
433+
434+
Lightning supports `sklearns metrics module <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_
435+
as a backend for calculating metrics. Sklearns metrics are well tested and robust,
424436
but requires conversion between pytorch and numpy thus may slow down your computations.
425437

426438
To use the sklearn backend of metrics simply import as
427439

428440
.. code-block:: python
429-
441+
430442
import pytorch_lightning.metrics.sklearns import plm
431443
metric = plm.Accuracy(normalize=True)
432444
val = metric(pred, target)
433-
434-
Each converted sklearn metric comes has the same interface as its
435-
original counterpart (e.g. accuracy takes the additional `normalize` keyword).
436-
Like the native Lightning metrics, these converted sklearn metrics also come
445+
446+
Each converted sklearn metric comes has the same interface as its
447+
original counterpart (e.g. accuracy takes the additional `normalize` keyword).
448+
Like the native Lightning metrics, these converted sklearn metrics also come
437449
with built-in distributed (ddp) support.
438450

439451
SklearnMetric (sk)
@@ -460,7 +472,7 @@ AveragePrecision (sk)
460472
.. autofunction:: pytorch_lightning.metrics.sklearns.AveragePrecision
461473
:noindex:
462474

463-
475+
464476
ConfusionMatrix (sk)
465477
^^^^^^^^^^^^^^^^^^^^
466478

environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies:
3030
- twine==1.13.0
3131
- pillow<7.0.0
3232
- scikit-image
33+
- nltk>=3.3
3334

3435
# Optional
3536
- scipy>=0.13.3

pytorch_lightning/metrics/__init__.py

+25-23
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
MSE,
66
PSNR,
77
RMSE,
8-
RMSLE
8+
RMSLE,
99
)
1010
from pytorch_lightning.metrics.classification import (
1111
Accuracy,
@@ -28,30 +28,32 @@
2828
PrecisionRecallCurve,
2929
SklearnMetric,
3030
)
31+
from pytorch_lightning.metrics.nlp import BLEUScore
3132

3233
__classification_metrics = [
33-
'AUC',
34-
'AUROC',
35-
'Accuracy',
36-
'AveragePrecision',
37-
'ConfusionMatrix',
38-
'DiceCoefficient',
39-
'F1',
40-
'FBeta',
41-
'MulticlassPrecisionRecall',
42-
'MulticlassROC',
43-
'Precision',
44-
'PrecisionRecall',
45-
'PrecisionRecallCurve',
46-
'ROC',
47-
'Recall',
48-
'IoU',
34+
"AUC",
35+
"AUROC",
36+
"Accuracy",
37+
"AveragePrecision",
38+
"ConfusionMatrix",
39+
"DiceCoefficient",
40+
"F1",
41+
"FBeta",
42+
"MulticlassPrecisionRecall",
43+
"MulticlassROC",
44+
"Precision",
45+
"PrecisionRecall",
46+
"PrecisionRecallCurve",
47+
"ROC",
48+
"Recall",
49+
"IoU",
4950
]
5051
__regression_metrics = [
51-
'MAE',
52-
'MSE',
53-
'PSNR',
54-
'RMSE',
55-
'RMSLE'
52+
"MAE",
53+
"MSE",
54+
"PSNR",
55+
"RMSE",
56+
"RMSLE",
5657
]
57-
__all__ = __regression_metrics + __classification_metrics + ['SklearnMetric']
58+
__sequence_metrics = ["BLEUScore"]
59+
__all__ = __regression_metrics + __classification_metrics + ["SklearnMetric"] + __sequence_metrics

pytorch_lightning/metrics/functional/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@
2525
mse,
2626
psnr,
2727
rmse,
28-
rmsle
28+
rmsle,
2929
)
30+
from pytorch_lightning.metrics.functional.nlp import bleu_score
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# referenced from
2+
# Library Name: torchtext
3+
# Authors: torchtext authors and @sluks
4+
# Date: 2020-07-18
5+
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
6+
from typing import Sequence, List
7+
from collections import Counter
8+
9+
import torch
10+
11+
12+
def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
13+
"""Counting how many times each word appears in a given text with ngram
14+
15+
Args:
16+
ngram_input_list: A list of translated text or reference texts
17+
n_gram: gram value ranged 1 to 4
18+
19+
Return:
20+
ngram_counter: a collections.Counter object of ngram
21+
"""
22+
23+
ngram_counter = Counter()
24+
25+
for i in range(1, n_gram + 1):
26+
for j in range(len(ngram_input_list) - i + 1):
27+
ngram_key = tuple(ngram_input_list[j : i + j])
28+
ngram_counter[ngram_key] += 1
29+
30+
return ngram_counter
31+
32+
33+
def bleu_score(
34+
translate_corpus: Sequence[str], reference_corpus: Sequence[str], n_gram: int = 4, smooth: bool = False
35+
) -> torch.Tensor:
36+
"""Calculate BLEU score of machine translated text with one or more references.
37+
38+
Args:
39+
translate_corpus: An iterable of machine translated corpus
40+
reference_corpus: An iterable of iterables of reference corpus
41+
n_gram: Gram value ranged from 1 to 4 (Default 4)
42+
smooth: Whether or not to apply smoothing – Lin et al. 2004
43+
44+
Return:
45+
A Tensor with BLEU Score
46+
47+
Example:
48+
49+
>>> translate_corpus = ['the cat is on the mat'.split()]
50+
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
51+
>>> bleu_score(translate_corpus, reference_corpus)
52+
tensor(0.7598)
53+
"""
54+
55+
assert len(translate_corpus) == len(reference_corpus)
56+
numerator = torch.zeros(n_gram)
57+
denominator = torch.zeros(n_gram)
58+
precision_scores = torch.zeros(n_gram)
59+
c = 0.0
60+
r = 0.0
61+
for (translation, references) in zip(translate_corpus, reference_corpus):
62+
c += len(translation)
63+
ref_len_list = [len(ref) for ref in references]
64+
ref_len_diff = [abs(len(translation) - x) for x in ref_len_list]
65+
r += ref_len_list[ref_len_diff.index(min(ref_len_diff))]
66+
translation_counter = _count_ngram(translation, n_gram)
67+
reference_counter = Counter()
68+
for ref in references:
69+
reference_counter |= _count_ngram(ref, n_gram)
70+
71+
ngram_counter_clip = translation_counter & reference_counter
72+
for counter_clip in ngram_counter_clip:
73+
numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]
74+
75+
for counter in translation_counter:
76+
denominator[len(counter) - 1] += translation_counter[counter]
77+
78+
trans_len = torch.tensor(c)
79+
ref_len = torch.tensor(r)
80+
if min(numerator) == 0.0:
81+
return torch.tensor(0.0)
82+
83+
if smooth:
84+
precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram))
85+
else:
86+
precision_scores = numerator / denominator
87+
log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores)
88+
geometric_mean = torch.exp(torch.sum(log_precision_scores))
89+
brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len))
90+
bleu = brevity_penalty * geometric_mean
91+
92+
return bleu

pytorch_lightning/metrics/nlp.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
3+
from pytorch_lightning.metrics.functional.nlp import bleu_score
4+
from pytorch_lightning.metrics.metric import Metric
5+
6+
7+
class BLEUScore(Metric):
8+
"""
9+
Calculate BLEU score of machine translated text with one or more references.
10+
11+
Example:
12+
13+
>>> translate_corpus = ['the cat is on the mat'.split()]
14+
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
15+
>>> metric = BLEUScore()
16+
>>> metric(translate_corpus, reference_corpus)
17+
tensor(0.7598)
18+
"""
19+
20+
def __init__(self, n_gram: int = 4, smooth: bool = False):
21+
"""
22+
Args:
23+
n_gram: Gram value ranged from 1 to 4 (Default 4)
24+
smooth: Whether or not to apply smoothing – Lin et al. 2004
25+
"""
26+
super().__init__(name="bleu")
27+
self.n_gram = n_gram
28+
self.smooth = smooth
29+
30+
def forward(self, translate_corpus: list, reference_corpus: list) -> torch.Tensor:
31+
"""
32+
Actual metric computation
33+
34+
Args:
35+
translate_corpus: An iterable of machine translated corpus
36+
reference_corpus: An iterable of iterables of reference corpus
37+
38+
Return:
39+
torch.Tensor: BLEU Score
40+
"""
41+
return bleu_score(
42+
translate_corpus=translate_corpus,
43+
reference_corpus=reference_corpus,
44+
n_gram=self.n_gram,
45+
smooth=self.smooth,
46+
).to(self.device, self.dtype)

requirements/extra.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ horovod>=0.19.1
1111
omegaconf>=2.0.0
1212
# scipy>=0.13.3
1313
scikit-learn>=0.20.0
14-
torchtext>=0.3.1
14+
torchtext>=0.3.1

requirements/test.txt

+1
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ black==19.10b0
1212
pre-commit>=1.0
1313

1414
cloudpickle>=1.2
15+
nltk>=3.3

tests/metrics/functional/test_nlp.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import pytest
2+
import torch
3+
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu
4+
5+
from pytorch_lightning.metrics.functional.nlp import bleu_score
6+
7+
# example taken from
8+
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu
9+
HYPOTHESIS1 = tuple(
10+
"It is a guide to action which ensures that the military always obeys the commands of the party".split()
11+
)
12+
REFERENCE1 = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split())
13+
REFERENCE2 = tuple(
14+
"It is a guiding principle which makes the military forces always being under the command of the Party".split()
15+
)
16+
REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split())
17+
18+
19+
# example taken from
20+
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu
21+
HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split()
22+
HYP2 = "he read the book because he was interested in world history".split()
23+
24+
REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split()
25+
REF1B = "It is a guiding principle which makes the military force always being under the command of the Party".split()
26+
REF1C = "It is the practical guide for the army always to heed the directions of the party".split()
27+
REF2A = "he was interested in world history because he read the book".split()
28+
29+
LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]]
30+
HYPOTHESES = [HYP1, HYP2]
31+
32+
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction
33+
smooth_func = SmoothingFunction().method2
34+
35+
36+
@pytest.mark.parametrize(
37+
["weights", "n_gram", "smooth_func", "smooth"],
38+
[
39+
pytest.param([1], 1, None, False),
40+
pytest.param([0.5, 0.5], 2, smooth_func, True),
41+
pytest.param([0.333333, 0.333333, 0.333333], 3, None, False),
42+
pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True),
43+
],
44+
)
45+
def test_bleu_score(weights, n_gram, smooth_func, smooth):
46+
nltk_output = sentence_bleu(
47+
[REFERENCE1, REFERENCE2, REFERENCE3], HYPOTHESIS1, weights=weights, smoothing_function=smooth_func
48+
)
49+
pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth)
50+
assert torch.allclose(pl_output, torch.tensor(nltk_output))
51+
52+
nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func)
53+
pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth)
54+
assert torch.allclose(pl_output, torch.tensor(nltk_output))
55+
56+
57+
def test_bleu_empty():
58+
hyp = [[]]
59+
ref = [[[]]]
60+
assert bleu_score(hyp, ref) == torch.tensor(0.0)
61+
62+
63+
def test_no_4_gram():
64+
hyps = [["My", "full", "pytorch-lightning"]]
65+
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
66+
assert bleu_score(hyps, refs) == torch.tensor(0.0)

0 commit comments

Comments
 (0)