# Imports

In [1]:
import torch
from torch.nn.utils.rnn import (pad_packed_sequence,
                                pack_padded_sequence,
                                pad_sequence,
                                pack_sequence)

In [2]:
torch.__version__

'1.0.0'

# Pack/Padded Sequences Experimentation

In [3]:
t = torch.tensor([[1, 0, 0, ], [1, 2, 3], [5, 2, 0]])
t

tensor([[1, 0, 0],
        [1, 2, 3],
        [5, 2, 0]])

In [4]:
(t > 0).sum(dim=1)

tensor([1, 3, 2])

In [5]:
((t > 0).sum(dim=1) - 1)

tensor([0, 2, 1])

In [6]:
sorted_lengths, sorted_idxs= (t > 0).sum(dim=1).sort(descending=True)
sorted_lengths, sorted_idxs

(tensor([3, 2, 1]), tensor([1, 2, 0]))

In [7]:
sorted_t = t[sorted_idxs]
sorted_t

tensor([[1, 2, 3],
        [5, 2, 0],
        [1, 0, 0]])

In [8]:
packed = pack_padded_sequence(sorted_t, lengths=sorted_lengths, batch_first=True)
packed

PackedSequence(data=tensor([1, 5, 1, 2, 2, 3]), batch_sizes=tensor([3, 2, 1]))

In [9]:
padded = pad_packed_sequence(packed, batch_first=True)
padded

(tensor([[1, 2, 3],
         [5, 2, 0],
         [1, 0, 0]]), tensor([3, 2, 1]))

In [10]:
pad_sequence([torch.tensor([1]), torch.tensor([1, 2, 3])], batch_first=True)

tensor([[1, 0, 0],
        [1, 2, 3]])

In [11]:
padded[0][sorted_idxs.sort(descending=False)[1]]

tensor([[1, 0, 0],
        [1, 2, 3],
        [5, 2, 0]])

In [12]:
def pack_padded_collate(inputs):
    
    sorted_lengths, sorted_idxs = (inputs > 0).sum(dim=1).sort(descending=True)
    sorted_inputs = inputs[sorted_idxs]
    orig_idxs = sorted_idxs.sort(descending=False)[1]
    
    return sorted_inputs, sorted_lengths, orig_idxs

In [13]:
sorted_t, sorted_lengths, orig_idxs = pack_padded_collate(t)
sorted_t

tensor([[1, 2, 3],
        [5, 2, 0],
        [1, 0, 0]])

In [14]:
(sorted_t[orig_idxs] == t).all()

tensor(1, dtype=torch.uint8)

In [15]:
arr = [(1, 2), (3, 4)]
arr

[(1, 2), (3, 4)]

In [16]:
[a[0] for a in arr]

[1, 3]

In [17]:
list(zip(*[(a, b) for (a, b) in arr]))

[(1, 3), (2, 4)]

In [18]:
import itertools

In [19]:
list(itertools.chain.from_iterable(arr))

[1, 2, 3, 4]

In [31]:
packed.data.shape == torch.Size([6])

True

In [34]:
packed.

tensor([3, 2, 1])