From 39242f6acf1cad0c576f8e94bba8391d0048fc81 Mon Sep 17 00:00:00 2001 From: Alexis Asseman <33075224+aasseman@users.noreply.github.com> Date: Tue, 30 Apr 2019 09:27:31 -0700 Subject: [PATCH] Add option to ignore words in BLEU --- .../default/components/publishers/bleu_statistics.yml | 3 +++ ptp/components/publishers/bleu_statistics.py | 11 +++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/configs/default/components/publishers/bleu_statistics.yml b/configs/default/components/publishers/bleu_statistics.yml index a79a245..c51f387 100644 --- a/configs/default/components/publishers/bleu_statistics.yml +++ b/configs/default/components/publishers/bleu_statistics.yml @@ -13,6 +13,9 @@ use_prediction_distributions: True # TODO! #use_masking: False +# Ignored words - useful for ignoring special tokens +ignored_words: ["", ""] + # Weights of n-grams used when calculating the score. weights: [0.25, 0.25, 0.25, 0.25] diff --git a/ptp/components/publishers/bleu_statistics.py b/ptp/components/publishers/bleu_statistics.py index b303ea9..6432c06 100644 --- a/ptp/components/publishers/bleu_statistics.py +++ b/ptp/components/publishers/bleu_statistics.py @@ -58,6 +58,9 @@ def __init__(self, name, config): # Get masking flag. #self.use_masking = self.config["use_masking"] + # Get ignored words + self.ignored_words = self.config["ignored_words"] + # Retrieve word mappings from globals. word_to_ix = self.globals["word_mappings"] # Construct reverse mapping for faster processing. @@ -144,12 +147,16 @@ def calculate_BLEU(self, data_dict): target_words = [] for t_ind in target_indices: if t_ind in self.ix_to_word.keys(): - target_words.append(self.ix_to_word[t_ind]) + w = self.ix_to_word[t_ind] + if w not in self.ignored_words: + target_words.append(w) # Change prediction indices to words. pred_words = [] for p_ind in pred_indices: if p_ind in self.ix_to_word.keys(): - pred_words.append(self.ix_to_word[p_ind]) + w = self.ix_to_word[p_ind] + if w not in self.ignored_words: + pred_words.append(w) # Calculate BLEU. scores.append(sentence_bleu([target_words], pred_words, self.weights)) #print("TARGET: {}\n".format(target_words))