# Checking Hidden State Length Correspondence

**|| Jonty Sinai ||** 02-05-2019

One thing we want to do is ensure that hidden states have been propagated correctly for each sequence in the batch. To see why, consider the following (sorted) batch of input sequences:

In [2]:
import random

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils

import numpy as np


random.seed(1901)
np.random.seed(1901)
torch.manual_seed(1901)

<torch._C.Generator at 0x1047b2e70>

In [3]:
test_inputs = [torch.tensor([1, 1, 1, 1, 1]),
               torch.tensor([2, 2, 2, 2]),
               torch.tensor([3, 3, 3]),
               torch.tensor([4, 4]),
               torch.tensor([5, 5])]

test_lengths = torch.tensor([len(seq) for seq in test_inputs], dtype=torch.long)

test_inputs_padded = rnn_utils.pad_sequence(test_inputs)
print(test_inputs_padded)

tensor([[1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 0, 0],
        [1, 2, 0, 0, 0],
        [1, 0, 0, 0, 0]])


In [4]:
test_embedding = nn.Embedding(num_embeddings=6, embedding_dim=3, padding_idx=0)

test_rnn = nn.RNN(input_size=3, hidden_size=2)

test_h0 = torch.randn(1, 1, 2).repeat(1, 5, 1)

print(test_h0)  # same h_0 for every index in batch

tensor([[[-1.2594,  0.6924],
         [-1.2594,  0.6924],
         [-1.2594,  0.6924],
         [-1.2594,  0.6924],
         [-1.2594,  0.6924]]])


We'll now compare final hidden states when we pack sequences and compute on the entire batch, then when we compute on each sequence, one at a time.

### Batched Computation

In [5]:
test_embedded = test_embedding(test_inputs_padded)

print(test_embedded)

tensor([[[ 0.2400, -0.6511, -0.1632],
         [-0.2048,  0.2239,  0.7058],
         [-0.4830,  2.1294, -0.9805],
         [ 0.3202,  1.0725, -0.0869],
         [-1.1208,  2.4741,  0.7153]],

        [[ 0.2400, -0.6511, -0.1632],
         [-0.2048,  0.2239,  0.7058],
         [-0.4830,  2.1294, -0.9805],
         [ 0.3202,  1.0725, -0.0869],
         [-1.1208,  2.4741,  0.7153]],

        [[ 0.2400, -0.6511, -0.1632],
         [-0.2048,  0.2239,  0.7058],
         [-0.4830,  2.1294, -0.9805],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]],

        [[ 0.2400, -0.6511, -0.1632],
         [-0.2048,  0.2239,  0.7058],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]],

        [[ 0.2400, -0.6511, -0.1632],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]]], grad_fn=<EmbeddingBackward>)


In [6]:
test_packed_input = rnn_utils.pack_padded_sequence(test_embedded, test_lengths)

test_output_packed, test_hn_batch = test_rnn(test_packed_input, test_h0)

print(test_hn_batch)

tensor([[[-0.0494, -0.3800],
         [-0.1983, -0.3963],
         [-0.6183, -0.8498],
         [-0.7652, -0.2736],
         [-0.5619, -0.8457]]], grad_fn=<StackBackward>)


### Sampled-based computation

Now we would like to confirm that we get the same hidden state by indexing into the $j^{th}$ sequence (2nd index) in each batch. We have the lengths of each sequence, so we can ignore the padding index by only computing up to these lengths.

Let's see how we can grab one sequence from the embeddings which have dimension `max_seq_length x batch_size x embedding_dim`:

In [7]:
test_embedded[:, 2, :]  # all time steps from the third sequence

tensor([[-0.4830,  2.1294, -0.9805],
        [-0.4830,  2.1294, -0.9805],
        [-0.4830,  2.1294, -0.9805],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000]], grad_fn=<SliceBackward>)

So far so good, all values are the same and it has length 3 corresponding to the seqeunce `[3, 3, 3]`. Let's ensure that it has the right shape for the RNN:

In [8]:
print(test_embedded[:, 0, :].size())

torch.Size([5, 3])


We will need to view this tensor as a `seq_len x 1 x embedding_dim` sequence (i.e. batch size of 1)

In [10]:
test_embedded[:, 2, :].view(5, 1, 3)

tensor([[[-0.4830,  2.1294, -0.9805]],

        [[-0.4830,  2.1294, -0.9805]],

        [[-0.4830,  2.1294, -0.9805]],

        [[ 0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000]]], grad_fn=<ViewBackward>)

Remember that the hidden state is initiliased as a `1 x batch_size x hidden_state` tensor with the same hidden state across the batch dimension. Let's grab this hidden state which we will use for computation for one sequence at a time.

In [11]:
test_h0_single = test_h0[0, 0, :].view(1, 1, 2)

print(test_h0_single)

tensor([[[-1.2594,  0.6924]]])


Now loop through the batch:

In [13]:
test_outputs = []
test_hn_states = []

for j in range(5):
    seq_length = test_lengths[j]
    seq_embedded = test_embedded[:, j, :].view(5, 1, 3)[:seq_length]
    
    out_j, h_t_j = test_rnn(seq_embedded, test_h0_single)
    
    test_outputs.append(out_j)
    test_hn_states.append(h_t_j)

In [14]:
print('Unbatched:')
for j in range(5):
    print(test_hn_states[j])

print('\nBatched: Matches')
print(test_hn_batch)

Unbatched:
tensor([[[-0.0494, -0.3800]]], grad_fn=<StackBackward>)
tensor([[[-0.1983, -0.3963]]], grad_fn=<StackBackward>)
tensor([[[-0.6183, -0.8498]]], grad_fn=<StackBackward>)
tensor([[[-0.7652, -0.2736]]], grad_fn=<StackBackward>)
tensor([[[-0.5619, -0.8457]]], grad_fn=<StackBackward>)

Batched:
tensor([[[-0.0494, -0.3800],
         [-0.1983, -0.3963],
         [-0.6183, -0.8498],
         [-0.7652, -0.2736],
         [-0.5619, -0.8457]]], grad_fn=<StackBackward>)


As we can see, the hidden states are correctly calculated only up to the right sequence length for each sequence using packed sequences.

Similary for outputs:

In [24]:
test_output_unpacked, _ = rnn_utils.pad_packed_sequence(test_output_packed)

for j in range(5):
    print(f'\nSample {j}:\n')
    print('Unbatched:')
    print(test_outputs[j][:, 0, :])

    print('\nBatched: Matches')
    print(test_output_unpacked[:, j, :])


Sample 0:

Unbatched:
tensor([[ 0.5373, -0.7590],
        [-0.5505,  0.0691],
        [ 0.1044, -0.4968],
        [-0.3366, -0.1353],
        [-0.0494, -0.3800]], grad_fn=<SliceBackward>)

Batched:
tensor([[ 0.5373, -0.7590],
        [-0.5505,  0.0691],
        [ 0.1044, -0.4968],
        [-0.3366, -0.1353],
        [-0.0494, -0.3800]], grad_fn=<SliceBackward>)

Sample 1:

Unbatched:
tensor([[ 0.6832, -0.8781],
        [-0.4492, -0.2140],
        [ 0.2100, -0.6302],
        [-0.1983, -0.3963]], grad_fn=<SliceBackward>)

Batched:
tensor([[ 0.6832, -0.8781],
        [-0.4492, -0.2140],
        [ 0.2100, -0.6302],
        [-0.1983, -0.3963],
        [ 0.0000,  0.0000]], grad_fn=<SliceBackward>)

Sample 2:

Unbatched:
tensor([[-0.0927, -0.9743],
        [-0.7976, -0.7854],
        [-0.6183, -0.8498]], grad_fn=<SliceBackward>)

Batched:
tensor([[-0.0927, -0.9743],
        [-0.7976, -0.7854],
        [-0.6183, -0.8498],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000]], grad_fn=<Slice

As we can see the output have been correctly computed with padding ignored 

## Garbage In, Garbage Out

Finally let's see what happens when we don't pay attention to sequenc length and just pass a padded tensor:

In [25]:
output_padded, hn_padded = test_rnn(test_embedded, test_h0)

In [29]:
print('Packed:')
for j in range(5):
    print(test_hn_states[j])

print('\nUnpacked: Mismatch')
print(hn_padded)

Packed:
tensor([[[-0.0494, -0.3800]]], grad_fn=<StackBackward>)
tensor([[[-0.1983, -0.3963]]], grad_fn=<StackBackward>)
tensor([[[-0.6183, -0.8498]]], grad_fn=<StackBackward>)
tensor([[[-0.7652, -0.2736]]], grad_fn=<StackBackward>)
tensor([[[-0.5619, -0.8457]]], grad_fn=<StackBackward>)

Unpacked: Mismatch
tensor([[[-0.0494, -0.3800],
         [-0.2015, -0.4695],
         [-0.1950, -0.5119],
         [-0.1184, -0.5026],
         [-0.2403, -0.4121]]], grad_fn=<StackBackward>)


In [30]:
for j in range(5):
    print(f'\nSample {j}:\n')
    print('Packed:')
    print(test_output_unpacked[:, j, :])

    print('\nUnpacked: Mismatch')
    print(output_padded[:, j, :])


Sample 0:

Packed:
tensor([[ 0.5373, -0.7590],
        [-0.5505,  0.0691],
        [ 0.1044, -0.4968],
        [-0.3366, -0.1353],
        [-0.0494, -0.3800]], grad_fn=<SliceBackward>)

Unpacked: Mismatch
tensor([[ 0.5373, -0.7590],
        [-0.5505,  0.0691],
        [ 0.1044, -0.4968],
        [-0.3366, -0.1353],
        [-0.0494, -0.3800]], grad_fn=<SliceBackward>)

Sample 1:

Packed:
tensor([[ 0.6832, -0.8781],
        [-0.4492, -0.2140],
        [ 0.2100, -0.6302],
        [-0.1983, -0.3963],
        [ 0.0000,  0.0000]], grad_fn=<SliceBackward>)

Unpacked: Mismatch
tensor([[ 0.6832, -0.8781],
        [-0.4492, -0.2140],
        [ 0.2100, -0.6302],
        [-0.1983, -0.3963],
        [-0.2015, -0.4695]], grad_fn=<SliceBackward>)

Sample 2:

Packed:
tensor([[-0.0927, -0.9743],
        [-0.7976, -0.7854],
        [-0.6183, -0.8498],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000]], grad_fn=<SliceBackward>)

Unpacked: Mismatch
tensor([[-0.0927, -0.9743],
        [-0.7976, -0.7