In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

import sys
sys.path.append('..')

from nmt.datasets import Vocab, batch_iter
from nmt.networks import CharEmbedding, Encoder

from typing import List, Tuple

### Setting up everything till before we need to use attention

In [3]:
## Sample data
sentences_words_src = [
    ['Human:', 'What', 'do', 'we', 'want?'],
    ['Computer:', 'Natural', 'language', 'processing!'],
    ['Human:', 'When', 'do', 'we', 'want', 'it?'],
    ['Computer:', 'When', 'do', 'we', 'want', 'what?']
]

sentences_words_tgt = [
    ['<s>', 'Human:', 'What', 'do', 'we', 'want?', '</s>'],
    ['<s>', 'Computer:', 'Natural', 'language', 'processing!', '</s>'],
    ['<s>', 'Human:', 'When', 'do', 'we', 'want', 'it?', '</s>'],
    ['<s>', 'Computer:', 'When', 'do', 'we', 'want', 'what?', '</s>']
]

In [4]:
## Setting up vocab
vocab = Vocab.build(sentences_words_src, sentences_words_tgt)

Initializing source vocab
Vocab Store: Tokens [size=17],                 Characters [size=97]
Initializing target vocab
Vocab Store: Tokens [size=17],                 Characters [size=97]


In [5]:
## Generating a batch
data = list(zip(sentences_words_src, sentences_words_tgt))
data_generator = batch_iter(
    data=data,
    batch_size=4,
    shuffle=True
)
batch_src, batch_tgt = next(data_generator)
print(batch_src)

[['Human:', 'When', 'do', 'we', 'want', 'it?'], ['Computer:', 'When', 'do', 'we', 'want', 'what?'], ['Human:', 'What', 'do', 'we', 'want?'], ['Computer:', 'Natural', 'language', 'processing!']]


In [6]:
## Getting source lengths for encoder
source_length = [len(sent) for sent in batch_src]
print(source_length)

[6, 6, 5, 4]


In [7]:
## Preparing input and output tensors
char_tensors_src = vocab.src.to_tensor(batch_src, tokens=False)
char_tensors_tgt = vocab.tgt.to_tensor(batch_tgt, tokens=False)
print(f"src char tensor size = {char_tensors_src.size()}; tgt char tensor size = {char_tensors_tgt.size()}")

src char tensor size = torch.Size([6, 4, 21]); tgt char tensor size = torch.Size([8, 4, 21])


In [8]:
## Defining Encoder
encoder = Encoder(
    num_embeddings=vocab.src.length(tokens=False),
    embedding_dim=300,
    char_padding_idx=vocab.src.pad_char_idx,
    hidden_size=1024
)

In [9]:
## Getting encoder output
char_enc_hidden, (char_hidden, char_cell) = encoder(char_tensors_src, source_length)
char_enc_hidden.shape, char_hidden.shape, char_cell.shape

(torch.Size([4, 6, 2048]), torch.Size([4, 1024]), torch.Size([4, 1024]))

Note here: We have encoder output equivalent of 6 timesteps

In [11]:
## Component of decoder - Target embedding layer
target_embedding = CharEmbedding(
    num_embeddings=vocab.tgt.length(tokens=False),
    char_embedding_dim=50,
    embedding_dim=300,
    char_padding_idx=vocab.tgt.pad_char_idx    
)

In [23]:
## Component of decoder - decoder LSTM Cell
sample_decoder_cell = nn.LSTMCell(
    input_size=300 + 1024,
    hidden_size=1024, bias=True
)

When we decode, we do it one time-step at a time. 

In [16]:
y_0 = torch.split(char_tensors_tgt, 1, dim=0)[0]

In [17]:
print("Target tensor shape:", y_0.shape)
embedded_y_0 = target_embedding(y_0)
print("Target embedded shape:", embedded_y_0.shape)

Target tensor shape: torch.Size([1, 4, 21])
Target embedded shape: torch.Size([1, 4, 300])


In [18]:
o_prev = torch.zeros(4, 1024, device="cpu")

In [20]:
ybar_t = torch.cat([embedded_y_0.squeeze(dim=0), o_prev], dim=1)
print(ybar_t.shape)

torch.Size([4, 1324])


In [24]:
dec_hidden, dec_cell = sample_decoder_cell(ybar_t, (char_hidden, char_cell)) # initial dec_state

In [25]:
dec_hidden.shape, dec_cell.shape

(torch.Size([4, 1024]), torch.Size([4, 1024]))

We have encoder output corresponding to _all_ timesteps and decoder output corresponding to _one_ timestep.

# General Attention

### What is Attention? 

There are a million explanations of attention and my favorite interpretation 
given that attention essentially talks of information retention and retrieval is that 
attention is a mechanism for mapping a query to a value corresponding to all relevant 
values in a set.  
  
The following image shows the computation of attention and context vectors using dot product attention.

<img src="images/attention.png" />

  
### What is general attention? 

Note that our decoder hidden values (query) are of size hidden_size = 1024 but encoder outputs are 2*hidden_size = 2048. There are many cases in which this could occur. Therefore, we use a linear layer to project the encoder outputs to match the decoder hidden values. 


In [26]:
char_enc_hidden.shape, dec_hidden.shape

(torch.Size([4, 6, 2048]), torch.Size([4, 1024]))

Here, we will say the encoder output *s = char_enc_hidden = the set*  
The decoder hidden *h_i = query*

In [27]:
attention_projection = nn.Linear(in_features=2048, out_features=1024, bias=False)

In [28]:
enc_projection = attention_projection(char_enc_hidden)

In [29]:
enc_projection.shape

torch.Size([4, 6, 1024])

In [30]:
dec_hidden_unsqueezed = dec_hidden.unsqueeze(dim=2)

In [31]:
dec_hidden_unsqueezed.shape

torch.Size([4, 1024, 1])

In [46]:
# https://pytorch.org/docs/stable/generated/torch.bmm.html
score = enc_projection.bmm(dec_hidden_unsqueezed)

In [33]:
score.shape

torch.Size([4, 6, 1])

In [34]:
score = score.squeeze(dim=2)

In [37]:
score.shape

torch.Size([4, 6])

In [35]:
attention_weights = F.softmax(score, dim=1)

In [39]:
attention_weights.shape

torch.Size([4, 6])

In [40]:
attention_weights = attention_weights.unsqueeze(dim=1)

In [41]:
attention_weights.shape

torch.Size([4, 1, 6])

https://pytorch.org/docs/stable/generated/torch.bmm.html
> If input is a (b×n×m) tensor, mat2 is a (b×m×p) tensor, out will be a (b×n×p) tensor.

In [42]:
context_vector = attention_weights.bmm(char_enc_hidden)

In [43]:
context_vector.shape

torch.Size([4, 1, 2048])

In [44]:
context_vector = context_vector.squeeze(dim=1)

In [45]:
context_vector.shape

torch.Size([4, 2048])

## Putting it together

In [47]:
class Attention(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super(Attention, self).__init__()
        self.linear = nn.Linear(
            in_features=in_features,
            out_features=out_features,
            bias=False
        )

    def forward(self,
                enc_hidden: torch.Tensor,
                dec_hidden_t: torch.Tensor,
                enc_masks: torch.Tensor = None) -> torch.Tensor:

        enc_projection = self.linear(enc_hidden)
        dec_hidden_unsqueezed_t = dec_hidden_t.unsqueeze(dim=2)
        score = enc_projection.bmm(dec_hidden_unsqueezed_t)
        score = score.squeeze(dim=2)

        if enc_masks is not None:
            score.data.masked_fill_(
                enc_masks.byte(),
                -float('inf')
            )

        attention_weights = softmax(score, dim=1)
        attention_weights = attention_weights.unsqueeze(dim=1)

        context_vector = attention_weights.bmm(enc_hidden)
        context_vector = context_vector.squeeze(dim=1)

        return attention_weights, context_vector

### Wait what is that mask thing? 

It's the encoder mask.  

Because of batch thing, some sentences will be shorter than the longest sentence with remaining words having pad indices. We don't want attention to focus on those parts and therefore we set the parts corresponding to pad_indices to -inf because softmax(-inf) = 0. 