In [1]:
# For tips on running notebooks in Google Colab, see
# https://pytorch.org/tutorials/beginner/colab
%matplotlib inline

import math
import copy
import os
import time
import enum
import argparse

# Visualization related imports
import matplotlib.pyplot as plt
import seaborn

# Deep learning related imports
import torch
import torch.nn as nn
from torch.optim import Adam

In [2]:
# !pip install -U torchtext
# !conda install -y torchtext==0.8.0 -c pytorch
# !pip install torchtext==0.8.1
# !pip uninstall torchtext -y
# !pip install torchtext

    


# Language Translation with ``nn.Transformer`` and torchtext

This tutorial shows:
    - How to train a translation model from scratch using Transformer.
    - Use torchtext library to access  [Multi30k](http://www.statmt.org/wmt16/multimodal-task.html#task1)_ dataset to train a German to English translation model.


## Data Sourcing and Processing

[torchtext library](https://pytorch.org/text/stable/)_ has utilities for creating datasets that can be easily
iterated through for the purposes of creating a language translation
model. In this example, we show how to use torchtext's inbuilt datasets,
tokenize a raw text sentence, build vocabulary, and numericalize tokens into tensor. We will use
[Multi30k dataset from torchtext library](https://pytorch.org/text/stable/datasets.html#multi30k)_
that yields a pair of source-target raw sentences.

To access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data.




In [3]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from typing import Iterable, List


# We need to modify the URLs for the dataset since the links to the original dataset are broken
# Refer to https://github.com/pytorch/text/issues/1756#issuecomment-1163664163 for more info
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

SRC_LANGUAGE = 'en'
TGT_LANGUAGE = 'de'


# Place-holders
token_transform = {}
vocab_transform = {}




In [4]:
BOS_TOKEN = '<bos>'
EOS_TOKEN = '<eos>'
PAD_TOKEN = "<pad>"

In [5]:
# !pip install -U torchdata
# !pip install -U spacy
# !python -m spacy download en_core_web_sm
# !python -m spacy download de_core_news_sm

Create source and target language tokenizer. Make sure to install the dependencies.

```python
pip install -U torchdata
pip install -U spacy
python -m spacy download en_core_web_sm
python -m spacy download de_core_news_sm
```


In [6]:
# !pip install -U portalocker
import portalocker

In [7]:
# token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
# token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')


# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])

# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    # Training data Iterator
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

# Set ``UNK_IDX`` as the default index. This index is returned when the token is not found.
# If not set, it throws ``RuntimeError`` when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
  vocab_transform[ln].set_default_index(UNK_IDX)

## Seq2Seq Network using Transformer

Transformer is a Seq2Seq model introduced in [“Attention is all you
need”](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)_
paper for solving machine translation tasks.
Below, we will create a Seq2Seq network that uses Transformer. The network
consists of three parts. First part is the embedding layer. This layer converts tensor of input indices
into corresponding tensor of input embeddings. These embedding are further augmented with positional
encodings to provide position information of input tokens to the model. The second part is the
actual [Transformer](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html)_ model.
Finally, the output of the Transformer model is passed through linear layer
that gives unnormalized probabilities for each token in the target language.




In [8]:
# torch.__version__

In [9]:
# print(translate(transformer, "Eine Gruppe von Menschen steht vor einem Iglu ."))

In [10]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')



In [11]:
class Transformer(nn.Module):

    def __init__(self, model_dimension, src_vocab_size, trg_vocab_size, number_of_heads, number_of_layers, dropout_probability, log_attention_weights=False):
        super().__init__()

        # Embeds source/target token ids into embedding vectors
        self.src_embedding = Embedding(src_vocab_size, model_dimension)
        self.trg_embedding = Embedding(trg_vocab_size, model_dimension)

        # Adds positional information to source/target token's embedding vector
        # (otherwise we'd lose the positional information which is important in human languages)
        self.src_pos_embedding = PositionalEncoding(model_dimension, dropout_probability)
        self.trg_pos_embedding = PositionalEncoding(model_dimension, dropout_probability)

        # All of these will get deep-copied multiple times internally
        if attn_type == 'qkv':
            mha = MultiHeadedAttention(model_dimension, number_of_heads, dropout_probability, log_attention_weights)
        else:
            mha = MultiHeadedAttention_KV(model_dimension, number_of_heads, dropout_probability, log_attention_weights)
        pwn = PositionwiseFeedForwardNet(model_dimension, dropout_probability)
        encoder_layer = EncoderLayer(model_dimension, dropout_probability, mha, pwn)
        decoder_layer = DecoderLayer(model_dimension, dropout_probability, mha, pwn)

        # Encoder and Decoder stacks
        self.encoder = Encoder(encoder_layer, number_of_layers)
        self.decoder = Decoder(decoder_layer, number_of_layers)

        # Converts final target token representations into log probability vectors of trg_vocab_size dimensionality
        # Why log? -> PyTorch's nn.KLDivLoss expects log probabilities
        self.decoder_generator = DecoderGenerator(model_dimension, trg_vocab_size)
        
        self.init_params()

    # This part wasn't mentioned in the paper, but it's super important!
    def init_params(self):
        # I tested both PyTorch's default initialization and this, and xavier has tremendous impact! I didn't expect
        # that the model's perf, with normalization layers, is so dependent on the choice of weight initialization.
        for name, p in self.named_parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src_token_ids_batch, trg_token_ids_batch, src_mask, trg_mask):
        src_representations_batch = self.encode(src_token_ids_batch, src_mask)
        trg_log_probs = self.decode(trg_token_ids_batch, src_representations_batch, trg_mask, src_mask)
        
        return trg_log_probs

    # Modularize into encode/decode functions for optimizing the decoding/translation process (we'll get to it later)
    def encode(self, src_token_ids_batch, src_mask):
        # Shape = (B, S, D) , where B - batch size, S - longest source token-sequence length and D - model dimension
        # The whole encoder stack perserves this shape
        src_embeddings_batch = self.src_embedding(src_token_ids_batch)  # get embedding vectors for src token ids
        src_embeddings_batch = self.src_pos_embedding(src_embeddings_batch)  # add positional embedding
        src_representations_batch = self.encoder(src_embeddings_batch, src_mask)  # forward pass through the encoder

        return src_representations_batch

    def decode(self, trg_token_ids_batch, src_representations_batch, trg_mask, src_mask):
        trg_embeddings_batch = self.trg_embedding(trg_token_ids_batch)  # get embedding vectors for trg token ids
        trg_embeddings_batch = self.trg_pos_embedding(trg_embeddings_batch)  # add positional embedding
        
        # Shape (B, T, D), where B - batch size, T - longest target token-sequence length and D - model dimension
        trg_representations_batch = self.decoder(trg_embeddings_batch, src_representations_batch, trg_mask, src_mask)

        # After this line we'll have a shape (B, T, V), where V - target vocab size, 
        # decoder generator does a simple linear projection followed by log softmax
        trg_log_probs = self.decoder_generator(trg_representations_batch)

        # Reshape into (B*T, V) as that's a suitable format for passing it into KL div loss
        trg_log_probs = trg_log_probs.reshape(-1, trg_log_probs.shape[-1])

        return trg_log_probs  # the reason I use log here is that PyTorch's nn.KLDivLoss expects log probabilities
    
    
    
    

In [12]:
class Encoder(nn.Module):

    def __init__(self, encoder_layer, number_of_layers):
        super().__init__()
        assert isinstance(encoder_layer, EncoderLayer), f'Expected EncoderLayer got {type(encoder_layer)}.'

        # Get a list of 'number_of_layers' independent encoder layers
        self.encoder_layers = get_clones(encoder_layer, number_of_layers)  
        self.norm = nn.LayerNorm(encoder_layer.model_dimension)

    def forward(self, src_embeddings_batch, src_mask):
        # Just update the naming so as to reflect the semantics of what this var will become (the initial encoder layer
        # has embedding vectors as input but later layers have richer token representations)
        src_representations_batch = src_embeddings_batch

        # Forward pass through the encoder stack
        for encoder_layer in self.encoder_layers:
            # src_mask's role is to mask/ignore padded token representations in the multi-headed self-attention module
            src_representations_batch = encoder_layer(src_representations_batch, src_mask)

        # Not mentioned explicitly in the paper 
        # (a consequence of using LayerNorm before instead of after the SublayerLogic module)
        return self.norm(src_representations_batch)


class EncoderLayer(nn.Module):

    def __init__(self, model_dimension, dropout_probability, multi_headed_attention, pointwise_net):
        super().__init__()
        num_of_sublayers_encoder = 2
        self.sublayers = get_clones(SublayerLogic(model_dimension, dropout_probability), num_of_sublayers_encoder)

        self.multi_headed_attention = multi_headed_attention
        self.pointwise_net = pointwise_net

        self.model_dimension = model_dimension

    def forward(self, src_representations_batch, src_mask):
        # Define an anonymous (lambda) function which only takes src_representations_batch (srb) as input,
        # this way we have a uniform interface for the sublayer logic.
        encoder_self_attention = lambda srb: self.multi_headed_attention(query=srb, key=srb, value=srb, mask=src_mask)

        # Self-attention MHA sublayer followed by point-wise feed forward net sublayer
        # SublayerLogic takes as the input the data and the logic it should execute (attention/feedforward)
        src_representations_batch = self.sublayers[0](src_representations_batch, encoder_self_attention)
        src_representations_batch = self.sublayers[1](src_representations_batch, self.pointwise_net)

        return src_representations_batch

In [13]:
class Decoder(nn.Module):

    def __init__(self, decoder_layer, number_of_layers):
        super().__init__()
        assert isinstance(decoder_layer, DecoderLayer), f'Expected DecoderLayer got {type(decoder_layer)}.'

        self.decoder_layers = get_clones(decoder_layer, number_of_layers)
        self.norm = nn.LayerNorm(decoder_layer.model_dimension)

    def forward(self, trg_embeddings_batch, src_representations_batch, trg_mask, src_mask):
        # Just update the naming so as to reflect the semantics of what this var will become
        trg_representations_batch = trg_embeddings_batch

        # Forward pass through the decoder stack
        for decoder_layer in self.decoder_layers:
            # Target mask masks pad tokens as well as future tokens (current target token can't look forward)
            trg_representations_batch = decoder_layer(trg_representations_batch, src_representations_batch, trg_mask, src_mask)

        # Not mentioned explicitly in the paper 
        # (a consequence of using LayerNorm before instead of after the SublayerLogic module)
        return self.norm(trg_representations_batch)


class DecoderLayer(nn.Module):

    def __init__(self, model_dimension, dropout_probability, multi_headed_attention, pointwise_net):
        super().__init__()
        num_of_sublayers_decoder = 3
        self.sublayers = get_clones(SublayerLogic(model_dimension, dropout_probability), num_of_sublayers_decoder)

        self.trg_multi_headed_attention = copy.deepcopy(multi_headed_attention)
        self.src_multi_headed_attention = copy.deepcopy(multi_headed_attention)
        self.pointwise_net = pointwise_net

        self.model_dimension = model_dimension

    def forward(self, trg_representations_batch, src_representations_batch, trg_mask, src_mask):
        # Define an anonymous (lambda) function which only takes trg_representations_batch (trb - funny name I know)
        # as input - this way we have a uniform interface for the sublayer logic.
        # The inputs which are not passed into lambdas (masks/srb) are "cached" here that's why the thing works.
        srb = src_representations_batch  # simple/short alias
        decoder_trg_self_attention = lambda trb: self.trg_multi_headed_attention(query=trb, key=trb, value=trb, mask=trg_mask)
        decoder_src_attention = lambda trb: self.src_multi_headed_attention(query=trb, key=srb, value=srb, mask=src_mask)

        # Self-attention MHA sublayer followed by a source-attending MHA and point-wise feed forward net sublayer
        trg_representations_batch = self.sublayers[0](trg_representations_batch, decoder_trg_self_attention)
        trg_representations_batch = self.sublayers[1](trg_representations_batch, decoder_src_attention)
        trg_representations_batch = self.sublayers[2](trg_representations_batch, self.pointwise_net)

        return trg_representations_batch

In [14]:
class DecoderGenerator(nn.Module):
    def __init__(self, model_dimension, vocab_size):
        super().__init__()

        self.linear = nn.Linear(model_dimension, vocab_size)

        # -1 stands for apply the log-softmax along the last dimension i.e. over the vocab dimension as the output from
        # the linear layer has shape (B, T, V), B - batch size, T - max target token-sequence, V - target vocab size
        
        # again using log softmax as PyTorch's nn.KLDivLoss expects log probabilities (just a technical detail)
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, trg_representations_batch):
        # Project from D (model dimension) into V (target vocab size) and apply the log softmax along V dimension
        return self.log_softmax(self.linear(trg_representations_batch))

In [15]:
class SublayerLogic(nn.Module):
    def __init__(self, model_dimension, dropout_probability):
        super().__init__()
        self.norm = nn.LayerNorm(model_dimension)
        self.dropout = nn.Dropout(p=dropout_probability)

    # Note: the original paper had LayerNorm AFTER the residual connection and addition operation
    # multiple experiments I found showed that it's more effective to do it BEFORE, how did they figure out which one is
    # better? Experiments! There is a similar thing in DCGAN and elsewhere.
    def forward(self, representations_batch, sublayer_module):
        # Residual connection between input and sublayer output, details: Page 7, Chapter 5.4 "Regularization",
        return representations_batch + self.dropout(sublayer_module(self.norm(representations_batch)))


class PositionwiseFeedForwardNet(nn.Module):
    """
        It's position-wise because this feed forward net will be independently applied to every token's representation.

        This net will basically be applied independently to every token's representation (you can think of it as if
        there was a nested for-loop going over the batch size and max token sequence length dimensions
        and applied this network to token representations. PyTorch does this auto-magically behind the scenes.

    """
    def __init__(self, model_dimension, dropout_probability, width_mult=4):
        super().__init__()

        self.linear1 = nn.Linear(model_dimension, width_mult * model_dimension)
        self.linear2 = nn.Linear(width_mult * model_dimension, model_dimension)

        # This dropout layer is not explicitly mentioned in the paper but it's common to use to avoid over-fitting
        self.dropout = nn.Dropout(p=dropout_probability)
        self.relu = nn.ReLU()

    # Representations_batch's shape = (B - batch size, S/T - max token sequence length, D- model dimension).
    def forward(self, representations_batch):
        return self.linear2(self.dropout(self.relu(self.linear1(representations_batch))))


#
# Input modules
#


class Embedding(nn.Module):

    def __init__(self, vocab_size, model_dimension):
        super().__init__()
        self.embeddings_table = nn.Embedding(vocab_size, model_dimension)
        self.model_dimension = model_dimension

    def forward(self, token_ids_batch):
        assert token_ids_batch.ndim == 2, f'Expected: (batch size, max token sequence length), got {token_ids_batch.shape}'

        # token_ids_batch has shape (B, S/T), where B - batch size, S/T max src/trg token-sequence length
        # Final shape will be (B, S/T, D) where D is the model dimension, every token id has associated vector
        embeddings = self.embeddings_table(token_ids_batch)

        # (stated in the paper) multiply the embedding weights by the square root of model dimension
        # Page 5, Chapter 3.4 "Embeddings and Softmax"
        return embeddings * math.sqrt(self.model_dimension)


class PositionalEncoding(nn.Module):

    def __init__(self, model_dimension, dropout_probability, expected_max_sequence_length=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_probability)

        # (stated in the paper) Use sine functions whose frequencies form a geometric progression as position encodings,
        # (learning encodings will also work so feel free to change it!). Page 6, Chapter 3.5 "Positional Encoding"
        position_id = torch.arange(0, expected_max_sequence_length).unsqueeze(1)
        frequencies = torch.pow(10000., -torch.arange(0, model_dimension, 2, dtype=torch.float) / model_dimension)

        # Checkout playground.py for visualization of how these look like (it's super simple don't get scared)
        positional_encodings_table = torch.zeros(expected_max_sequence_length, model_dimension)
        positional_encodings_table[:, 0::2] = torch.sin(position_id * frequencies)  # sine on even positions
        positional_encodings_table[:, 1::2] = torch.cos(position_id * frequencies)  # cosine on odd positions

        # Register buffer because we want to save the positional encodings table inside state_dict even though
        # these are not trainable (not model's parameters) so they otherwise would be excluded from the state_dict
        self.register_buffer('positional_encodings_table', positional_encodings_table)

    def forward(self, embeddings_batch):
        assert embeddings_batch.ndim == 3 and embeddings_batch.shape[-1] == self.positional_encodings_table.shape[1], \
            f'Expected (batch size, max token sequence length, model dimension) got {embeddings_batch.shape}'

        # embedding_batch's shape = (B, S/T, D), where S/T max src/trg token-sequence length, D - model dimension
        # So here we get (S/T, D) shape which will get broad-casted to (B, S/T, D) when we try and add it to embeddings
        positional_encodings = self.positional_encodings_table[:embeddings_batch.shape[1]]

        # (stated in the paper) Applying dropout to the sum of positional encodings and token embeddings
        # Page 7, Chapter 5.4 "Regularization"
        return self.dropout(embeddings_batch + positional_encodings)


#
# Helper model functions
#


def get_clones(module, num_of_deep_copies):
    # Create deep copies so that we can tweak each module's weights independently
    return nn.ModuleList([copy.deepcopy(module) for _ in range(num_of_deep_copies)])

In [16]:
class MultiHeadedAttention(nn.Module):
    """
        This module already exists in PyTorch. The reason I implemented it here from scratch is that
        PyTorch implementation is super complicated as they made it as generic/robust as possible whereas
        on the other hand I only want to support a limited use-case.

        Also this is arguable the most important architectural component in the Transformer model.

        Additional note:
        This is conceptually super easy stuff. It's just that matrix implementation makes things a bit less intuitive.
        If you take your time and go through the code and figure out all of the dimensions + write stuff down on paper
        you'll understand everything. Also do check out this amazing blog for conceptual understanding:

        https://jalammar.github.io/illustrated-transformer/

        Optimization notes:

        qkv_nets could be replaced by Parameter(torch.empty(3 * model_dimension, model_dimension)) and one more matrix
        for bias, which would make the implementation a bit more optimized. For the sake of easier understanding though,
        I'm doing it like this - using 3 "feed forward nets" (without activation/identity hence the quotation marks).
        Conceptually both implementations are the same.

        PyTorch's query/key/value are of different shape namely (max token sequence length, batch size, model dimension)
        whereas I'm using (batch size, max token sequence length, model dimension) because it's easier to understand
        and consistent with computer vision apps (batch dimension is always first followed by the number of channels (C)
        and image's spatial dimensions height (H) and width (W) -> (B, C, H, W).

        This has an important optimization implication, they can reshape their matrix into (B*NH, S/T, HD)
        (where B - batch size, S/T - max src/trg sequence length, NH - number of heads, HD - head dimension)
        in a single step and I can only get to (B, NH, S/T, HD) in single step
        (I could call contiguous() followed by view but that's expensive as it would incur additional matrix copy)

    """

    def __init__(self, model_dimension, number_of_heads, dropout_probability, log_attention_weights):
        super().__init__()
        assert model_dimension % number_of_heads == 0, f'Model dimension must be divisible by the number of heads.'

        self.head_dimension = int(model_dimension / number_of_heads)
        self.number_of_heads = number_of_heads

        self.qkv_nets = get_clones(nn.Linear(model_dimension, model_dimension), 3)  # identity activation hence "nets"
        self.out_projection_net = nn.Linear(model_dimension, model_dimension)

        self.attention_dropout = nn.Dropout(p=dropout_probability)  # no pun intended, not explicitly mentioned in paper
        self.softmax = nn.Softmax(dim=-1)  # -1 stands for apply the softmax along the last dimension

        self.log_attention_weights = log_attention_weights  # should we log attention weights
        self.attention_weights = None  # for visualization purposes, I cache the weights here (translation_script.py)
        

    def attention(self, query, key, value, mask):
        # Step 1: Scaled dot-product attention, Page 4, Chapter 3.2.1 "Scaled Dot-Product Attention"
        # Notation: B - batch size, S/T max src/trg token-sequence length, NH - number of heads, HD - head dimension
        # query/key/value shape = (B, NH, S/T, HD), scores shape = (B, NH, S, S), (B, NH, T, T) or (B, NH, T, S)
        # scores have different shapes as MHA is used in 3 contexts, self attention for src/trg and source attending MHA
        
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dimension)
        

        # Step 2: Optionally mask tokens whose representations we want to ignore by setting a big negative number
        # to locations corresponding to those tokens (force softmax to output 0 probability on those locations).
        # mask shape = (B, 1, 1, S) or (B, 1, T, T) will get broad-casted (copied) as needed to match scores shape

        # if mask is not None:
        scores.masked_fill_(mask == torch.tensor(False), float("-inf"))

        # Step 3: Calculate the attention weights - how much should we attend to surrounding token representations
        attention_weights = self.softmax(scores)
        
        # Step 4: Not defined in the original paper apply dropout to attention weights as well
        attention_weights = self.attention_dropout(attention_weights)

        # Step 5: based on attention weights calculate new token representations
        # attention_weights shape = (B, NH, S, S)/(B, NH, T, T) or (B, NH, T, S), value shape = (B, NH, S/T, HD)
        # Final shape (B, NH, S, HD) for source MHAs or (B, NH, T, HD) target MHAs (again MHAs are used in 3 contexts)
        intermediate_token_representations = torch.matmul(attention_weights, value)

        return intermediate_token_representations, attention_weights  # attention weights for visualization purposes

    def forward(self, query, key, value, mask):
        batch_size = query.shape[0]
        
        # Step 1: Input linear projection
        # Notation: B - batch size, NH - number of heads, S/T - max src/trg token-sequence length, HD - head dimension
        # Shape goes from (B, S/T, NH*HD) over (B, S/T, NH, HD) to (B, NH, S/T, HD) (NH*HD=D where D is model dimension)
        query, key, value = [net(x).view(batch_size, -1, self.number_of_heads, self.head_dimension).transpose(1, 2)
                             for net, x in zip(self.qkv_nets, (query, key, value))]

        # Step 2: Apply attention - compare query with key and use that to combine values (see the function for details)
        intermediate_token_representations, attention_weights = self.attention(query, key, value, mask)

        # Potentially, for visualization purposes, log the attention weights, turn off during training though!
        # I had memory problems when I leave this on by default
        if self.log_attention_weights:
            self.attention_weights = attention_weights

        # Step 3: Reshape from (B, NH, S/T, HD) over (B, S/T, NH, HD) (via transpose) into (B, S/T, NHxHD) which is
        # the same shape as in the beginning of this forward function i.e. input to MHA (multi-head attention) module
        reshaped = intermediate_token_representations.transpose(1, 2).reshape(batch_size, -1, self.number_of_heads * self.head_dimension)

        # Step 4: Output linear projection
        token_representations = self.out_projection_net(reshaped)

        return token_representations

In [17]:
# !pip install positional-encodings
from positional_encodings.torch_encodings import PositionalEncoding2D

p_enc_2d = PositionalEncoding2D(3)
y = torch.zeros((1,200,200,3))
print(p_enc_2d(y).shape) # (1, 6, 2, 8)
yy = p_enc_2d(y)


torch.Size([1, 200, 200, 3])


In [18]:
# yy

In [19]:
# plt.imshow(yy[0,:5,:5]); plt.show

In [20]:
# yy[0,:5,:5]
# generate pos embeddings offline

pos_dim = 10

p_enc_2d = PositionalEncoding2D(pos_dim)
pos_embeddings = {}
# for i in range(1,200):
#     for j in range(2,100):    
z = torch.zeros(1,300,300,pos_dim)
# pos_embeddings[(i,j)] = p_enc_2d(z[:,:i,:j,:].to(DEVICE))
pos_embeddings = p_enc_2d(z.to(DEVICE))
    

In [21]:
pos_embeddings.shape

torch.Size([1, 300, 300, 10])

In [22]:
class MultiHeadedAttention_KV(nn.Module):

    def __init__(self, model_dimension, number_of_heads, dropout_probability, log_attention_weights):
        super().__init__()
        assert model_dimension % number_of_heads == 0, f'Model dimension must be divisible by the number of heads.'

        self.head_dimension = int(model_dimension / number_of_heads)
        self.number_of_heads = number_of_heads

        self.kv_nets = get_clones(nn.Linear(model_dimension, model_dimension), 3)  # identity activation hence "nets"
        
        # self.kk = nn.Linear(2*self.head_dimension, 1, bias=True)          
        self.map_pos = nn.Linear(pos_dim, 1, bias=True)          
        
        self.out_projection_net = nn.Linear(model_dimension, model_dimension)

        self.attention_dropout = nn.Dropout(p=dropout_probability)  # no pun intended, not explicitly mentioned in paper
        self.softmax = nn.Softmax(dim=-1)  # -1 stands for apply the softmax along the last dimension

        self.log_attention_weights = log_attention_weights  # should we log attention weights
        self.attention_weights = None  # for visualization purposes, I cache the weights here (translation_script.py)
        
        
    def attention(self, query, key, value, mask):
        
        b, h, n, _ = query.shape
        b, h, m, _ = key.shape
        
        tmp = key if n==m else query
        scores = torch.matmul(tmp, key.transpose(-2, -1)) / math.sqrt(self.head_dimension)

        if attn_type == 'kv':
            pos = pos_embeddings[:,:n,:m,:] # 1, m, m, ? 
            scores = scores.unsqueeze(-1) + pos.unsqueeze(0)
            scores = self.map_pos(scores).squeeze(-1)

        
        # Step 2: Optionally mask tokens whose representations we want to ignore by setting a big negative number
        # to locations corresponding to those tokens (force softmax to output 0 probability on those locations).
        # mask shape = (B, 1, 1, S) or (B, 1, T, T) will get broad-casted (copied) as needed to match scores shape

        # if mask is not None:
        scores.masked_fill_(mask == torch.tensor(False), float("-inf"))

        # Step 3: Calculate the attention weights - how much should we attend to surrounding token representations
        attention_weights = self.softmax(scores)
        
        # Step 4: Not defined in the original paper apply dropout to attention weights as well
        attention_weights = self.attention_dropout(attention_weights)

        # Step 5: based on attention weights calculate new token representations
        # attention_weights shape = (B, NH, S, S)/(B, NH, T, T) or (B, NH, T, S), value shape = (B, NH, S/T, HD)
        # Final shape (B, NH, S, HD) for source MHAs or (B, NH, T, HD) target MHAs (again MHAs are used in 3 contexts)
        intermediate_token_representations = torch.matmul(attention_weights, value)

        return intermediate_token_representations, attention_weights  # attention weights for visualization purposes

    def forward(self, query, key, value, mask):
        batch_size = query.shape[0]
        
        # Step 1: Input linear projection
        # Notation: B - batch size, NH - number of heads, S/T - max src/trg token-sequence length, HD - head dimension
        # Shape goes from (B, S/T, NH*HD) over (B, S/T, NH, HD) to (B, NH, S/T, HD) (NH*HD=D where D is model dimension)
        query, key, value = [net(x).view(batch_size, -1, self.number_of_heads, self.head_dimension).transpose(1, 2)
                             for net, x in zip([self.kv_nets[0], self.kv_nets[1], self.kv_nets[2]],  (query, key, value))]


        
        
        # Step 2: Apply attention - compare query with key and use that to combine values (see the function for details)
        intermediate_token_representations, attention_weights = self.attention(query, key, value, mask)

        # Potentially, for visualization purposes, log the attention weights, turn off during training though!
        # I had memory problems when I leave this on by default
        if self.log_attention_weights:
            self.attention_weights = attention_weights

        # Step 3: Reshape from (B, NH, S/T, HD) over (B, S/T, NH, HD) (via transpose) into (B, S/T, NHxHD) which is
        # the same shape as in the beginning of this forward function i.e. input to MHA (multi-head attention) module
        reshaped = intermediate_token_representations.transpose(1, 2).reshape(batch_size, -1, self.number_of_heads * self.head_dimension)

        # Step 4: Output linear projection
        token_representations = self.out_projection_net(reshaped)

        return token_representations

In [23]:
from nltk.translate.bleu_score import corpus_bleu

In [24]:
# Calculate the BLEU-4 score
def calculate_bleu_score(transformer, token_ids_loader, trg_field_processor):
    with torch.no_grad():
        pad_token_id = PAD_IDX #trg_field_processor.vocab.stoi[PAD_TOKEN]

        gt_sentences_corpus = []
        predicted_sentences_corpus = []

        ts = time.time()
        for batch_idx, token_ids_batch in enumerate(token_ids_loader):
            # if batch_idx > 30: break # ali!!!
            
            # import pdb; pdb.set_trace()
            # src_token_ids_batch, trg_token_ids_batch = token_ids_batch.src, token_ids_batch.trg
            src_token_ids_batch, trg_token_ids_batch = token_ids_batch#.src, token_ids_batch.trg
            src_token_ids_batch, trg_token_ids_batch =  src_token_ids_batch.to(DEVICE), trg_token_ids_batch.to(DEVICE)
            src_token_ids_batch, trg_token_ids_batch = src_token_ids_batch.transpose(0,1), trg_token_ids_batch.transpose(0,1)              
            
            
            # if batch_idx % 10 == 0:
            #     print(f'batch={batch_idx}, time elapsed = {time.time()-ts} seconds.')

            # Optimization - compute the source token representations only once
            # import pdb; pdb.set_trace()
            # try:
            src_token_ids_batch = src_token_ids_batch.to(dtype=torch.int64)
            src_mask, _ = get_masks_and_count_tokens_src(src_token_ids_batch, pad_token_id)
            src_representations_batch = transformer.encode(src_token_ids_batch, src_mask)
            # except Exception as e:
            # print(src_token_ids_batch)
            # print(e)

            predicted_sentences = greedy_decoding(transformer, src_representations_batch, src_mask, trg_field_processor)
            predicted_sentences_corpus.extend(predicted_sentences)  # add them to the corpus of translations

            # Get the token and not id version of GT (ground-truth) sentences
            trg_token_ids_batch = trg_token_ids_batch.cpu().numpy()
            # for target_sentence_ids in trg_token_ids_batch:
            #     target_sentence_tokens = [trg_field_processor.vocab.itos[id] for id in target_sentence_ids if id != pad_token_id]
            #     gt_sentences_corpus.append([target_sentence_tokens])  # add them to the corpus of GT translations
            for target_sentence_ids in trg_token_ids_batch:
                target_sentence_tokens = [trg_field_processor.lookup_tokens([id])[0] for id in target_sentence_ids if id != pad_token_id]
                gt_sentences_corpus.append(target_sentence_tokens)  # add them to the corpus of GT translations
        
        try:
            # import pdb; pdb.set_trace()
            # hh = [i[0] for i in gt_sentences_corpus]
            hh = [[i] for i in gt_sentences_corpus]
            bleu_score = corpus_bleu(hh, predicted_sentences_corpus)
            # bleu_score = corpus_bleu(gt_sentences_corpus, predicted_sentences_corpus)
            print(f'BLEU-4 corpus score = {bleu_score}, corpus length = {len(gt_sentences_corpus)}, time elapsed = {time.time()-ts} seconds.')
        except:
            print('ERROR in BLEU!!')
            bleu_score = 0
        return bleu_score

In [25]:
def greedy_decoding(baseline_transformer, src_representations_batch, src_mask, trg_field_processor, max_target_tokens=100):
    """
    Supports batch (decode multiple source sentences) greedy decoding.

    Decoding could be further optimized to cache old token activations because they can't look ahead and so
    adding a newly predicted token won't change old token's activations.

    Example: we input <s> and do a forward pass. We get intermediate activations for <s> and at the output at position
    0, after the doing linear layer we get e.g. token <I>. Now we input <s>,<I> but <s>'s activations will remain
    the same. Similarly say we now got <am> at output position 1, in the next step we input <s>,<I>,<am> and so <I>'s
    activations will remain the same as it only looks at/attends to itself and to <s> and so forth.

    """

    device = next(baseline_transformer.parameters()).device
    # pad_token_id = trg_field_processor.vocab.stoi[PAD_TOKEN]
    pad_token_id = PAD_IDX

    
    
    # BOS_IDX = BOS_TOKEN
    target_sentences_tokens = [[BOS_TOKEN] for _ in range(src_representations_batch.shape[0])]
    # trg_token_ids_batch = torch.tensor([[trg_field_processor.vocab.stoi[tokens[0]]] for tokens in target_sentences_tokens], device=device)
    
    trg_token_ids_batch = torch.tensor([[trg_field_processor[tokens[0]]] for tokens in target_sentences_tokens], device=device)    

    # Set to true for a particular target sentence once it reaches the EOS (end-of-sentence) token
    is_decoded = [False] * src_representations_batch.shape[0]

    # import pdb; pdb.set_trace()
    while True:
        trg_mask, _ = get_masks_and_count_tokens_trg(trg_token_ids_batch, pad_token_id)
        # Shape = (B*T, V) where T is the current token-sequence length and V target vocab size
        predicted_log_distributions = baseline_transformer.decode(trg_token_ids_batch, src_representations_batch, trg_mask, src_mask)

        # Extract only the indices of last token for every target sentence (we take every T-th token)
        num_of_trg_tokens = len(target_sentences_tokens[0])
        predicted_log_distributions = predicted_log_distributions[num_of_trg_tokens-1::num_of_trg_tokens]

        # This is the "greedy" part of the greedy decoding:
        # We find indices of the highest probability target tokens and discard every other possibility
        most_probable_last_token_indices = torch.argmax(predicted_log_distributions, dim=-1).cpu().numpy()

        # Find target tokens associated with these indices
        # predicted_words = [trg_field_processor.vocab.itos[index] for index in most_probable_last_token_indices]
        predicted_words = [trg_field_processor.lookup_tokens([index])[0] for index in most_probable_last_token_indices]
        
        for idx, predicted_word in enumerate(predicted_words):
            target_sentences_tokens[idx].append(predicted_word)

            if predicted_word == EOS_TOKEN:  # once we find EOS token for a particular sentence we flag it
                is_decoded[idx] = True

        if all(is_decoded) or num_of_trg_tokens == max_target_tokens:
            break

        # Prepare the input for the next iteration (merge old token ids with the new column of most probable token ids)
        trg_token_ids_batch = torch.cat((trg_token_ids_batch, torch.unsqueeze(torch.tensor(most_probable_last_token_indices, device=device), 1)), 1)

    # Post process the sentences - remove everything after the EOS token
    target_sentences_tokens_post = []
    for target_sentence_tokens in target_sentences_tokens:
        try:
            target_index = target_sentence_tokens.index(EOS_TOKEN) + 1
        except:
            target_index = None

        target_sentence_tokens = target_sentence_tokens[:target_index]
        target_sentences_tokens_post.append(target_sentence_tokens)

    return target_sentences_tokens_post

In [26]:
# bleu_score = calculate_bleu_score(transformer, test_dataloader, vocab_transform[TGT_LANGUAGE])
# print(bleu_score)

During training, we need a subsequent word mask that will prevent the model from looking into
the future words when making predictions. We will also need masks to hide
source and target padding tokens. Below, let's define a function that will take care of both.




In [27]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX)#.transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX)#.transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask



In [28]:
class LabelSmoothingDistribution(nn.Module):
    """
        Instead of one-hot target distribution set the target word's probability to "confidence_value" (usually 0.9)
        and distribute the rest of the "smoothing_value" mass (usually 0.1) over the rest of the vocab.

        Check out playground.py for visualization of how the smooth target distribution looks like compared to one-hot.
    """

    def __init__(self, smoothing_value, pad_token_id, trg_vocab_size, device):
        assert 0.0 <= smoothing_value <= 1.0

        super(LabelSmoothingDistribution, self).__init__()

        self.confidence_value = 1.0 - smoothing_value
        self.smoothing_value = smoothing_value

        self.pad_token_id = pad_token_id
        self.trg_vocab_size = trg_vocab_size
        self.device = device

    def forward(self, trg_token_ids_batch):

        batch_size = trg_token_ids_batch.shape[0]
        smooth_target_distributions = torch.zeros((batch_size, self.trg_vocab_size), device=self.device)

        # -2 because we are not distributing the smoothing mass over the pad token index and over the ground truth index
        # those 2 values will be overwritten by the following 2 lines with confidence_value and 0 (for pad token index)
        smooth_target_distributions.fill_(self.smoothing_value / (self.trg_vocab_size - 2))

        smooth_target_distributions.scatter_(1, trg_token_ids_batch, self.confidence_value)
        smooth_target_distributions[:, self.pad_token_id] = 0.

        # If we had a pad token as a target we set the distribution to all 0s instead of smooth labeled distribution
        smooth_target_distributions.masked_fill_(trg_token_ids_batch == self.pad_token_id, 0.)

        return smooth_target_distributions

Let's now define the parameters of our model and instantiate the same. Below, we also
define our loss function which is the cross-entropy loss and the optimizer used for training.




In [29]:
# SRC_VOCAB_SIZE
# vocab_transform[SRC_LANGUAGE]

In [30]:
DEVICE

device(type='cuda', index=1)

## Collation

As seen in the ``Data Sourcing and Processing`` section, our data iterator yields a pair of raw strings.
We need to convert these string pairs into the batched tensors that can be processed by our ``Seq2Seq`` network
defined previously. Below we define our collate function that converts a batch of raw strings into batch tensors that
can be fed directly into our model.




In [31]:
from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

In [32]:
def get_masks_and_count_tokens_src(src_token_ids_batch, pad_token_id):
    batch_size = src_token_ids_batch.shape[0]

    # src_mask shape = (B, 1, 1, S) check out attention function in transformer_model.py where masks are applied
    # src_mask only masks pad tokens as we want to ignore their representations (no information in there...)
    src_mask = (src_token_ids_batch != pad_token_id).view(batch_size, 1, 1, -1)
    num_src_tokens = torch.sum(src_mask.long())

    return src_mask, num_src_tokens


def get_masks_and_count_tokens_trg(trg_token_ids_batch, pad_token_id):
    batch_size = trg_token_ids_batch.shape[0]
    device = trg_token_ids_batch.device

    # Same as src_mask but we additionally want to mask tokens from looking forward into the future tokens
    # Note: wherever the mask value is true we want to attend to that token, otherwise we mask (ignore) it.
    sequence_length = trg_token_ids_batch.shape[1]  # trg_token_ids shape = (B, T) where T max trg token-sequence length
    trg_padding_mask = (trg_token_ids_batch != pad_token_id).view(batch_size, 1, 1, -1)  # shape = (B, 1, 1, T)
    trg_no_look_forward_mask = torch.triu(torch.ones((1, 1, sequence_length, sequence_length), device=device) == 1).transpose(2, 3)

    # logic AND operation (both padding mask and no-look-forward must be true to attend to a certain target token)
    trg_mask = trg_padding_mask & trg_no_look_forward_mask  # final shape = (B, 1, T, T)
    num_trg_tokens = torch.sum(trg_padding_mask.long())

    return trg_mask, num_trg_tokens


def get_masks_and_count_tokens(src_token_ids_batch, trg_token_ids_batch, pad_token_id, device):
    src_mask, num_src_tokens = get_masks_and_count_tokens_src(src_token_ids_batch, pad_token_id)
    trg_mask, num_trg_tokens = get_masks_and_count_tokens_trg(trg_token_ids_batch, pad_token_id)

    return src_mask, trg_mask, num_src_tokens, num_trg_tokens




Let's define training and evaluation loop that will be called for each
epoch.




In [33]:
from torch.utils.data import DataLoader

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn,  shuffle = True)

    for src, tgt in train_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        src = src.transpose(0,1)
        tgt = tgt.transpose(0,1)
        
        tgt_input = tgt[:, :-1]

        # src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        src_mask, tgt_mask, _, _ = get_masks_and_count_tokens(src, tgt_input, PAD_IDX, DEVICE)                    

        # logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
        # import pdb; pdb.set_trace()
        src = src.to(dtype=torch.int64)
        tgt_input = tgt_input.to(dtype=torch.int64)
        
        # try:
        logits = model(src, tgt_input, src_mask, tgt_mask)
        # except:
        #     import pdb; pdb.set_trace()
            
        # logits = model(src, tgt_input, src_padding_mask, tgt_padding_mask)
        # predicted_log_distributions = baseline_transformer(src_token_ids_batch, trg_token_ids_batch_input, src_mask, trg_mask)        
        
        optimizer.zero_grad()

        # tgt_out = tgt[1:, :]
        tgt_out = tgt[:, 1:]
        # loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        
        trg_token_ids_batch_gt = tgt[:, 1:].reshape(-1, 1)
        trg_token_ids_batch_gt = trg_token_ids_batch_gt.to(dtype=torch.int64)
        smooth_target_distributions = label_smoothing(trg_token_ids_batch_gt)  # these are regular probabilities
        loss = kl_div_loss(logits, smooth_target_distributions)
        loss.backward()

        # optimizer.step()
        custom_lr_optimizer.step()
        
        losses += loss.item()

    return losses / len(list(train_dataloader))


def evaluate(model):
    model.eval()
    losses = 0

    val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in val_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        src = src.transpose(0,1)
        tgt = tgt.transpose(0,1)
        
        
        # tgt_input = tgt[:-1, :]
        tgt_input = tgt[:, :-1]

        src_mask, tgt_mask, _, _ = get_masks_and_count_tokens(src, tgt_input, PAD_IDX, DEVICE)                    
        
        src = src.to(dtype=torch.int64)
        tgt_input = tgt_input.to(dtype=torch.int64)
        
        try:
            logits = model(src, tgt_input, src_mask, tgt_mask)
        except:
            import pdb; pdb.set_trace()
        
        # src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        # logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        # tgt_out = tgt[1:, :]
        tgt_out = tgt[:,1:]
        # loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))

        trg_token_ids_batch_gt = tgt[:, 1:].reshape(-1, 1)
        trg_token_ids_batch_gt = trg_token_ids_batch_gt.to(dtype=torch.int64)
        smooth_target_distributions = label_smoothing(trg_token_ids_batch_gt)  # these are regular probabilities
        loss = kl_div_loss(logits, smooth_target_distributions)
        
        losses += loss.item()

    return losses / len(list(val_dataloader))

In [34]:
class CustomLRAdamOptimizer:
    """
        Linear ramp learning rate for the warm-up number of steps and then start decaying
        according to the inverse square root law of the current training step number.

        Check out playground.py for visualization of the learning rate (visualize_custom_lr_adam).
    """

    def __init__(self, optimizer, model_dimension, num_of_warmup_steps):
        self.optimizer = optimizer
        self.model_size = model_dimension
        self.num_of_warmup_steps = num_of_warmup_steps

        self.current_step_number = 0

    def step(self):
        self.current_step_number += 1
        current_learning_rate = self.get_current_learning_rate()

        for p in self.optimizer.param_groups:
            p['lr'] = current_learning_rate

        self.optimizer.step()  # apply gradients

    # Check out the formula at Page 7, Chapter 5.3 "Optimizer" and playground.py for visualization
    def get_current_learning_rate(self):
        # For readability purpose
        step = self.current_step_number
        warmup = self.num_of_warmup_steps

        return self.model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5))

    def zero_grad(self):
        self.optimizer.zero_grad()

In [35]:
# torch.manual_seed(0)

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])

# attn_type = 'kv'
num_runs = 2
which_task = f'translation_{SRC_LANGUAGE}_{TGT_LANGUAGE}'
NUM_EPOCHS = 20
BATCH_SIZE = 50

def define_model():
    transformer = Transformer(
            model_dimension=BASELINE_MODEL_DIMENSION,
            src_vocab_size=SRC_VOCAB_SIZE,
            trg_vocab_size=TGT_VOCAB_SIZE,
            number_of_heads=BASELINE_MODEL_NUMBER_OF_HEADS,
            number_of_layers=BASELINE_MODEL_NUMBER_OF_LAYERS,
            dropout_probability=BASELINE_MODEL_DROPOUT_PROB
        ).to(DEVICE)


    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    # transformer = transformer.to(DEVICE)


    # loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
    kl_div_loss = nn.KLDivLoss(reduction='batchmean')  # gives better BLEU score than "mean"
    # optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

    # for KV
    optimizer = torch.optim.Adam(transformer.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
    custom_lr_optimizer = CustomLRAdamOptimizer(
                Adam(transformer.parameters(), betas=(0.9, 0.98), eps=1e-9),
                BASELINE_MODEL_DIMENSION,
                4000 # num warmup steps
            )
    
    return transformer, kl_div_loss, optimizer, custom_lr_optimizer


In [36]:
# file_name
test_iter = Multi30k(split='test', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
test_dataloader = DataLoader(test_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)


BASELINE_MODEL_DROPOUT_PROB = 0.1

BASELINE_MODEL_LABEL_SMOOTHING_VALUE = 0.1

label_smoothing = LabelSmoothingDistribution(BASELINE_MODEL_LABEL_SMOOTHING_VALUE, PAD_IDX, TGT_VOCAB_SIZE, DEVICE)

In [37]:
from timeit import default_timer as timer
import pickle
import copy
from torch.optim import Adam
import time

In [None]:
for BASELINE_MODEL_NUMBER_OF_LAYERS in [2]: #[1, 2]:
  for BASELINE_MODEL_DIMENSION in [64, 128, 256]:
      for BASELINE_MODEL_NUMBER_OF_HEADS in [1, 4]:

            results = {}

            # for norm_type in ['no-sm-std', 'no-sm-layernorm']:
            for attn_type in ['kv', 'kv-nopos', 'qkv']:

              for run_no in range(num_runs):

                  print(f'{which_task}---------- norm type: {attn_type}, run number: {run_no} ------------------------------')

                  file_name = f'{BASELINE_MODEL_NUMBER_OF_LAYERS}_{BASELINE_MODEL_DIMENSION}_\
                  {BASELINE_MODEL_NUMBER_OF_HEADS}_{BASELINE_MODEL_DROPOUT_PROB}_{which_task}'

                  transformer, kl_div_loss, optimizer, custom_lr_optimizer = define_model()

                  results[(attn_type, run_no)] = {}
                  results[(attn_type, run_no)]['time_spent'] = []  
                  results[(attn_type, run_no)]['train_loss_history'] = []  
                  results[(attn_type, run_no)]['val_loss_history'] = []  
                  results[(attn_type, run_no)]['val_time_spent'] = []  
                  results[(attn_type, run_no)]['test_acc'] = []  

                  start_time = time.time()

                  for epoch in range(1, NUM_EPOCHS+1):
                        
                        train_loss = train_epoch(transformer, optimizer)
                        results[(attn_type, run_no)]['time_spent'].append(time.time() - start_time) 

                        val_loss = evaluate(transformer)
                        results[(attn_type, run_no)]['val_time_spent'].append(time.time() - start_time) 

                        print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(time.time() - start_time):.3f}s"))


                        results[(attn_type, run_no)]['train_loss_history'].append(train_loss)
                        results[(attn_type, run_no)]['val_loss_history'].append(val_loss)

                        bleu_score = calculate_bleu_score(transformer, test_dataloader, vocab_transform[TGT_LANGUAGE])


                        results[(attn_type, run_no)]['test_acc'].append(bleu_score)


                  with open(f'./results_new/nlp/{file_name}.pkl', 'wb') as f:
                    pickle.dump(results, f)





# # reverse_model
# input = torch.randn(1,16,10).to(device)
# macs, params = profile(reverse_model, inputs=(input, ))
# print(macs/1e6, params/1e6)

# plt.imshow(attention[0,0].cpu().numpy()); plt.show()
# plt.imshow(attention[0,0].detach().cpu().numpy()); plt.show()

translation_en_de---------- norm type: kv, run number: 0 ------------------------------




Epoch: 1, Train loss: 3.680, Val loss: 2.772, Epoch time = 38.194s


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


BLEU-4 corpus score = 1.6010157907937408e-156, corpus length = 1000, time elapsed = 0.9775991439819336 seconds.
Epoch: 2, Train loss: 2.234, Val loss: 1.988, Epoch time = 75.060s
BLEU-4 corpus score = 0.06996482449497146, corpus length = 1000, time elapsed = 8.511854410171509 seconds.
Epoch: 3, Train loss: 1.810, Val loss: 1.705, Epoch time = 120.062s
BLEU-4 corpus score = 0.14205408466755146, corpus length = 1000, time elapsed = 6.2505576610565186 seconds.
Epoch: 4, Train loss: 1.598, Val loss: 1.537, Epoch time = 163.027s
BLEU-4 corpus score = 0.17695563116886215, corpus length = 1000, time elapsed = 8.70157766342163 seconds.
Epoch: 5, Train loss: 1.454, Val loss: 1.422, Epoch time = 208.162s
BLEU-4 corpus score = 0.23377143638444187, corpus length = 1000, time elapsed = 5.5308518409729 seconds.
Epoch: 6, Train loss: 1.345, Val loss: 1.358, Epoch time = 250.213s
BLEU-4 corpus score = 0.22666851966910048, corpus length = 1000, time elapsed = 12.768577098846436 seconds.
Epoch: 7, Train

The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


BLEU-4 corpus score = 4.2849027717056804e-234, corpus length = 1000, time elapsed = 0.6220293045043945 seconds.
Epoch: 2, Train loss: 2.240, Val loss: 1.981, Epoch time = 70.162s
BLEU-4 corpus score = 0.04408060017866504, corpus length = 1000, time elapsed = 19.518733024597168 seconds.
Epoch: 3, Train loss: 1.786, Val loss: 1.676, Epoch time = 124.469s
BLEU-4 corpus score = 0.16496876520031528, corpus length = 1000, time elapsed = 3.7280468940734863 seconds.
Epoch: 4, Train loss: 1.544, Val loss: 1.474, Epoch time = 163.001s
BLEU-4 corpus score = 0.21923017838479877, corpus length = 1000, time elapsed = 10.401109218597412 seconds.
Epoch: 5, Train loss: 1.366, Val loss: 1.351, Epoch time = 208.205s
BLEU-4 corpus score = 0.28971574671657974, corpus length = 1000, time elapsed = 3.6849050521850586 seconds.
Epoch: 6, Train loss: 1.258, Val loss: 1.274, Epoch time = 246.623s
BLEU-4 corpus score = 0.2886067811366814, corpus length = 1000, time elapsed = 7.9016032218933105 seconds.
Epoch: 7, 

Now we have all the ingredients to train our model. Let's do it!




In [None]:
from timeit import default_timer as timer
NUM_EPOCHS = 20

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer)
    end_time = timer()
    val_loss = evaluate(transformer)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

    bleu_score = calculate_bleu_score(transformer, test_dataloader, vocab_transform[TGT_LANGUAGE])
    # print(bleu_score)    

In [None]:
plt.imshow(scores[0,0].detach().cpu().numpy()); plt.show()

In [None]:
# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    
    # import pdb; pdb.set_trace()
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys



# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    model.eval()
    src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)
    # import pdb; pdb.set_trace()
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)std
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")

In [None]:
# print(translate(transformer, "Eine Gruppe von Menschen steht vor einem Iglu ."))

In [None]:
# print(translate(transformer, "Two men are sitting in front of a wall."))

## Bleu

In [None]:
from nltk.translate.bleu_score import corpus_bleu

In [None]:
transformer.eval()
losses = 0

val_iter = Multi30k(split='test', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)


all_ref = []
all_cand = []

for src, tgt in val_dataloader:
    # src = src.to(DEVICE)
    for idx in range(src.shape[-1]):
        srs_sent = " ".join(vocab_transform[SRC_LANGUAGE].lookup_tokens(list(src[:,idx].cpu().numpy()))).replace("<bos>", "").replace("<eos>", "").replace("<pad>", "")
        tgt_sent = " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt[:,idx].cpu().numpy()))).replace("<bos>", "").replace("<eos>", "").replace("<pad>", "")
        trans = translate(transformer, srs_sent)

        # print(f' src: {srs_sent}\n target: {tgt_sent}\n translation: {trans}\n')

        if not tgt_sent.strip(): continue

        all_ref.append(tgt_sent.split()) 
        all_cand.append(trans)
      
        # print('one more batch')
        
bleu = corpus_bleu(all_ref, all_cand)*100
print(bleu)
    

In [None]:
not tgt_sent.strip()

In [None]:
all_cand[-1]

In [None]:
corpus_bleu(all_ref[1000:1014], all_cand[1000:1014])*100

In [None]:
all_ref[-1]

In [None]:
src.shape

In [None]:
tgt[:,0]

## References

1. Attention is all you need paper.
   https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
2. The annotated transformer. https://nlp.seas.harvard.edu/2018/04/03/attention.html#positional-encoding

