# Bahdanau Attention (Additive Attention)

One of the motivations behind Bahdanau Attention approach was the use of a fixed-length context vector in the basic encoder–decoder approach. This limitation makes the basic encoder-decoder approach to underperform with long sentences. In basic encoder-decoder approach, the last element of a sequence contains the memory of all the previous elements and thus form a fixed-dimension context vector. But in case of Bahdanau attention approach:

- First, we initialize the Decoder states by using the last states of the Encoder as usual
- Then at each decoding time step:
    - We use Encoder's all hidden states and the previous Decoder's output to calculate a Context Vector by applying an Attention Mechanism
    - Lastly, we concatenate the Context Vector with the previous Decoder's output to create the input to the decoder.

All the preprocessing steps will be same as that used in seq2seq model. Let's start by doing the same.

In [1]:
import os
import time
import math
import torch
import random
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from typing import Iterable, List
from torch.utils.data import DataLoader
from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator as bvfi

## Tokenization and Vocabulary Building

In [2]:
SRC_LANG = 'de'
TGT_LANG = 'en'
specials = {'<UNK>': 0, '<PAD>': 1, '<SOS>': 2, '<EOS>': 3}

tokenizer = dict()
vocab = dict()

Create source and target language tokenizer. Make sure to install the dependencies.

```
pip install -U torchdata
pip install -U spacy
python -m spacy download en_core_web_sm
python -m spacy download de_core_news_sm
```

In [3]:
# !pip install -U torchdata
# !pip install -U spacy
# !python -m spacy download en_core_web_sm
# !python -m spacy download de_core_news_sm

In [4]:
tokenizer[SRC_LANG] = get_tokenizer('spacy', language='de_core_news_sm')
tokenizer[TGT_LANG] = get_tokenizer('spacy', language='en_core_web_sm')

In [5]:
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANG: 0, TGT_LANG: 1}

    for data_sample in data_iter:
        yield tokenizer[language](data_sample[language_index[language]])

In [6]:
for lang in [SRC_LANG, TGT_LANG]:
    train_iterator, valid_iterator, test_iterator = Multi30k()    # Training data Iterator
    vocab[lang] = bvfi(yield_tokens(train_iterator, lang), min_freq=1, specials=specials.keys(), special_first=True)

Set token index (i.e. 0 here) as the default index. This index is returned when the token is not found. If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.

In [7]:
for lang in [SRC_LANG, TGT_LANG]:
  vocab[lang].set_default_index(specials['<UNK>'])

## Encoder

The encoder architecture is same as that used in seq2seq except the following two facts:
- We will be using single layer of RNN
- We will be using bidirectional RNN (forward + backward)

As done in seq2seq, we initialize both forward and backward hidden states to a tensor of zeros. We get two context vectors one from each of forward and backward RNNs. However the decoder being unidirectional needs single context vector as input. To facilitate this we'll be concatinating two context vectors together.

In [8]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        self.embed = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, enc_hid_dim, bidirectional=True)
        self.fc = nn.Linear(enc_hid_dim*2, dec_hid_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        embedding = self.dropout(self.embed(src))  # [len(src), batch_size, emb_dim]
        output, (hidden, cell) = self.rnn(embedding)
        # output = [len(src), batch_size, hid_dim * n_directions]
        # hidden = cell = [n layers * n directions, batch size, hid dim]
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)))  # [batch_size, dec_hid_dim]
        return output, hidden

## Attention

The attention layer is important to decide which word in the encoder should we pay most attention to in order to predict the next word. The attention layer takes all the hidden states of encoders and only the previous hidden state of the decoder. The result is an attention vector whose elements have values between 0 and 1 and the softmax layer confirms these values sum to 1.

We calculate the energy between encoder hidden states and previous decoder hidden state. However the shape of encoder and decoder hidden states do not match, because we have n encoder hidden states and a single decoder hidden state. So we repeat decoder hidden state n times and concatinate two tensors together. This `energy` determines how well each encoder hidden state matches the previous decoder hidden state.

We currently have a [dec_hid_dim, src_len] tensor for each example in the batch. We want this to be [src_len] for each example in the batch as the attention should be over the length of the source sentence. This is achieved by multiplying the energy by a [1, dec_hid_dim] tensor `v`. We can think of `v` as the weights for a weighted sum of the energy across all encoder hidden states. These weights tell us how much we should attend to each token in the source sequence. The parameters of `v` are initialized randomly, but learned with the rest of the model via backpropagation. Note how `v` is not dependent on time, and the same `v` is used for each time-step of the decoding. We implement `v` as a linear layer without a bias.

In [9]:
class BahdanauAttention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
      super().__init__()
      self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
      self.v = nn.Linear(dec_hid_dim, 1, bias = False)
        
    def forward(self, hidden, encoder_outputs):
      src_len = encoder_outputs.shape[0]
      batch_size = encoder_outputs.shape[1]
      hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)  # Repeat decoder hidden state src_len times
      encoder_outputs = encoder_outputs.permute(1, 0, 2)
      # hidden = [batch size, src len, dec hid dim]
      # encoder_outputs = [batch size, src len, enc hid dim * 2]
      energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2)))  # [batch size, src len, dec hid dim]
      attention = self.v(energy).squeeze(2)  # [batch size, src len]
      return F.softmax(attention, dim=1)

## Decoder

Decoder uses the attention vector to create a weighted source vector where attention vector `a` is the weight to encoder hidden states denoted by `encoder_outputs`.

In [10]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
        self.fc = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, encoder_outputs): 
        #input = [batch size]
        #hidden = [batch size, dec hid dim]
        #encoder_outputs = [src len, batch size, enc hid dim * 2]
        input = input.unsqueeze(0)  # [1, batch size]
        embedded = self.dropout(self.embedding(input))  # [1, batch size, emb dim]
        a = self.attention(hidden, encoder_outputs)  # [batch size, src len]
        a = a.unsqueeze(1)  # [batch size, 1, src len]
        encoder_outputs = encoder_outputs.permute(1, 0, 2)  # [batch size, src len, enc hid dim * 2]
        weighted = torch.bmm(a, encoder_outputs)  # [batch size, 1, enc hid dim * 2]
        weighted = weighted.permute(1, 0, 2)  # [1, batch size, enc hid dim * 2]
        rnn_input = torch.cat((embedded, weighted), dim = 2)  # [1, batch size, (enc hid dim * 2) + emb dim]
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
        #output = [seq len, batch size, dec hid dim * n directions]
        #hidden = [n layers * n directions, batch size, dec hid dim]
        assert (output == hidden).all()
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        prediction = self.fc(torch.cat((output, weighted, embedded), dim = 1))  # [batch size, output dim]
        return prediction, hidden.squeeze(0)

## Seq2Seq

Here decoder and encoder have different hidden dimensions.

In [11]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, src, tgt, teacher_forcing_ratio = 0.5):
        #src = [src len, batch size]
        #tgt = [tgt len, batch size]
        batch_size = src.shape[1]
        tgt_len = tgt.shape[0]
        tgt_vocab_size = self.decoder.output_dim
        outputs = torch.zeros(tgt_len, batch_size, tgt_vocab_size).to(self.device)  # tensor to store decoder outputs
        encoder_outputs, hidden = self.encoder(src)
        input = tgt[0, :]  # First input to decoder is the <sos> tokens
        for t in range(1, tgt_len):
            output, hidden = self.decoder(input, hidden, encoder_outputs)
            outputs[t] = output  # Store predictions in a tensor initialized above
            teacher_force = random.random() < teacher_forcing_ratio  # Decide if we are going to use teacher forcing
            top1 = output.argmax(1)  # Get the highest predicted token
            input = tgt[t] if teacher_force else top1
        return outputs

## Training Seq2Seq Model

In [12]:
INPUT_DIM = len(vocab[SRC_LANG])
OUTPUT_DIM = len(vocab[TGT_LANG])
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
BATCH_SIZE = 128

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
attn = BahdanauAttention(ENC_HID_DIM, DEC_HID_DIM)
encoder = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
decoder = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(encoder, decoder, device).to(device)

In [14]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)
            
model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (embed): Embedding(19214, 256)
    (rnn): LSTM(256, 512, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (attention): BahdanauAttention(
      (attn): Linear(in_features=1536, out_features=512, bias=True)
      (v): Linear(in_features=512, out_features=1, bias=False)
    )
    (embedding): Embedding(10837, 256)
    (rnn): GRU(1280, 512)
    (fc): Linear(in_features=1792, out_features=10837, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [15]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 34,345,557 trainable parameters


In [16]:
optimizer = Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=specials['<PAD>'])

In [17]:
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func


def tensor_transform(token_id: List[int]):
    return torch.cat((torch.tensor([specials['<SOS>']]), torch.tensor(token_id), torch.tensor([specials['<EOS>']])))

In [18]:
text_transform = {}
for ln in [SRC_LANG, TGT_LANG]:
    text_transform[ln] = sequential_transforms(tokenizer[ln], vocab[ln], tensor_transform)

In [19]:
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANG](src_sample))
        tgt_batch.append(text_transform[TGT_LANG](tgt_sample))

    src_batch = pad_sequence(src_batch, padding_value=specials['<PAD>'])
    tgt_batch = pad_sequence(tgt_batch, padding_value=specials['<PAD>'])
    return src_batch, tgt_batch

In [20]:
train_dataloader = DataLoader(train_iterator, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
valid_dataloader = DataLoader(valid_iterator, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
test_dataloader = DataLoader(test_iterator, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

In [21]:
def train(model, dataloader, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    batch_idx = 0
    for src, tgt in dataloader:
        src = src.to(device)  # [len(src), batch_size]
        tgt = tgt.to(device)  # [len(tgt), batch_size]
        optimizer.zero_grad()
        output = model(src, tgt)  # [len(tgt), batch_size, output_dim]
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        tgt = tgt[1:].view(-1)
        loss = criterion(output, tgt)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
        batch_idx += 1
    return epoch_loss / batch_idx

In [22]:
def evaluate(model, dataloader, criterion):
    model.eval()
    epoch_loss = 0
    batch_idx = 0
    with torch.no_grad():
        for src, tgt in dataloader:
            src = src.to(device)  # [len(src), batch_size]
            tgt = tgt.to(device)  # [len(tgt), batch_size]
            output = model(src, tgt, 0)  # Teacher forcing is turned off
            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)  # [(len(tgt) - 1) * batch size, output_dim]
            tgt = tgt[1:].view(-1)  # Shape = [(len(tgt) - 1) * batch size]
            loss = criterion(output, tgt)
            epoch_loss += loss.item()
            batch_idx += 1
    return epoch_loss / batch_idx

In [23]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [24]:
EPOCHS = 10
CLIP = 1

if not os.path.exists('./../models'):
  os.mkdir('./../models')

In [25]:
best_valid_loss = float('inf')
for epoch in range(EPOCHS):
    start_time = time.time()
    train_loss = train(model, train_dataloader, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_dataloader, criterion)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), './../models/bahdanau.pt')
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}:{epoch_secs} | Train Loss: {train_loss:.3f} | Val Loss: {valid_loss:.3f}')

Epoch: 01 | Time: 1:49 | Train Loss: 5.177 | Val Loss: 5.005
Epoch: 02 | Time: 1:50 | Train Loss: 4.284 | Val Loss: 4.612
Epoch: 03 | Time: 1:51 | Train Loss: 3.766 | Val Loss: 4.290
Epoch: 04 | Time: 1:52 | Train Loss: 3.262 | Val Loss: 3.883
Epoch: 05 | Time: 1:53 | Train Loss: 2.804 | Val Loss: 3.730
Epoch: 06 | Time: 1:53 | Train Loss: 2.417 | Val Loss: 3.646
Epoch: 07 | Time: 1:52 | Train Loss: 2.087 | Val Loss: 3.553
Epoch: 08 | Time: 1:53 | Train Loss: 1.817 | Val Loss: 3.629
Epoch: 09 | Time: 1:52 | Train Loss: 1.640 | Val Loss: 3.623
Epoch: 10 | Time: 1:53 | Train Loss: 1.485 | Val Loss: 3.685


In [26]:
# model.load_state_dict(torch.load('./../models/bahdanau.pt'))
# test_loss = evaluate(model, test_dataloader, criterion)

# print(f'Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f}')

## References

- [The Power of Attention in Deep Learning](https://www.youtube.com/watch?v=Qu81irGlR-0)
- [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/pdf/1409.0473.pdf)
- [The Bahdanau Attention Mechanism](https://machinelearningmastery.com/the-bahdanau-attention-mechanism/#:~:text=The%20Bahdanau%20attention%20was%20proposed,mechanism%20for%20neural%20machine%20translation.)