From e331a390158842fbfc33281c7801be3230e6b48c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 20 May 2024 16:01:56 +0300 Subject: [PATCH] Reimplement fixed size batching Fixed size batching entails using - "batch_type: sents" to fix the batch dimension to batch_size, and - "pad_to_max_length: true" together with "max_length" to fix the sequence length dimension. Closes #67 --- mammoth/inputters/dataloader.py | 26 +++++++++++++++++- mammoth/inputters/dataset.py | 30 ++++++++++++++++++--- mammoth/modules/layer_stack_encoder.py | 11 ++++---- mammoth/opts.py | 2 ++ mammoth/tests/test_look_ahead_bucketing.py | 31 ++++++++++++++++++++++ 5 files changed, 90 insertions(+), 10 deletions(-) diff --git a/mammoth/inputters/dataloader.py b/mammoth/inputters/dataloader.py index d9e6d38c..39358553 100644 --- a/mammoth/inputters/dataloader.py +++ b/mammoth/inputters/dataloader.py @@ -22,7 +22,10 @@ def build_dataloader( loader = InferenceBatcher(dataset, batch_size) else: if batch_type == 'sents': - raise NotImplementedError() + loader = SentenceMinibatcher( + dataset=dataset, + batch_size=batch_size, + ) elif batch_type == 'tokens': loader = SimpleLookAheadBucketing( dataset=dataset, @@ -153,6 +156,27 @@ def __iter__(self): ) +class SentenceMinibatcher(): + """ + Arguments: + dataset: mammoth.inputters.ParallelCorpus + batch_size: + The maximum size of each minibatch in sentence. + """ + def __init__(self, dataset, batch_size): + self.batch_size = batch_size + self.collate_fn = dataset.collate_fn + self._sie = ScoredInfiniteExamples(dataset, score_fn=lambda x: 1) + + def __iter__(self): + while True: + minibatch = [] + for _ in range(self.batch_size): + _, example = self._sie.next() + minibatch.append(example) + yield self.collate_fn(minibatch) + + class DynamicDatasetIter(object): """Yield batch from (multiple) plain text corpus. diff --git a/mammoth/inputters/dataset.py b/mammoth/inputters/dataset.py index 5f2779d5..df51d9b0 100644 --- a/mammoth/inputters/dataset.py +++ b/mammoth/inputters/dataset.py @@ -89,6 +89,7 @@ def __init__( offset=None, is_train=False, task=None, + max_length=None, ): self.src_file = src_file self.tgt_file = tgt_file @@ -102,6 +103,7 @@ def __init__( self.offset = offset self.is_train = is_train self.corpus_id = task.corpus_id + self.max_length = max_length # FIXME: most likely redundant with mammoth.transforms.tokenize def _tokenize(self, string, side='src'): @@ -121,6 +123,18 @@ def _numericalize(self, tokens, side='src'): ], device='cpu') return indices + def _pad_sequence(self, tensors: list, padding_value: int = 0): + padded = None + if self.max_length is not None: + padded = torch.full((self.max_length, len(tensors)), padding_value, device='cpu') + for idx, tensor in enumerate(tensors): + if tensor.numel() > self.max_length: + tensor = tensor[:self.max_length] + padded[:tensor.numel(), idx] = tensor + else: + padded = pad_sequence(tensors, padding_value=padding_value) + return padded.unsqueeze(-1) + def to(self, device): self.device = device return self @@ -158,14 +172,17 @@ def collate_fn(self, examples): has_tgt = 'tgt' in examples[0].keys() src_padidx = self.vocabs['src'][DefaultTokens.PAD] tgt_padidx = self.vocabs['tgt'][DefaultTokens.PAD] - src_lengths = torch.tensor([ex['src'].numel() for ex in examples], device='cpu') - src = (pad_sequence([ex['src'] for ex in examples], padding_value=src_padidx).unsqueeze(-1), src_lengths) + if self.max_length is None: + src_lengths = torch.tensor([ex['src'].numel() for ex in examples], device='cpu') + else: + src_lengths = torch.tensor([min(ex['src'].numel(), self.max_length) for ex in examples], device='cpu') + src = (self._pad_sequence([ex['src'] for ex in examples], padding_value=src_padidx), src_lengths) if has_tgt: - tgt = pad_sequence([ex['tgt'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1) + tgt = self._pad_sequence([ex['tgt'] for ex in examples], padding_value=tgt_padidx) if 'labels' not in examples[0].keys(): labels = tgt else: - labels = pad_sequence([ex['labels'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1) + labels = self._pad_sequence([ex['labels'] for ex in examples], padding_value=tgt_padidx) else: tgt = None labels = None @@ -190,6 +207,10 @@ def get_corpus(opts, task, src_vocab: Vocab, tgt_vocab: Vocab, is_train: bool = ) transforms_to_apply = [transforms_cls[trf_name] for trf_name in transforms_to_apply] + max_length = None + if opts.pad_to_max_length: + assert opts.max_length is not None and opts.max_length > 0, 'Please provide a --max_length' + max_length = opts.max_length # build Dataset proper dataset = ParallelCorpus( corpus_opts["path_src"] if is_train else corpus_opts["path_valid_src"], @@ -201,6 +222,7 @@ def get_corpus(opts, task, src_vocab: Vocab, tgt_vocab: Vocab, is_train: bool = offset=corpus_opts.get('offset', None), is_train=is_train, task=task, + max_length=max_length, ) return dataset diff --git a/mammoth/modules/layer_stack_encoder.py b/mammoth/modules/layer_stack_encoder.py index aec0754e..d70a90a8 100644 --- a/mammoth/modules/layer_stack_encoder.py +++ b/mammoth/modules/layer_stack_encoder.py @@ -8,13 +8,14 @@ class LayerStackEncoder(EncoderBase): - def __init__(self, embeddings, encoders): + def __init__(self, embeddings, encoders, max_length=None): super().__init__() self.embeddings = embeddings self.encoders: nn.ModuleList[nn.ModuleDict] = encoders self._adapter_to_stack: Dict[str, int] = dict() self._active: List[str] = [] + self._max_length = max_length @classmethod def from_opts(cls, opts, embeddings, task_queue_manager): @@ -48,7 +49,7 @@ def from_opts(cls, opts, embeddings, task_queue_manager): is_normformer=opts.normformer, ) encoders.append(stacks) - return cls(embeddings, encoders) + return cls(embeddings, encoders, opts.max_length) @classmethod def from_trans_opt(cls, opts, embeddings, task): @@ -79,7 +80,7 @@ def from_trans_opt(cls, opts, embeddings, task): is_normformer=opts.normformer, ) encoders.append(stacks) - return cls(embeddings, encoders) + return cls(embeddings, encoders, max_length=None) def update_dropout(self, dropout, attention_dropout): self.embeddings.update_dropout(dropout) @@ -91,7 +92,7 @@ def forward(self, src, lengths=None, **kwargs): # wrapper embeds src and creates mask emb = self.embeddings(src) emb = emb.transpose(0, 1).contiguous() - mask = ~sequence_mask(lengths).unsqueeze(1) + mask = ~sequence_mask(lengths, max_len=self._max_length).unsqueeze(1) output = emb for active_id, stacks in zip(self._active, self.encoders): @@ -144,7 +145,7 @@ def add_adapter( self._adapter_to_stack[name] = layer_stack_index if layer_stack_index >= len(self.encoders): raise ValueError( - f'No layer stack with index {layer_stack_index}. There are {len(len(self.encoders))} layer stacks' + f'No layer stack with index {layer_stack_index}. There are {len(self.encoders)} layer stacks' ) if len(module_ids) == 0: raise Exception(f'Adapter {adapter_group} {sub_id} has no module_ids') diff --git a/mammoth/opts.py b/mammoth/opts.py index b6c690aa..dfdc7858 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -617,6 +617,8 @@ def _add_train_general_opts(parser): choices=["sents", "tokens"], help="Batch grouping for batch_size. Standard is sents. Tokens will do dynamic batching", ) + group.add('--pad_to_max_length', '-pad_to_max_length', action='store_true') + group.add('--max_length', '-max_length', type=int, default=None, help='Maximum sequence length.') group.add( '--task_distribution_strategy', '-task_distribution_strategy', diff --git a/mammoth/tests/test_look_ahead_bucketing.py b/mammoth/tests/test_look_ahead_bucketing.py index e74bcf9f..6a97f79f 100644 --- a/mammoth/tests/test_look_ahead_bucketing.py +++ b/mammoth/tests/test_look_ahead_bucketing.py @@ -71,3 +71,34 @@ def test_simple_lookeahead_bucketing(max_batch_size, lookahead_minibatches): examples_read.extend(batch) # Check that the stream was cycled assert len(examples_read) > len(stream) + + +@pytest.mark.parametrize( + 'batch_size', + [1, 5, 12, 2048], +) +def test_sentence_minibatcher(batch_size): + stream = MockStream([ + hashabledict({ + 'src': tuple([letter for _ in range(i)]), + 'tgt': tuple([letter for _ in range(j)]), + }) + for letter in 'xyz' + for i, j in product(range(1, 11), range(1, 11)) + ]) + lab = build_dataloader( + stream, + batch_size=batch_size, + batch_type='sents', + cycle=True, + as_iter=False + ) + examples_read = [] + batches = iter(lab) + for _ in range(1000): + batch = next(batches) + print(batch) + assert len(batch) == batch_size + examples_read.extend(batch) + # Check that the stream was cycled + assert len(examples_read) > len(stream)