In [2]:
import sys
from pathlib import Path

from dataclasses import dataclass
import math

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter

from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace


In [3]:
@dataclass
class ScratchConfig:

    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    tokenizer_file: str = "./data/opus_books/tokenizer_{0}.json"
    src_lang: str = "fr"
    tgt_lang: str = "en"
    batch_size: int = 16
    max_seq_len: int = 350

    # Model
    d_model: int = 512
    num_layers: int = 6
    num_heads: int = 8
    d_ff: int = 2048
    dropout: float = 0.1
    
    weights_folder = "./weights"
    epochs_save = "scratch_epoch_{0:03d}.pt"
    log_dir = "./logs/scratch"

    # Training
    epochs: int = 10
    lr: float = 1e-4
    label_smoothing: float = 0.1

# Embeddings

## Input Embeddings

Input Embeddings are basically, just the vector form of a token(word), which are used as representation for that particular token. The embedding capture the semantic and syntatic properties of the words, and also have an emerging property that words with similar meaning will have similar vectors, here similar means the distance to such vectors is less compared to some random vector.

These vectors can be of any size, the larger the vector size, the better meaning it can capture, but also consumes more memory. In original transformer the embedding size used was 512, whereas in BERT-base its 768, and in latest GPT models its even larger like 2048, 4096 etc.

![Input Embeddings Example](../images/input_emb.png)

In [4]:
class InputEmbedding(nn.Module):
    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size

        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)

    def forward(self, x: torch.FloatTensor | torch.cuda.FloatTensor) -> torch.FloatTensor | torch.cuda.FloatTensor:
        # x: (batch_size, seq_len) -> (batch_size, seq_len, d_model)

        # (from Paper) In the embedding layers, we multiply those weights by √d_model
        return self.embedding(x) * math.sqrt(self.d_model)

In [5]:
# test

input_embedding = InputEmbedding(d_model=2, vocab_size=100)
x = torch.randint(low=0, high=10, size=(3,5)) # (batch_size, seq_len)
print(x)
print(input_embedding(x)) # (batch_size, seq_len, d_model)
print(input_embedding(x).shape)

tensor([[1, 1, 6, 6, 9],
        [9, 3, 3, 7, 6],
        [7, 5, 4, 7, 8]])
tensor([[[-0.5992, -0.1990],
         [-0.5992, -0.1990],
         [ 0.6258, -0.0069],
         [ 0.6258, -0.0069],
         [ 0.6561, -1.5271]],

        [[ 0.6561, -1.5271],
         [-0.2786, -0.2459],
         [-0.2786, -0.2459],
         [ 1.0389,  0.5644],
         [ 0.6258, -0.0069]],

        [[ 1.0389,  0.5644],
         [-0.6462, -0.5525],
         [ 0.4011,  1.6332],
         [ 1.0389,  0.5644],
         [ 0.1460,  0.1679]]], grad_fn=<MulBackward0>)
torch.Size([3, 5, 2])


## Positional Embedding

Since transformers, uses attention, which is "permutation-equivariant", which simply means that the if you change the order of the input, the order of output will change in the same, and there won't be any difference in the values apart from order change. The self-attention does weighted sum of all elements, since sum is same irrespective of the order of the sequence, the output will be same.

But in language, the order of words does matter. The sequence "Man cooks Turkey" and "Turkey cooks Man" has totally different meaning, but for self-attention it doesn't matter the input is "1,2,3" or "3,2,1", and will give same results. But we want the embeddings to be different when the different order is used.

Hence we use positional embeddings, where we add the same dimension vector to each token. In original transformer we used fixed embedding, which can be pre-calculate for each position in the sequence and for each embedding vector value.

But there are positional embeddings which can be learnt during training, and are set as weights, and there are other positional embeddings like `rotary` embeddings which depends on the relative position of the tokens rather than absolute position.

![Positional Embeddings](../images/pos_emb.png)

![Positional Encoding](../images/pos_enc.png)

In [6]:
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model: int, seq_len:int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        # Weights of shape (seq_len, d_model)
        pe = torch.zeros((self.seq_len, self.d_model))

        # Position Vector
        position = torch.arange(0, seq_len, dtype=float).unsqueeze(1) # (seq_len, 1)

        # Division term using `log` for numerical stability
        # It is mathematically equal to the above formula.
        # e^(2i * ln(10_000)/d) == 10_000^(2i/d)
        # negative sign to make it denominator
        div_term = torch.exp( torch.arange(0,d_model,2).float() * (-math.log(10_000.)/self.d_model) ) # (d_model/2)

        # sin() for even position and cos() for odd position
        # for each token all even position in the position embedding
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0) # (1, seq_len, d_model)

        # register positional encoding as a persistent buffer within the module
        # automatically tracked by the module and included in the state dictionary
        # when saving or loading the model
        self.register_buffer('pe', pe)

    def forward(self, x: torch.FloatTensor | torch.cuda.FloatTensor) -> torch.FloatTensor | torch.cuda.FloatTensor:
        # x: (batch_size, seq_len, d_model)  -> (batch_size, seq_len, d_model)
        # Positional Encodings here are fixed, so setting the gradients to False inplace.
        # Using x.shape[1], to use PEs from 0 to the size of input sequence.
        x = x + self.pe[:, :x.shape[1], :].requires_grad_(False) # (batch_size, seq_len, d_model)
        return self.dropout(x)

In [7]:
# test

batch_size = 3
seq_len = 5
embedding_dim = 2
positional_embedding = PositionalEmbedding(d_model=embedding_dim, seq_len=seq_len, dropout=0.1)
x = torch.randn((batch_size,seq_len,embedding_dim))
print(x.shape) # (batch_size, seq_len, d_model)
print(positional_embedding(x).shape) # (batch_size, seq_len, d_model)

torch.Size([3, 5, 2])
torch.Size([3, 5, 2])


# Normalization

The purpose of this normalization is to stabilize the network, speed up convergence, and reduce the sensitivity to the initialization of the model parameters. 

Normalization helps in dealing with the problem of internal covariate shift, where the distribution of each layer's inputs changes during training, as the parameters of the previous layers change. This can slow down the training process and make it harder for the network to converge. Normalization mitigates this problem by ensuring that the layer's inputs are more stable.

The transformer uses `LayerNormalization` over the `BatchNormalization` as to avoid calculating statistics over the entire batch, which can slow down training, as it has to bring all the batches on single machine to get statistics and send it back to different machines after normalization. Whereas Layer Normalization, normalizes values across each feature in a single example.

In [8]:
class LayerNormalization(nn.Module):
    def __init__(self, features: int, eps: float = 10e-6) -> None:
        super().__init__()
        self.eps = eps

        self.gamma = nn.Parameter(torch.ones(features)) # multiplied
        self.beta = nn.Parameter(torch.zeros(features)) # added

    def forward(self, x):
        # x : (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_model)
        mean = x.mean(dim=-1, keepdim=True) # (batch_size, seq_len, 1)
        std = x.std(dim=-1, keepdim=True) # (batch_size, seq_len, 1)

        return self.gamma * ( (x - mean) / (std + self.eps) ) + self.beta

In [9]:
# test
batch_size = 3
seq_len = 5
embedding_dim = 2
layer_norm = LayerNormalization(features=embedding_dim, eps=10e-6)
x = torch.rand((batch_size,seq_len,embedding_dim)) # (batch_size, seq_len, d_model)
print(x.shape) # (batch_size, seq_len, d_model)
print(layer_norm(x).shape) # (batch_size, seq_len, d_model)

torch.Size([3, 5, 2])
torch.Size([3, 5, 2])


# FeedForward Block

Vanilla 2 layer fully connected network along with dropout, to transform the representation obtained from the self-attention mechanism, and adding depth to the model.

The self-attention mechanism helps the model understand the context of each word in the sequence, the FFN help the model to represent the words by taking account the words individual meaning and its context.

In [10]:
class FeedForwardBlock(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.linear1 = nn.Linear(in_features=d_model, out_features=d_ff, bias=True) # W1, b1
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(in_features=d_ff, out_features=d_model, bias=True) # W2, b2

    def forward(self, x: torch.FloatTensor | torch.cuda.FloatTensor) -> torch.FloatTensor | torch.cuda.FloatTensor:
        # x: (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_model)

        # (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff)
        x = torch.relu(self.linear1(x))
        x = self.dropout(x)
        # (batch_size, seq_len, d_ff) -> (batch_size, seq_len, d_model)
        return self.linear2(x)


In [11]:
# test

batch_size = 3
seq_len = 5
embedding_dim = 2
d_ff = 4

ffn_block = FeedForwardBlock(d_model=embedding_dim, d_ff=d_ff, dropout=0.1)
x = torch.rand((batch_size,seq_len,embedding_dim)) 
print(x.shape) # (batch_size, seq_len, d_model)
print(ffn_block(x).shape) # (batch_size, seq_len, d_model)

torch.Size([3, 5, 2])
torch.Size([3, 5, 2])


# Attentions

Self-attention, is a mechanism that allows each token in the input to interact with every other token. It computes a weighted sum of all input tokens' representations for each token, where the weights are determined by the input tokens themselves. In simple words, for each token, it asks all the other tokens `how much related you are to me?`, and the relationship is given by a weight. This relationship is produced by doing simple dot product between two tokens. If two tokens are having high relationship, the attention score between them will be higher and vice-versa.

The self-attention mechanism allows the model to focus on different parts of the input sequence when processing each token, which helps it understand the context and dependencies between words in a sentence, even if they are far apart.

At the end, it then multiplies this attention-score matrix, with the final `value` embeddings, to acheive each token final value. This final value of the token now contains the relationship/context of itself with all other tokens based on the attention score.

![Multi Head Attention](../images/mh_attn.png)

In [12]:
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float, verbose: bool =False) -> None:
        super().__init__()
        self.d_model = d_model # 512
        self.num_heads = num_heads # 8
        self.dropout = nn.Dropout(dropout)
        self.verbose = verbose

        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_k = d_model // num_heads # 512//8 = 64
        self.w_q = nn.Linear(in_features=d_model, out_features=d_model, bias=True) # Wq
        self.w_k = nn.Linear(in_features=d_model, out_features=d_model, bias=True) # Wk
        self.w_v = nn.Linear(in_features=d_model, out_features=d_model, bias=True) # Wv

        self.w_o = nn.Linear(in_features=num_heads*self.d_k, out_features=d_model, bias=True) # Wo


    @staticmethod
    def attention(key, query, value, mask, dropout) -> torch.FloatTensor | torch.cuda.FloatTensor:
        # (batch_size, num_heads, seq_len, d_k) -> (batch_size, num_heads, seq_len, d_k)
        d_k =  query.shape[-1]

        attention_scores = query @ key.transpose(-2, -1) # (batch_size, num_heads, seq_len, seq_len)
        attention_scores /= math.sqrt(d_k) # (batch_size, num_heads, seq_len, seq_len)
        if mask is not None:
            # Replace all values in attention scores, where mask is 0, with -1e9
            attention_scores.masked_fill_(mask == 0, -1e9) # (batch_size, num_heads, seq_len, seq_len)
        
        attention_scores = attention_scores.softmax(dim=-1) # (batch_size, num_heads, seq_len, seq_len)
        
        attention_scores = dropout(attention_scores) # (batch_size, num_heads, seq_len, seq_len)

        # (batch_size, num_heads, seq_len, seq_len) @ (batch_size, num_heads, seq_len, d_k)
        # -> (batch_size, num_heads, seq_len, d_k)
        return attention_scores @ value, attention_scores

    def forward(self,
                q: torch.FloatTensor | torch.cuda.FloatTensor,
                k: torch.FloatTensor | torch.cuda.FloatTensor,
                v: torch.FloatTensor | torch.cuda.FloatTensor,
                mask: torch.FloatTensor | torch.cuda.FloatTensor,
                ) -> torch.FloatTensor | torch.cuda.FloatTensor:
        # q: (batch_size, seq_len, d_model)
        query = self.w_q(q) # (batch_size, seq_len, d_model)
        # k: (batch_size, seq_len, d_model)
        key = self.w_k(k) # (batch_size, seq_len, d_model)
        # v: (batch_size, seq_len, d_model)
        value = self.w_v(v) # (batch_size, seq_len, d_model)

        batch_size, seq_len = query.shape[:2]
        # using einops.rearrange: rearrange(query, 'b s (h d) -> b h s d', h=self.num_heads)
        
        query = query.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1,2) # (batch_size, num_heads, seq_len, d_k)
        # During Decoder cross-attn, seq_len of query and key/value can be different, hence using `-1` for key/value
        key = key.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2) # (batch_size, num_heads, -1, d_k)
        value = value.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2) # (batch_size, num_heads, seq_len, d_k)

        if self.verbose:
            print("Query shape", query.shape)
            print("Key shape", key.shape)
            print("Value shape", value.shape)

        x, attn_scores = MultiHeadAttentionBlock.attention(key, query, value, mask, self.dropout)
        # x: (batch_size, num_heads, seq_len, d_k)
        if self.verbose:
            print("Attention Scores shape", attn_scores.shape)
            print("Attention Output shape", x.shape)

        # continguous to make the tensor in contiguous block of memory
        x = x.transpose(1,2).contiguous() # (batch_size, seq_len, num_heads, d_k)
        x = x.view(batch_size, seq_len, self.num_heads*self.d_k)

        if self.verbose:
            print("MHA Final Output shape", x.shape)

        return self.w_o(x) # attn_scores # (batch_size, seq_len, d_model), (batch_size, num_heads, seq_len, seq_len)

         

In [13]:
# test

batch_size = 3
seq_len = 5
embedding_dim = 8
num_heads = 4

attn_block = MultiHeadAttentionBlock(d_model=embedding_dim, num_heads=num_heads, dropout=0.1, verbose=True)
x = torch.rand((batch_size,seq_len,embedding_dim)) 
print("Input Shape", x.shape) # (batch_size, seq_len, d_model)
print()
x = attn_block(x,x,x, mask=None)
print()
print("Output Shape", x.shape) # (batch_size, seq_len, d_model)

Input Shape torch.Size([3, 5, 8])

Query shape torch.Size([3, 4, 5, 2])
Key shape torch.Size([3, 4, 5, 2])
Value shape torch.Size([3, 4, 5, 2])
Attention Scores shape torch.Size([3, 4, 5, 5])
Attention Output shape torch.Size([3, 4, 5, 2])
MHA Final Output shape torch.Size([3, 5, 8])

Output Shape torch.Size([3, 5, 8])


## Residual/Skip Connection

In [14]:
class ResidualConnection(nn.Module):
    def __init__(self, features:int, dropout: float) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(features=features)

    def forward(self, x, sublayer):
        # (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_model)
        return x + self.dropout(sublayer(self.norm(x))) # In recent changes/implementations, norm is performed before sublayer.


In [15]:
# test

batch_size = 3
seq_len = 5
embedding_dim = 2
d_ff = 4

res_block = ResidualConnection(features=embedding_dim, dropout=0.1)

x = torch.rand((batch_size,seq_len,embedding_dim)) 
print("Input Shape", x.shape) # (batch_size, seq_len, d_model)
print()

x = res_block(x, lambda x: FeedForwardBlock(d_model=embedding_dim, d_ff=d_ff, dropout=0.1)(x))
print("Output Shape", x.shape) # (batch_size, seq_len, d_model)

Input Shape torch.Size([3, 5, 2])

Output Shape torch.Size([3, 5, 2])


# Encoder

![Encoder Block](../images/encoder.png)

[Source](https://kikaben.com/transformers-encoder-decoder/)

## One Encoder Block

In [16]:
class EncoderBlock(nn.Module):

    def __init__(self,
                 features: int,
                 self_attn_block: MultiHeadAttentionBlock,
                 feed_foward_block: FeedForwardBlock,
                 dropout: float,
                 verbose: bool = False) -> None:
        super().__init__()
        self.self_attn_block = self_attn_block
        self.feed_foward_block = feed_foward_block
        # Two Residual Connections, one after MHA and other after FFN
        self.skip = nn.ModuleList(
            [
                ResidualConnection(features=features, dropout=dropout) for _ in range(2)
            ]
        )
        self.verbose = verbose

    def forward(self, x, src_mask):
        # x: (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_model)
        x = self.skip[0](x, lambda x: self.self_attn_block(x, x, x, src_mask)) # x + dropout( MHA( LayerNorm(x) ) )
        if self.verbose:
            print("EncoderBlock MHA Output shape", x.shape)
        x = self.skip[1](x, self.feed_foward_block) # x + dropout( FFN( LayerNorm(x) ) )
        if self.verbose:
            print("EncoderBlock FFN Output shape", x.shape)

        return x

In [17]:
# test

batch_size = 3
seq_len = 5
embedding_dim = 8
num_heads = 4
d_ff = 8

attn_block = MultiHeadAttentionBlock(d_model=embedding_dim, num_heads=num_heads, dropout=0.1, verbose=True)
ffn_block = FeedForwardBlock(d_model=embedding_dim, d_ff=d_ff, dropout=0.1)
res_block = ResidualConnection(features=embedding_dim, dropout=0.1)
enc_block = EncoderBlock(features=embedding_dim, self_attn_block=attn_block, 
                         feed_foward_block=ffn_block, dropout=0.1,
                         verbose=True)

x = torch.rand((batch_size,seq_len,embedding_dim)) 
print("Input Shape", x.shape) # (batch_size, seq_len, d_model)
print()

x = enc_block(x, src_mask=None)
print()
print("Output Shape", x.shape) # (batch_size, seq_len, d_model)

Input Shape torch.Size([3, 5, 8])

Query shape torch.Size([3, 4, 5, 2])
Key shape torch.Size([3, 4, 5, 2])
Value shape torch.Size([3, 4, 5, 2])
Attention Scores shape torch.Size([3, 4, 5, 5])
Attention Output shape torch.Size([3, 4, 5, 2])
MHA Final Output shape torch.Size([3, 5, 8])
EncoderBlock MHA Output shape torch.Size([3, 5, 8])
EncoderBlock FFN Output shape torch.Size([3, 5, 8])

Output Shape torch.Size([3, 5, 8])


## Encoder Stack

In [18]:
class Encoder(nn.Module):

    def __init__(self, features:int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features=features)

    def forward(self, x, mask):
        # x: (batch_size, seq_len, d_model)
        # mask: (batch_size, seq_len, seq_len)
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [19]:
# test

batch_size = 3
seq_len = 5
embedding_dim = 8
num_heads = 4
d_ff = 8

attn_block = MultiHeadAttentionBlock(d_model=embedding_dim, num_heads=num_heads, dropout=0.1, verbose=True)
ffn_block = FeedForwardBlock(d_model=embedding_dim, d_ff=d_ff, dropout=0.1)
enc_block = EncoderBlock(features=embedding_dim, self_attn_block=attn_block, 
                         feed_foward_block=ffn_block, dropout=0.1,
                         verbose=True)
enc_stack = Encoder(features=embedding_dim, layers=nn.ModuleList([enc_block]*2))

x = torch.rand((batch_size,seq_len,embedding_dim)) 
print("Input Shape", x.shape) # (batch_size, seq_len, d_model)
print()

x = enc_stack(x, mask=None)
print()
print("Output Shape", x.shape) # (batch_size, seq_len, d_model)

Input Shape torch.Size([3, 5, 8])

Query shape torch.Size([3, 4, 5, 2])
Key shape torch.Size([3, 4, 5, 2])
Value shape torch.Size([3, 4, 5, 2])
Attention Scores shape torch.Size([3, 4, 5, 5])
Attention Output shape torch.Size([3, 4, 5, 2])
MHA Final Output shape torch.Size([3, 5, 8])
EncoderBlock MHA Output shape torch.Size([3, 5, 8])
EncoderBlock FFN Output shape torch.Size([3, 5, 8])
Query shape torch.Size([3, 4, 5, 2])
Key shape torch.Size([3, 4, 5, 2])
Value shape torch.Size([3, 4, 5, 2])
Attention Scores shape torch.Size([3, 4, 5, 5])
Attention Output shape torch.Size([3, 4, 5, 2])
MHA Final Output shape torch.Size([3, 5, 8])
EncoderBlock MHA Output shape torch.Size([3, 5, 8])
EncoderBlock FFN Output shape torch.Size([3, 5, 8])

Output Shape torch.Size([3, 5, 8])


# Decoder

![Decoder Block](../images/decoder.png)

[Source](https://kikaben.com/transformers-encoder-decoder/)

## Single Decoder Block

In [20]:
class DecoderBlock(nn.Module):

    def __init__(self, features: int,
                 self_attn_block: MultiHeadAttentionBlock,
                 cross_attn_block: MultiHeadAttentionBlock,
                 feed_forward_block: FeedForwardBlock,
                 dropout: float,
                 verbose: bool = False) -> None:
        super().__init__()
        self.self_attn_block = self_attn_block
        self.cross_attn_block = cross_attn_block
        self.feed_forward_block = feed_forward_block
        # Three Residual Connections, one after each block
        self.skip = nn.ModuleList(
            [
                ResidualConnection(features=features, dropout=dropout) for _ in range(3)
            ]
        )
        self.verbose = verbose

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        # x: (batch_size, seq_len, d_model)
        # encoder_output: (batch_size, seq_len, d_model)
        # src_mask: (batch_size, seq_len, seq_len)
        # tgt_mask: (batch_size, seq_len, seq_len)

        x = self.skip[0](x, lambda x: self.self_attn_block(x, x, x, tgt_mask)) # Input + MHA -> LayerNorm -> Dropout
        if self.verbose:
            print("DecoderBlock Self-Attention Output shape", x.shape)
        # Input + MHA -> LayerNorm -> Dropout
        x = self.skip[1](x, lambda x: self.cross_attn_block(x, encoder_output, encoder_output, src_mask)) 
        if self.verbose:
            print("DecoderBlock Cross-Attention Output shape", x.shape)
        x = self.skip[2](x, self.feed_forward_block) # Input + FFN -> LayerNorm -> Dropout
        if self.verbose:
            print("DecoderBlock FFN Output shape", x.shape)

        return x
        

In [21]:
# test

batch_size = 3
seq_len = 5
embedding_dim = 8
num_heads = 4
d_ff = 8

attn_block = MultiHeadAttentionBlock(d_model=embedding_dim, num_heads=num_heads, dropout=0.1, verbose=True)
cross_attn_block = MultiHeadAttentionBlock(d_model=embedding_dim, num_heads=num_heads, dropout=0.1, verbose=True)
ffn_block = FeedForwardBlock(d_model=embedding_dim, d_ff=d_ff, dropout=0.1)

dec_block = DecoderBlock(features=embedding_dim, self_attn_block=attn_block,
                         cross_attn_block=cross_attn_block, feed_forward_block=ffn_block,
                         dropout=0.1, verbose=True)

x = torch.rand((batch_size,seq_len,embedding_dim)) 
print("Input Shape", x.shape) # (batch_size, seq_len, d_model)
print()

x = dec_block(x, x, src_mask=None, tgt_mask=None)
print()
print("Output Shape", x.shape) # (batch_size, seq_len, d_model)

Input Shape torch.Size([3, 5, 8])

Query shape torch.Size([3, 4, 5, 2])
Key shape torch.Size([3, 4, 5, 2])
Value shape torch.Size([3, 4, 5, 2])
Attention Scores shape torch.Size([3, 4, 5, 5])
Attention Output shape torch.Size([3, 4, 5, 2])
MHA Final Output shape torch.Size([3, 5, 8])
DecoderBlock Self-Attention Output shape torch.Size([3, 5, 8])
Query shape torch.Size([3, 4, 5, 2])
Key shape torch.Size([3, 4, 5, 2])
Value shape torch.Size([3, 4, 5, 2])
Attention Scores shape torch.Size([3, 4, 5, 5])
Attention Output shape torch.Size([3, 4, 5, 2])
MHA Final Output shape torch.Size([3, 5, 8])
DecoderBlock Cross-Attention Output shape torch.Size([3, 5, 8])
DecoderBlock FFN Output shape torch.Size([3, 5, 8])

Output Shape torch.Size([3, 5, 8])


## Decoder Stack

In [22]:
class Decoder(nn.Module):

    def __init__(self, features:int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features=features)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)

        return self.norm(x)

In [23]:
# test

batch_size = 3
seq_len = 5
embedding_dim = 8
num_heads = 4
d_ff = 8

attn_block = MultiHeadAttentionBlock(d_model=embedding_dim, num_heads=num_heads, dropout=0.1, verbose=True)
cross_attn_block = MultiHeadAttentionBlock(d_model=embedding_dim, num_heads=num_heads, dropout=0.1, verbose=True)
ffn_block = FeedForwardBlock(d_model=embedding_dim, d_ff=d_ff, dropout=0.1)

dec_block = DecoderBlock(features=embedding_dim, self_attn_block=attn_block,
                         cross_attn_block=cross_attn_block, feed_forward_block=ffn_block,
                         dropout=0.1, verbose=True)

dec_stack = Decoder(features=embedding_dim, layers=nn.ModuleList([dec_block]*1))

x = torch.rand((batch_size,seq_len,embedding_dim)) 
print("Input Shape", x.shape) # (batch_size, seq_len, d_model)
print()

x = dec_stack(x, x, src_mask=None, tgt_mask=None)
print()
print("Output Shape", x.shape) # (batch_size, seq_len, d_model)

Input Shape torch.Size([3, 5, 8])

Query shape torch.Size([3, 4, 5, 2])
Key shape torch.Size([3, 4, 5, 2])
Value shape torch.Size([3, 4, 5, 2])
Attention Scores shape torch.Size([3, 4, 5, 5])
Attention Output shape torch.Size([3, 4, 5, 2])
MHA Final Output shape torch.Size([3, 5, 8])
DecoderBlock Self-Attention Output shape torch.Size([3, 5, 8])
Query shape torch.Size([3, 4, 5, 2])
Key shape torch.Size([3, 4, 5, 2])
Value shape torch.Size([3, 4, 5, 2])
Attention Scores shape torch.Size([3, 4, 5, 5])
Attention Output shape torch.Size([3, 4, 5, 2])
MHA Final Output shape torch.Size([3, 5, 8])
DecoderBlock Cross-Attention Output shape torch.Size([3, 5, 8])
DecoderBlock FFN Output shape torch.Size([3, 5, 8])

Output Shape torch.Size([3, 5, 8])


### Testing Live Inference

In [24]:
batch_size = 1
seq_len = 1
embedding_dim = 8
num_heads = 4
d_ff = 8

attn_block = MultiHeadAttentionBlock(d_model=embedding_dim, num_heads=num_heads, dropout=0.1, verbose=False)
cross_attn_block = MultiHeadAttentionBlock(d_model=embedding_dim, num_heads=num_heads, dropout=0.1, verbose=False)
ffn_block = FeedForwardBlock(d_model=embedding_dim, d_ff=d_ff, dropout=0.1)

dec_block = DecoderBlock(features=embedding_dim, self_attn_block=attn_block,
                         cross_attn_block=cross_attn_block, feed_forward_block=ffn_block,
                         dropout=0.1, verbose=False)

dec_stack = Decoder(features=embedding_dim, layers=nn.ModuleList([dec_block]*2))

x = torch.rand((batch_size,seq_len,embedding_dim)) # Start Token
print("Input Shape", x.shape) # (batch_size, seq_len, d_model)
while x.size(1) <= 5: # Generate next 5 tokens
    pred = dec_stack(x, x, src_mask=None, tgt_mask=None)
    print("Prediction Shape", pred.shape)
    x = torch.cat([x, pred[:,-1:,:]], dim=1) # Concat Last Prediction Token to Input    

Input Shape torch.Size([1, 1, 8])
Prediction Shape torch.Size([1, 1, 8])
Prediction Shape torch.Size([1, 2, 8])
Prediction Shape torch.Size([1, 3, 8])
Prediction Shape torch.Size([1, 4, 8])
Prediction Shape torch.Size([1, 5, 8])


# Projection

In [25]:
class ProjectLayer(nn.Module):

    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.linear = nn.Linear(in_features=d_model, out_features=vocab_size, bias=True)

    def forward(self, x):
        # x: (batch_size, seq_len, d_model) -> (batch_size, seq_len, vocab_size)
        return torch.log_softmax(self.linear(x), dim=-1)

In [26]:
# test

batch_size = 3
seq_len = 5
embedding_dim = 8
vocab_size = 10

proj_layer = ProjectLayer(d_model=embedding_dim, vocab_size=vocab_size)
x = torch.rand((batch_size,seq_len,embedding_dim))
print("Input Shape", x.shape) # (batch_size, seq_len, d_model)
print()

x = proj_layer(x)
print("Output Shape", x.shape) # (batch_size, seq_len, vocab_size)


Input Shape torch.Size([3, 5, 8])

Output Shape torch.Size([3, 5, 10])


# Transformer

![Transformer](../images/transformer.png)

[Source](https://lena-voita.github.io/nlp_course/seq2seq_and_attention.html)

In [27]:
class Transformer(nn.Module):
    def __init__(self,
                 encoder: Encoder,
                 decoder: Decoder,
                 src_embedding: InputEmbedding,
                 tgt_embedding: InputEmbedding,
                 src_pos_embedding: PositionalEmbedding,
                 tgt_pos_embedding: PositionalEmbedding,
                 project_layer: ProjectLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embedding = src_embedding
        self.tgt_embedding = tgt_embedding
        self.src_pos_embedding = src_pos_embedding
        self.tgt_pos_embedding = tgt_pos_embedding
        self.project_layer = project_layer

    def encode(self, src, src_mask):
        # src: (batch_size, seq_len) -> (batch_size, seq_len, d_model)
        x = self.src_embedding(src)
        x = self.src_pos_embedding(x)
        return self.encoder(x, src_mask)
    
    def decode(self, encoder_output, src_mask, tgt, tgt_mask):
        # encoder_output: (batch_size, seq_len, d_model)
        # src_mask: (batch_size, seq_len, seq_len)
        # tgt: (batch_size, seq_len) -> (batch_size, seq_len, d_model)
        x = self.tgt_embedding(tgt)
        x = self.tgt_pos_embedding(x)
        return self.decoder(x, encoder_output, src_mask, tgt_mask)
    
    def project(self, x):
        # x: (batch_size, seq_len, d_model) -> (batch_size, seq_len, vocab_size)
        return self.project_layer(x)

In [28]:
# test

batch_size = 3
seq_len = 5
embedding_dim = 8
num_heads = 4
d_ff = 8
vocab_size = 10

# Embeddings
src_embedding = InputEmbedding(d_model=embedding_dim, vocab_size=vocab_size)
tgt_embedding = InputEmbedding(d_model=embedding_dim, vocab_size=vocab_size)
src_pos_embedding = PositionalEmbedding(d_model=embedding_dim, seq_len=seq_len, dropout=0.1)
tgt_pos_embedding = PositionalEmbedding(d_model=embedding_dim, seq_len=seq_len, dropout=0.1)

# Encoder
attn_block = MultiHeadAttentionBlock(d_model=embedding_dim, num_heads=num_heads, dropout=0.1, verbose=False)
ffn_block = FeedForwardBlock(d_model=embedding_dim, d_ff=d_ff, dropout=0.1)
enc_block = EncoderBlock(features=embedding_dim, self_attn_block=attn_block, 
                         feed_foward_block=ffn_block, dropout=0.1,
                         verbose=True)
enc_stack = Encoder(features=embedding_dim, layers=nn.ModuleList([enc_block]*2))

# Decoder
attn_block = MultiHeadAttentionBlock(d_model=embedding_dim, num_heads=num_heads, dropout=0.1, verbose=False)
cross_attn_block = MultiHeadAttentionBlock(d_model=embedding_dim, num_heads=num_heads, dropout=0.1, verbose=False)
ffn_block = FeedForwardBlock(d_model=embedding_dim, d_ff=d_ff, dropout=0.1)

dec_block = DecoderBlock(features=embedding_dim, self_attn_block=attn_block,
                         cross_attn_block=cross_attn_block, feed_forward_block=ffn_block,
                         dropout=0.1, verbose=True)

dec_stack = Decoder(features=embedding_dim, layers=nn.ModuleList([dec_block]*2))

# Projection Layer
proj_layer = ProjectLayer(d_model=embedding_dim, vocab_size=vocab_size)

# Transformer
transformer = Transformer(encoder=enc_stack, decoder=dec_stack,
                          src_embedding=src_embedding, tgt_embedding=tgt_embedding,
                          src_pos_embedding=src_pos_embedding, tgt_pos_embedding=tgt_pos_embedding,
                          project_layer=proj_layer)

enc_x = torch.randint(low=0, high=vocab_size, size=(batch_size,seq_len))
dec_x = torch.randint(low=0, high=vocab_size, size=(batch_size,seq_len))
print("Input Shape", enc_x.shape) # (batch_size, seq_len)
print("Target Shape", dec_x.shape) # (batch_size, seq_len)
print()

enc_x = transformer.encode(enc_x, src_mask=None)
dec_x = transformer.decode(enc_x, src_mask=None, tgt=dec_x, tgt_mask=None)
out = transformer.project(dec_x)
print()
print("Output Shape", out.shape) # (batch_size, seq_len, d_model)

Input Shape torch.Size([3, 5])
Target Shape torch.Size([3, 5])

EncoderBlock MHA Output shape torch.Size([3, 5, 8])
EncoderBlock FFN Output shape torch.Size([3, 5, 8])
EncoderBlock MHA Output shape torch.Size([3, 5, 8])
EncoderBlock FFN Output shape torch.Size([3, 5, 8])
DecoderBlock Self-Attention Output shape torch.Size([3, 5, 8])
DecoderBlock Cross-Attention Output shape torch.Size([3, 5, 8])
DecoderBlock FFN Output shape torch.Size([3, 5, 8])
DecoderBlock Self-Attention Output shape torch.Size([3, 5, 8])
DecoderBlock Cross-Attention Output shape torch.Size([3, 5, 8])
DecoderBlock FFN Output shape torch.Size([3, 5, 8])

Output Shape torch.Size([3, 5, 10])


## Building Transformer

In [29]:
def build_transformer(src_vocab_size: int, tgt_vocab_size: int,
                      src_max_len: int, tgt_max_len: int,
                      d_model: int = 512,
                      num_layers: int = 6,
                      num_heads: int = 8,
                      d_ff: int = 2048,
                      dropout: float = 0.1) -> Transformer:
    
    # Embedding Layers for Source and Target
    src_embedding = InputEmbedding(d_model=d_model, vocab_size=src_vocab_size)
    tgt_embedding = InputEmbedding(d_model=d_model, vocab_size=tgt_vocab_size)

    # Positional Embedding Layers for Source and Target
    src_pos_embedding = PositionalEmbedding(d_model=d_model, seq_len=src_max_len, dropout=dropout)
    tgt_pos_embedding = PositionalEmbedding(d_model=d_model, seq_len=tgt_max_len, dropout=dropout)

    # Encoder Blocks
    encoder_blocks = []
    for _ in range(num_layers):
        encoder_blocks.append(
            EncoderBlock(
                features = d_model,
                self_attn_block = MultiHeadAttentionBlock(d_model=d_model, num_heads=num_heads, dropout=dropout),
                feed_foward_block = FeedForwardBlock(d_model=d_model, d_ff=d_ff, dropout=dropout),
                dropout = dropout
            )
        )

    # Decoder Blocks
    decoder_blocks = []
    for _ in range(num_layers):
        decoder_blocks.append(
            DecoderBlock(
                features = d_model,
                self_attn_block = MultiHeadAttentionBlock(d_model=d_model, num_heads=num_heads, dropout=dropout),
                cross_attn_block = MultiHeadAttentionBlock(d_model=d_model, num_heads=num_heads, dropout=dropout),
                feed_forward_block = FeedForwardBlock(d_model=d_model, d_ff=d_ff, dropout=dropout),
                dropout = dropout
            )
        )

    # Encoder and Decoder
    encoder = Encoder(features=d_model, layers=nn.ModuleList(encoder_blocks))
    decoder = Decoder(features=d_model, layers=nn.ModuleList(decoder_blocks))

    # Project Layer
    projection_layer = ProjectLayer(d_model=d_model, vocab_size=tgt_vocab_size)

    transformer =  Transformer(encoder=encoder, decoder=decoder,
                       src_embedding=src_embedding, tgt_embedding=tgt_embedding,
                       src_pos_embedding=src_pos_embedding, tgt_pos_embedding=tgt_pos_embedding,
                       project_layer=projection_layer)
    
    # Initialize Weights using Xavier Initialization
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer

# Tokenizers

In [30]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]

In [31]:
def get_or_build_tokenizer(config, ds, lang):
    
    tokenizer_path = Path(config.tokenizer_file.format(lang))
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        # WorldLevel Model Trainer, with vocab having words with frequency >= 2
        trainer = WordLevelTrainer(special_tokens = ["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        Path(tokenizer_path).parent.mkdir(parents=True, exist_ok=True)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))

    return tokenizer        

In [None]:
# test

config = ScratchConfig()
dataset_en_fr = load_dataset('opus_books', f'en-fr', split='train')
print(dataset_en_fr[:4])
tokenizer = get_or_build_tokenizer(config, dataset_en_fr, "en")
print(tokenizer.encode("Hello World!").ids)
print(tokenizer.encode("Hello World!").tokens)
print(tokenizer.decode(tokenizer.encode("Hello World!").ids))

# Dataset

In [33]:
def casual_mask(size):
    # Create Upper Triangular Matrix of 1s
    # diagonal=1, to exclude diagonal elements
    # What is diagonal: A positive value excludes just as many diagonals above the main diagonal
    #    If diagonal=0, then the diagonal elements will be 1
    #    If diagonal=2, then the diagonal elements and one layer above also will be 0
    
    mask = torch.triu(torch.ones(size, size), diagonal=1) # or `return torch.tril(torch.ones(size,size)) == 1`
    return mask == 0

In [34]:
class BiLingualDataset(Dataset):

    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, max_seq_len) -> None:
        super().__init__()

        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.max_seq_len = max_seq_len

        # self.sos_token =tokenizer_src.token_to_id("[SOS]")]
        # self.eos_token = torch.Tensor([tokenizer_src.token_to_id("[EOS]")], dtype = torch.long)
        # self.pad_token = torch.Tensor([tokenizer_src.token_to_id("[PAD]")], dtype = torch.long)

        self.sos_token = tokenizer_src.token_to_id("[SOS]")
        self.eos_token = tokenizer_src.token_to_id("[EOS]")
        self.pad_token = tokenizer_src.token_to_id("[PAD]")

    def __len__(self):
        return len(self.ds)
    
    def __getitem__(self, index: int):
        
        src_tgt_pair = self.ds[index]
        # Getting the Source and Target texts
        src_text = src_tgt_pair["translation"][self.src_lang]
        tgt_text = src_tgt_pair["translation"][self.tgt_lang]
        
        # Tokenize the Source and Target texts to integer ids.
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_tokens = self.tokenizer_src.encode(tgt_text).ids

        # If the length of the tokens is greater than max_seq_len, truncate the tokens
        enc_input_tokens = enc_input_tokens[:self.max_seq_len-2] # 2 for SOS and EOS
        dec_tokens = dec_tokens[:self.max_seq_len-1] # 1 for SOS/EOS 
            

        # Add SOS and EOS tokens to the Source Tokens
        enc_input_tokens = [self.sos_token] + enc_input_tokens + [self.eos_token]
        dec_input_tokens = [self.sos_token] + dec_tokens # No EOS token for Input
        dec_output_tokens = dec_tokens + [self.eos_token] # No SOS token for Target

        # Padding: Pad the tokens to max_seq_len

        # enc_input = torch.cat([
        #     torch.tensor(enc_input_tokens, dtype=torch.long),
        #     torch.tensor([self.pad_token] * self.max_seq_len - len(enc_input_tokens), dtype=torch.long)
        # ], dim=0)

        enc_input = F.pad(torch.tensor(enc_input_tokens, dtype=torch.long), 
                          (0, self.max_seq_len - len(enc_input_tokens)), # pad only on right side
                          value=self.pad_token
                          )
        
        dec_input = F.pad(torch.tensor(dec_input_tokens, dtype=torch.long), 
                    (0, self.max_seq_len - len(dec_input_tokens)), # pad only on right side
                    value=self.pad_token
                    )

        label = F.pad(torch.tensor(dec_output_tokens, dtype=torch.long), 
                    (0, self.max_seq_len - len(dec_output_tokens)), # pad only on right side
                    value=self.pad_token
                    )

        # 0 for padded tokens, 1 for non-padded tokens
        enc_mask = (enc_input != self.pad_token).unsqueeze(0).unsqueeze(0).int()
        # 0 for padded tokens and causal tokens, 1 for remaining tokens
        dec_mask = (dec_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & casual_mask(dec_input.size(0))

        assert enc_input.shape[0] == self.max_seq_len
        assert dec_input.shape[0] == self.max_seq_len
        assert label.shape[0] == self.max_seq_len
 
        return {
            "encoder_input": enc_input, # (max_seq_len)
            "encoder_mask": enc_mask, # (1, 1, max_seq_len)
            "decoder_input": dec_input, # (max_seq_len)
            "decoder_mask": dec_mask, # (1, 1, max_seq_len)

            "label": label, # (max_seq_len)

            "src_text": src_text,
            "tgt_text": tgt_text
        }

### Dataset Test

In [35]:
config = ScratchConfig()
config.max_seq_len = 10

try:
    ds_raw = load_dataset('opus_books', f'{config.src_lang}-{config.tgt_lang}', split='train')
except ValueError:
    ds_raw = load_dataset('opus_books', f'{config.tgt_lang}-{config.src_lang}', split='train')

tokenizer_src = get_or_build_tokenizer(config, ds_raw, config.src_lang)
tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config.tgt_lang)

print("SOS Token id:", tokenizer_src.token_to_id("[SOS]"))
print("EOS Token id:", tokenizer_src.token_to_id("[EOS]"))
print("PAD Token id:", tokenizer_src.token_to_id("[PAD]"))

d = BiLingualDataset(ds_raw, tokenizer_src, tokenizer_tgt, config.src_lang, config.tgt_lang, config.max_seq_len)

SOS Token id: 2
EOS Token id: 3
PAD Token id: 1


In [36]:
print("Encoder Input:", d[0]["encoder_input"], "Shape:", d[0]["encoder_input"].shape)
print("Encoder Mask:", d[0]["encoder_mask"], "Shape:", d[0]["encoder_mask"].shape)
print("Decoder Input:", d[0]["decoder_input"], "Shape:", d[0]["decoder_input"].shape)
print("Decoder Mask:", d[0]["decoder_mask"], "Shape:", d[0]["decoder_mask"].shape)
print("Label:", d[0]["label"], "Shape:", d[0]["label"].shape)
print("Source Text:", d[0]["src_text"])
print("Target Text:", d[0]["tgt_text"])

Encoder Input: tensor([  2,  82, 157, 774,   3,   1,   1,   1,   1,   1]) Shape: torch.Size([10])
Encoder Mask: tensor([[[1, 1, 1, 1, 1, 0, 0, 0, 0, 0]]], dtype=torch.int32) Shape: torch.Size([1, 1, 10])
Decoder Input: tensor([  2, 273,   0,   1,   1,   1,   1,   1,   1,   1]) Shape: torch.Size([10])
Decoder Mask: tensor([[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0, 0, 0, 0, 0]]], dtype=torch.int32) Shape: torch.Size([1, 10, 10])
Label: tensor([273,   0,   3,   1,   1,   1,   1,   1,   1,   1]) Shape: torch.Size([10])
Source Text: Le grand Meaulnes
Target Text: The Wanderer


## DataLoaders

In [37]:
def get_ds(config, verbose=False):
    
    try:
        ds_raw = load_dataset('opus_books', f'{config.src_lang}-{config.tgt_lang}', split='train')
    except ValueError:
        ds_raw = load_dataset('opus_books', f'{config.tgt_lang}-{config.src_lang}', split='train')

    src_lang = config.src_lang
    tgt_lang = config.tgt_lang

    # Build tokenizer
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, src_lang)
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, tgt_lang)

    # Split Datasets, 90%-10%
    train_ds_size = int(0.9* len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size

    train_ds, val_ds = random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BiLingualDataset(train_ds, tokenizer_src, tokenizer_tgt,
                                src_lang, tgt_lang, config.max_seq_len)
    val_ds = BiLingualDataset(val_ds, tokenizer_src, tokenizer_tgt,
                                src_lang, tgt_lang, config.max_seq_len)
    
    if verbose:
        # Getting Maximum Source and Target Sequence Lengths from the Dataset
        max_src_len = 0
        max_tgt_len = 0
        for item in ds_raw:
            max_src_len = max(max_src_len, len(tokenizer_src.encode(item["translation"][src_lang]).ids))
            max_tgt_len = max(max_tgt_len, len(tokenizer_tgt.encode(item["translation"][tgt_lang]).ids))

        # Print the Max Lengths
        print(f"Max Source Length: {max_src_len}")
        print(f"Max Target Length: {max_tgt_len}")

    # DataLoaders
    train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True)
    val_dl = DataLoader(val_ds, batch_size=config.batch_size, shuffle=False)

    return train_dl, val_dl, tokenizer_src, tokenizer_tgt

In [38]:
# test
config = ScratchConfig()
train_dl, val_dl, tokenizer_src, tokenizer_tgt = get_ds(config, verbose=True)
print("Length of Train Loader:", len(train_dl))
print("Length of Validation Loader:", len(val_dl))
print("Source Vocab Size:", tokenizer_src.get_vocab_size())
print("Target Vocab Size:", tokenizer_tgt.get_vocab_size())

Max Source Length: 482
Max Target Length: 471
Length of Train Loader: 7149
Length of Validation Loader: 795
Source Vocab Size: 30000
Target Vocab Size: 30000


In [39]:
def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(
        src_vocab_size=vocab_src_len, tgt_vocab_size=vocab_tgt_len,
        src_max_len=config.max_seq_len, tgt_max_len=config.max_seq_len,
        d_model=config.d_model, num_layers=config.num_layers,
        num_heads=config.num_heads, d_ff=config.d_ff,
        dropout=config.dropout
        )
    return model

# Training and Validation Loop

In [40]:
def batch2device(batch, device):
    encoder_input = batch["encoder_input"].to(device) # (batch_size, seq_len)
    decoder_input = batch["decoder_input"].to(device) # (batch_size, seq_len)
    label = batch["label"].to(device)

    encoder_mask = batch["encoder_mask"].to(device) # (batch_size, 1, 1, seq_len)
    decoder_mask = batch["decoder_mask"].to(device) # (batch_size, 1, seq_len, seq_len)

    return encoder_input, decoder_input, label, encoder_mask, decoder_mask

In [41]:
def train_epoch(config, model, device, data_loader, epoch, optimizer, criterion, writer):
    global global_iter
    model.train()
    train_iterator = tqdm(data_loader, desc=f"Training epoch {epoch}", total=len(data_loader))
    total_loss = 0
    for batch in train_iterator:
        encoder_input, decoder_input, label, encoder_mask, decoder_mask = batch2device(batch, device)

        # Forward Pass
        enocder_output = model.encode(encoder_input, encoder_mask) # (batch_size, seq_len, d_model)
        decoder_output = model.decode(enocder_output, encoder_mask, decoder_input, decoder_mask) # (batch_size, seq_len, d_model)
        proj_output = model.project(decoder_output) # (batch_size, seq_len, tgt_vocab_size)

        # Calculate Loss
        # Flatten the output and label tensors
        # prediction: (batch_size, seq_len, tgt_vocab_size) -> (batch_size*seq_len, tgt_vocab_size)
        # label: (batch_size, seq_len) -> (batch_size*seq_len)
        loss = criterion(proj_output.view(-1, proj_output.size(-1)), label.view(-1))
        train_iterator.set_postfix(loss=f"{loss.item():.3f}")
        total_loss += loss.item()

        # Tensorboard Logging
        writer.add_scalar("Loss/train", loss.item(), global_iter)
        writer.flush()
        global_iter += 1

        # Backward Pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    total_loss /= len(data_loader)
    writer.add_scalar("Loss_epoch/train", total_loss, epoch)
    writer.flush() 
        
    # Saving the model
    if epoch % 10 == 0:
        Path(config.weights_folder).mkdir(parents=True, exist_ok=True)
        file_name = config.epochs_save.format(epoch)
        torch.save(
            {
                "epoch": epoch,
                "global_iter": global_iter,
                "text_gen_iter": text_gen_iter,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": loss.item()                    
            },
            Path(config.weights_folder) / file_name
        )
    

In [42]:
def generate(model, device, example, epoch, tokenizer_tgt, writer):
    global text_gen_iter
    model.eval()
    with torch.inference_mode():
        example_dev = list(batch2device(example, device))
        for i in range(len(example_dev)):
            example_dev[i] = example_dev[i][:1] # Take only first example in the batch
        encoder_input, decoder_input, label, encoder_mask, decoder_mask = example_dev
        # Getting SOS and EOS token ids from tokenizer
        sos_token = tokenizer_tgt.token_to_id("[SOS]")
        eos_token = tokenizer_tgt.token_to_id("[EOS]")
        max_len = encoder_input.size(1)
        # Encoder Forward Pass
        enocder_output = model.encode(encoder_input, encoder_mask)
        # Decoder Forward Pass
        decoder_input = torch.tensor([[sos_token]], dtype=torch.long).to(device) # (1, 1)
        decoder_text = []
        while True:
            decoder_output = model.decode(enocder_output, encoder_mask, decoder_input, None)
            proj_output = model.project(decoder_output)
            # Get the last prediction
            _, next_token = torch.max(proj_output[0,-1,:], dim=-1)
            decoder_input = torch.cat([decoder_input, next_token.unsqueeze(0).unsqueeze(0)], dim=1)
            decoder_text.append( tokenizer_tgt.decode( [next_token.item()] ) )

            if (decoder_input[0,-1] == eos_token) or (decoder_input.size(1) >= max_len):
                break
                
        # Tensorboard Logging
        writer.add_text("Text/Source", example["src_text"][0], text_gen_iter)
        writer.add_text("Text/Target", example["tgt_text"][0], text_gen_iter)
        writer.add_text("Text/Prediction", " ".join(decoder_text), text_gen_iter)
        writer.flush()

        text_gen_iter += 1

In [43]:
# model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size())
# example = list(batch2device(next(iter(val_dl)), "cpu"))
# for i in range(len(example)):
#     example[i] = example[i][:1] # Take only first example in the batch
#     print(example[i].shape)
#     # print(example[i][:1].shape)
# generate(model, "cpu", next(iter(val_dl)), 1, tokenizer_tgt, None)

In [44]:
def val_epoch(model, device, val_loader, epoch, criterion, tokenizer_tgt, writer):
    model.eval()
    val_iterator = tqdm(val_loader, desc=f"Validating epoch {epoch}", total=len(val_loader))
    random_indices = np.random.randint(0, len(val_loader), size=5)
    total_loss = 0
    with torch.inference_mode():
        for idx, batch in enumerate(val_iterator):

            if idx in random_indices:
                generate(model, device, batch, epoch, tokenizer_tgt, writer)
                
            encoder_input, decoder_input, label, encoder_mask, decoder_mask = batch2device(batch, device)

            # Forward Pass
            enocder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(enocder_output, encoder_mask, decoder_input, decoder_mask)
            proj_output = model.project(decoder_output)

            # Calculate Loss
            loss = criterion(proj_output.view(-1, proj_output.size(-1)), label.view(-1))
            val_iterator.set_postfix(loss=f"{loss.item():.3f}")
            total_loss += loss.item()

    total_loss /= len(val_loader)
    # Tensorboard Logging
    writer.add_scalar("Loss_epoch/val", total_loss, epoch)
    writer.flush()


In [47]:
def train(config):

    # DataLoaders
    train_dl, val_dl, tokenizer_src, tokenizer_tgt= get_ds(config)
    
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size())
    model = model.to(config.device)
    # Number of Parameters
    total_params = sum(p.numel() for p in model.parameters())
    # Number of source and target Input Embedding and Position Embedding Parameters
    src_emb_params = sum(p.numel() for p in model.src_embedding.parameters())
    tgt_emb_params = sum(p.numel() for p in model.tgt_embedding.parameters())
    print("Total Parameters:", total_params/1e6, "M parameters" )
    print("Source Embedding Parameters:", src_emb_params/1e6, "M parameters" )
    print("Target Embedding Parameters:", tgt_emb_params/1e6, "M parameters" )
    print("Total Embedding Parameters:", (src_emb_params + tgt_emb_params)/1e6, "M parameters" )
    print("Total Transformer Parameters:", (total_params - src_emb_params - tgt_emb_params)/1e6, "M parameters" )

    # Tensorboard Writer
    writer = SummaryWriter(log_dir=config.log_dir, filename_suffix=f"_{config.src_lang}_{config.tgt_lang}")
    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, eps=1e-9)
    # Loss
    criterion = nn.CrossEntropyLoss(ignore_index = tokenizer_src.token_to_id("[PAD]"),
                                    label_smoothing=config.label_smoothing) # ignore padding tokens
    
    for epoch in range(1, config.epochs+1):
        train_epoch(config, model, config.device, train_dl, epoch, optimizer, criterion, writer)
        val_epoch(model, config.device, val_dl, epoch, criterion, tokenizer_tgt, writer)



# Training

In [48]:
global_iter = 0
text_gen_iter = 0
train(config)

Total Parameters: 90.250544 M parameters
Source Embedding Parameters: 15.36 M parameters
Target Embedding Parameters: 15.36 M parameters
Total Embedding Parameters: 30.72 M parameters
Total Transformer Parameters: 59.530544 M parameters


# Plots

#### Training Loss 10 Epochs

![Training Loss](../images/ScratchTrainLoss.png)

#### Val Loss 10 Epochs

![Validation Loss](../images/ScratchValLoss.png)