# Translation using Bahdanau Attention

Major parts of this notebooks are copied from our encoder-decoder notebook: [Translation with Encoder-Decoder RNN](http://localhost:5173/blocks/deep_learning/sequence_modelling/pytorch_implementations/encoder_decoder_translation). Use that for reference, if some parts seem unfamiliar.

In [1]:
import re
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

We once again use the English-German Anki dataset. 

In [3]:
!cd ../datasets/ && { curl -O https://www.manythings.org/anki/deu-eng.zip ; cd -; }

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 9376k  100 9376k    0     0   913k      0  0:00:10  0:00:10 --:--:-- 1381k
/home/petruschka/repos/World4AI/website/src/notebooks/attention


In [4]:
!rm -rf ../datasets/deu_eng/
!unzip ../datasets/deu-eng.zip -d ../datasets/deu_eng

Archive:  ../datasets/deu-eng.zip
  inflating: ../datasets/deu_eng/deu.txt  
  inflating: ../datasets/deu_eng/_about.txt  


And we use a simple tokenizer that lowercases the text, strips unnecessary whitespace and adds some padding between words and tokens like .!?. We also use special tokens `<sos>` and `<eos>` to  indicate the start and end of a sentence.

In [5]:
def normalize(s):
    s = s.lower().strip()
    s = re.sub(r"([.!?])", r" \1", s)
    return s

def tokenizer(s):
    s = normalize(s)
    s = s.split(' ')
    s.insert(0, '<sos>')
    s.append('<eos>')
    return s

In [6]:
def read_pairs(max_len=20):
    print("Reading lines...")
    en_seq = []
    de_seq = []
    with open('../datasets/deu_eng/deu.txt', 'r', encoding='utf-8') as file:
        print(f"Tokenizing and removing sentences larger than {max_len}")
        for line in file:
            pairs = line.split('\t')
            
            en_sentence, de_sentence = tokenizer(pairs[0]), tokenizer(pairs[1])
            
            if len(en_sentence) <= max_len and len(de_sentence) <= max_len:
                en_seq.append(en_sentence)
                de_seq.append(de_sentence)
        print(f"The dataset has {len(en_seq)} pairs")
        return en_seq, de_seq

In [7]:
en_seq, de_seq = read_pairs()

Reading lines...
Tokenizing and removing sentences larger than 20
The dataset has 255279 pairs


We divide our dataset into the train, validation and test sets using sklearn.

In [8]:
from sklearn.model_selection import train_test_split

In [9]:
#separate into train test split
# train_frac = 0.8
# val_frac = 0.1
# test_frac = 0.1
train_en, test_val_en, train_de, test_val_de = train_test_split(en_seq, de_seq, test_size=0.2)
val_en, test_en, val_de, test_de = train_test_split(test_val_en, test_val_de, test_size=0.5)

In [10]:
class PairDataset(Dataset):
    def __init__(self, en, de):
        assert len(en) == len(de)
        self.en = en
        self.de = de
    
    def __len__(self):
        return len(self.en)
    
    def __getitem__(self, idx):
        return self.en[idx], self.de[idx]

In [11]:
train_dataset = PairDataset(train_en, train_de)
val_dataset = PairDataset(val_en, val_de)
test_dataset = PairDataset(test_en, test_de)

And we create an English and a German vocabulary using torchtext.

In [12]:
from collections import Counter, OrderedDict

In [13]:
en_counter = Counter()
de_counter = Counter()

for line in train_en:
    en_counter.update(line)

for line in train_de:
    de_counter.update(line)

In [14]:
en_sorted_by_freq_tuples = sorted(en_counter.items(), key=lambda x: x[1], reverse=True)
en_ordered_dict = OrderedDict(en_sorted_by_freq_tuples)

de_sorted_by_freq_tuples = sorted(de_counter.items(), key=lambda x: x[1], reverse=True)
de_ordered_dict = OrderedDict(de_sorted_by_freq_tuples)

In [15]:
import torchtext
en_vocab = torchtext.vocab.vocab(en_ordered_dict, min_freq = 5, specials=['<pad>', '<unk>', '<sos>', '<eos>'], special_first = True)
de_vocab = torchtext.vocab.vocab(de_ordered_dict, min_freq = 5, specials=['<pad>', '<unk>', '<sos>', '<eos>'], special_first = True)

en_vocab.set_default_index(1)
de_vocab.set_default_index(1)

The collate function is required to zero pad shorter sentences to the sentence lenght of the larger sentece in a batch.

In [16]:
def collate(batch):
    en, de = [], []
    for en_token, de_token in batch:
        en.append(torch.tensor(en_vocab(en_token), dtype=torch.int64))
        de.append(torch.tensor(de_vocab(de_token), dtype=torch.int64))
    en_padded = nn.utils.rnn.pad_sequence(en, batch_first=True)
    de_padded = nn.utils.rnn.pad_sequence(de, batch_first=True)
    return en_padded, de_padded

In [17]:
BATCH_SIZE=128
train_dataloader = DataLoader(dataset=train_dataset, 
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=2,
                              drop_last=True,
                              collate_fn=collate)
val_dataloader = DataLoader(dataset=val_dataset, 
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              num_workers=2,
                              drop_last=False,
                              collate_fn=collate)
test_dataloader = DataLoader(dataset=test_dataset, 
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              num_workers=2,
                              drop_last=False,
                              collate_fn=collate)

We use just one layer for encoder and decoder to make the calculations simpler, but we use a biderectional LSTM, as in the paper. We could achieve a better performance with several layers, but the purpose of this notebook is not to achieve state of the art results, but to provide a simple and intuitive explanation.

In [18]:
class Encoder(nn.Module):
    
    def __init__(self, num_embeddings, embedding_dim=128, hidden_size=128):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=1, batch_first=True)
        
    def forward(self, x):
        x = self.embedding(x)
        outputs, (h_n, c_n) = self.lstm(x)
        return outputs, h_n, c_n

The decoder looks a little more complicated, so let's try and explain its components.

1. The `energy` linear layer is the input to the softmax function, that outputs attention weights. To create energy we concatenate encoder_outputs and hidden state h and use that as input to the linear layer. This linear layer maps a 128*2 (hidden + output) sized vectors and outputs just a single value per encoder output.

2. The context is produced by multiplying encoder outputs with attention weights. We combine those with the decoder embeddings using the `combine` linear layer. The output is used as input into the LSTMCell.

3. The last fully connected layer `fc` is responsible for producing logits. These will be used in greedy samplin

In [19]:
class Decoder(nn.Module):
    
    def __init__(self, num_embeddings, embedding_dim=128, hidden_size=128):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)
        self.lstm_cell = nn.LSTMCell(input_size=embedding_dim, hidden_size=hidden_size)
        self.energy = nn.Linear(hidden_size*2, 1)
        self.combine = nn.Linear(hidden_size*2, hidden_size)
        self.fc = nn.Linear(hidden_size, num_embeddings)
    
    def forward(self, x, h, c, encoder_outputs):
        embedding = self.embedding(x) 
        h, c = h.squeeze(0), c.squeeze(0)
        
        energy_input = h.unsqueeze(1).repeat(1, encoder_outputs.shape[1], 1)
        energy_input = torch.cat((encoder_outputs, energy_input), dim=2)
        energy = self.energy(energy_input).squeeze(2)
        attention = torch.softmax(energy, dim=1)
        
        context = attention.unsqueeze(1) @ encoder_outputs
        context = context.squeeze(1)
        
        x = self.combine(torch.cat((context, embedding), dim=1))
        (h_n, c_n) = self.lstm_cell(x, (h, c))
        logits = self.fc(h_n)
        return logits, h_n, c_n

The rest of the implementation is the same.

In [20]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, teacher_forcing_ratio=0.5):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.teacher_forcing_ratio = teacher_forcing_ratio
    
    def forward(self, en_sequence, de_sequence):
        batch_size, sequence_len, num_de_embeddings = de_sequence.size()[0], de_sequence.size()[1], self.decoder.embedding.num_embeddings
        
        # minus 1 due to fewer predictions as inputs, we don't predict <sos>
        outputs = torch.zeros(batch_size, sequence_len-1, num_de_embeddings, device=DEVICE)

        encoder_outputs, h_n, c_n = self.encoder(en_sequence)
        inp = de_sequence[:, 0]
        for i in range(1, sequence_len):
            logits, h_n, c_n = decoder(inp, h_n, c_n, encoder_outputs)
            outputs[:, i-1] = logits
            
            force = random.random() < self.teacher_forcing_ratio
            if force:
                inp = de_sequence[:, i]
            else:
                inp = logits.argmax(dim=1)
        
        return outputs
        

In [21]:
def track_performance(dataloader, model, criterion):
    # switch to evaluation mode
    model.eval()
    loss_sum = 0
    num_iterations = 0

    # no need to calculate gradients
    with torch.inference_mode():
        for en_sequence, de_sequence in dataloader:
            en_sequence = en_sequence.to(DEVICE)
            de_sequence = de_sequence.to(DEVICE)

            logits = model(en_sequence, de_sequence)
            
            # we don't actually predict the <sos> token
            labels = de_sequence[:, 1:]
            # we need to reshape in order to be able to use these tensors with CrossEntropyLoss
            logits = logits.reshape(-1, logits.size()[2])
            labels = labels.reshape(-1)
            loss = criterion(logits, labels)
            loss_sum += loss.cpu().item()
            num_iterations+=1

    # we return the average loss
    return loss_sum/num_iterations



In [22]:
def train(num_epochs, train_dataloader, val_dataloader, model, optimizer, criterion, scheduler=None):
    min_loss = float("inf")
    for epoch in range(num_epochs):
        loss_sum = 0
        num_iterations = 0
        for en_sequence, de_sequence in train_dataloader:
            model.train()

            optimizer.zero_grad()
            en_sequence = en_sequence.to(DEVICE)
            de_sequence = de_sequence.to(DEVICE)

            logits = model(en_sequence, de_sequence)
            # we don't actually predict the <sos> token
            labels = de_sequence[:, 1:]

            # we need to reshape in order to be able to use these tensors with CrossEntropyLoss
            logits = logits.reshape(-1, logits.size()[2])
            labels = labels.reshape(-1)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            
            loss_sum += loss.cpu().item()
            num_iterations += 1
        train_loss=loss_sum/num_iterations
        val_loss = track_performance(val_dataloader, model, criterion)
        if scheduler:
            scheduler.step(val_loss)
        print(f'Epoch: {epoch+1:>2}/{num_epochs} | Train Loss: {train_loss:.5f} | Val Loss: {val_loss:.5f}')
        
        if val_loss < min_loss:
            print("Saving Weights!")
            min_loss = val_loss
            torch.save({'encoder_weights': encoder.state_dict(), 'decoder_weights': decoder.state_dict()}, f='../temp/encoder_decoder.pt')


In [23]:
encoder = Encoder(num_embeddings=len(en_vocab), embedding_dim=128)
decoder = Decoder(num_embeddings=len(de_vocab), embedding_dim=128)
seq2seq = EncoderDecoder(encoder, decoder).to(DEVICE)

In [24]:
optimizer = optim.Adam(seq2seq.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=0)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       factor=0.1,
                                                       mode='min',
                                                       patience=2,
                                                       verbose=True)

num_epochs=25

The validation loss looks very similar to the one in the simple encoder-decoder implementation.

In [25]:
train(num_epochs, train_dataloader, val_dataloader, seq2seq, optimizer, criterion, scheduler)

Epoch:  1/25 | Train Loss: 4.66697 | Val Loss: 3.96499
Saving Weights!
Epoch:  2/25 | Train Loss: 3.67404 | Val Loss: 3.43004
Saving Weights!
Epoch:  3/25 | Train Loss: 3.22975 | Val Loss: 3.10357
Saving Weights!
Epoch:  4/25 | Train Loss: 2.92408 | Val Loss: 2.87413
Saving Weights!
Epoch:  5/25 | Train Loss: 2.69841 | Val Loss: 2.70660
Saving Weights!
Epoch:  6/25 | Train Loss: 2.51097 | Val Loss: 2.57303
Saving Weights!
Epoch:  7/25 | Train Loss: 2.37196 | Val Loss: 2.49242
Saving Weights!
Epoch:  8/25 | Train Loss: 2.25726 | Val Loss: 2.39603
Saving Weights!
Epoch:  9/25 | Train Loss: 2.15084 | Val Loss: 2.35311
Saving Weights!
Epoch: 10/25 | Train Loss: 2.07572 | Val Loss: 2.30002
Saving Weights!
Epoch: 11/25 | Train Loss: 2.00712 | Val Loss: 2.24609
Saving Weights!
Epoch: 12/25 | Train Loss: 1.94464 | Val Loss: 2.23760
Saving Weights!
Epoch: 13/25 | Train Loss: 1.88540 | Val Loss: 2.18921
Saving Weights!
Epoch: 14/25 | Train Loss: 1.85016 | Val Loss: 2.16090
Saving Weights!
Epoch:

In [26]:
weights = torch.load('../temp/encoder_decoder.pt')
encoder_weights = weights['encoder_weights']
decoder_weights = weights['decoder_weights']

In [27]:
encoder.load_state_dict(encoder_weights)
decoder.load_state_dict(decoder_weights)

<All keys matched successfully>

In [28]:
def translate_sentence(sentence, vocab, encoder, decoder):
    with torch.inference_mode():
        outputs = []
        
        start_token = ["<sos>"]
        end_token = ["<eos>"]
        start_idx = vocab(start_token)[0]
        end_idx = vocab(end_token)[0]
                
        encoder_outputs, h_n, c_n = encoder(sentence)
        inp = torch.tensor([start_idx], device=DEVICE)
        while True:
            logits, h_n, c_n = decoder(inp, h_n, c_n, encoder_outputs)
            h_n = h_n.unsqueeze(0)
            c_n = c_n.unsqueeze(0)
            inp = logits.argmax(dim=1)
            outputs.append(inp.cpu().item())
            if inp.item() == end_idx:
                break
        return outputs

In [29]:
en_sequence, de_sequence = next(iter(test_dataloader))
en_sequence = en_sequence.to(DEVICE)

Similar to the encoder-decoder implementation, the quality of the translation is not optimal. But given our small model and the limited amount of data, this result is ok.

In [30]:
for i in range(10):
    en_sentence = en_sequence[i].unsqueeze(0)
    de_sentence = de_sequence[i].unsqueeze(0)
    translation = translate_sentence(en_sentence, en_vocab, encoder, decoder)
    print('-'*130)
    print(f'English Sentence: {en_vocab.lookup_tokens(en_sentence[0].cpu().tolist())}')
    print(f'German Translation: {de_vocab.lookup_tokens(de_sentence[0].cpu().tolist())}')
    print(f'Model Translation: {de_vocab.lookup_tokens(translation)}')
    

----------------------------------------------------------------------------------------------------------------------------------
English Sentence: ['<sos>', 'he', 'tried', 'to', 'approach', 'her', 'using', 'every', 'possible', 'means', '.', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
German Translation: ['<sos>', 'er', 'versuchte', 'auf', 'jede', '<unk>', '<unk>', 'an', 'sie', '<unk>', '.', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
Model Translation: ['er', 'versuchte', 'jeden', '<unk>', 'von', '<unk>', 'zu', '.', '<eos>']
----------------------------------------------------------------------------------------------------------------------------------
English Sentence: ['<sos>', 'please', "don't", 'touch', 'the', '<unk>', '.', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad