!<img src="transformer_architecture.png" alt="Transformer Architecture" width="500"/>

A classic RNN $$ h_t = W_{xh} \odot x + W_{hh} \odot h_{t-1) $$ $$ y_t = W_{hy} \odot h_t $$

This has infinite short-term memory. Attempts have appeared with LSTM and GRU to fix it, but it seems like today the best way is self-attention.

**Self-attention** has, by definition, bounded memory to the size of the sequence length   
and now the inputs are not passed through hidden layers sequentially,  
instead are passed through a self-attention layer as a sequence with positional embeddings.

**Attention Mechanism** can be though of as a memory with keys and values and a layer  
 which, when someone queries it, generates an output from value whose keys map the input.
 
The formulas are simple:
- In Bahdanau Seq2Seq there is no $\mathbf{v}$ vector for values, the $\mathbf{v}$ is a linear projection matrix (energies_layer) - 
$$ output = \text{softmax}(\textbf{a}(\mathbf{q}, \mathbf{k})) \cdot \mathbf{v} $$ 

where $ \textbf{a} $ can be as simple as $ \alpha(\mathbf{q}, \mathbf{k}) = \mathbf{q} \cdot \mathbf{k} $ 
or a projection to a hidden dimension using $ \left(\mathbf{W_k}, \mathbf{W_q} \right) $ so now we end up with 
$$ \alpha(\mathbf{q}, \mathbf{k}) = \mathbf{v}^\mathsf{T} \text{tanh}(\mathbf{W_k}\mathbf{k} + \mathbf{W_q}\mathbf{q}) $$

To maintain positional information we use positional encodings $$ P_{i, 2j} = \text{sin}(i/1000^{2j/d})   P_{i, 2j+1} = \text{cos}(i/1000^{2j/d}) $$

!<img src="positional_encoding_visualization.png" alt="Positional Encoding for Different Depths" width="500"/>

And so we end up with a new input representation where we combine the token embeddings $ W $ and the positional ones: $$ \mathbf{X} = \mathbf{W} + \mathbf{P} $$
 


## The Transformer body

**self-attention**:
$ \text{attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^\mathsf{T}} {\sqrt{d_k}}\right) \mathbf{V} $


And when we use multiple self-attention heads to capture different meanings we have multi-head attention, where: 
$$ head_i = attention(\mathbf{W_q^i}\mathbf{Q}, \mathbf{W_k^i}\mathbf{K}, \mathbf{W_v^i}\mathbf{V}) $$
$$ multihead(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \mathbf{W_0}\text{concat}(head_1,..., head_h)$$


**Masked Attention:** For the decoder to learn we hide future inputs by adding M where M is -inf for future values.
$$ \text{maskedAttention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^\mathsf{T} + \mathbf{M}} {\sqrt{d_k}}\right) \mathbf{V} $$


**AddAndNorm and ResidualConnections**
As we can see in the transformer image, we use a skip-connection to bypass the inputs $X$ to the output $Z$ -like in ResNets- where they are added and normalized to mean 0 and std 1 $h_i = \frac{g_{ain}}{\sigma}(h_i - \mu)$.  

**Positionwise FFNs**: They are kind of unique... (Similar linear transformations with ReLU activation in between is performed)
$$ FFN(\mathbf{x}) = \text{ReLU}(\mathbf{xW_1} + b_1)\mathbf{W_2} + b_2 $$ 

**Encoder-Decoder** As in image


# Machine Translation

## First, we create a Bahdanau Seq2Seq Bidirectional GRU with attention
Update Gate:
$ z_t^{(f)} = \sigma(W_z^{(f)} \cdot [h_{t-1}^{(f)}, x_t]) $
$ z_t^{(b)} = \sigma(W_z^{(b)} \cdot [h_{t+1}^{(b)}, x_t]) $  
Reset Gate:
$ r_t^{(f)} = \sigma(W_r^{(f)} \cdot [h_{t-1}^{(f)}, x_t]) $
$ r_t^{(b)} = \sigma(W_r^{(b)} \cdot [h_{t+1}^{(b)}, x_t]) $  
Hidden State:
$ \tilde{h}_t^{(f)} = \tanh(W_h^{(f)} \cdot [r_t^{(f)} \odot h_{t-1}^{(f)}, x_t]) $
$ h_t^{(f)} = (1 - z_t^{(f)}) \odot h_{t-1}^{(f)} + z_t^{(f)} \odot \tilde{h}_t^{(f)} $    
$ \tilde{h}_t^{(b)} = \tanh(W_h^{(b)} \cdot [r_t^{(b)} \odot h_{t+1}^{(b)}, x_t]) $
$ h_t^{(b)} = (1 - z_t^{(b)}) \odot h_{t+1}^{(b)} + z_t^{(b)} \odot \tilde{h}_t^{(b)} $

**Combined Hidden State (concat)**
$ h_t = [h_t^{(f)}, h_t^{(b)}] $

In [66]:
import os
import pickle
import random

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [3]:

# Splitting data into train, validation, and test sets
SPECIALS = ['<unk>', '<pad>', '<bos>', '<eos>']
tokenized_data_path = 'tokenized_data.pt'
en_vocab_path = 'en_vocab.pkl'
fr_vocab_path = 'fr_vocab.pkl'

with open(en_vocab_path, 'rb') as f:
    en_vocab = pickle.load(f)

with open(fr_vocab_path, 'rb') as f:
    fr_vocab = pickle.load(f)

en_data, fr_data = torch.load(tokenized_data_path)

VALID_PCT = 0.1
TEST_PCT = 0.1

train_data = []
valid_data = []
test_data = []

random.seed(6547)
for en_tensor_, fr_tensor_ in zip(en_data, fr_data):
    en_tensor_ = torch.tensor(en_tensor_)
    fr_tensor_ = torch.tensor(fr_tensor_)
    random_draw = random.random()
    if random_draw <= VALID_PCT:
        valid_data.append((en_tensor_, fr_tensor_))
    elif random_draw <= VALID_PCT + TEST_PCT:
        test_data.append((en_tensor_, fr_tensor_))
    else:
        train_data.append((en_tensor_, fr_tensor_))

print(f"""
  Training pairs: {len(train_data):,}
Validation pairs: {len(valid_data):,}
      Test pairs: {len(test_data):,}""")

# Define special tokens indices
PAD_IDX = en_vocab['<pad>']
BOS_IDX = en_vocab['<bos>']
EOS_IDX = en_vocab['<eos>']

# Ensure that special tokens are the same in both vocabularies
for en_id, fr_id in zip([en_vocab[token] for token in SPECIALS], [fr_vocab[token] for token in SPECIALS]):
    assert en_id == fr_id


# Function to generate a batch of data
def generate_batch(data_batch):
    en_batch, fr_batch = [], []
    for (en_item, fr_item) in data_batch:
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
        fr_batch.append(torch.cat([torch.tensor([BOS_IDX]), fr_item, torch.tensor([EOS_IDX])], dim=0))

    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX, batch_first=False)
    fr_batch = pad_sequence(fr_batch, padding_value=PAD_IDX, batch_first=False)

    return en_batch, fr_batch


# Create DataLoaders
BATCH_SIZE = 16

train_iter = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=generate_batch)
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=generate_batch)

# Verify the DataLoader output
for i, (en_id, fr_id) in enumerate(train_iter):
    print('\nEnglish:', ' '.join([list(en_vocab.keys())[list(en_vocab.values()).index(idx)] for idx in en_id[:, 0]]))
    print('French:', ' '.join([list(fr_vocab.keys())[list(fr_vocab.values()).index(idx)] for idx in fr_id[:, 0]]))
    if i == 4: break



  Training pairs: 108,111
Validation pairs: 13,648
      Test pairs: 13,525

English: <bos> this time you ve gone too far . <eos> <pad> <pad>
French: <bos> cette fois tu es alle trop loin . <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad>

English: <bos> open up . <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
French: <bos> ouvre moi ! <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>

English: <bos> do you think that eating breakfast every day is important ? <eos> <pad>
French: <bos> penses tu que prendre un petit dejeuner chaque jour soit important ? <eos> <pad> <pad> <pad> <pad>

English: <bos> how tall you are ! <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
French: <bos> comme tu es grand ! <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>

English: <bos> it keeps you on your toes . <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
French: <bos> ca t oblige a rester vigilante . <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


In [64]:
MAX_SENTENCE_LENGTH = 20
FILTER_TO_BASIC_PREFIXES = False
SAVE_DIR = os.path.join(".", "simple_models")

ENCODER_EMBEDDING_DIM = 256
ENCODER_HIDDEN_SIZE = 256
DECODER_EMBEDDING_DIM = 256
DECODER_HIDDEN_SIZE = 256

In [58]:
class BahdanauEncoder(nn.Module):
    def __init__(self, input_dim, embedding_dim, encoder_hidden_dim, decoder_hidden_dim, dropout_p):
        super().__init__()
        self.input_dim = input_dim  # ~= Vocabulary size
        self.embedding_dim = embedding_dim
        self.encoder_hidden_dim = encoder_hidden_dim
        self.decoder_hidden_dim = decoder_hidden_dim
        self.dropout_p = dropout_p

        # (B,Voc) -> (B, Voc, Emb) The weights the multiply (1,Voc) vectors follow U(-embedding_dim^-1/2, embedding_dim^-1/2)  
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.gru = nn.GRU(embedding_dim, encoder_hidden_dim, bidirectional=True)
        self.linear = nn.Linear(encoder_hidden_dim * 2,
                                decoder_hidden_dim)  # Maps concatenated hidden states (from both directions) to the decoder's hidden.
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, x):
        embedded = self.dropout(self.embedding(x))
        # The embedded sequence is processed by the bidirectional GRU. 
        # The outputs contain the hidden states for all time steps, while hidden contains the final hidden states from both directions.
        outputs, hidden = self.gru(embedded)

        """Syntax Breakdown:
        torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1):

        hidden[-2, :, :] refers to the last hidden state from the backward GRU.
        hidden[-1, :, :] refers to the last hidden state from the forward GRU.
        torch.hstack(...) concatenates these two hidden states along the feature dimension, resulting in a tensor of size [batch_size, encoder_hidden_dim * 2].
        self.linear(...):

        The concatenated hidden states are passed through the linear layer to map them to the decoder_hidden_dim."""
        hidden = torch.tanh(self.linear(torch.hstack((hidden[-2, :, :], hidden[-1, :, :]))))
        return outputs, hidden


!<img src="seq2seq.png">>
1. The output of the last encoder states are used as keys $\mathbf{k}$ and values $\mathbf{v}$
2. The output of the last decoder state, at time $t-1$ is used as query $\mathbf{q}$
3. The output from the attention layer $\mathbf{o}$, the context variable, is used for the next decoder state $t$

In [59]:
class BahdanauAttentionQKV(nn.Module):
    def __init__(self, hidden_size, query_size=None, key_size=None, dropout_p=0.15):
        super().__init__()
        self.hidden_size = hidden_size
        self.query_size = hidden_size if query_size is None else query_size

        # assume bidirectional encoder, but can specify otherwise
        self.key_size = 2 * hidden_size if key_size is None else key_size
        self.query_layer = nn.Linear(self.query_size, self.hidden_size)
        self.key_layer = nn.Linear(self.key_size, self.hidden_size)
        self.energy_layer = nn.Linear(self.hidden_size, 1)  # score?
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, decoder_last_hidden, encoder_outputs, src_mask=None):
        # (B, H)
        query_out = self.query_layer(decoder_last_hidden)
        # (Src, B, 2*H) --> (Src, B, H)
        key_out = self.key_layer(encoder_outputs)
        # (B, H) + (Src, B, H) = (Src, B ,H)
        energy_input = torch.tanh(query_out + key_out)
        # (Src, B, H) --> (Src, B, 1) --> (Src, B)
        energies = self.energy_layer(energy_input).squeeze(2)  # todo could this be just squeeze()
        # if a mask is provided, remove masked tokens from softmax calc
        if src_mask is not None:
            energies.data.masked_fill_(src_mask == 0, float("-inf"))
        # softmax over the length dimension
        weights = F.softmax(energies,
                            dim=0)  # So we now have the attention probs. We do softmax in batch dim so for every element.
        # return as (B, Src) as expected by later multiplication
        return weights.transpose(0, 1)

In [76]:
class BahdanauDecoder(nn.Module):
    def __init__(self, output_dim, embedding_dim, encoder_hidden_dim,
                 decoder_hidden_dim, attention, dropout_p):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.output_dim = output_dim
        self.encoder_hidden_dim = encoder_hidden_dim
        self.decoder_hidden_dim = decoder_hidden_dim
        self.dropout_p = dropout_p

        self.embedding = nn.Embedding(output_dim, embedding_dim)
        self.attention = attention  # allowing for custom attention
        self.gru = nn.GRU((encoder_hidden_dim * 2) + embedding_dim,
                          decoder_hidden_dim)
        self.out = nn.Linear((encoder_hidden_dim * 2) + embedding_dim + decoder_hidden_dim,
                             output_dim)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input, hidden, encoder_outputs, src_mask=None):
        '''
        Decode an encoder's output. 

        B: batch size
        S: source sentence length
        T: target sentence length
        O: output size (target vocab size)
        Enc: encoder hidden dim
        Dec: decoder hidden dim
        Emb: embedding dim

        Inputs:
          - input: a vector of length B giving the most recent decoded token
          - hidden: a (B, Dec) most recent RNN hidden state
          - encoder_outputs: (S, B, 2*Enc) sequence of outputs from encoder RNN

        Outputs:
          - output: logits for next token in the sequence (B, O)
          - hidden: a new (B, Dec) RNN hidden state
          - attentions: (B, S) attention weights for the current token over the source sentence
        '''

        # (B) --> (1, B)
        input = input.unsqueeze(0)

        embedded = self.dropout(self.embedding(input))

        attentions = self.attention(hidden, encoder_outputs, src_mask)

        # (B, S) --> (B, 1, S)
        a = attentions.unsqueeze(1)

        # (S, B, 2*Enc) --> (B, S, 2*Enc)
        encoder_outputs = encoder_outputs.transpose(0, 1)

        # weighted encoder representation
        # (B, 1, S) @ (B, S, 2*Enc) = (B, 1, 2*Enc)
        weighted = torch.bmm(a, encoder_outputs)

        # (B, 1, 2*Enc) --> (1, B, 2*Enc)
        weighted = weighted.transpose(0, 1)

        # concat (1, B, Emb) and (1, B, 2*Enc)
        # results in (1, B, Emb + 2*Enc)
        rnn_input = torch.cat((embedded, weighted), dim=2)

        output, hidden = self.gru(rnn_input, hidden.unsqueeze(0))

        assert (output == hidden).all()

        # get rid of empty leading dimensions
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)

        # concatenate the pieces above
        # (B, Dec), (B, 2*Enc), and (B, Emb)
        # result is (B, Dec + 2*Enc + Emb)
        linear_input = torch.cat((output, weighted, embedded), dim=1)

        # (B, Dec + 2*Enc + Emb) --> (B, O)
        output = self.out(linear_input)

        return output, hidden.squeeze(0), attentions

In [77]:
# class BahdanauDecoder(nn.Module):
#     def __init__(self, output_dim, embedding_dim, encoder_hidden_dim,
#                  decoder_hidden_dim, attention, dropout_p=0.15):
#         super().__init__()
#         self.embedding_dim = embedding_dim
#         self.output_dim = output_dim
#         self.encoder_hidden_dim = encoder_hidden_dim
#         self.decoder_hidden_dim = decoder_hidden_dim
#         self.dropout_p = dropout_p
#         
#         self.embedding = nn.Embedding(output_dim, embedding_dim)
#         self.attention = attention  # allowing for custom attention
#         self.gru = nn.GRU((encoder_hidden_dim * 2) + embedding_dim, decoder_hidden_dim)     # Now only forward direction
#         self.out = nn.Linear((encoder_hidden_dim * 2) + embedding_dim + decoder_hidden_dim, output_dim)
#         self.dropout = nn.Dropout(dropout_p)
#         
#     def forward(self, input, dec_prev_hidden, encoder_outputs, src_mask=None):
#         # (B) --> (1, B)
#         input = input.unsqueeze(0)
#         embedded = self.dropout(self.embedding(input))
#         attentions = self.attention(dec_prev_hidden, encoder_outputs, src_mask)
#         # (B, S) --> (B, 1, S)
#         a = attentions.unsqueeze(1)
#         # (S, B, 2*Enc) --> (B, S, 2*Enc)
#         encoder_outputs = encoder_outputs.transpose(0, 1)
#         # weighted encoder representation
#         # (B, 1, S) @ (B, S, 2*Enc) = (B, 1, 2*Enc)
#         weighted = torch.bmm(a, encoder_outputs)    # enc_out = values here
#         # (B, 1, 2*Enc) --> (1, B, 2*Enc)
#         weighted = weighted.transpose(0, 1)
#         # concat (1, B, Emb) and (1, B, 2*Enc)
#         # results in (1, B, Emb + 2*Enc)
#         rnn_input = torch.cat((embedded, weighted), dim=2)
#         output, dec_current_hidden = self.gru(rnn_input, dec_prev_hidden.unsqueeze(0))
#         
#         assert (output == dec_prev_hidden)
#         
#         # get rid of empty leading dimensions
#         embedded = embedded.squeeze(0)
#         output = output.squeeze(0)
#         weighted = weighted.squeeze(0)
#         
#         # concatenate the pieces above
#         # (B, Dec), (B, 2*Enc), and (B, Emb)
#         # result is (B, Dec + 2*Enc + Emb)
#         # To take the final decision we add two skip connections so that
#         #   we do so w.r.t. dec_cur_out, weighted_enc_cur_out, targ_emb  
#         linear_input = torch.cat((output, weighted, embedded), dim=1)
#         # (B, Dec + 2*Enc + Emb) --> (B, O)
#         output = self.out(linear_input)
#         return output, dec_current_hidden(0), attentions

In [78]:
class BahdanauSeq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()

        self.encoder = encoder.to(device)
        self.decoder = decoder.to(device)
        self.device = device
        self.tgt_vocab_size = decoder.output_dim

    def forward(self, src, tgt, src_mask=None, teacher_forcing_ratio=0.5, return_attentions=False):

        tgt_length, batch_size = tgt.shape

        # store decoder outputs
        outputs = torch.zeros(tgt_length, batch_size, self.tgt_vocab_size).to(self.device)
        # attentions = torch.zeros(tgt_length, batch_size, )

        encoder_outputs, hidden = self.encoder(src)
        hidden = hidden.squeeze(1)  # B, 1, Enc --> B, Enc (if necessary)

        # start with <bos> as the decoder input
        decoder_input = tgt[0, :]
        attentions = []

        for t in range(1, tgt_length):
            decoder_output, hidden, attention = self.decoder(decoder_input, hidden, encoder_outputs, src_mask)
            outputs[t] = decoder_output
            teacher_force = random.random() < teacher_forcing_ratio
            top_token = decoder_output.max(1)[1]
            decoder_input = (tgt[t] if teacher_force else top_token)
            attentions.append(attention.unsqueeze(-1))

        if return_attentions:
            return outputs, torch.cat(attentions, dim=-1)
        else:
            return outputs

In [79]:
enc = BahdanauEncoder(input_dim=len(en_vocab),
                      embedding_dim=ENCODER_EMBEDDING_DIM,
                      encoder_hidden_dim=ENCODER_HIDDEN_SIZE,
                      decoder_hidden_dim=DECODER_HIDDEN_SIZE,
                      dropout_p=0.15)

attn = BahdanauAttentionQKV(DECODER_HIDDEN_SIZE)

dec = BahdanauDecoder(output_dim=len(fr_vocab),
                      embedding_dim=DECODER_EMBEDDING_DIM,
                      encoder_hidden_dim=ENCODER_HIDDEN_SIZE,
                      decoder_hidden_dim=DECODER_HIDDEN_SIZE,
                      attention=attn,
                      dropout_p=0.15)

seq2seq = BahdanauSeq2Seq(enc, dec, device)

In [80]:
class MultipleOptimizer(object):
    def __init__(self, *op):
        self.optimizers = op

    def zero_grad(self):
        for op in self.optimizers:
            op.zero_grad()

    def step(self):
        for op in self.optimizers:
            op.step()

In [81]:
def train(model, iterator, optimizer, loss_fn, device, clip=None):
    model.train()
    if model.device != device:
        model = model.to(device)

    epoch_loss = 0
    with tqdm(total=len(iterator), leave=False) as t:
        for i, (src, tgt) in enumerate(iterator):
            src_mask = (src != PAD_IDX).to(device)
            src = src.to(device)
            tgt = tgt.to(device)

            optimizer.zero_grad()

            output = model(src, tgt, src_mask)

            loss = loss_fn(output[1:].view(-1, output.shape[2]),
                           tgt[1:].view(-1))

            loss.backward()

            if clip is not None:
                nn.utils.clip_grad_norm_(model.parameters(), clip)

            optimizer.step()
            epoch_loss += loss.item()

            avg_loss = epoch_loss / (i + 1)
            t.set_postfix(loss='{:05.3f}'.format(avg_loss),
                          ppl='{:05.3f}'.format(np.exp(avg_loss)))
            t.update()

    return epoch_loss / len(iterator)

In [82]:
def evaluate(model, iterator, loss_fn, device):
    model.eval()
    if model.device != device:
        model = model.to(device)

    epoch_loss = 0
    with torch.no_grad():
        with tqdm(total=len(iterator), leave=False) as t:
            for i, (src, tgt) in enumerate(iterator):
                src_mask = (src != PAD_IDX).to(device)
                src = src.to(device)
                tgt = tgt.to(device)

                output = model(src, tgt, src_mask, teacher_forcing_ratio=0)
                loss = loss_fn(output[1:].view(-1, output.shape[2]),
                               tgt[1:].view(-1))

                epoch_loss += loss.item()

                avg_loss = epoch_loss / (i + 1)
                t.set_postfix(loss='{:05.3f}'.format(avg_loss),
                              ppl='{:05.3f}'.format(np.exp(avg_loss)))
                t.update()

    return epoch_loss / len(iterator)

In [83]:
def count_params(model, return_int=False):
    params = sum([torch.prod(torch.tensor(x.shape)).item() for x in model.parameters() if x.requires_grad])
    if return_int:
        return params
    else:
        print("There are {:,} trainable parameters in this model.".format(params))

# Training Time

In [None]:
count_params(seq2seq)
enc_optim = torch.optim.AdamW(seq2seq.encoder.parameters(), lr=1e-4)
dec_optim = torch.optim.AdamW(seq2seq.decoder.parameters(), lr=1e-4)
optims = MultipleOptimizer(enc_optim, dec_optim)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
N_EPOCHS = 20
CLIP = 10  # clipping value, or None to prevent gradient clipping
EARLY_STOPPING_EPOCHS = 2

if not os.path.exists(SAVE_DIR):
    print(f"Creating directory {SAVE_DIR}")
    os.mkdir(SAVE_DIR)

model_path = os.path.join(SAVE_DIR, 'bahdanau_en_fr.pt')
bahdanau_metrics = {}
best_valid_loss = float("inf")
early_stopping_count = 0
for epoch in tqdm(range(N_EPOCHS), leave=False, desc="Epoch"):
    train_loss = train(seq2seq, train_iter, optims, loss_fn, device, clip=CLIP)
    valid_loss = evaluate(seq2seq, valid_iter, loss_fn, device)

    if valid_loss < best_valid_loss:
        tqdm.write(f"Checkpointing at epoch {epoch + 1}")
        best_valid_loss = valid_loss
        torch.save(seq2seq.state_dict(), model_path)
        early_stopping_count = 0
    else:
        early_stopping_count += 1

    bahdanau_metrics[epoch + 1] = dict(
        train_loss=train_loss,
        train_ppl=np.exp(train_loss),
        valid_loss=valid_loss,
        valid_ppl=np.exp(valid_loss)
    )

    if early_stopping_count == EARLY_STOPPING_EPOCHS:
        tqdm.write(f"Early stopping triggered in epoch {epoch + 1}")
        break

There are 21,449,906 trainable parameters in this model.


Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/6757 [00:00<?, ?it/s]

  0%|          | 0/853 [00:00<?, ?it/s]

Checkpointing at epoch 1


  0%|          | 0/6757 [00:00<?, ?it/s]

In [None]:
seq2seq.load_state_dict(torch.load(model_path, map_location=device))
bahdanau_metrics_df = pd.DataFrame(bahdanau_metrics).T

plt.figure(figsize=(10, 6))
plt.plot(bahdanau_metrics_df['train_loss'], label="Training", color='gray', linestyle='solid', lw=2.5)
plt.plot(bahdanau_metrics_df['valid_loss'], label="Validation", color='gray', linestyle='dashed', lw=2.5)
plt.legend()
plt.title("Bahdanau Attention: Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

plt.figure(figsize=(10, 6))
plt.plot(bahdanau_metrics_df['train_ppl'], label="Training", color='gray', linestyle='solid', lw=2.5)
plt.plot(bahdanau_metrics_df['valid_ppl'], label="Validation", color='gray', linestyle='dashed', lw=2.5)
plt.legend()
plt.title("Bahdanau Attention: Perplexity")
plt.xlabel("Epoch")
plt.ylabel("Perplexity")
plt.show()

# Transformer time
!<img src="transformer_architecture.png" width=500>.

In [None]:
class PositionalEncoding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens in the sequence.
        The positional encodings have the same dimension as the embeddings, so that the two can be summed.
        Here, we use sine and cosine functions of different frequencies.
    .. math::
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoding(d_model)
    """
    def __init__(self, d_model, dropout_p=0.1, max_len=100):
        super().__init__()

        self.dropout = nn.Dropout(dropout_p)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.) / d_model)) # todo change step to 3, what happens?
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)
        # pe.shape -> (l, 1, d)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):
    def __init__(self, input_dim, output_dim, d_model, num_attention_heads,
                 num_encoder_layers, num_decoder_layers, dim_feedforward,
                 max_seq_length, pos_dropout, transformer_dropout):
        super().__init__(self)
        self.d_model = d_model
        self.embed_src = nn.Embedding(input_dim, d_model)
        self.embed_tgt = nn.Embedding(output_dim, d_model)
        self.pos_enc = PositionalEncoding(d_model, pos_dropout, max_seq_length)
        self.transformer = nn.Transformer(d_model, num_attention_heads, num_encoder_layers,
                                          num_decoder_layers, dim_feedforward, transformer_dropout)
        self.output = nn.Linear(d_model, output_dim)
        
    def forward(self, src, tgt,
                src_mask=None, tgt_mask=None,
                src_key_padding_mask=None, tgt_key_padding_mask=None,
                memory_key_padding_mask=None):
        """
        Forward pass for the Transformer model.
        
        The key_padding_masks are square because In self-attention, each token in the sequence attends to every other token, including itself. This requires a square matrix where each element (i, j) represents whether token i should attend to token j.
    
        :param src: Tensor of shape (S, N), where S is the source sequence length and N is the batch size.
                    This tensor contains the indices of the source sequence tokens in the source vocabulary.
        :param tgt: Tensor of shape (T, N), where T is the target sequence length and N is the batch size.
                    This tensor contains the indices of the target sequence tokens in the target vocabulary.
        :param src_mask: Tensor of shape (S, S) or None. The source mask is used to mask out positions in the source sequence.
                         Typically used for preventing attention to future tokens in self-attention.
        :param tgt_mask: Tensor of shape (T, T) or None. The target mask is used to mask out positions in the target sequence.
                         Typically used for preventing attention to future tokens in self-attention.
        :param src_key_padding_mask: Tensor of shape (N, S) or None. The source key padding mask is used to mask out padding tokens
                                     in the source sequences, ensuring they do not affect the attention mechanism.
        :param tgt_key_padding_mask: Tensor of shape (N, T) or None. The target key padding mask is used to mask out padding tokens
                                     in the target sequences, ensuring they do not affect the attention mechanism.
        :param memory_key_padding_mask: Tensor of shape (N, S) or None. The memory key padding mask is used to mask out padding tokens
                                        in the encoder outputs (memory) during the decoding process, ensuring they do not affect the attention mechanism.
        :return: Tensor of shape (T, N, E), where T is the target sequence length, N is the batch size, and E is the embedding dimension.
                 This tensor contains the output logits of the Transformer model for each token in the target sequence.
        """
        src_embedded = self.embed_src(src) * np.sqrt(self.d_model)
        tgt_embedded = self.embed_tgt(tgt) * np.sqrt(self.d_model)
        
        src_embedded = self.pos_enc(src_embedded)
        tgt_embedded = self.pos_enc(tgt_embedded)
        
        output = self.transformer(src_embedded, tgt_embedded,
                                  tgt_mask=tgt_mask,
                                  src_key_padding_mask=src_key_padding_mask,
                                  tgt_key_padding_mask=tgt_key_padding_mask,
                                  memory_key_padding_mask=memory_key_padding_mask)
        
        return self.output(output)

In [None]:
transformer = TransformerModel(input_dim=len(en_vocab), output_dim=len(fr_vocab), d_model=256,
                               num_attention_heads=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048,
                               max_seq_length=32, pos_dropout=0.15, transformer_dropout=0.3
                               )
transformer = transformer.to(device)