Skip to content

Commit

Permalink
Minor fix in BLEU (#109)
Browse files Browse the repository at this point in the history
* Resolve comments
  • Loading branch information
gpengzhi committed Jul 22, 2019
1 parent 859ec78 commit bf2d655
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 81 deletions.
12 changes: 6 additions & 6 deletions docs/code/evals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ BLEU
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: texar.evals.corpus_bleu_moses

:hidden:`compute_bleu`
:hidden:`corpus_bleu_transformer`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: texar.evals.compute_bleu
.. autofunction:: texar.evals.corpus_bleu_transformer

:hidden:`bleu_tokenize`
:hidden:`bleu_transformer_tokenize`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: texar.evals.bleu_tokenize
.. autofunction:: texar.evals.bleu_transformer_tokenize

:hidden:`bleu_wrapper`
:hidden:`file_bleu`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: texar.evals.bleu_wrapper
.. autofunction:: texar.evals.file_bleu


Accuracy
Expand Down
6 changes: 3 additions & 3 deletions examples/transformer/bleu_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from argparse import ArgumentParser

from texar.evals.bleu_tool import bleu_wrapper
from texar.evals.bleu_transformer import file_bleu


def main():
Expand All @@ -42,9 +42,9 @@ def main():
parser.add_argument("--reference", type=str)
args = parser.parse_args()

bleu = bleu_wrapper(args.reference, args.translation, case_sensitive=False)
bleu = file_bleu(args.reference, args.translation, case_sensitive=False)
print("BLEU_uncased = %6.2f" % bleu)
bleu = bleu_wrapper(args.reference, args.translation, case_sensitive=True)
bleu = file_bleu(args.reference, args.translation, case_sensitive=True)
print("BLEU_cased = %6.2f" % bleu)


Expand Down
2 changes: 1 addition & 1 deletion texar/evals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@

from texar.evals.bleu import *
from texar.evals.bleu_moses import *
from texar.evals.bleu_tool import *
from texar.evals.bleu_transformer import *
from texar.evals.metrics import *
56 changes: 0 additions & 56 deletions texar/evals/bleu_tool_test.py

This file was deleted.

57 changes: 42 additions & 15 deletions texar/evals/bleu_tool.py → texar/evals/bleu_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
`https://github.com/tensorflow/models/blob/master/official/transformer/compute_bleu.py`
"""

from typing import Counter, List, Tuple
from typing import Callable, Counter, List, Tuple

import re
import sys
Expand All @@ -25,10 +25,14 @@
import math
import numpy as np

from texar.evals.bleu import corpus_bleu
from texar.evals.bleu_moses import corpus_bleu_moses
from texar.utils.types import MaybeList

__all__ = [
"compute_bleu",
"bleu_tokenize",
"bleu_wrapper",
"corpus_bleu_transformer",
"bleu_transformer_tokenize",
"file_bleu",
]


Expand All @@ -54,10 +58,10 @@ def _get_ngrams(segment: List[str],
return ngram_counts


def compute_bleu(reference_corpus: List[List[str]],
translation_corpus: List[List[str]],
max_order: int = 4,
use_bp: bool = True) -> float:
def corpus_bleu_transformer(reference_corpus: List[List[str]],
translation_corpus: List[List[str]],
max_order: int = 4,
use_bp: bool = True) -> float:
r"""Computes BLEU score of translated segments against references.
This BLEU has been used in evaluating Transformer (Vaswani et al.)
Expand Down Expand Up @@ -155,9 +159,14 @@ def property_chars(prefix):
uregex = UnicodeRegex()


def bleu_tokenize(string: str) -> List[str]:
def bleu_transformer_tokenize(string: str) -> List[str]:
r"""Tokenize a string following the official BLEU implementation.
The BLEU scores from `multi-bleu.perl` depend on your `tokenizer`, which is
unlikely to be reproducible from your experiment or consistent across
different users. This function provides a standard tokenization following
`mteval-v14.pl`.
See
`https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983`.
In our case, the input string is expected to be just one line
Expand Down Expand Up @@ -185,14 +194,18 @@ def bleu_tokenize(string: str) -> List[str]:
return string.split()


def bleu_wrapper(ref_filename: str,
hyp_filename: str,
case_sensitive: bool = False) -> float:
def file_bleu(ref_filename: str,
hyp_filename: str,
bleu_version: str = "corpus_bleu_transformer",
case_sensitive: bool = False) -> float:
r"""Compute BLEU for two files (reference and hypothesis translation).
Args:
ref_filename: Reference file path.
hyp_filename: Hypothesis file path.
bleu_version: A str with the name of a BLEU computing method selected
in the list of: `corpus_bleu`, `corpus_bleu_moses`,
`corpus_bleu_transformer`.
case_sensitive: If `False`, lowercase reference and hypothesis
tokens.
Expand All @@ -212,6 +225,20 @@ def bleu_wrapper(ref_filename: str,
if not case_sensitive:
ref_lines = [x.lower() for x in ref_lines]
hyp_lines = [x.lower() for x in hyp_lines]
ref_tokens = [bleu_tokenize(x) for x in ref_lines]
hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
return compute_bleu(ref_tokens, hyp_tokens)

ref_tokens: List[MaybeList[List[str]]]
if bleu_version == "corpus_bleu_transformer":
ref_tokens = [bleu_transformer_tokenize(x) for x in ref_lines]
else:
ref_tokens = [[bleu_transformer_tokenize(x)] for x in ref_lines]
hyp_tokens = [bleu_transformer_tokenize(x) for x in hyp_lines]

bleu_dict = {
"corpus_bleu": corpus_bleu,
"corpus_bleu_moses": corpus_bleu_moses,
"corpus_bleu_transformer": corpus_bleu_transformer
}
fn: Callable[[List[MaybeList[List[str]]], List[List[str]]], float]
fn = bleu_dict[bleu_version] # type: ignore

return fn(ref_tokens, hyp_tokens)
77 changes: 77 additions & 0 deletions texar/evals/bleu_transformer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
#
"""
Unit tests for bleu_tool.
"""

import unittest

import tempfile

from texar.evals.bleu_transformer import *


class BLEUToolTest(unittest.TestCase):
r"""Test bleu_tool.
"""

def _create_temp_file(self, text):
temp_file = tempfile.NamedTemporaryFile(delete=False)
with open(temp_file.name, "w") as f:
f.write(text)
return temp_file.name

def test_bleu_same(self):
ref = self._create_temp_file("test 1 two 3\nmore tests!")
hyp = self._create_temp_file("test 1 two 3\nmore tests!")

uncased_score = file_bleu(ref, hyp, case_sensitive=False)
cased_score = file_bleu(ref, hyp, case_sensitive=True)
self.assertEqual(100, uncased_score)
self.assertEqual(100, cased_score)

def test_bleu_same_different_case(self):
ref = self._create_temp_file("Test 1 two 3\nmore tests!")
hyp = self._create_temp_file("test 1 two 3\nMore tests!")
uncased_score = file_bleu(ref, hyp, case_sensitive=False)
cased_score = file_bleu(ref, hyp, case_sensitive=True)
self.assertEqual(100, uncased_score)
self.assertLess(cased_score, 100)

def test_bleu_different(self):
ref = self._create_temp_file("Testing\nmore tests!")
hyp = self._create_temp_file("Dog\nCat")
uncased_score = file_bleu(ref, hyp, case_sensitive=False)
cased_score = file_bleu(ref, hyp, case_sensitive=True)
self.assertLess(uncased_score, 100)
self.assertLess(cased_score, 100)

def test_bleu_tokenize(self):
s = "Test0, 1 two, 3"
tokenized = bleu_transformer_tokenize(s)
self.assertEqual(["Test0", ",", "1", "two", ",", "3"], tokenized)

def test_bleu_version(self):
ref = self._create_temp_file("Test 1 two 3\nmore tests!")
hyp = self._create_temp_file("test 1 two 3\nMore tests!")
uncased_score = file_bleu(ref, hyp,
bleu_version="corpus_bleu",
case_sensitive=False)
cased_score = file_bleu(ref, hyp,
bleu_version="corpus_bleu",
case_sensitive=True)
self.assertEqual(100, uncased_score)
self.assertLess(cased_score, 100)

uncased_score = file_bleu(ref, hyp,
bleu_version="corpus_bleu_moses",
case_sensitive=False)
cased_score = file_bleu(ref, hyp,
bleu_version="corpus_bleu_moses",
case_sensitive=True)
self.assertEqual(100, uncased_score)
self.assertLess(cased_score, 100)


if __name__ == "__main__":
unittest.main()

0 comments on commit bf2d655

Please sign in to comment.