In [17]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

seqs = ['gigantic_string','tiny_str','medium_str']

# make <pad> idx 0
vocab = ['<pad>'] + sorted(set(''.join(seqs)))

In [18]:
vocab

['<pad>', '_', 'a', 'c', 'd', 'e', 'g', 'i', 'm', 'n', 'r', 's', 't', 'u', 'y']

In [19]:
# make model
embed = nn.Embedding(len(vocab), 10).cuda()
lstm = nn.LSTM(10, 5).cuda()

In [20]:
vectorized_seqs = [[vocab.index(tok) for tok in seq] for seq in seqs]
vectorized_seqs

[[6, 7, 6, 2, 9, 12, 7, 3, 1, 11, 12, 10, 7, 9, 6],
 [12, 7, 9, 14, 1, 11, 12, 10],
 [8, 5, 4, 7, 13, 8, 1, 11, 12, 10]]

In [21]:
# get the length of each seq in your batch
seq_lengths = torch.LongTensor([len(seq) for seq in vectorized_seqs]).cuda()
seq_lengths

tensor([15,  8, 10], device='cuda:0')

In [22]:
# dump padding everywhere, and place seqs on the left.
# NOTE: you only need a tensor as big as your longest sequence
seq_tensor = torch.zeros((len(vectorized_seqs), seq_lengths.max())).long().cuda()
for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
	seq_tensor[idx, :seqlen] = torch.LongTensor(seq)

seq_tensor.shape

torch.Size([3, 15])

In [23]:
# SORT YOUR TENSORS BY LENGTH!
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]
seq_lengths, seq_tensor

(tensor([15, 10,  8], device='cuda:0'),
 tensor([[ 6,  7,  6,  2,  9, 12,  7,  3,  1, 11, 12, 10,  7,  9,  6],
         [ 8,  5,  4,  7, 13,  8,  1, 11, 12, 10,  0,  0,  0,  0,  0],
         [12,  7,  9, 14,  1, 11, 12, 10,  0,  0,  0,  0,  0,  0,  0]],
        device='cuda:0'))

In [24]:
# utils.rnn lets you give (B,L,D) tensors where B is the batch size, L is the maxlength, if you use batch_first=True
# Otherwise, give (L,B,D) tensors
seq_tensor = seq_tensor.transpose(0,1) # (B,L,D) -> (L,B,D)

# embed your sequences
seq_tensor = embed(seq_tensor)

# pack them up nicely
packed_input = pack_padded_sequence(seq_tensor, seq_lengths.cpu().numpy())

In [27]:
packed_input[0].shape, packed_input[1].shape

(torch.Size([33, 10]), torch.Size([15]))

In [29]:
#packed_input

In [32]:
# throw them through your LSTM (remember to give batch_first=True here if you packed with it)
packed_output, (ht, ct) = lstm(packed_input)
packed_output[0].shape, packed_output[1].shape

(torch.Size([33, 5]), torch.Size([15]))

In [34]:
# unpack your output if required
output, _ = pad_packed_sequence(packed_output)
print (output.shape)

torch.Size([15, 3, 5])


In [35]:
# Or if you just want the final hidden state?
print (ht[-1].shape)

torch.Size([3, 5])


In [36]:
# REMEMBER: Your outputs are sorted. If you want the original ordering
# back (to compare to some gt labels) unsort them
_, unperm_idx = perm_idx.sort(0)
output = output[unperm_idx]
print (output)

tensor([[[-0.0042,  0.0596,  0.0292, -0.0732,  0.0511],
         [-0.2934,  0.0321,  0.0041,  0.1061, -0.0744],
         [-0.0915, -0.0569,  0.0357,  0.0068, -0.0530]],

        [[-0.0221,  0.1037,  0.2066, -0.0090, -0.0025],
         [-0.3393, -0.1646,  0.1647,  0.0874,  0.0589],
         [-0.4734,  0.0245,  0.0429,  0.4150, -0.1388]],

        [[-0.0840,  0.2060,  0.2354,  0.0935, -0.0371],
         [-0.0736,  0.1592,  0.0597, -0.1342, -0.0997],
         [-0.1283,  0.1173,  0.2353,  0.1197, -0.1147]]],
       device='cuda:0', grad_fn=<TakeBackward>)
