In [None]:
import math, copy, os, time, enum, argparse

import matplotlib.pyplot as plt
import seaborn

import torch
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torch.hub import download_url_to_file

from torchtext.data import Dataset, BucketIterator, Field, Example
from torchtext.data.utils import interleave_keys
from torchtext import datasets
from torchtext.data import Example
import spacy

from nltk.translate.bleu_score import corpus_bleu

In [None]:
BASELINE_MODEL_NUMBER_OF_LAYERS = 6
BASELINE_MODEL_DIMENSION = 512
BASELINE_MODEL_NUMBER_OF_HEADS = 8
BASELINE_MODEL_DROPOUT_PROB = 0.1
BASELINE_MODEL_LABEL_SMOOTHING_VALUE = 0.1

CHECKPOINTS_PATH = os.path.join(os.getcwd(), 'models', 'checkpoints')
BINARIES_PATH = os.path.join(os.getcwd(), 'models', 'binaries')
DATA_DIR_PATH = os.path.join(os.getcwd(), 'data')

os.makedirs(CHECKPOINTS_PATH, exist_ok = True)
os.makedirs(BINARIES_PATH, exist_ok = True)
os.makedirs(DATA_DIR_PATH, exist_ok = True)

BOS_TOKEN = '<s>'
EOS_TOKEN = '</s>'
PAD_TOKEN = '<pad'

# PART 1 : Understanding the model.

In [None]:
class Transformer(torch.nn.Module):
    def __init__(self, model_dim, src_vocab_size, tar_vocab_size, n_heads, n_layers, p_dropout, log_att_w = False):
        super().__init__()

        # Embeds source/target token ids into embedding vectors.
        self.src_embedding = Embedding(src_vocab_size, model_dim)
        self.tar_embedding = Embedding(trg_vocab_size, model_dim)

        # Positional embedding.
        self.src_pos_embedding = PositionalEncoding(model_dim, p_dropout)
        self.tar_pos_embedding = PositionalEncoding(model_dim, p_dropout)

        # Embedding get deep-copied multiple times.
        mha = MultiHeadAttention(model_dim, n_heads, p_dropout)
        pwn = PositionwiseFeedForwardNet(moidel_dim, p_dropout)
        encoder_layer = EncoderLayer(model_dim, p_dropout, mha, pwn)
        decoder_layer = DecoderLayer(model_dim, p_dropout, mha, pwn)

        # Encoder/Decoder stacks.
        self.encoder = Encoder(encoder_layer, n_layers)
        self.decoder = Decoder(decoder_layer, n_layers)

        # To convert target token representations into log probability vectors of tar vocabulary size.
        # We use log probability vectors because torch's KLDivLoss expects log probabilities.
        self.decoder_generator = DecoderGenerator(model_dim, tar_vocab_size)
        self.init_params()

    def init_params(self):
        for name, p in self.model.parameters():
            if p_dim() > 1:
                torch.nn.init.xavier_unifrom(p)

    def forward(self, src_token_ids_batch, tar_token_ids_batch, src_mask, tar_mask):
        src_repr_batch = self.encode(src_token_ids_batch, src_mask)
        tar_log_probs = self.decode(tar_token_ids_batch, src_repr_batch, tar_mask, src_mask)
        return tar_log_probs

    def encode(self, src_token_ids_batch, src_mask):
        """

        Args:
            src_token_ids_batch ([B, S, D]): [batch_size, longest_src_token, model_dim] the encoder stack preserves this shape.
            src_mask ([type]): [description]
        """
        src_embeddings_batch = self.src_embedding(src_token_ids_batch) # get embedding vectors for src token ids.
        src_embeddings_batch = self.src_pos_embedding(src_embeddings_batch) # add positional embedding.
        src_repr_batch = self.encoder(src_embeddings_batch, src_mask) # Forward pass through the encoder.
        return src_repr_batch

    def decode(self, tar_token_ids_batch, src_repr_batch, tar_mask, src_mask):
        tar_embeddings_batch = self.tar_embedding(tar_token_ids_batch, src_repr_batch, tar_mask, src_mask) # get embedding vector for tar token ids.
        tar_embeddings_batch = self.tar_pos_embedding(tar_embeddings_batch) # add positional embedding.

        # shape (b, t, d) batch_size, longest tar token sequence length, model_dim
        tar_repr_batch = self.decoder(tar_embedding_batch, src_repr_batch, tar_mask, src_mask)
        # (b, t, v) withg v the vocab size.
        # decoder generator does linear projection + log softmax.
        tar_log_probs = self.decoder_generator(tar_representations_batch)
        # reshape into (b*t, v) to pass in KL-div loss.
        tar_log_probs = tar_log_probs.reshape(-1, tar_log_probs.shape[-1])
        return tar_log_probs

1. tokens get embedded into source sequences.
2. encoder takes a batch of the source sequences.
3. encoder mixes sources sequences through 6 layers of the base transformer via attention.
4. the final output gets consumed by the decoder.

In [None]:
class Encoder(torch.nn.Module):
    def __init__(self, encoder_layer, n_layers):
        super().__init__()
        assert isinstance(encoder_layer, EncoderLayer), f'Expected Encoder layer, got {type(encoder_layer)} instead !'

        # get a list of of the encoder layers.
        self.encoder_layers = get_clones(encoder_layer, n_layers)
        self.norm = torch.nn.LayerNorm(encoder_layer.model_dim)

    def forward(self, src_embeddings_batch, src_mask):
        src_representations_batch = src_embeddings_batch
        # the role of the source mask is to ignore padded token representations in the multi head self attention module.
        for encoder_layer in self.encoder_layers:
            src_representations_batch = encoder_layer(src_representations_batch, src_mask)
        return self.norm(src_representations_batch)

In [None]:
class EncoderLayer(torch.nn.Module):
    def __init__(self, model_dim, p_dropout, multihead_att, pointwise_net):
        super().__init__()
        n_sublayers_encoder = 2
        self.sublayers = get_clones(SublayerLogic(model_dim, p_dropout), n_sublayers_encoder)

        self.multihead_att = multihead_att
        self.pointwise_net = pointwise_net

        self.model_dimension = model_dimension
    
    def forward(self, src_repr_batch, src_mask):
        # define a lambda function that takes src_repr_batch as input to have a uniform inteface for the sublayer logic.
        encoder_self_attention = lambda srb: self.multihead_att(query = srb, key = srb, value = srb, mask = src_mask)
        # self-attention mha sublayer followed by pointwise feed forward sublayer.
        # sublayerLogic takes as input the data and the logic it should execute (attention/feedforward)
        src_repr_batch = self.sublayers[0](src_repr_batch, encoder_self_attention)
        src_repr_batch = self.sublayers[1](src_repr_batch, self.pointwise_net)

        return src_repr_batch


 1. target sequences with embedded tokens.
 2. 6 iterations of mixing via attention while attending to source token representations.
 3. final output sends target token representations into the decoder generator.
 4. target tokens are converted to log probabilities.

The decoder uses causal masking to prevent tokens from looking into the future.

<img src="images/causal_mask.PNG">

In [None]:
class Decoder(torch.nn.Module):
    def __init__(self, decoder_layer, n_layers):
        super().__init__()
        assert isinstance(decoder_layer, DecoderLayer), f'Expected DecoderLayer, got {type(decoder_layer)} !'

        self.encoder_layers = get_clones(decoder_layer, n_layers)
        self.norm = torch.nn.LayerNorm(decoder_layer.model_dimension)
    
    def forward(self, tar_embedding_batch, src_repr_batch, tar_mask, src_mask):
        tar_repr_batch = tar_embedding_batch

        # Forward pass through decoder stack.
        for decoder_layer in self.decoder_layers:
            # target mask masks pad tokens + future tokens.
            tar_repr_batch = decoder_layer(tar_repr_batch, src_repr_batch, tar_mask, src_mask)

        return self.norm(tar_repr_batch)
    
class DecoderLayer(torch.nn.Module):
    def __init__(self, model_dim, p_dropout, multihead_att, pointwise_net):
        super().__init__()
        n_sublayers_decoder = 3
        self.sublayers = get_clones(SublayerLogic(model_dim, p_dropout), n_sublayers_decoder)
        self.tar_multihead_att = copy.deepcopy(multihead_att)
        self.src_multihead_att = copy.depcopy(multihead_att)
        self.pointwise_net = pointwise_net
        self.model_dimension = model_dim
    
    def forward(self, tar_repr_batch, src_repr_batch, tar_mask, src_mask):
        # the inputs that are not passed into lambdas (masks and source representation batches) are cached.
        srb = src_repr_batch
        decoder_tar_self_att = lambda trb: self.tar_multihead_att(query = trb, key = trb, vavlue = trb, mask = tar_mask)
        decoder_src_att = lambda trb: self.src_multihead_att(query = trb, key = srb, value = srb, mask = src_mask)

        # self-attention multihead attention sublayer followed by a source-attending multihead attention and pointwise feed forward net sublayer.
        tar_repr_batch = self.sublayers[0](tar_repr_batch, decoder_tar_self_att)
        tar_repr_batch = self.sublayers[1](tar_repr_batch, decoder_src_att)
        tar_repr_batch = self.sublayers[2](tar_repr_batch, self.pointwise_net)
        return tar_repr_batch

the **decoder generator** :
1. Projects the final decoder token representation. D --> V
2. applies log softmax

In [None]:
class DecoderGenerator(torch.nn.Module):
    def __init__(self, model_dim, vocab_size):
        super().__init__()
        self.linear = torch.nn.Linear(model_dim, vocab_size)
        # linear layer has shape (B, T, V). B batch_size, T max_target_token_sequences, V_target_vocab_size.
        self.log_softmax = torch.nn.LogSoftmax(dim = -1)

    def forward(self, tar_repr_batch):
        return self.log_softmax(self.linear(tar_repr_batch))

# Positional encoding

this is what the positional encoding traditionally looks like : 
<img src="images/pos_encoding.jpg">

In [None]:
class SublayerLogic(torch.nn.Module):
    def __init__(self, model_dim, p_dropout):
        super().__init__()
        self.norm = torch.nn.LayerNorm(model_dim)
        self.dropout = torch.nn.Dropout(p = p_dropout)
        # In the original paper, layer norm is doe after the residual connection but experiments proved to be more effective before.
    
    def forward(self, repr_batch, sublayer_module):
        return repr_batch + self.dropout(sublayer_module(self.norm(repr_batch)))


In [None]:
class PositionwiseFeedForwardNet(torch.nn.Module):
    """position-wise because the feed-forward net is independantly applied to every token's representation.
    i.e same as a nested loop going over the batch size and max token sequence length dimensions then applies the network to the sequence representation.

    Args:
        torch ([type]): [description]
    """
    def __init__(self, model_dim, p_dropout, width_mult = 4):
        super().__init__()
        self.linear1 = torch.nn.Linear(model_dim, width_mult*model_dim)
        self.linear2 = torch.nn.Linear(width_mult*model_dim, model_dim)
    
        # dropout layer not mentionned in the paper but commonly used to avoid overfitting.
        self.dropout = torch.nn.Dropout(p = p_dropout)
        self.relu = torch.nn.ReLU()

        # representations batch : (B, S/T, D) (batch_size, max_token_sequence_length, model_dim)
    
    def forward(self, repr_batch):
        return self.linear2(self.dropout(self.relu(self.linear1(repr_batch))))

In [None]:
class Embedding(torch.nn.Module):
    def __init__(self, vocab_size, model_dim):
        super().__init__()
        self.embeddings_table = torch.nn.Embedding(vocab_size, model_dim)
        self.model_dim = model_dim
    
    def forward(self, token_ids_batch):
        assert token_ids_batch.ndim == 2, f'Expected : (batch_size, max_token_seq_length), got {token_ids_batch.shape}'
        # token_ids_batch has size (B, S/T)
        # final sequence wille be (B, S/T, D) witjh the model dimensions, so every token id has an associated vector.
        embeddings = self.embeddings_table(token_ids_batch)
        # we multiply the embedding weights by the squre root of the model dimension as stated in the paper.
        return embeddings*math.sqrt(self.model_dim)

In [None]:
class PositionEncoding(torch.nn.Module):
    def __init__(self, model_dim, p_dropout, expected_max_seq_length = 5000):
        super().__init__()
        self.dropout = torch.nn.Dropout(p = p_dropout)
        # as suggested in the paper, use sine frequencies to form a geometric progression as position encodings.
        position_id = torch.arange(0, expected_max_seq_length).unsqueeze(1)
        frequencies = torch.pow(10000., -torch.arange(0, model_dim, 2, dtype = torch.float)/model_dim)
        
        positional_encodings_table = torch.zeros(expected_max_seq_length, model_dim)
        positional_encodings_table[0, 0::2] = torch.sin(position_id*frequencies) # sine on even positions.
        positional_encodings_table[:, 1::2] = torch.cos(position_id*frequencies) # cosine on odd positions.

        # register buffer in order to save the positional encodings table inside the state dict even though these are not trainable.
        # So if we don't register them to the buffer, they will not be saved in the state dict.
        self.register_buffer('positional_encodings_table', positional_encodings_table)

    def forward(self, embeddings_batch):
        assert embeddings_batch.ndim == 3 and embeddings_batch.shape[-1] == self.positional_encodings_table.shape[1], f'Expected (batch_size, max_token_sequence_length, model_dimension and got {embeddings_batch.shape}'
        # embeddings_batch.shape (B, S/T, D)
        # transformed into (S/T, D) that ill be broadcasted to (B, S/T, D) before adding it to the embedding.
        positional_encodings = self.encodings_table[:embeddings_batch.shape[1]]
        # then apply dropout to the sum of the positional encodings and token embeddings.
        return self.dropout(embeddings_batch + positional_encodings)

In [None]:
def get_clones(module, n_deep_clones):
    # creates independent modules so that each clones weights can be independantly updated.
    return torch.nn.ModuleList([copy.deepcopy(module) for _ in range(n_deep_copies)])