In [16]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

seqs = ['gigantic_string','tiny_str','medium_str']

# make <pad> idx 0
vocab = ['<pad>'] + sorted(set(''.join(seqs)))

# make model
embed = nn.Embedding(len(vocab), 10).cuda()
lstm = nn.LSTM(10, 5).cuda()

vectorized_seqs = [[vocab.index(tok) for tok in seq] for seq in seqs]

# get the length of each seq in your batch
seq_lengths = torch.LongTensor([len(seq) for seq in vectorized_seqs]).cuda()
print("seq_lengths", seq_lengths, seq_lengths.shape)
# dump padding everywhere, and place seqs on the left.
# NOTE: you only need a tensor as big as your longest sequence
seq_tensor = torch.zeros((len(vectorized_seqs), seq_lengths.max())).long().cuda()
print("seq_tensor", seq_tensor, seq_tensor.shape)
for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
	seq_tensor[idx, :seqlen] = torch.LongTensor(seq)


# SORT YOUR TENSORS BY LENGTH!
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
print("seq_lengths", seq_lengths.shape)
seq_tensor = seq_tensor[perm_idx]

# utils.rnn lets you give (B,L,D) tensors where B is the batch size, L is the maxlength, if you use batch_first=True
# Otherwise, give (L,B,D) tensors
seq_tensor = seq_tensor.transpose(0,1) # (B,L,D) -> (L,B,D)

# embed your sequences
seq_tensor = embed(seq_tensor)

# pack them up nicely
packed_input = pack_padded_sequence(seq_tensor, seq_lengths.cpu().numpy())
#print("packed_input: ", packed_input)
print("packed length: ", len(packed_input))
# throw them through your LSTM (remember to give batch_first=True here if you packed with it)
packed_output, (ht, ct) = lstm(packed_input)
print("packed length output: ", len(packed_output))
# unpack your output if required
output, _ = pad_packed_sequence(packed_output)
print ("output: ", output.shape)

# Or if you just want the final hidden state?
#print ("final_hidden_state: ", ht[-1])

# REMEMBER: Your outputs are sorted. If you want the original ordering
# back (to compare to some gt labels) unsort them
_, unperm_idx = perm_idx.sort(0)
output = output[unperm_idx]
print ("final output: ", output, output.shape)

seq_lengths tensor([15,  8, 10], device='cuda:0') torch.Size([3])
seq_tensor tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0') torch.Size([3, 15])
seq_lengths torch.Size([3])
packed length:  4
packed length output:  4
output:  torch.Size([15, 3, 5])
final output:  tensor([[[-0.4176,  0.1017,  0.1668,  0.1087,  0.0277],
         [-0.1694,  0.0083,  0.0499, -0.0340, -0.1278],
         [-0.3136, -0.0012,  0.0833, -0.2994, -0.1930]],

        [[-0.5143,  0.1591,  0.1161,  0.0153,  0.0206],
         [ 0.0041, -0.0337, -0.2034, -0.3268, -0.5914],
         [-0.3694, -0.0056, -0.0800,  0.0097,  0.0205]],

        [[-0.2977,  0.0608, -0.0493, -0.1456, -0.2551],
         [-0.1150, -0.0728, -0.2370, -0.5131, -0.3397],
         [-0.3136,  0.0368, -0.1173, -0.3246, -0.3586]]], device='cuda:0',
       grad_fn=<IndexBackward>) torch.Size([3, 3, 5])
