Skip to content
This repository was archived by the owner on Jul 18, 2024. It is now read-only.

Commit 7c97fe2

Browse files
authored
Merge pull request #18 from aasseman/feat/fixed-sentence-padding
Add fixed padding option to sentence_embeddings, sentence_indexer
2 parents bf85186 + 633553d commit 7c97fe2

File tree

5 files changed

+63
-2
lines changed

5 files changed

+63
-2
lines changed

configs/default/components/models/sentence_embeddings.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ import_word_mappings_from_globals: False
2525
# Flag informing whether word mappings will be exported to globals (LOADED)
2626
export_word_mappings_to_globals: False
2727

28+
# Fixed padding length
29+
# -1 -> For each batch, automatically pad to the length of the longest sequence of the batch
30+
# (variable from batch to batch)
31+
# > 0 -> Pad each pad to the chosen length (fixed for all batches)
32+
fixed_padding: -1
33+
2834
# File containing pretrained embeddings (LOADED)
2935
# Empty means that no embeddings will be loaded.
3036
pretrained_embeddings_file: ''

configs/default/components/text/sentence_indexer.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ import_word_mappings_from_globals: False
2525
# Flag informing whether word mappings will be exported to globals (LOADED)
2626
export_word_mappings_to_globals: False
2727

28+
# Fixed padding length
29+
# -1 -> For each batch, automatically pad to the length of the longest sequence of the batch
30+
# (variable from batch to batch)
31+
# > 0 -> Pad each pad to the chosen length (fixed for all batches)
32+
fixed_padding: -1
33+
2834
# Operation mode. If 'reverse' is True, then it will change indices into words (LOADED)
2935
reverse: False
3036

ptp/components/models/sentence_embeddings.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ptp.data_types.data_definition import DataDefinition
2626

2727
import ptp.components.utils.embeddings as emb
28+
from ptp.components.utils.word_mappings import pad_trunc_list
2829

2930

3031
class SentenceEmbeddings(Model, WordMappings):
@@ -56,6 +57,9 @@ def __init__(self, name, config):
5657
self.key_inputs = self.stream_keys["inputs"]
5758
self.key_outputs = self.stream_keys["outputs"]
5859

60+
# Force padding to a fixed length
61+
self.fixed_padding = self.config['fixed_padding']
62+
5963
# Retrieve embeddings size from configuration and export it to globals.
6064
self.embeddings_size = self.config['embeddings_size']
6165
self.globals["embeddings_size"] = self.embeddings_size
@@ -120,6 +124,11 @@ def forward(self, data_dict):
120124
# Add index to outputs.
121125
output_sample.append( output_index )
122126

127+
# Apply fixed padding to all sequences if requested
128+
# Otherwise let torch.nn.utils.rnn.pad_sequence handle it and choose a dynamic padding
129+
if self.fixed_padding > 0:
130+
pad_trunc_list(output_sample, self.fixed_padding)
131+
123132
#indices_list.append(self.app_state.FloatTensor(output_sample))
124133
indices_list.append(self.app_state.LongTensor(output_sample))
125134

ptp/components/text/sentence_indexer.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ptp.components.component import Component
2020
from ptp.components.mixins.word_mappings import WordMappings
2121
from ptp.data_types.data_definition import DataDefinition
22+
from ptp.components.utils.word_mappings import pad_trunc_list
2223

2324

2425
class SentenceIndexer(Component, WordMappings):
@@ -50,6 +51,9 @@ def __init__(self, name, config):
5051
# Read mode from the configuration.
5152
self.mode_reverse = self.config['reverse']
5253

54+
# Force padding to a fixed length
55+
self.fixed_padding = self.config['fixed_padding']
56+
5357
if self.mode_reverse:
5458
# We will need reverse (index:word) mapping.
5559
self.ix_to_word = dict((v,k) for k,v in self.word_to_ix.items())
@@ -140,10 +144,16 @@ def sentences_to_tensor(self, data_dict):
140144
# Add index to outputs.
141145
output_sample.append( output_index )
142146

143-
outputs_list.append(output_sample)
147+
# Apply fixed padding to all sequences if requested
148+
# Otherwise let torch.nn.utils.rnn.pad_sequence handle it and choose a dynamic padding
149+
if self.fixed_padding > 0:
150+
pad_trunc_list(output_sample, self.fixed_padding)
151+
152+
outputs_list.append(self.app_state.LongTensor(output_sample))
144153

145154
# Transform the list of lists to tensor.
146-
output = self.app_state.LongTensor(outputs_list)
155+
# output = self.app_state.LongTensor(outputs_list)
156+
output = torch.nn.utils.rnn.pad_sequence(outputs_list, batch_first=True)
147157
# Create the returned dict.
148158
data_dict.extend({self.key_outputs: output})
149159

@@ -172,6 +182,12 @@ def tensor_indices_to_sentences(self, data_dict):
172182
output_word = self.ix_to_word[token]
173183
# Add index to outputs.
174184
output_sample.append( output_word )
185+
186+
# Apply fixed padding to all sequences if requested
187+
# Otherwise let torch.nn.utils.rnn.pad_sequence handle it and choose a dynamic padding
188+
if self.fixed_padding > 0:
189+
pad_trunc_list(output_sample, self.fixed_padding)
190+
175191
# Add sentence to batch.
176192
outputs_list.append(output_sample)
177193

@@ -204,6 +220,12 @@ def tensor_distributions_to_sentences(self, data_dict):
204220
output_word = self.ix_to_word[token]
205221
# Add index to outputs.
206222
output_sample.append( output_word )
223+
224+
# Apply fixed padding to all sequences if requested
225+
# Otherwise let torch.nn.utils.rnn.pad_sequence handle it and choose a dynamic padding
226+
if self.fixed_padding > 0:
227+
pad_trunc_list(output_sample, self.fixed_padding)
228+
207229
# Add sentence to batch.
208230
outputs_list.append(output_sample)
209231

ptp/components/utils/word_mappings.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,21 @@ def save_word_mappings_to_csv_file(logger, folder, filename, word_to_ix, fieldna
135135
writer.writerow({fieldnames[0]:k, fieldnames[1]: v})
136136

137137
logger.info("Saved mappings of size {} to file '{}'".format(len(word_to_ix), file_path))
138+
139+
def pad_trunc_list(l: list, length: int, value = 0):
140+
"""
141+
Will apply padding / clipping to list to meet requested length.
142+
Works on the list in-place.
143+
144+
:param l: List to manipulate
145+
146+
:param length: Target length
147+
148+
:param value: Value to fill when padding. Default is int(0).
149+
150+
:return: None
151+
"""
152+
if len(l) < length:
153+
l.extend([value]*(length-len(l)))
154+
elif len(l) > length:
155+
del l[length:]

0 commit comments

Comments
 (0)