In [1]:
import torch
from torch.nn import Embedding, LSTM
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

In [2]:
# dataset is a list of sequences/sentences
# the elements of the sentences could be anything, as long as it can be contained in a torch tensor
# usually, these will be indices of words based on some vocabulary
# 0 is commonly reserved for the padding token, here it appears once explicitly and on purpose,
#  to check that it functions properly (= in the same way as the automatically added padding tokens)
DATA = [
    [1, 2, 3],
    [4, 5],
    [6, 7, 8, 9],
    [4, 6, 2, 9, 0]
]

In [3]:
# need torch tensors for torch's pad_sequence(); this could be a part of e.g. dataset's __getitem__ instead
DATA = list(map(lambda x: torch.tensor(x), DATA))
# vocab size (for embedding); including 0 (the padding token)
NUM_WORDS = 10

In [4]:
SEED = 0
# for consistent results between runs
torch.manual_seed(SEED)

BATCH_SIZE = 3
EMB_DIM = 2
LSTM_DIM = 5

In [5]:
class MinimalDataset(Dataset):
    def __init__(self, data) -> None:
        super().__init__()
        self.data = data

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

In [6]:
dataset = MinimalDataset(DATA)
# len(data) is not divisible by batch_size on purpose to verify consistency across batch sizes
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=lambda x: x)

In [7]:
# collate_fn is crucial for handling data points of varying length (as is the case here)
print(next(iter(data_loader)))

[tensor([1, 2, 3]), tensor([4, 5]), tensor([6, 7, 8, 9])]


In [8]:
# I would think that we should always obtain:
# [ [1, 2, 3], [4, 5], [6, 7, 8, 9] ]
# but, without collate_fn set to identity as above, you would get:
# RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 2 in dimension 1 ...
# ¯\_(ツ)_/¯

# iterate through the dataset:
for i, batch in enumerate(data_loader):
    print(f'{i}, {batch}')

0, [tensor([1, 2, 3]), tensor([4, 5]), tensor([6, 7, 8, 9])]
1, [tensor([4, 6, 2, 9, 0])]


In [9]:
# playing around with padding (= unpacking) and packing (= unpadding)
print('padding and [un]packing')
# this always gets you the first batch of the dataset:
batch = next(iter(data_loader))

padding and [un]packing


In [10]:
print(f'batch: \n{batch}\n')
# need to store the sequence lengths explicitly if we want to later pack the sequence:
lens = list(map(len, batch))

batch: 
[tensor([1, 2, 3]), tensor([4, 5]), tensor([6, 7, 8, 9])]



In [11]:
padded = pad_sequence(batch, batch_first=True)
print(f' [0] padded: \n{padded}\n')

 [0] padded: 
tensor([[1, 2, 3, 0],
        [4, 5, 0, 0],
        [6, 7, 8, 9]])



In [12]:
# pytorch <1.1.0 does not support enforce_sorted=False and you would have to sort the sequences manually
packed = pack_padded_sequence(padded, lens, batch_first=True, enforce_sorted=False)
print(f' [1] packed: \n{packed}\n')
padded2 = pad_packed_sequence(packed, batch_first=True)
print(f' [2] padded: \n{padded2}\n')
# pad(pack(pad(x))) == pad(x) as pad() and pack() are inverse to each other
assert torch.all(torch.eq(padded2[0], padded))

 [1] packed: 
PackedSequence(data=tensor([6, 1, 4, 7, 2, 5, 8, 3, 9]), batch_sizes=tensor([3, 3, 2, 1]), sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))

 [2] padded: 
(tensor([[1, 2, 3, 0],
        [4, 5, 0, 0],
        [6, 7, 8, 9]]), tensor([3, 2, 4]))



In [13]:
# putting everything together: dataset - data_loader - padding - embedding - packing - lstm - unpacking (padding)
print('embedding')
batch = next(iter(data_loader))
# or:
# for batch in data_loader:

embedding


In [14]:
print(f'------------------------\nbatch: \n{batch}\n')
lens = list(map(len, batch))

------------------------
batch: 
[tensor([1, 2, 3]), tensor([4, 5]), tensor([6, 7, 8, 9])]



In [15]:
embedding = Embedding(NUM_WORDS, EMB_DIM)
lstm = LSTM(input_size=EMB_DIM, hidden_size=LSTM_DIM, batch_first=True)

In [23]:
embedding

Embedding(10, 2)

In [16]:
# we first have to pad, making all sequences in the batch equally long
padded = pad_sequence(batch, batch_first=True)
print(f'> pad: \n{padded}\n')

> pad: 
tensor([[1, 2, 3, 0],
        [4, 5, 0, 0],
        [6, 7, 8, 9]])



In [17]:
# now add the embedding dimension:
pad_embed = embedding(padded)
print(f'> pad_embed: \n{pad_embed}\n')

> pad_embed: 
tensor([[[ 0.4913, -0.2041],
         [ 0.1665,  0.8744],
         [-0.1435, -0.1116],
         [-0.3561,  0.4372]],

        [[-0.6136,  0.0316],
         [-0.4927,  0.2484],
         [-0.3561,  0.4372],
         [-0.3561,  0.4372]],

        [[ 0.6181, -0.4128],
         [-0.8411, -2.3160],
         [-0.1023,  0.7924],
         [-0.2897,  0.0525]]], grad_fn=<EmbeddingBackward>)



In [18]:
# pack it up to one sequence (where each element is EMB_DIM long)
pad_embed_pack = pack_padded_sequence(pad_embed, lens, batch_first=True, enforce_sorted=False)
print(f'> pad_embed_pack: \n{pad_embed_pack}\n')

> pad_embed_pack: 
PackedSequence(data=tensor([[ 0.6181, -0.4128],
        [ 0.4913, -0.2041],
        [-0.6136,  0.0316],
        [-0.8411, -2.3160],
        [ 0.1665,  0.8744],
        [-0.4927,  0.2484],
        [-0.1023,  0.7924],
        [-0.1435, -0.1116],
        [-0.2897,  0.0525]], grad_fn=<PackPaddedSequenceBackward>), batch_sizes=tensor([3, 3, 2, 1]), sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))



In [19]:
# run that through the lstm
pad_embed_pack_lstm = lstm(pad_embed_pack)
print(f'> pad_embed_pack_lstm: \n{pad_embed_pack_lstm}\n')

> pad_embed_pack_lstm: 
(PackedSequence(data=tensor([[-2.8078e-02, -7.5184e-02, -1.5413e-01, -4.4770e-02,  1.2383e-02],
        [-3.6822e-02, -6.6412e-02, -1.4248e-01, -4.7936e-02,  2.4047e-02],
        [-8.1972e-02, -5.3362e-02, -1.9377e-01,  3.4363e-02, -2.3898e-04],
        [-5.0494e-02, -2.0331e-01, -4.3655e-01,  3.7436e-02, -1.6537e-01],
        [-9.4117e-02, -8.0996e-02, -1.1510e-01, -9.3560e-02,  7.8516e-02],
        [-1.3141e-01, -9.9222e-02, -2.7083e-01,  6.3981e-02,  6.1639e-03],
        [-1.5719e-01, -1.5504e-01, -2.6151e-01,  4.5258e-02,  5.2963e-03],
        [-1.0997e-01, -1.5025e-01, -2.4058e-01, -6.0991e-03,  5.5633e-02],
        [-1.4524e-01, -1.9434e-01, -3.1854e-01,  7.6890e-02,  3.8153e-03]],
       grad_fn=<CatBackward>), batch_sizes=tensor([3, 3, 2, 1]), sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0])), (tensor([[[-0.1100, -0.1502, -0.2406, -0.0061,  0.0556],
         [-0.1314, -0.0992, -0.2708,  0.0640,  0.0062],
         [-0.1452, -0.1943, -0

In [20]:
# unpack the results (we can do that because it remembers how we packed the sentences)
# the [0] just takes the first element ("out") of the LSTM output (hidden states after each timestep)
pad_embed_pack_lstm_pad = pad_packed_sequence(pad_embed_pack_lstm[0], batch_first=True)
print(f'> pad_embed_pack_lstm_pad: \n{pad_embed_pack_lstm_pad}\n')

> pad_embed_pack_lstm_pad: 
(tensor([[[-3.6822e-02, -6.6412e-02, -1.4248e-01, -4.7936e-02,  2.4047e-02],
         [-9.4117e-02, -8.0996e-02, -1.1510e-01, -9.3560e-02,  7.8516e-02],
         [-1.0997e-01, -1.5025e-01, -2.4058e-01, -6.0991e-03,  5.5633e-02],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],

        [[-8.1972e-02, -5.3362e-02, -1.9377e-01,  3.4363e-02, -2.3898e-04],
         [-1.3141e-01, -9.9222e-02, -2.7083e-01,  6.3981e-02,  6.1639e-03],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],

        [[-2.8078e-02, -7.5184e-02, -1.5413e-01, -4.4770e-02,  1.2383e-02],
         [-5.0494e-02, -2.0331e-01, -4.3655e-01,  3.7436e-02, -1.6537e-01],
         [-1.5719e-01, -1.5504e-01, -2.6151e-01,  4.5258e-02,  5.2963e-03],
         [-1.4524e-01, -1.9434e-01, -3.1854e-01,  7.6890e-02,  3.8153e-03]]],
       grad_fn=<IndexSelectBackward>), tensor([3, 2, 

In [21]:
# however, usually, we would just be interested in the last hidden state of the lstm for each sequence,
# i.e., the [last] lstm state after it has processed the sentence
# for this, the last unpacking/padding is not necessary, as we can obtain this already by:
seq, (ht, ct) = pad_embed_pack_lstm
print(f'lstm last state without unpacking:\n{ht[-1]}')

lstm last state without unpacking:
tensor([[-0.1100, -0.1502, -0.2406, -0.0061,  0.0556],
        [-0.1314, -0.0992, -0.2708,  0.0640,  0.0062],
        [-0.1452, -0.1943, -0.3185,  0.0769,  0.0038]],
       grad_fn=<SelectBackward>)


In [22]:
# which is the same as
outs, lens = pad_embed_pack_lstm_pad
print(f'lstm last state after unpacking:\n'
      f'{torch.cat([outs[i, len - 1] for i, len in enumerate(lens)]).view((BATCH_SIZE, -1))}')
# i.e. the last non-masked/padded/null state
# so, you probably shouldn't unpack the sequence if you don't need to

lstm last state after unpacking:
tensor([[-0.1100, -0.1502, -0.2406, -0.0061,  0.0556],
        [-0.1314, -0.0992, -0.2708,  0.0640,  0.0062],
        [-0.1452, -0.1943, -0.3185,  0.0769,  0.0038]], grad_fn=<ViewBackward>)
