In [1]:
import torch
from torch import LongTensor
from torch.nn import Embedding, LSTM
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [2]:
# We want to run LSTM on a batch following 3 character sequences
seqs = ['long_str',  # len = 8
        'tiny',  # len = 4
        'medium']  # len = 6

In [4]:
vocab = ['<pad>'] + sorted(set([char for seq in seqs for char in seq]))
vectorized_seqs = [[vocab.index(tok) for tok in seq] for seq in seqs]


In [5]:
embed = Embedding(len(vocab), 4)  # embedding_dim = 4
lstm = LSTM(input_size=4, hidden_size=5, batch_first=True)  # input_dim = 4, hidden_dim = 5

In [6]:
seq_lengths = LongTensor(list(map(len, vectorized_seqs)))

In [7]:
seq_tensor = Variable(torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long()

In [8]:
for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
    seq_tensor[idx, :seqlen] = LongTensor(seq)

In [13]:
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]
seq_lengths,perm_idx

(tensor([8, 6, 4]), tensor([0, 1, 2]))

In [10]:
embedded_seq_tensor = embed(seq_tensor)

In [11]:
packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True)
packed_input

PackedSequence(data=tensor([[-2.1994e-01, -3.4274e-01, -4.9565e-01, -8.3763e-01],
        [-5.9977e-02,  7.9111e-03,  1.3742e+00,  4.4563e-01],
        [-1.9601e+00, -1.0830e+00,  5.9390e-03,  1.4341e+00],
        [ 9.8643e-01, -1.1596e+00,  1.1308e+00, -1.1160e+00],
        [ 5.6375e-01, -3.1504e-01,  2.8399e-01,  9.3115e-01],
        [-1.1338e+00,  5.7008e-01, -3.0532e-02,  8.0733e-01],
        [-1.4077e+00, -5.2582e-01, -2.3200e-01,  2.0413e+00],
        [ 3.8409e-01,  2.0369e+00,  1.6379e-01, -8.6935e-01],
        [-1.4077e+00, -5.2582e-01, -2.3200e-01,  2.0413e+00],
        [-7.4631e-01,  6.8849e-02, -2.0879e-01,  2.0581e+00],
        [-1.1338e+00,  5.7008e-01, -3.0532e-02,  8.0733e-01],
        [-7.9554e-02,  1.6705e+00, -1.1925e+00, -2.7017e-01],
        [-1.0493e+00,  1.7665e+00, -2.4513e-01, -8.1481e-01],
        [-1.9334e-01,  1.1458e+00,  1.5379e+00, -2.8905e-01],
        [-1.4825e-01,  1.6241e+00,  1.4917e-04, -2.3513e-01],
        [-5.9977e-02,  7.9111e-03,  1.3742e+00,  4

In [14]:
packed_output, (ht, ct) = lstm(packed_input)

In [15]:
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)

In [16]:
print(ht[-1])

tensor([[-0.2423,  0.1180,  0.1261, -0.0246, -0.3053],
        [-0.1770,  0.1730, -0.1704,  0.0487, -0.3208],
        [-0.1375,  0.1541, -0.0232, -0.1059, -0.4230]],
       grad_fn=<SelectBackward0>)
