From 1aa5a0e756dd69f37600c2b3cc3f5f203e036f07 Mon Sep 17 00:00:00 2001 From: tkornut Date: Fri, 19 Apr 2019 10:56:16 -0700 Subject: [PATCH 1/5] first version of component calculating BLEU score --- .../components/publishers/bleu_statistics.yml | 50 ++++ .../wikitext_language_modeling_rnn.yml | 23 +- ptp/components/publishers/__init__.py | 2 + ptp/components/publishers/bleu_statistics.py | 223 ++++++++++++++++++ ptp/components/text/sentence_indexer.py | 2 +- 5 files changed, 296 insertions(+), 4 deletions(-) create mode 100644 configs/default/components/publishers/bleu_statistics.yml create mode 100644 ptp/components/publishers/bleu_statistics.py diff --git a/configs/default/components/publishers/bleu_statistics.yml b/configs/default/components/publishers/bleu_statistics.yml new file mode 100644 index 0000000..2a52fa3 --- /dev/null +++ b/configs/default/components/publishers/bleu_statistics.yml @@ -0,0 +1,50 @@ +# This file defines the default values for the Accuracy statistics. + +#################################################################### +# 1. CONFIGURATION PARAMETERS that will be LOADED by the component. +#################################################################### + +# Flag indicating whether prediction are represented as distributions or indices (LOADED) +# Options: True (expects distribution for each preditions) +# False (expects indices (max args)) +use_prediction_distributions: True + +# When set to True, performs masking of selected samples from batch (LOADED) +# TODO! +#use_masking: False + +streams: + #################################################################### + # 2. Keymappings associated with INPUT and OUTPUT streams. + #################################################################### + + # Stream containing targets (label ids) (INPUT) + targets: targets + + # Stream containing batch of predictions (INPUT) + predictions: predictions + + # Stream containing masks used for masking of selected samples from batch (INPUT) + #masks: masks + +globals: + #################################################################### + # 3. Keymappings of variables that will be RETRIEVED from GLOBALS. + #################################################################### + + # Word mappings used for mappings of predictions/targets into list of words (RERIEVED) + word_mappings: word_mappings + + #################################################################### + # 4. Keymappings associated with GLOBAL variables that will be SET. + #################################################################### + +statistics: + #################################################################### + # 5. Keymappings associated with statistics that will be ADDED. + #################################################################### + + # Name used for collected statistics (ADDED). + bleu: bleu + + diff --git a/configs/wikitext/wikitext_language_modeling_rnn.yml b/configs/wikitext/wikitext_language_modeling_rnn.yml index 88274b9..d2806d5 100644 --- a/configs/wikitext/wikitext_language_modeling_rnn.yml +++ b/configs/wikitext/wikitext_language_modeling_rnn.yml @@ -86,12 +86,29 @@ pipeline: num_targets_dims: 2 streams: targets: indexed_targets - loss: loss + # Statistics. + batch_size: + type: BatchSizeStatistics + priority: 100.0 + + #accuracy: + # type: AccuracyStatistics + # priority: 100.1 + # streams: + # targets: indexed_targets + + bleu: + type: BLEUStatistics + priority: 100.2 + streams: + targets: indexed_targets + + # Viewers. viewer: type: StreamViewer - priority: 100.1 - input_streams: sources,indexed_targets,targets,predictions + priority: 100.3 + input_streams: sources,targets,indexed_targets,predictions #: pipeline diff --git a/ptp/components/publishers/__init__.py b/ptp/components/publishers/__init__.py index a412f6f..1db7f75 100644 --- a/ptp/components/publishers/__init__.py +++ b/ptp/components/publishers/__init__.py @@ -1,11 +1,13 @@ from .accuracy_statistics import AccuracyStatistics from .batch_size_statistics import BatchSizeStatistics +from .bleu_statistics import BLEUStatistics from .global_variable_publisher import GlobalVariablePublisher from .precision_recall_statistics import PrecisionRecallStatistics __all__ = [ 'AccuracyStatistics', 'BatchSizeStatistics', + 'BLEUStatistics', 'GlobalVariablePublisher', 'PrecisionRecallStatistics', ] diff --git a/ptp/components/publishers/bleu_statistics.py b/ptp/components/publishers/bleu_statistics.py new file mode 100644 index 0000000..eafb7ea --- /dev/null +++ b/ptp/components/publishers/bleu_statistics.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) tkornuta, IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__author__ = "Tomasz Kornuta" + +import torch +import math +import numpy as np +from nltk.translate.bleu_score import sentence_bleu + +from ptp.components.component import Component +from ptp.data_types.data_definition import DataDefinition + + +class BLEUStatistics(Component): + """ + Class collecting statistics: BLEU (Bilingual Evaluation Understudy Score). + + It accepts targets and predictions represented as indices of words and uses the provided word mappings to change those into words used finally for calculation of BLEU similarity. + + """ + + def __init__(self, name, config): + """ + Initializes object. + + :param name: Loss name. + :type name: str + + :param config: Dictionary of parameters (read from the configuration ``.yaml`` file). + :type config: :py:class:`ptp.configuration.ConfigInterface` + + """ + # Call constructors of parent classes. + Component.__init__(self, name, BLEUStatistics, config) + + # Get stream key mappings. + self.key_targets = self.stream_keys["targets"] + self.key_predictions = self.stream_keys["predictions"] + self.key_masks = self.stream_keys["masks"] + + # Get prediction distributions/indices flag. + self.use_prediction_distributions = self.config["use_prediction_distributions"] + + # Get masking flag. + #self.use_masking = self.config["use_masking"] + + # Retrieve word mappings from globals. + word_to_ix = self.globals["word_mappings"] + # Construct reverse mapping for faster processing. + self.ix_to_word = dict((v,k) for k,v in word_to_ix.items()) + + + # Get statistics key mappings. + self.key_bleu = self.statistics_keys["bleu"] + + + def input_data_definitions(self): + """ + Function returns a dictionary with definitions of input data that are required by the component. + + :return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`). + """ + # Add targets. + input_defs = { + self.key_targets: DataDefinition([-1, -1], [torch.Tensor], "Batch of sentences represented as a single tensor of indices of particular words [BATCH_SIZE x SEQ_LENGTH]"), + } + # Add predictions. + if self.use_prediction_distributions: + input_defs[self.key_predictions] = DataDefinition([-1, -1, -1], [torch.Tensor], "Batch of predictions, represented as tensor with sequences of probability distributions over classes [BATCH_SIZE x SEQ_LENGTH x NUM_CLASSES]") + else: + input_defs[self.key_predictions] = DataDefinition([-1, -1], [torch.Tensor], "Batch of predictions, represented as tensor with sequences of indices of predicted answers [BATCH_SIZE x SEQ_LENGTH]") + # Add masks. + #if self.use_masking: + # input_defs[self.key_masks] = DataDefinition([-1, -1], [torch.Tensor], "Batch of masks (separate mask for each sequence in the batch) [BATCH_SIZE x SEQ_LENGTH]") + return input_defs + + + def output_data_definitions(self): + """ + Function returns a empty dictionary with definitions of output data produced the component. + + :return: Empty dictionary. + """ + return {} + + + def __call__(self, data_dict): + """ + Call method - empty for all statistics. + """ + pass + + + def calculate_BLEU(self, data_dict): + """ + Calculates BLEU for predictions of a given batch. + + :param data_dict: DataDict containing the targets and predictions (and optionally masks). + :type data_dict: DataDict + + :return: Accuracy. + + """ + # Get targets. + targets = data_dict[self.key_targets].data.cpu().numpy().tolist() + + if self.use_prediction_distributions: + # Get indices of the max log-probability. + preds = data_dict[self.key_predictions].max(1)[1].data.cpu().numpy().tolist() + else: + preds = data_dict[self.key_predictions].data.cpu().numpy().tolist() + + #if self.use_masking: + # # Get masks from inputs. + # masks = data_dict[self.key_masks].data.cpu().numpy().tolist() + #else: + # batch_size = preds.shape[0] + + # Calculate the correct predictinos. + scores = [] + + for target_indices, pred_indices in zip(targets, preds): + # Change target indices to words. + target_words = [] + for t_ind in target_indices: + if t_ind in self.ix_to_word.keys(): + target_words.append(self.ix_to_word[t_ind]) + # Change prediction indices to words. + pred_words = [] + for p_ind in pred_indices: + if p_ind in self.ix_to_word.keys(): + pred_words.append(self.ix_to_word[p_ind]) + # Calculate BLEU. + scores.append(sentence_bleu(target_words, pred_words)) + + # Get batch size. + batch_size = len(targets) + + # Normalize by batch size. + if batch_size > 0: + score = sum(scores) / batch_size + else: + score = 0 + + return score + + + def add_statistics(self, stat_col): + """ + Adds 'accuracy' statistics to ``StatisticsCollector``. + + :param stat_col: ``StatisticsCollector``. + + """ + stat_col.add_statistics(self.key_bleu, '{:6.4f}') + + def collect_statistics(self, stat_col, data_dict): + """ + Collects statistics (batch_size) for given episode. + + :param stat_col: ``StatisticsCollector``. + + """ + stat_col[self.key_bleu] = self.calculate_BLEU(data_dict) + + def add_aggregators(self, stat_agg): + """ + Adds aggregator summing samples from all collected batches. + + :param stat_agg: ``StatisticsAggregator``. + + """ + stat_agg.add_aggregator(self.key_bleu, '{:7.5f}') # represents the average accuracy + #stat_agg.add_aggregator(self.key_bleu+'_min', '{:7.5f}') + #stat_agg.add_aggregator(self.key_bleu+'_max', '{:7.5f}') + stat_agg.add_aggregator(self.key_bleu+'_std', '{:7.5f}') + + + def aggregate_statistics(self, stat_col, stat_agg): + """ + Aggregates samples from all collected batches. + + :param stat_col: ``StatisticsCollector`` + + :param stat_agg: ``StatisticsAggregator`` + + """ + scores = stat_col[self.key_bleu] + + # Check if batch size was collected. + if "batch_size" in stat_col.keys(): + batch_sizes = stat_col['batch_size'] + + # Calculate weighted precision. + scores_avg = np.average(scores, weights=batch_sizes) + scores_var = np.average((scores-scores_avg)**2, weights=batch_sizes) + + stat_agg[self.key_bleu] = scores_avg + #stat_agg[self.key_bleu+'_min'] = np.min(scores) + #stat_agg[self.key_bleu+'_max'] = np.max(scores) + stat_agg[self.key_bleu+'_std'] = math.sqrt(scores_var) + else: + # Else: use simple mean. + stat_agg[self.key_bleu] = np.mean(scores) + #stat_agg[self.key_bleu+'_min'] = np.min(scores) + #stat_agg[self.key_bleu+'_max'] = np.max(scores) + stat_agg[self.key_bleu+'_std'] = np.std(scores) + # But inform user about that! + self.logger.warning("Aggregated statistics might contain errors due to the lack of information about sizes of aggregated batches") diff --git a/ptp/components/text/sentence_indexer.py b/ptp/components/text/sentence_indexer.py index 18394f9..abaf94a 100644 --- a/ptp/components/text/sentence_indexer.py +++ b/ptp/components/text/sentence_indexer.py @@ -61,7 +61,7 @@ def output_data_definitions(self): :return: dictionary containing output data definitions (each of type :py:class:`ptp.utils.DataDefinition`). """ return { - self.key_outputs: DataDefinition([-1, -1], [torch.Tensor], "Batch of sentences represented as a single tensor of indices [BATCH_SIZE x SEQ_LENGTH]"), + self.key_outputs: DataDefinition([-1, -1], [torch.Tensor], "Batch of sentences represented as a single tensor of indices of particular words [BATCH_SIZE x SEQ_LENGTH]"), } def __call__(self, data_dict): From d3373c0b9600a79e14ab3f157d341c52aac41b48 Mon Sep 17 00:00:00 2001 From: tkornut Date: Fri, 19 Apr 2019 10:57:27 -0700 Subject: [PATCH 2/5] first version of component calculating BLEU score --- ptp/components/publishers/bleu_statistics.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ptp/components/publishers/bleu_statistics.py b/ptp/components/publishers/bleu_statistics.py index eafb7ea..4c19564 100644 --- a/ptp/components/publishers/bleu_statistics.py +++ b/ptp/components/publishers/bleu_statistics.py @@ -146,6 +146,9 @@ def calculate_BLEU(self, data_dict): pred_words.append(self.ix_to_word[p_ind]) # Calculate BLEU. scores.append(sentence_bleu(target_words, pred_words)) + print("TARGET: {}\n".format(target_words)) + print("PREDICTION: {}\n".format(pred_words)) + print("BLEU: {}\n".format(scores[-1])) # Get batch size. batch_size = len(targets) From a564f74ea158a7239054398fbc2ed4666d26e32e Mon Sep 17 00:00:00 2001 From: tkornut Date: Fri, 19 Apr 2019 14:13:28 -0700 Subject: [PATCH 3/5] bleu with weights, fixed bug with max along item axis --- .../components/publishers/bleu_statistics.yml | 3 +++ ptp/components/publishers/bleu_statistics.py | 17 ++++++++++++----- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/configs/default/components/publishers/bleu_statistics.yml b/configs/default/components/publishers/bleu_statistics.yml index 2a52fa3..a79a245 100644 --- a/configs/default/components/publishers/bleu_statistics.yml +++ b/configs/default/components/publishers/bleu_statistics.yml @@ -13,6 +13,9 @@ use_prediction_distributions: True # TODO! #use_masking: False +# Weights of n-grams used when calculating the score. +weights: [0.25, 0.25, 0.25, 0.25] + streams: #################################################################### # 2. Keymappings associated with INPUT and OUTPUT streams. diff --git a/ptp/components/publishers/bleu_statistics.py b/ptp/components/publishers/bleu_statistics.py index 4c19564..b303ea9 100644 --- a/ptp/components/publishers/bleu_statistics.py +++ b/ptp/components/publishers/bleu_statistics.py @@ -63,6 +63,9 @@ def __init__(self, name, config): # Construct reverse mapping for faster processing. self.ix_to_word = dict((v,k) for k,v in word_to_ix.items()) + # Get masking flag. + self.weights = self.config["weights"] + # Get statistics key mappings. self.key_bleu = self.statistics_keys["bleu"] @@ -120,7 +123,7 @@ def calculate_BLEU(self, data_dict): if self.use_prediction_distributions: # Get indices of the max log-probability. - preds = data_dict[self.key_predictions].max(1)[1].data.cpu().numpy().tolist() + preds = data_dict[self.key_predictions].max(-1)[1].data.cpu().numpy().tolist() else: preds = data_dict[self.key_predictions].data.cpu().numpy().tolist() @@ -133,6 +136,9 @@ def calculate_BLEU(self, data_dict): # Calculate the correct predictinos. scores = [] + #print("targets ({}): {}\n".format(len(targets), targets[0])) + #print("preds ({}): {}\n".format(len(preds), preds[0])) + for target_indices, pred_indices in zip(targets, preds): # Change target indices to words. target_words = [] @@ -145,11 +151,12 @@ def calculate_BLEU(self, data_dict): if p_ind in self.ix_to_word.keys(): pred_words.append(self.ix_to_word[p_ind]) # Calculate BLEU. - scores.append(sentence_bleu(target_words, pred_words)) - print("TARGET: {}\n".format(target_words)) - print("PREDICTION: {}\n".format(pred_words)) - print("BLEU: {}\n".format(scores[-1])) + scores.append(sentence_bleu([target_words], pred_words, self.weights)) + #print("TARGET: {}\n".format(target_words)) + #print("PREDICTION: {}\n".format(pred_words)) + #print("BLEU: {}\n".format(scores[-1])) + # Get batch size. batch_size = len(targets) From f1ff3f7eee85a1bf0c614f2175cbf4f495cc69a6 Mon Sep 17 00:00:00 2001 From: tkornut Date: Fri, 19 Apr 2019 14:36:49 -0700 Subject: [PATCH 4/5] Fixed mutltiple inheritance issue with mixin WordEmbeddings class --- .../components/text/sentence_indexer.yml | 3 +++ .../wikitext_language_modeling_rnn.yml | 4 +-- ptp/components/mixins/word_mappings.py | 25 ++++++++----------- ptp/components/models/sentence_embeddings.py | 2 +- ptp/components/text/label_indexer.py | 8 +++--- ptp/components/text/sentence_indexer.py | 14 ++++++++--- .../text/sentence_one_hot_encoder.py | 8 +++--- ptp/components/text/word_decoder.py | 8 +++--- 8 files changed, 41 insertions(+), 31 deletions(-) diff --git a/configs/default/components/text/sentence_indexer.yml b/configs/default/components/text/sentence_indexer.yml index 65d5d03..25c2f5e 100644 --- a/configs/default/components/text/sentence_indexer.yml +++ b/configs/default/components/text/sentence_indexer.yml @@ -25,6 +25,9 @@ import_word_mappings_from_globals: False # Flag informing whether word mappings will be exported to globals (LOADED) export_word_mappings_to_globals: False +# Operation mode. If 'reverse' is True, then it will change indices into words (LOADED) +reverse: False + streams: #################################################################### # 2. Keymappings associated with INPUT and OUTPUT streams. diff --git a/configs/wikitext/wikitext_language_modeling_rnn.yml b/configs/wikitext/wikitext_language_modeling_rnn.yml index d2806d5..3e87643 100644 --- a/configs/wikitext/wikitext_language_modeling_rnn.yml +++ b/configs/wikitext/wikitext_language_modeling_rnn.yml @@ -5,7 +5,7 @@ training: data_folder: &data_folder ~/data/language_modeling/wikitext-2 dataset: &dataset wikitext-2 subset: train - sentence_length: 50 + sentence_length: 10 batch_size: 64 # optimizer parameters: @@ -27,7 +27,7 @@ validation: data_folder: *data_folder dataset: *dataset subset: valid - sentence_length: 50 + sentence_length: 20 batch_size: 64 # Testing parameters: diff --git a/ptp/components/mixins/word_mappings.py b/ptp/components/mixins/word_mappings.py index 53bcf0c..1920574 100644 --- a/ptp/components/mixins/word_mappings.py +++ b/ptp/components/mixins/word_mappings.py @@ -17,31 +17,26 @@ import os import ptp.components.utils.word_mappings as wm -from ptp.components.component import Component -class WordMappings(Component): +class WordMappings(object): """ Mixin class that handles the initialization of (word:index) mappings. + Assumes that it is mixed-in into class that is derived from the component. + .. warning:: + Constructor (__init__) of the Component class has to be called before component of the mixin WordMapping class. + """ - def __init__(self, name, class_type, config): + def __init__(self): #, name, class_type, config): """ Initializes the (word:index) mappings. - Loads parameters from configuration, - - :param name: Component name (read from configuration file). - :type name: str - - :param class_type: Class type of the component (derrived from this class). - - :param config: Dictionary of parameters (read from the configuration ``.yaml`` file). - :type config: :py:class:`ptp.configuration.ConfigInterface` + Assumes that Component was initialized in advance, which means that the self object possesses the following objects: + - self.config + - self.globals + - self.logger """ - # Call constructors of parent classes. - Component.__init__(self, name, class_type, config) - # Read the actual configuration. self.data_folder = os.path.expanduser(self.config['data_folder']) diff --git a/ptp/components/models/sentence_embeddings.py b/ptp/components/models/sentence_embeddings.py index 5d44bd4..6004e2a 100644 --- a/ptp/components/models/sentence_embeddings.py +++ b/ptp/components/models/sentence_embeddings.py @@ -50,7 +50,7 @@ def __init__(self, name, config): """ # Call base class constructors. Model.__init__(self, name, SentenceEmbeddings, config) - WordMappings.__init__(self, name, SentenceEmbeddings, config) + WordMappings.__init__(self) # Set key mappings. self.key_inputs = self.stream_keys["inputs"] diff --git a/ptp/components/text/label_indexer.py b/ptp/components/text/label_indexer.py index 410aa46..c3090cd 100644 --- a/ptp/components/text/label_indexer.py +++ b/ptp/components/text/label_indexer.py @@ -16,11 +16,12 @@ import torch +from ptp.components.component import Component from ptp.components.mixins.word_mappings import WordMappings from ptp.data_types.data_definition import DataDefinition -class LabelIndexer(WordMappings): +class LabelIndexer(Component, WordMappings): """ Class responsible for changing of samples consisting of single words/labels into indices (that e.g. can be latter used for loss calculation, PyTorch-style). """ @@ -35,8 +36,9 @@ def __init__(self, name, config): :type config: :py:class:`ptp.configuration.ConfigInterface` """ - # Call constructor(s) of parent class(es). - WordMappings.__init__(self, name, LabelIndexer, config) + # Call constructor(s) of parent class(es) - in the right order! + Component.__init__(self, name, LabelIndexer, config) + WordMappings.__init__(self) # Set key mappings. self.key_inputs = self.stream_keys["inputs"] diff --git a/ptp/components/text/sentence_indexer.py b/ptp/components/text/sentence_indexer.py index abaf94a..b21e0f4 100644 --- a/ptp/components/text/sentence_indexer.py +++ b/ptp/components/text/sentence_indexer.py @@ -16,11 +16,12 @@ import torch +from ptp.components.component import Component from ptp.components.mixins.word_mappings import WordMappings from ptp.data_types.data_definition import DataDefinition -class SentenceIndexer(WordMappings): +class SentenceIndexer(Component, WordMappings): """ Class responsible for encoding of sequences of words into list of indices. Those can be letter embedded, encoded with 1-hot encoding or else. @@ -36,13 +37,18 @@ def __init__(self, name, config): :type config: :py:class:`ptp.configuration.ConfigInterface` """ - # Call constructor(s) of parent class(es). - WordMappings.__init__(self, name, SentenceIndexer, config) + # Call constructor(s) of parent class(es) - in the right order! + Component.__init__(self, name, SentenceIndexer, config) + WordMappings.__init__(self) # Set key mappings. self.key_inputs = self.stream_keys["inputs"] self.key_outputs = self.stream_keys["outputs"] - + + # Read mode from the configuration. + self.mode_reverse = self.config['reverse'] + + def input_data_definitions(self): """ diff --git a/ptp/components/text/sentence_one_hot_encoder.py b/ptp/components/text/sentence_one_hot_encoder.py index c25100f..b25a3e8 100644 --- a/ptp/components/text/sentence_one_hot_encoder.py +++ b/ptp/components/text/sentence_one_hot_encoder.py @@ -16,11 +16,12 @@ import torch +from ptp.components.component import Component from ptp.components.mixins.word_mappings import WordMappings from ptp.data_types.data_definition import DataDefinition -class SentenceOneHotEncoder(WordMappings): +class SentenceOneHotEncoder(Component, WordMappings): """ Class responsible for encoding of samples being sequences of words using 1-hot encoding. """ @@ -35,8 +36,9 @@ def __init__(self, name, config): :type config: :py:class:`ptp.configuration.ConfigInterface` """ - # Call constructor(s) of parent class(es). - WordMappings.__init__(self, name, SentenceOneHotEncoder, config) + # Call constructor(s) of parent class(es) - in the right order! + Component.__init__(self, name, SentenceOneHotEncoder, config) + WordMappings.__init__(self) # Set key mappings. self.key_inputs = self.stream_keys["inputs"] diff --git a/ptp/components/text/word_decoder.py b/ptp/components/text/word_decoder.py index 0e5a052..e75dd15 100644 --- a/ptp/components/text/word_decoder.py +++ b/ptp/components/text/word_decoder.py @@ -16,11 +16,12 @@ import torch +from ptp.components.component import Component from ptp.components.mixins.word_mappings import WordMappings from ptp.data_types.data_definition import DataDefinition -class WordDecoder(WordMappings): +class WordDecoder(Component, WordMappings): """ Class responsible for decoding of samples encoded in the form of vectors ("probability distributions"). """ @@ -35,8 +36,9 @@ def __init__(self, name, config): :type config: :py:class:`ptp.configuration.ConfigInterface` """ - # Call constructor(s) of parent class(es). - WordMappings.__init__(self, name, WordDecoder, config) + # Call constructor(s) of parent class(es) - in the right order! + Component.__init__(self, name, WordDecoder, config) + WordMappings.__init__(self) # Construct reverse mapping for faster processing. self.ix_to_word = dict((v,k) for k,v in self.word_to_ix.items()) From d01b8ebdda6da417cdf527945124eda16f8b25e0 Mon Sep 17 00:00:00 2001 From: tkornut Date: Fri, 19 Apr 2019 15:14:13 -0700 Subject: [PATCH 5/5] Deindexing mode added to sentence indexer --- .../components/text/sentence_indexer.yml | 5 + .../wikitext_language_modeling_rnn.yml | 21 ++- ptp/components/text/sentence_indexer.py | 131 ++++++++++++++++-- 3 files changed, 142 insertions(+), 15 deletions(-) diff --git a/configs/default/components/text/sentence_indexer.yml b/configs/default/components/text/sentence_indexer.yml index 25c2f5e..0921bc7 100644 --- a/configs/default/components/text/sentence_indexer.yml +++ b/configs/default/components/text/sentence_indexer.yml @@ -28,6 +28,11 @@ export_word_mappings_to_globals: False # Operation mode. If 'reverse' is True, then it will change indices into words (LOADED) reverse: False +# Flag indicating whether inputs are represented as distributions or indices (LOADED) +# Options: True (expects distribution for each input item in sequence) +# False (expects indices (max args)) +use_input_distributions: False + streams: #################################################################### # 2. Keymappings associated with INPUT and OUTPUT streams. diff --git a/configs/wikitext/wikitext_language_modeling_rnn.yml b/configs/wikitext/wikitext_language_modeling_rnn.yml index 3e87643..811dbb5 100644 --- a/configs/wikitext/wikitext_language_modeling_rnn.yml +++ b/configs/wikitext/wikitext_language_modeling_rnn.yml @@ -46,7 +46,7 @@ pipeline: # Source encoding - model 1. source_sentence_embedding: type: SentenceEmbeddings - priority: 1.1 + priority: 1 embeddings_size: 50 pretrained_embeddings: glove.6B.50d.txt data_folder: *data_folder @@ -61,7 +61,7 @@ pipeline: # Target encoding. target_indexer: type: SentenceIndexer - priority: 2.1 + priority: 2 data_folder: *data_folder import_word_mappings_from_globals: True streams: @@ -87,6 +87,21 @@ pipeline: streams: targets: indexed_targets + # Prediction decoding. + prediction_decoder: + type: SentenceIndexer + priority: 10 + # Reverse mode. + reverse: True + # Use distributions as inputs. + use_input_distributions: True + data_folder: *data_folder + import_word_mappings_from_globals: True + streams: + inputs: predictions + outputs: prediction_sentences + + # Statistics. batch_size: type: BatchSizeStatistics @@ -109,6 +124,6 @@ pipeline: viewer: type: StreamViewer priority: 100.3 - input_streams: sources,targets,indexed_targets,predictions + input_streams: sources,targets,indexed_targets,prediction_sentences #: pipeline diff --git a/ptp/components/text/sentence_indexer.py b/ptp/components/text/sentence_indexer.py index b21e0f4..7cb0ece 100644 --- a/ptp/components/text/sentence_indexer.py +++ b/ptp/components/text/sentence_indexer.py @@ -25,6 +25,8 @@ class SentenceIndexer(Component, WordMappings): """ Class responsible for encoding of sequences of words into list of indices. Those can be letter embedded, encoded with 1-hot encoding or else. + + Additianally, when 'reverse' mode is on, it works in the oposite direction, i.e. changing tensor witl indices into list of words. """ def __init__(self, name, config): """ @@ -48,6 +50,13 @@ def __init__(self, name, config): # Read mode from the configuration. self.mode_reverse = self.config['reverse'] + if self.mode_reverse: + # We will need reverse (index:word) mapping. + self.ix_to_word = dict((v,k) for k,v in self.word_to_ix.items()) + + # Get inputs distributions/indices flag. + self.use_input_distributions = self.config["use_input_distributions"] + def input_data_definitions(self): @@ -56,9 +65,19 @@ def input_data_definitions(self): :return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`). """ - return { - self.key_inputs: DataDefinition([-1, -1, 1], [list, list, str], "Batch of sentences, each represented as a list of words [BATCH_SIZE] x [SEQ_LENGTH] x [string]"), - } + if self.mode_reverse: + if self.use_input_distributions: + return { + self.key_inputs: DataDefinition([-1, -1, -1], [torch.Tensor], "Batch of sentences represented as a single tensor with batch of probability distributions [BATCH_SIZE x SEQ_LENGTH x ITEM_SIZE]"), + } + else: + return { + self.key_inputs: DataDefinition([-1, -1], [torch.Tensor], "Batch of sentences represented as a single tensor of indices of particular words [BATCH_SIZE x SEQ_LENGTH]"), + } + else: + return { + self.key_inputs: DataDefinition([-1, -1, 1], [list, list, str], "Batch of sentences, each represented as a list of words [BATCH_SIZE] x [SEQ_LENGTH] x [string]"), + } def output_data_definitions(self): """ @@ -66,25 +85,50 @@ def output_data_definitions(self): :return: dictionary containing output data definitions (each of type :py:class:`ptp.utils.DataDefinition`). """ - return { - self.key_outputs: DataDefinition([-1, -1], [torch.Tensor], "Batch of sentences represented as a single tensor of indices of particular words [BATCH_SIZE x SEQ_LENGTH]"), - } + if self.mode_reverse: + return { + self.key_outputs: DataDefinition([-1, -1, 1], [list, list, str], "Batch of sentences, each represented as a list of words [BATCH_SIZE] x [SEQ_LENGTH] x [string]"), + } + else: + return { + self.key_outputs: DataDefinition([-1, -1], [torch.Tensor], "Batch of sentences represented as a single tensor of indices of particular words [BATCH_SIZE x SEQ_LENGTH]"), + } + def __call__(self, data_dict): """ - Encodes "inputs" in the format of list of tokens (for a single sample) - Stores result in "encoded_inputs" field of in data_dict. + Encodes inputs into outputs. + Depending on the mode (set by 'reverse' config param) calls sentences_to_tensor() (when False) or tensor_to_sentences() (when set to True). - :param data_dict: :py:class:`ptp.utils.DataDict` object containing (among others): + :param data_dict: :py:class:`ptp.datatypes.DataDict` object. + """ + if self.mode_reverse: + if self.use_input_distributions: + # Produce list of words. + self.tensor_distributions_to_sentences(data_dict) + else: + # Produce list of words. + self.tensor_indices_to_sentences(data_dict) + else: + # Produce indices. + self.sentences_to_tensor(data_dict) + + + def sentences_to_tensor(self, data_dict): + """ + Encodes "inputs" in the format of batch of list of words into a single tensor with corresponding indices. - - "inputs": expected input field containing list of words [BATCH_SIZE] x [SEQ_SIZE] x [string] + :param data_dict: :py:class:`ptp.datatypes.DataDict` object containing (among others): - - "encoded_targets": added output field containing list of indices [BATCH_SIZE x SEQ_SIZE] + - "inputs": expected input field containing list of lists of words [BATCH_SIZE] x [SEQ_SIZE] x [string] + + - "outputs": added output field containing tensor with indices [BATCH_SIZE x SEQ_SIZE] """ # Get inputs to be encoded. inputs = data_dict[self.key_inputs] + outputs_list = [] - # Process samples 1 by one. + # Process sentences 1 by 1. for sample in inputs: assert isinstance(sample, (list,)), 'This encoder requires input sample to contain a list of words' # Process list. @@ -102,3 +146,66 @@ def __call__(self, data_dict): output = self.app_state.LongTensor(outputs_list) # Create the returned dict. data_dict.extend({self.key_outputs: output}) + + def tensor_indices_to_sentences(self, data_dict): + """ + Encodes "inputs" in the format of tensor with indices into a batch of list of words. + + :param data_dict: :py:class:`ptp.datatypes.DataDict` object containing (among others): + + - "inputs": added output field containing tensor with indices [BATCH_SIZE x SEQ_SIZE] + + - "outputs": expected input field containing list of lists of words [BATCH_SIZE] x [SEQ_SIZE] x [string] + + """ + # Get inputs to be changed to words. + inputs = data_dict[self.key_inputs].data.cpu().numpy().tolist() + + outputs_list = [] + # Process samples 1 by 1. + for sample in inputs: + # Process list. + output_sample = [] + # "Decode" sample (list of indices). + for token in sample: + # Get word. + output_word = self.ix_to_word[token] + # Add index to outputs. + output_sample.append( output_word ) + # Add sentence to batch. + outputs_list.append(output_sample) + + # Create the returned dict. + data_dict.extend({self.key_outputs: outputs_list}) + + def tensor_distributions_to_sentences(self, data_dict): + """ + Encodes "inputs" in the format of tensor with probability distributions into a batch of list of words. + + :param data_dict: :py:class:`ptp.datatypes.DataDict` object containing (among others): + + - "inputs": added output field containing tensor with indices [BATCH_SIZE x SEQ_SIZE x ITEM_SIZE] + + - "outputs": expected input field containing list of lists of words [BATCH_SIZE] x [SEQ_SIZE] x [string] + + """ + # Get inputs to be changed to words. + inputs = data_dict[self.key_inputs].max(2)[1].data.cpu().numpy().tolist() + + outputs_list = [] + # Process samples 1 by 1. + for sample in inputs: + # Process list. + output_sample = [] + # "Decode" sample (list of indices). + for token in sample: + + # Get word. + output_word = self.ix_to_word[token] + # Add index to outputs. + output_sample.append( output_word ) + # Add sentence to batch. + outputs_list.append(output_sample) + + # Create the returned dict. + data_dict.extend({self.key_outputs: outputs_list})