In [1]:
# PyTorch Tutorial: Merging Variable Length Batches
import torch
import torch.nn.functional as F

In [2]:
# Example data

batch1 = {
    'input_ids': torch.tensor([[1]]),       # Single token
    'attention_mask': torch.tensor([[1]])  # Valid attention at the first position
}

batch2 = {
    'input_ids': torch.tensor([[2, 3]]),       # Two tokens
    'attention_mask': torch.tensor([[1, 1]])  # Valid attention at two positions
}


In [5]:
# Step #1 Determine Maximum Sequence Length

seq1 = batch1['attention_mask'].shape[1]
seq2 = batch2['attention_mask'].shape[1]

print(seq1,seq2)
max_seq_len = max(seq1,seq2)



1 2


In [20]:
# Step2: Padding Sequences left

max_seq_len = max(batch1['input_ids'].shape[1], batch2['input_ids'].shape[1])

# Padding function to ensure both input_ids tensors have the same sequence length
def pad_to_max_len(tensor, max_len):
    padding_length = max_len - tensor.shape[1]
    # Pad on the right to keep data intact on the left
    return F.pad(tensor, (padding_length,0), 'constant', 0)

# Pad 'input_ids' and 'attention_mask' for both batches and write back to the original dictionaries
batch1['input_ids'] = pad_to_max_len(batch1['input_ids'], max_seq_len)
batch1['attention_mask'] = pad_to_max_len(batch1['attention_mask'], max_seq_len)
batch2['input_ids'] = pad_to_max_len(batch2['input_ids'], max_seq_len)
batch2['attention_mask'] = pad_to_max_len(batch2['attention_mask'], max_seq_len)

# Concatenating padded 'input_ids' and 'attention_mask'
merged_input_ids = torch.cat([batch1['input_ids'], batch2['input_ids']], dim=0)
merged_attention_mask = torch.cat([batch1['attention_mask'], batch2['attention_mask']], dim=0)

print("Merged Input IDs:")
print(merged_input_ids)
print("Merged Attention Mask:")
print(merged_attention_mask)

Merged Input IDs:
tensor([[1, 0],
        [2, 3]])
Merged Attention Mask:
tensor([[1, 0],
        [1, 1]])
