In [8]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

In [9]:
def pad_sentences(sents, pad_id):
    """
    Adding pad_id to sentences in a mini-batch to ensure that 
    all augmented sentences in a mini-batch have the same word length.
    Args:
        sents: list(list(int)), a list of a list of word ids
        pad_id: the word id of the "<pad>" token
    Return:
        aug_sents: list(list(int)), |s_1| == |s_i|, for s_i in sents
    """

    sequences = [torch.tensor(seq) for seq in sents]
    aug_sents = pad_sequence(sequences, batch_first=True, padding_value=pad_id)
    return aug_sents

In [10]:
pad_sentences([[3,3,4,2,1,3]],3)

tensor([[3, 3, 4, 2, 1, 3]])

In [11]:
# Additional tests for pad_sentences
print(pad_sentences([[3,3,4,2,1,3]], 3))  # single sequence, no padding
print(pad_sentences([[1,2,3],[4,5]], 0))  # differing lengths -> pad with 0
print(pad_sentences([[7],[8,9,10,11]], 9))  # differing lengths, custom pad_id
# Edge case: empty batch
try:
    print(pad_sentences([], 0))
except Exception as e:
    print('empty batch error:', type(e).__name__, e)


tensor([[3, 3, 4, 2, 1, 3]])
tensor([[1, 2, 3],
        [4, 5, 0]])
tensor([[ 7,  9,  9,  9],
        [ 8,  9, 10, 11]])
empty batch error: RuntimeError received an empty list of sequences
