Skip to content

Commit

Permalink
Use texar.evals directly in Transformer example (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
gpengzhi authored and huzecong committed Jul 15, 2019
1 parent b23b5f3 commit db30e8b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 112 deletions.
109 changes: 7 additions & 102 deletions examples/transformer/bleu_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,104 +26,12 @@
# BLEU score will be similar to the one obtained using: mteval-v14.pl
# Note:compound splitting is not implemented in this module

import collections
import math
import re
import sys
import unicodedata
from argparse import ArgumentParser

import numpy as np


def _get_ngrams(segment, max_order):
"""Extracts all n-grams upto a given maximum order from an input segment.
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
Returns:
The Counter containing all n-grams upto max_order in segment
with a count of how many times each n-gram occurred.
"""
ngram_counts = collections.Counter()
for order in range(1, max_order + 1):
for i in range(0, len(segment) - order + 1):
ngram = tuple(segment[i: i + order])
ngram_counts[ngram] += 1
return ngram_counts


def compute_bleu(reference_corpus, translation_corpus,
max_order=4, use_bp=True):
"""Computes BLEU score of translated segments against references.
Args:
reference_corpus: list of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
use_bp: boolean, whether to apply brevity penalty.
Returns:
BLEU score.
"""

reference_length = 0
translation_length = 0
bp = 1.0
geo_mean = 0

matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
precisions = []

for (references, translations) in zip(reference_corpus, translation_corpus):
reference_length += len(references)
translation_length += len(translations)
ref_ngram_counts = _get_ngrams(references, max_order)
translation_ngram_counts = _get_ngrams(translations, max_order)

overlap = dict(
(ngram, min(count, translation_ngram_counts[ngram]))
for ngram, count in ref_ngram_counts.items()
)

for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for ngram in translation_ngram_counts:
possible_matches_by_order[len(ngram) - 1] += \
translation_ngram_counts[ngram]
precisions = [0] * max_order
smooth = 1.0
for i in range(max_order):
if possible_matches_by_order[i] > 0:
precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
if matches_by_order[i] > 0:
precisions[i] = (matches_by_order[i] /
possible_matches_by_order[i])
else:
smooth *= 2
precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
else:
precisions[i] = 0.0

if max(precisions) > 0:
p_log_sum = sum(math.log(p) for p in precisions if p)
geo_mean = math.exp(p_log_sum / max_order)

if use_bp:
ratio = translation_length / reference_length
if ratio <= 0:
bp = 0
elif ratio < 1.0:
bp = math.exp(1 - 1.0 / ratio)
else:
bp = 1.0
bleu = geo_mean * bp
return np.float32(bleu)
from texar.evals.bleu import corpus_bleu


class UnicodeRegex:
Expand Down Expand Up @@ -182,12 +90,11 @@ def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
ref_lines = open(ref_filename, encoding="utf-8").read().splitlines()
hyp_lines = open(hyp_filename, encoding="utf-8").read().splitlines()
assert len(ref_lines) == len(hyp_lines)
if not case_sensitive:
ref_lines = [x.lower() for x in ref_lines]
hyp_lines = [x.lower() for x in hyp_lines]
ref_tokens = [bleu_tokenize(x) for x in ref_lines]
ref_tokens = [[bleu_tokenize(x)] for x in ref_lines]
hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
return compute_bleu(ref_tokens, hyp_tokens)
return corpus_bleu(list_of_references=ref_tokens,
hypotheses=hyp_tokens,
lowercase=(not case_sensitive))


def main():
Expand All @@ -200,11 +107,9 @@ def main():
parser.add_argument("--reference", type=str)
args = parser.parse_args()

bleu = 100 * bleu_wrapper(args.reference, args.translation,
case_sensitive=False)
bleu = bleu_wrapper(args.reference, args.translation, case_sensitive=False)
print("BLEU_uncased = %6.2f" % bleu)
bleu = 100 * bleu_wrapper(args.reference, args.translation,
case_sensitive=True)
bleu = bleu_wrapper(args.reference, args.translation, case_sensitive=True)
print("BLEU_cased = %6.2f" % bleu)


Expand Down
1 change: 0 additions & 1 deletion examples/transformer/transformer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def _eval_epoch(epoch, mode, print_fn=None):
src_fname_suffix="hyp", tgt_fname_suffix="ref",
)
eval_bleu = bleu_wrapper(ref_fn, hyp_fn, case_sensitive=True)
eval_bleu = 100.0 * eval_bleu
logger.info("epoch: %d, eval_bleu %.4f", epoch, eval_bleu)
print_fn(f"epoch: {epoch:d}, eval_bleu {eval_bleu:.4f}")

Expand Down
25 changes: 16 additions & 9 deletions texar/evals/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def sentence_bleu(references: List[MaybeList[str]],
max_order: int = 4,
lowercase: bool = False,
smooth: bool = False,
use_bp: bool = True,
return_all: bool = False) -> MaybeList[float]:
r"""Calculates BLEU score of a hypothesis sentence.
Expand All @@ -93,6 +94,7 @@ def sentence_bleu(references: List[MaybeList[str]],
max_order (int): Maximum n-gram order to use when computing
BLEU score.
smooth (bool): Whether or not to apply `(Lin et al. 2004)` smoothing.
use_bp (bool): Whether to apply brevity penalty.
return_all (bool): If `True`, returns BLEU and all
n-gram precisions.
Expand All @@ -109,6 +111,7 @@ def sentence_bleu(references: List[MaybeList[str]],
max_order=max_order,
lowercase=lowercase,
smooth=smooth,
use_bp=use_bp,
return_all=return_all)


Expand All @@ -117,7 +120,8 @@ def corpus_bleu(list_of_references: List[List[MaybeList[str]]],
max_order: int = 4,
lowercase: bool = False,
smooth: bool = False,
return_all: bool = True) -> MaybeList[float]:
use_bp: bool = True,
return_all: bool = False) -> MaybeList[float]:
r"""Computes corpus-level BLEU score.
Args:
Expand All @@ -134,6 +138,7 @@ def corpus_bleu(list_of_references: List[List[MaybeList[str]]],
max_order (int): Maximum n-gram order to use when computing
BLEU score.
smooth (bool): Whether or not to apply `(Lin et al. 2004)` smoothing.
use_bp (bool): Whether to apply brevity penalty.
return_all (bool): If `True`, returns BLEU and all
n-gram precisions.
Expand Down Expand Up @@ -196,15 +201,17 @@ def corpus_bleu(list_of_references: List[List[MaybeList[str]]],
else:
geo_mean = 0

ratio = float(hypothesis_length) / reference_length

if ratio > 1.0:
bp = 1.
if use_bp:
ratio = float(hypothesis_length) / reference_length
if ratio > 1.0:
bp = 1.
else:
if abs(ratio) < 1e-8:
bp = 0.
else:
bp = math.exp(1 - 1. / ratio)
else:
try:
bp = math.exp(1 - 1. / ratio)
except ZeroDivisionError:
bp = math.exp(1 - 1. / (ratio + 1e-8))
bp = 1.

bleu = geo_mean * bp

Expand Down

0 comments on commit db30e8b

Please sign in to comment.