Skip to content

Commit

Permalink
Reimplement fixed size batching
Browse files Browse the repository at this point in the history
Fixed size batching entails using
- "batch_type: sents" to fix the batch dimension to batch_size, and
- "pad_to_max_length: true" together with "max_length" to fix the
  sequence length dimension.

Closes #67
  • Loading branch information
Waino committed May 20, 2024
1 parent a4c0dfe commit e331a39
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 10 deletions.
26 changes: 25 additions & 1 deletion mammoth/inputters/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ def build_dataloader(
loader = InferenceBatcher(dataset, batch_size)
else:
if batch_type == 'sents':
raise NotImplementedError()
loader = SentenceMinibatcher(
dataset=dataset,
batch_size=batch_size,
)
elif batch_type == 'tokens':
loader = SimpleLookAheadBucketing(
dataset=dataset,
Expand Down Expand Up @@ -153,6 +156,27 @@ def __iter__(self):
)


class SentenceMinibatcher():
"""
Arguments:
dataset: mammoth.inputters.ParallelCorpus
batch_size:
The maximum size of each minibatch in sentence.
"""
def __init__(self, dataset, batch_size):
self.batch_size = batch_size
self.collate_fn = dataset.collate_fn
self._sie = ScoredInfiniteExamples(dataset, score_fn=lambda x: 1)

def __iter__(self):
while True:
minibatch = []
for _ in range(self.batch_size):
_, example = self._sie.next()
minibatch.append(example)
yield self.collate_fn(minibatch)


class DynamicDatasetIter(object):
"""Yield batch from (multiple) plain text corpus.
Expand Down
30 changes: 26 additions & 4 deletions mammoth/inputters/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
offset=None,
is_train=False,
task=None,
max_length=None,
):
self.src_file = src_file
self.tgt_file = tgt_file
Expand All @@ -102,6 +103,7 @@ def __init__(
self.offset = offset
self.is_train = is_train
self.corpus_id = task.corpus_id
self.max_length = max_length

# FIXME: most likely redundant with mammoth.transforms.tokenize
def _tokenize(self, string, side='src'):
Expand All @@ -121,6 +123,18 @@ def _numericalize(self, tokens, side='src'):
], device='cpu')
return indices

def _pad_sequence(self, tensors: list, padding_value: int = 0):
padded = None
if self.max_length is not None:
padded = torch.full((self.max_length, len(tensors)), padding_value, device='cpu')
for idx, tensor in enumerate(tensors):
if tensor.numel() > self.max_length:
tensor = tensor[:self.max_length]
padded[:tensor.numel(), idx] = tensor
else:
padded = pad_sequence(tensors, padding_value=padding_value)
return padded.unsqueeze(-1)

def to(self, device):
self.device = device
return self
Expand Down Expand Up @@ -158,14 +172,17 @@ def collate_fn(self, examples):
has_tgt = 'tgt' in examples[0].keys()
src_padidx = self.vocabs['src'][DefaultTokens.PAD]
tgt_padidx = self.vocabs['tgt'][DefaultTokens.PAD]
src_lengths = torch.tensor([ex['src'].numel() for ex in examples], device='cpu')
src = (pad_sequence([ex['src'] for ex in examples], padding_value=src_padidx).unsqueeze(-1), src_lengths)
if self.max_length is None:
src_lengths = torch.tensor([ex['src'].numel() for ex in examples], device='cpu')
else:
src_lengths = torch.tensor([min(ex['src'].numel(), self.max_length) for ex in examples], device='cpu')
src = (self._pad_sequence([ex['src'] for ex in examples], padding_value=src_padidx), src_lengths)
if has_tgt:
tgt = pad_sequence([ex['tgt'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1)
tgt = self._pad_sequence([ex['tgt'] for ex in examples], padding_value=tgt_padidx)
if 'labels' not in examples[0].keys():
labels = tgt
else:
labels = pad_sequence([ex['labels'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1)
labels = self._pad_sequence([ex['labels'] for ex in examples], padding_value=tgt_padidx)
else:
tgt = None
labels = None
Expand All @@ -190,6 +207,10 @@ def get_corpus(opts, task, src_vocab: Vocab, tgt_vocab: Vocab, is_train: bool =
)
transforms_to_apply = [transforms_cls[trf_name] for trf_name in transforms_to_apply]

max_length = None
if opts.pad_to_max_length:
assert opts.max_length is not None and opts.max_length > 0, 'Please provide a --max_length'
max_length = opts.max_length
# build Dataset proper
dataset = ParallelCorpus(
corpus_opts["path_src"] if is_train else corpus_opts["path_valid_src"],
Expand All @@ -201,6 +222,7 @@ def get_corpus(opts, task, src_vocab: Vocab, tgt_vocab: Vocab, is_train: bool =
offset=corpus_opts.get('offset', None),
is_train=is_train,
task=task,
max_length=max_length,
)
return dataset

Expand Down
11 changes: 6 additions & 5 deletions mammoth/modules/layer_stack_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@


class LayerStackEncoder(EncoderBase):
def __init__(self, embeddings, encoders):
def __init__(self, embeddings, encoders, max_length=None):
super().__init__()

self.embeddings = embeddings
self.encoders: nn.ModuleList[nn.ModuleDict] = encoders
self._adapter_to_stack: Dict[str, int] = dict()
self._active: List[str] = []
self._max_length = max_length

@classmethod
def from_opts(cls, opts, embeddings, task_queue_manager):
Expand Down Expand Up @@ -48,7 +49,7 @@ def from_opts(cls, opts, embeddings, task_queue_manager):
is_normformer=opts.normformer,
)
encoders.append(stacks)
return cls(embeddings, encoders)
return cls(embeddings, encoders, opts.max_length)

@classmethod
def from_trans_opt(cls, opts, embeddings, task):
Expand Down Expand Up @@ -79,7 +80,7 @@ def from_trans_opt(cls, opts, embeddings, task):
is_normformer=opts.normformer,
)
encoders.append(stacks)
return cls(embeddings, encoders)
return cls(embeddings, encoders, max_length=None)

def update_dropout(self, dropout, attention_dropout):
self.embeddings.update_dropout(dropout)
Expand All @@ -91,7 +92,7 @@ def forward(self, src, lengths=None, **kwargs):
# wrapper embeds src and creates mask
emb = self.embeddings(src)
emb = emb.transpose(0, 1).contiguous()
mask = ~sequence_mask(lengths).unsqueeze(1)
mask = ~sequence_mask(lengths, max_len=self._max_length).unsqueeze(1)

output = emb
for active_id, stacks in zip(self._active, self.encoders):
Expand Down Expand Up @@ -144,7 +145,7 @@ def add_adapter(
self._adapter_to_stack[name] = layer_stack_index
if layer_stack_index >= len(self.encoders):
raise ValueError(
f'No layer stack with index {layer_stack_index}. There are {len(len(self.encoders))} layer stacks'
f'No layer stack with index {layer_stack_index}. There are {len(self.encoders)} layer stacks'
)
if len(module_ids) == 0:
raise Exception(f'Adapter {adapter_group} {sub_id} has no module_ids')
Expand Down
2 changes: 2 additions & 0 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,8 @@ def _add_train_general_opts(parser):
choices=["sents", "tokens"],
help="Batch grouping for batch_size. Standard is sents. Tokens will do dynamic batching",
)
group.add('--pad_to_max_length', '-pad_to_max_length', action='store_true')
group.add('--max_length', '-max_length', type=int, default=None, help='Maximum sequence length.')
group.add(
'--task_distribution_strategy',
'-task_distribution_strategy',
Expand Down
31 changes: 31 additions & 0 deletions mammoth/tests/test_look_ahead_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,34 @@ def test_simple_lookeahead_bucketing(max_batch_size, lookahead_minibatches):
examples_read.extend(batch)
# Check that the stream was cycled
assert len(examples_read) > len(stream)


@pytest.mark.parametrize(
'batch_size',
[1, 5, 12, 2048],
)
def test_sentence_minibatcher(batch_size):
stream = MockStream([
hashabledict({
'src': tuple([letter for _ in range(i)]),
'tgt': tuple([letter for _ in range(j)]),
})
for letter in 'xyz'
for i, j in product(range(1, 11), range(1, 11))
])
lab = build_dataloader(
stream,
batch_size=batch_size,
batch_type='sents',
cycle=True,
as_iter=False
)
examples_read = []
batches = iter(lab)
for _ in range(1000):
batch = next(batches)
print(batch)
assert len(batch) == batch_size
examples_read.extend(batch)
# Check that the stream was cycled
assert len(examples_read) > len(stream)

0 comments on commit e331a39

Please sign in to comment.