diff --git a/configs/default/components/models/sentence_embeddings.yml b/configs/default/components/models/sentence_embeddings.yml index ab3a6da..7ca8987 100644 --- a/configs/default/components/models/sentence_embeddings.yml +++ b/configs/default/components/models/sentence_embeddings.yml @@ -25,6 +25,12 @@ import_word_mappings_from_globals: False # Flag informing whether word mappings will be exported to globals (LOADED) export_word_mappings_to_globals: False +# Fixed padding length +# -1 -> For each batch, automatically pad to the length of the longest sequence of the batch +# (variable from batch to batch) +# > 0 -> Pad each pad to the chosen length (fixed for all batches) +fixed_padding: -1 + # File containing pretrained embeddings (LOADED) # Empty means that no embeddings will be loaded. pretrained_embeddings_file: '' diff --git a/configs/default/components/text/sentence_indexer.yml b/configs/default/components/text/sentence_indexer.yml index 0921bc7..8ec714f 100644 --- a/configs/default/components/text/sentence_indexer.yml +++ b/configs/default/components/text/sentence_indexer.yml @@ -25,6 +25,12 @@ import_word_mappings_from_globals: False # Flag informing whether word mappings will be exported to globals (LOADED) export_word_mappings_to_globals: False +# Fixed padding length +# -1 -> For each batch, automatically pad to the length of the longest sequence of the batch +# (variable from batch to batch) +# > 0 -> Pad each pad to the chosen length (fixed for all batches) +fixed_padding: -1 + # Operation mode. If 'reverse' is True, then it will change indices into words (LOADED) reverse: False diff --git a/ptp/components/models/sentence_embeddings.py b/ptp/components/models/sentence_embeddings.py index 6004e2a..78c3065 100644 --- a/ptp/components/models/sentence_embeddings.py +++ b/ptp/components/models/sentence_embeddings.py @@ -25,6 +25,7 @@ from ptp.data_types.data_definition import DataDefinition import ptp.components.utils.embeddings as emb +from ptp.components.utils.word_mappings import pad_trunc_list class SentenceEmbeddings(Model, WordMappings): @@ -56,6 +57,9 @@ def __init__(self, name, config): self.key_inputs = self.stream_keys["inputs"] self.key_outputs = self.stream_keys["outputs"] + # Force padding to a fixed length + self.fixed_padding = self.config['fixed_padding'] + # Retrieve embeddings size from configuration and export it to globals. self.embeddings_size = self.config['embeddings_size'] self.globals["embeddings_size"] = self.embeddings_size @@ -120,6 +124,11 @@ def forward(self, data_dict): # Add index to outputs. output_sample.append( output_index ) + # Apply fixed padding to all sequences if requested + # Otherwise let torch.nn.utils.rnn.pad_sequence handle it and choose a dynamic padding + if self.fixed_padding > 0: + pad_trunc_list(output_sample, self.fixed_padding) + #indices_list.append(self.app_state.FloatTensor(output_sample)) indices_list.append(self.app_state.LongTensor(output_sample)) diff --git a/ptp/components/text/sentence_indexer.py b/ptp/components/text/sentence_indexer.py index 7cb0ece..4450e83 100644 --- a/ptp/components/text/sentence_indexer.py +++ b/ptp/components/text/sentence_indexer.py @@ -19,6 +19,7 @@ from ptp.components.component import Component from ptp.components.mixins.word_mappings import WordMappings from ptp.data_types.data_definition import DataDefinition +from ptp.components.utils.word_mappings import pad_trunc_list class SentenceIndexer(Component, WordMappings): @@ -50,6 +51,9 @@ def __init__(self, name, config): # Read mode from the configuration. self.mode_reverse = self.config['reverse'] + # Force padding to a fixed length + self.fixed_padding = self.config['fixed_padding'] + if self.mode_reverse: # We will need reverse (index:word) mapping. self.ix_to_word = dict((v,k) for k,v in self.word_to_ix.items()) @@ -140,10 +144,16 @@ def sentences_to_tensor(self, data_dict): # Add index to outputs. output_sample.append( output_index ) - outputs_list.append(output_sample) + # Apply fixed padding to all sequences if requested + # Otherwise let torch.nn.utils.rnn.pad_sequence handle it and choose a dynamic padding + if self.fixed_padding > 0: + pad_trunc_list(output_sample, self.fixed_padding) + + outputs_list.append(self.app_state.LongTensor(output_sample)) # Transform the list of lists to tensor. - output = self.app_state.LongTensor(outputs_list) + # output = self.app_state.LongTensor(outputs_list) + output = torch.nn.utils.rnn.pad_sequence(outputs_list, batch_first=True) # Create the returned dict. data_dict.extend({self.key_outputs: output}) @@ -172,6 +182,12 @@ def tensor_indices_to_sentences(self, data_dict): output_word = self.ix_to_word[token] # Add index to outputs. output_sample.append( output_word ) + + # Apply fixed padding to all sequences if requested + # Otherwise let torch.nn.utils.rnn.pad_sequence handle it and choose a dynamic padding + if self.fixed_padding > 0: + pad_trunc_list(output_sample, self.fixed_padding) + # Add sentence to batch. outputs_list.append(output_sample) @@ -204,6 +220,12 @@ def tensor_distributions_to_sentences(self, data_dict): output_word = self.ix_to_word[token] # Add index to outputs. output_sample.append( output_word ) + + # Apply fixed padding to all sequences if requested + # Otherwise let torch.nn.utils.rnn.pad_sequence handle it and choose a dynamic padding + if self.fixed_padding > 0: + pad_trunc_list(output_sample, self.fixed_padding) + # Add sentence to batch. outputs_list.append(output_sample) diff --git a/ptp/components/utils/word_mappings.py b/ptp/components/utils/word_mappings.py index d43abf6..5b94350 100644 --- a/ptp/components/utils/word_mappings.py +++ b/ptp/components/utils/word_mappings.py @@ -135,3 +135,21 @@ def save_word_mappings_to_csv_file(logger, folder, filename, word_to_ix, fieldna writer.writerow({fieldnames[0]:k, fieldnames[1]: v}) logger.info("Saved mappings of size {} to file '{}'".format(len(word_to_ix), file_path)) + +def pad_trunc_list(l: list, length: int, value = 0): + """ + Will apply padding / clipping to list to meet requested length. + Works on the list in-place. + + :param l: List to manipulate + + :param length: Target length + + :param value: Value to fill when padding. Default is int(0). + + :return: None + """ + if len(l) < length: + l.extend([value]*(length-len(l))) + elif len(l) > length: + del l[length:]