# Sequence Tensor Batching with PyTorch

**|| Jonty Sinai ||** 30-04-2019

In [2]:
import random

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils

import numpy as np

random.seed(1901)
np.random.seed(1901)
torch.manual_seed(1901)

<torch._C.Generator at 0x7f5098118670>

### Test Batch of Input-Target Pairs

In [6]:
# (input_tensor, target_tensor) 1-to-1 correspondence
batch = [(torch.tensor([1, 1, 1]), torch.tensor([1, 1, 1, 1])), 
         (torch.tensor([2, 2, 2, 2, 2]), torch.tensor([ 2,  2, 2,  2])), 
         (torch.tensor([3, 3]), torch.tensor([3, 3,  3])), 
         (torch.tensor([4, 4, 4]), torch.tensor([ 4, 4, 4])), 
         (torch.tensor([5, 5, 5, 5]), torch.tensor([ 5, 5, 5,  5,  5]))]

batch

[(tensor([1, 1, 1]), tensor([1, 1, 1, 1])),
 (tensor([2, 2, 2, 2, 2]), tensor([2, 2, 2, 2])),
 (tensor([3, 3]), tensor([3, 3, 3])),
 (tensor([4, 4, 4]), tensor([4, 4, 4])),
 (tensor([5, 5, 5, 5]), tensor([5, 5, 5, 5, 5]))]

In [8]:
input_, target = zip(*batch)

input_orig = rnn_utils.pad_sequence(input_)
target_orig = rnn_utils.pad_sequence(target)

print('Padded input: original\n')
print(input_orig)
print('\nPadded target: original\n')
print(target_orig)

Padded input: original

tensor([[1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 0, 4, 5],
        [0, 2, 0, 0, 5],
        [0, 2, 0, 0, 0]])

Padded target: original

tensor([[1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 0, 0, 5],
        [0, 0, 0, 0, 5]])


### Option 1: Sort Input Independently and Resort Later

In [11]:
lengths = torch.tensor([len(seq) for seq in input_], dtype=torch.long)

print(lengths)

tensor([3, 5, 2, 3, 4])


In [12]:
sorted_lengths, sort_idx = lengths.sort(descending=True)

print(sort_idx)

tensor([1, 4, 0, 3, 2])


In [15]:
max_length = sorted_lengths[0].item()

# I use sequence first batching, so use repeat to index consistently across each time step
input_sorted = input_orig.gather(dim=1, index=sort_idx.repeat(max_length, 1))

print('Padded input: sorted\n')
print(input_sorted)

Padded input: sorted

tensor([[2, 5, 1, 4, 3],
        [2, 5, 1, 4, 3],
        [2, 5, 1, 4, 0],
        [2, 5, 0, 0, 0],
        [2, 0, 0, 0, 0]])


In [19]:
input_packed = rnn_utils.pack_padded_sequence(input_sorted, sorted_lengths)

print(input_packed)

PackedSequence(data=tensor([2, 5, 1, 4, 3, 2, 5, 1, 4, 3, 2, 5, 1, 4, 2, 5, 2]), batch_sizes=tensor([5, 5, 4, 2, 1]))


In [31]:
input_unpacked, _ = rnn_utils.pad_packed_sequence(input_packed) # suppose this has been transformed

print('Padded input: unpacked\n')
print(input_unpacked)
print('\nPadded target: misaligned\n')
print(target_orig)

Padded input: unpacked

tensor([[2, 5, 1, 4, 3],
        [2, 5, 1, 4, 3],
        [2, 5, 1, 4, 0],
        [2, 5, 0, 0, 0],
        [2, 0, 0, 0, 0]])

Padded target: misaligned

tensor([[1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 0, 0, 5],
        [0, 0, 0, 0, 5]])


In [22]:
# this is a neat trick I found which will map each sequence in the unpacked batch to the original location
# the argsort of the argsort is the original index
_, orig_idx = sort_idx.sort()

print(orig_idx)

tensor([2, 0, 4, 3, 1])


In [30]:
# use the same trick with repeat to consistently index each time step
input_restored = input_unpacked.gather(dim=1, index=orig_idx.repeat(max_length, 1))

print('Padded input: restored\n')
print(input_restored)
print('\nPadded target: aligned\n')
print(target_orig)

Padded input: restored

tensor([[1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 0, 4, 5],
        [0, 2, 0, 0, 5],
        [0, 2, 0, 0, 0]])

Padded target: aligned

tensor([[1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 0, 0, 5],
        [0, 0, 0, 0, 5]])


In [28]:
torch.eq(input_restored, input_orig)

tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]], dtype=torch.uint8)

### Option 2: Just Sort the Input and Targets Together

Since we will process the targets sequentially we can just sort the targets by the input lengths. The decoder won't care which sequence it sees first in the batch, only that the target sequences are correctly aligned to the encoder outputs batchwise.

In [32]:
sorted_batch = sorted(batch, key=lambda b: len(b[0]), reverse=True)
sorted_batch

[(tensor([2, 2, 2, 2, 2]), tensor([2, 2, 2, 2])),
 (tensor([5, 5, 5, 5]), tensor([5, 5, 5, 5, 5])),
 (tensor([1, 1, 1]), tensor([1, 1, 1, 1])),
 (tensor([4, 4, 4]), tensor([4, 4, 4])),
 (tensor([3, 3]), tensor([3, 3, 3]))]

In [33]:
input_presorted, target_presorted = zip(*sorted_batch)

presorted_lengths = torch.tensor([len(seq) for seq in input_presorted])

padded_input_presorted = nn.utils.rnn.pad_sequence(input_presorted)
padded_target_presorted = nn.utils.rnn.pad_sequence(target_presorted)

print('Padded input: presorted\n')
print(padded_input_presorted)
print('\nPadded target: presorted\n')
print(padded_target_presorted)

Padded input: presorted

tensor([[2, 5, 1, 4, 3],
        [2, 5, 1, 4, 3],
        [2, 5, 1, 4, 0],
        [2, 5, 0, 0, 0],
        [2, 0, 0, 0, 0]])

Padded target: presorted

tensor([[2, 5, 1, 4, 3],
        [2, 5, 1, 4, 3],
        [2, 5, 1, 4, 3],
        [2, 5, 1, 0, 0],
        [0, 5, 0, 0, 0]])


In [36]:
packed_input_presorted = rnn_utils.pack_padded_sequence(padded_input_presorted, presorted_lengths)

print(packed_input_presorted)

PackedSequence(data=tensor([2, 5, 1, 4, 3, 2, 5, 1, 4, 3, 2, 5, 1, 4, 2, 5, 2]), batch_sizes=tensor([5, 5, 4, 2, 1]))


In [40]:
# As far as the recurrent unit is concerned, it's input is unchanged
torch.eq(rnn_utils.pad_packed_sequence(packed_input_presorted)[0], 
         rnn_utils.pad_packed_sequence(input_packed)[0])

tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]], dtype=torch.uint8)

In [41]:
unpacked_input_presorted, _ = rnn_utils.pad_packed_sequence(packed_input_presorted) # suppose this has been transformed

print('Presorted input: unpacked\n')
print(unpacked_input_presorted)
print('\nPresorted target: aligned\n')
print(padded_target_presorted)

Presorted input: unpacked

tensor([[2, 5, 1, 4, 3],
        [2, 5, 1, 4, 3],
        [2, 5, 1, 4, 0],
        [2, 5, 0, 0, 0],
        [2, 0, 0, 0, 0]])

Presorted target: aligned

tensor([[2, 5, 1, 4, 3],
        [2, 5, 1, 4, 3],
        [2, 5, 1, 4, 3],
        [2, 5, 1, 0, 0],
        [0, 5, 0, 0, 0]])
