Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

max length padding/truncating #55

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mammoth/inputters/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def build_dataloader(dataset, batch_size, batch_type, pool_size=None, n_buckets=
n_buckets = 1

def bucket_fn(_):
return 0
return 0, 0

def numel_fn(_):
return 1
Expand Down
30 changes: 26 additions & 4 deletions mammoth/inputters/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
offset=None,
is_train=False,
task=None,
max_length=None,
):
self.src_file = src_file
self.tgt_file = tgt_file
Expand All @@ -102,6 +103,7 @@ def __init__(
self.offset = offset
self.is_train = is_train
self.corpus_id = task.corpus_id
self.max_length = max_length

# FIXME: most likely redundant with mammoth.transforms.tokenize
def _tokenize(self, string, side='src'):
Expand Down Expand Up @@ -153,19 +155,34 @@ def _cast(example_dict):
examples = map(_cast, examples)
yield from examples

def _pad_sequence(self, tensors: list, padding_value: int = 0):
padded = None
if self.max_length is not None:
padded = torch.full((self.max_length, len(tensors)), padding_value, device='cpu')
for idx, tensor in enumerate(tensors):
if tensor.numel() > self.max_length:
tensor = tensor[:self.max_length]
padded[:tensor.numel(), idx] = tensor
else:
padded = pad_sequence(tensors, padding_value=padding_value)
return padded.unsqueeze(-1)

# FIXME: some RNN archs require sorting src's by length
def collate_fn(self, examples):
has_tgt = 'tgt' in examples[0].keys()
src_padidx = self.vocabs['src'][DefaultTokens.PAD]
tgt_padidx = self.vocabs['tgt'][DefaultTokens.PAD]
src_lengths = torch.tensor([ex['src'].numel() for ex in examples], device='cpu')
src = (pad_sequence([ex['src'] for ex in examples], padding_value=src_padidx).unsqueeze(-1), src_lengths)
if self.max_length is None:
src_lengths = torch.tensor([ex['src'].numel() for ex in examples], device='cpu')
else:
src_lengths = torch.tensor([min(ex['src'].numel(), self.max_length) for ex in examples], device='cpu')
src = (self._pad_sequence([ex['src'] for ex in examples], padding_value=src_padidx), src_lengths)
if has_tgt:
tgt = pad_sequence([ex['tgt'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1)
tgt = self._pad_sequence([ex['tgt'] for ex in examples], padding_value=tgt_padidx)
if 'labels' not in examples[0].keys():
labels = tgt
else:
labels = pad_sequence([ex['labels'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1)
labels = self._pad_sequence([ex['labels'] for ex in examples], padding_value=tgt_padidx)
else:
tgt = None
labels = None
Expand All @@ -190,6 +207,10 @@ def get_corpus(opts, task, src_vocab: Vocab, tgt_vocab: Vocab, is_train: bool =
)
transforms_to_apply = [transforms_cls[trf_name] for trf_name in transforms_to_apply]

max_length = None
if opts.pad_to_max_length:
assert opts.max_length is not None and opts.max_length > 0, 'Please provide a --max_length'
max_length = opts.max_length
# build Dataset proper
dataset = ParallelCorpus(
corpus_opts["path_src"] if is_train else corpus_opts["path_valid_src"],
Expand All @@ -201,6 +222,7 @@ def get_corpus(opts, task, src_vocab: Vocab, tgt_vocab: Vocab, is_train: bool =
offset=corpus_opts.get('offset', None),
is_train=is_train,
task=task,
max_length=max_length,
)
return dataset

Expand Down
7 changes: 4 additions & 3 deletions mammoth/modules/layer_stack_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@


class LayerStackEncoder(EncoderBase):
def __init__(self, embeddings, encoders):
def __init__(self, embeddings, encoders, max_length):
super().__init__()

self.embeddings = embeddings
self.encoders: nn.ModuleList[nn.ModuleDict] = encoders
self._adapter_to_stack: Dict[str, int] = dict()
self._active: List[str] = []
self._max_length = max_length

@classmethod
def from_opts(cls, opts, embeddings, task_queue_manager):
Expand Down Expand Up @@ -48,7 +49,7 @@ def from_opts(cls, opts, embeddings, task_queue_manager):
is_normformer=opts.normformer,
)
encoders.append(stacks)
return cls(embeddings, encoders)
return cls(embeddings, encoders, opts.max_length)

@classmethod
def from_trans_opt(cls, opts, embeddings, task):
Expand Down Expand Up @@ -91,7 +92,7 @@ def forward(self, src, lengths=None, **kwargs):
# wrapper embeds src and creates mask
emb = self.embeddings(src)
emb = emb.transpose(0, 1).contiguous()
mask = ~sequence_mask(lengths).unsqueeze(1)
mask = ~sequence_mask(lengths, max_len=self._max_length).unsqueeze(1)

output = emb
for active_id, stacks in zip(self._active, self.encoders):
Expand Down
2 changes: 2 additions & 0 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,8 @@ def _add_train_general_opts(parser):
choices=["sents", "tokens"],
help="Batch grouping for batch_size. Standard is sents. Tokens will do dynamic batching",
)
group.add('--pad_to_max_length', '-pad_to_max_length', action='store_true')
group.add('--max_length', '-max_length', type=int, default=100, help='Maximum prediction length.')
group.add(
'--normalization',
'-normalization',
Expand Down
Loading