diff --git a/configs/default/components/publishers/bleu_statistics.yml b/configs/default/components/publishers/bleu_statistics.yml index a79a245..c51f387 100644 --- a/configs/default/components/publishers/bleu_statistics.yml +++ b/configs/default/components/publishers/bleu_statistics.yml @@ -13,6 +13,9 @@ use_prediction_distributions: True # TODO! #use_masking: False +# Ignored words - useful for ignoring special tokens +ignored_words: ["", ""] + # Weights of n-grams used when calculating the score. weights: [0.25, 0.25, 0.25, 0.25] diff --git a/ptp/components/publishers/bleu_statistics.py b/ptp/components/publishers/bleu_statistics.py index b303ea9..6432c06 100644 --- a/ptp/components/publishers/bleu_statistics.py +++ b/ptp/components/publishers/bleu_statistics.py @@ -58,6 +58,9 @@ def __init__(self, name, config): # Get masking flag. #self.use_masking = self.config["use_masking"] + # Get ignored words + self.ignored_words = self.config["ignored_words"] + # Retrieve word mappings from globals. word_to_ix = self.globals["word_mappings"] # Construct reverse mapping for faster processing. @@ -144,12 +147,16 @@ def calculate_BLEU(self, data_dict): target_words = [] for t_ind in target_indices: if t_ind in self.ix_to_word.keys(): - target_words.append(self.ix_to_word[t_ind]) + w = self.ix_to_word[t_ind] + if w not in self.ignored_words: + target_words.append(w) # Change prediction indices to words. pred_words = [] for p_ind in pred_indices: if p_ind in self.ix_to_word.keys(): - pred_words.append(self.ix_to_word[p_ind]) + w = self.ix_to_word[p_ind] + if w not in self.ignored_words: + pred_words.append(w) # Calculate BLEU. scores.append(sentence_bleu([target_words], pred_words, self.weights)) #print("TARGET: {}\n".format(target_words))