<a href="https://colab.research.google.com/github/Benj-admin/MAP583_X/blob/main/Lesson/Tutorial_packing_sequences.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Minimal tutorial on packing and unpacking sequences in PyTorch
# aka how to use `pack_padded_sequence` and  `pad_packed_sequence`

This is a jupyter version of [@Tushar-N 's gist](https://gist.github.com/Tushar-N/dfca335e370a2bc3bc79876e6270099e) with comments from [@Harsh Trivedi repo](https://github.com/HarshTrivedi/packing-unpacking-pytorch-minimal-tutorial) and adapted from [the dataflowr website](https://github.com/dataflowr/notebooks/blob/master/Module11/11_Tutorial_packing_sequences.ipynb).


In [1]:
# from https://github.com/HarshTrivedi/packing-unpacking-pytorch-minimal-tutorial
import torch
from torch import LongTensor
from torch.nn import Embedding, LSTM
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

## We want to run LSTM on a batch of 3 character sequences ['long_str', 'tiny', 'medium']
#
#     Step 1: Construct Vocabulary
#     Step 2: Load indexed data (list of instances, where each instance is list of character indices)
#     Step 3: Make Model
#  *  Step 4: Pad instances with 0s till max length sequence
#  *  Step 5: Sort instances by sequence length in descending order
#  *  Step 6: Embed the instances
#  *  Step 7: Call pack_padded_sequence with embeded instances and sequence lengths
#  *  Step 8: Forward with LSTM
#  *  Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector
#  *  Summary of Shape Transformations

In [2]:
# We want to run LSTM on a batch following 3 character sequences
seqs = ['long_str',  # len = 8
        'tiny',      # len = 4
        'medium']    # len = 6

In [3]:
## Step 1: Construct Vocabulary ##
##------------------------------##
# make sure <pad> idx is 0
vocab = ['<pad>'] + sorted(set([char for seq in seqs for char in seq]))

In [4]:
vocab

['<pad>', '_', 'd', 'e', 'g', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'y']

In [5]:
## Step 2: Load indexed data (list of instances, where each instance is list of character indices) ##
##-------------------------------------------------------------------------------------------------##
vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs]

In [6]:
vectorized_seqs

[[6, 9, 8, 4, 1, 11, 12, 10], [12, 5, 8, 14], [7, 3, 2, 5, 13, 7]]

In [7]:
## Step 3: Make Model ##
##--------------------##
embed = Embedding(len(vocab), 4) # embedding_dim = 4
lstm = LSTM(input_size=4, hidden_size=5, num_layers=2, batch_first=True) # input_dim = 4, hidden_dim = 5

In [8]:
## Step 4: Pad instances with 0s till max length sequence ##
##--------------------------------------------------------##

# get the length of each seq in your batch
seq_lengths = LongTensor(list(map(len, vectorized_seqs)))
# seq_lengths => [ 8, 4,  6]
# batch_sum_seq_len: 8 + 4 + 6 = 18
# max_seq_len: 8

seq_tensor = (torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long()
# seq_tensor => [[0 0 0 0 0 0 0 0]
#                [0 0 0 0 0 0 0 0]
#                [0 0 0 0 0 0 0 0]]

for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
    seq_tensor[idx, :seqlen] = LongTensor(seq)
# seq_tensor => [[ 6  9  8  4  1 11 12 10]          # long_str
#                [12  5  8 14  0  0  0  0]          # tiny
#                [ 7  3  2  5 13  7  0  0]]         # medium
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)

In [9]:
seq_lengths

tensor([8, 4, 6])

In [10]:
seq_tensor

tensor([[ 6,  9,  8,  4,  1, 11, 12, 10],
        [12,  5,  8, 14,  0,  0,  0,  0],
        [ 7,  3,  2,  5, 13,  7,  0,  0]])

In [11]:
## Step 5: Sort instances by sequence length in descending order ##
##---------------------------------------------------------------##

seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)

In [12]:
perm_idx

tensor([0, 2, 1])

In [13]:
seq_tensor

tensor([[ 6,  9,  8,  4,  1, 11, 12, 10],
        [ 7,  3,  2,  5, 13,  7,  0,  0],
        [12,  5,  8, 14,  0,  0,  0,  0]])

In [14]:
## Step 6: Embed the instances ##
##-----------------------------##

embedded_seq_tensor = embed(seq_tensor)
# embedded_seq_tensor.shape : (batch_size X max_seq_len X embedding_dim) = (3 X 8 X 4)

In [15]:
embedded_seq_tensor

tensor([[[-2.7212e-02, -1.8416e-01,  1.4720e+00,  1.2781e+00],
         [-5.2917e-01,  8.1698e-01, -3.5535e-01,  9.0872e-01],
         [ 9.2419e-01,  1.1221e+00,  5.3511e-01, -2.7484e-02],
         [-6.2186e-01,  3.8420e-02, -1.5629e+00,  1.7817e+00],
         [-7.6745e-01, -9.3460e-01,  4.4326e-01, -4.6463e-01],
         [-6.3583e-02,  1.2693e+00, -1.3119e+00, -1.9940e+00],
         [-6.3914e-01, -1.3176e-01, -1.4167e-01,  5.7377e-01],
         [ 3.0209e-01, -2.0734e+00, -1.9709e+00, -4.1188e-01]],

        [[ 3.4736e-01, -8.8285e-01, -3.2616e+00,  6.5706e-01],
         [-6.2543e-02, -2.7879e-01, -8.7413e-01, -6.0503e-02],
         [-1.4458e+00,  5.9985e-01, -9.4143e-01, -6.4183e-01],
         [ 9.0825e-01, -1.6225e+00, -9.9412e-01,  8.9833e-01],
         [ 1.4871e+00, -5.3385e-02, -2.0026e-01, -1.2087e+00],
         [ 3.4736e-01, -8.8285e-01, -3.2616e+00,  6.5706e-01],
         [ 2.4081e-03,  4.0468e-01,  2.5185e-01, -2.3913e+00],
         [ 2.4081e-03,  4.0468e-01,  2.5185e-01, -2.3

In [16]:
embedded_seq_tensor.shape

torch.Size([3, 8, 4])

In [17]:
## Step 7: Call pack_padded_sequence with embeded instances and sequence lengths ##
##-------------------------------------------------------------------------------##

packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True)
# packed_input (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
# packed_input.data.shape : (batch_sum_seq_len X embedding_dim) = (18 X 4)
#
# packed_input.batch_sizes => [ 3,  3,  3,  3,  2,  2,  1,  1]
# visualization :
# l  o  n  g  _  s  t  r   #(long_str)
# m  e  d  i  u  m         #(medium)
# t  i  n  y               #(tiny)
# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])

In [18]:
packed_input.data.shape

torch.Size([18, 4])

In [19]:
## Step 8: Forward with LSTM ##
##---------------------------##

packed_output, (ht, ct) = lstm(packed_input)
# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
# packed_output.data.shape : (batch_sum_seq_len X hidden_dim) = (18 X 5)

# packed_output.batch_sizes => [ 3,  3,  3,  3,  2,  2,  1,  1] (same as packed_input.batch_sizes)
# visualization :
# l  o  n  g  _  s  t  r   #(long_str)
# m  e  d  i  u  m         #(medium)
# t  i  n  y               #(tiny)
# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])

In [20]:
packed_output.data.shape

torch.Size([18, 5])

In [21]:
ht

tensor([[[ 0.4507,  0.3145,  0.1443, -0.2201,  0.3672],
         [ 0.2193,  0.3463,  0.0641, -0.0828,  0.5242],
         [ 0.1499,  0.1867, -0.1596, -0.0195,  0.2578]],

        [[-0.0088, -0.0409, -0.0770,  0.0813, -0.0619],
         [-0.0292, -0.0492, -0.0912,  0.1028, -0.0431],
         [ 0.0045, -0.0472, -0.0574,  0.1133, -0.0490]]],
       grad_fn=<StackBackward0>)

In [22]:
ht.shape

torch.Size([2, 3, 5])

In [23]:
ct

tensor([[[ 0.6435,  0.5615,  0.2564, -0.4226,  0.6030],
         [ 0.2882,  0.9806,  0.1696, -0.3333,  0.7945],
         [ 0.3515,  0.4045, -0.3672, -0.0361,  0.4653]],

        [[-0.0150, -0.0859, -0.1337,  0.1347, -0.1911],
         [-0.0482, -0.0981, -0.1607,  0.1683, -0.1249],
         [ 0.0070, -0.0922, -0.1011,  0.1852, -0.1497]]],
       grad_fn=<StackBackward0>)

In [24]:
ct.shape

torch.Size([2, 3, 5])

In [25]:
## Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector ##
##------------------------------------------------------------------------------------##

# unpack your output if required
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
# output:
# output.shape : ( batch_size X max_seq_len X hidden_dim) = (3 X 8 X 5)

In [26]:
output

tensor([[[ 0.0015,  0.0092, -0.0283,  0.0632, -0.0510],
         [-0.0059,  0.0013, -0.0549,  0.0985, -0.0620],
         [-0.0026, -0.0195, -0.0565,  0.1178, -0.0583],
         [-0.0230, -0.0254, -0.0815,  0.1261, -0.0534],
         [-0.0004, -0.0448, -0.0941,  0.1046, -0.0471],
         [ 0.0500, -0.0715, -0.0660,  0.0944, -0.0432],
         [ 0.0423, -0.0588, -0.0665,  0.0957, -0.0503],
         [-0.0088, -0.0409, -0.0770,  0.0813, -0.0619]],

        [[ 0.0158, -0.0090, -0.0381,  0.0692, -0.0366],
         [ 0.0222, -0.0239, -0.0638,  0.0959, -0.0468],
         [ 0.0495, -0.0454, -0.0606,  0.1087, -0.0464],
         [ 0.0013, -0.0338, -0.0863,  0.1034, -0.0528],
         [ 0.0197, -0.0631, -0.0834,  0.0884, -0.0425],
         [-0.0292, -0.0492, -0.0912,  0.1028, -0.0431],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0240, -0.0058, -0.0214,  0.0593, -0.0469],
         [-0.0120, -0.0077, -0.0597,  0.0820

In [27]:
output.shape

torch.Size([3, 8, 5])

In [28]:
# Or if you just want the final hidden state?
print(ht[-1])

## Summary of Shape Transformations ##
##----------------------------------##

# (batch_size X max_seq_len X embedding_dim) --> Sort by seqlen ---> (batch_size X max_seq_len X embedding_dim)
# (batch_size X max_seq_len X embedding_dim) --->      Pack     ---> (batch_sum_seq_len X embedding_dim)
# (batch_sum_seq_len X embedding_dim)        --->      LSTM     ---> (batch_sum_seq_len X hidden_dim)
# (batch_sum_seq_len X hidden_dim)           --->    UnPack     ---> (batch_size X max_seq_len X hidden_dim)

tensor([[-0.0088, -0.0409, -0.0770,  0.0813, -0.0619],
        [-0.0292, -0.0492, -0.0912,  0.1028, -0.0431],
        [ 0.0045, -0.0472, -0.0574,  0.1133, -0.0490]],
       grad_fn=<SelectBackward0>)
