In [263]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter

# Math
import math

# HuggingFace libraries 
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

# Pathlib 
from pathlib import Path

# typing
from typing import Any

# Library for progress bars in loops
from tqdm import tqdm

# Importing library of warnings
import warnings

# Architecture

![Image](https://shreyansh26.github.io/assets/img/posts_images/attention/arch.PNG)

# Tokenizer

![Tokenizer](https://i.ytimg.com/vi/hL4ZnAWSyuU/sddefault.jpg)

In [264]:
def build_tokenizer(config, ds, lang):
    
    # Crating a file path for the tokenizer 
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    
    # Checking if Tokenizer already exists
    if not Path.exists(tokenizer_path): 
        
        # If it doesn't exist, we create a new one
        tokenizer = Tokenizer(WordLevel(unk_token = '[UNK]')) # Initializing a new world-level tokenizer
        tokenizer.pre_tokenizer = Whitespace() # We will split the text into tokens based on whitespace
        
        # Creating a trainer for the new tokenizer
        trainer = WordLevelTrainer(special_tokens = ["[UNK]", "[PAD]", 
                                                     "[SOS]", "[EOS]"], min_frequency = 2) # Defining Word Level strategy and special tokens
        
        # Training new tokenizer on sentences from the dataset and language specified 
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer = trainer)
        tokenizer.save(str(tokenizer_path)) # Saving trained tokenizer to the file path specified at the beginning of the function
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path)) # If the tokenizer already exist, we load it
    return tokenizer # Returns the loaded tokenizer or the trained tokenizer

In [265]:
def get_config():
    return{
        'batch_size': 8,
        'num_epochs': 20,
        'lr': 10**-4,
        'seq_len': 350,
        'd_model': 512, # Dimensions of the embeddings in the Transformer. 512 like in the "Attention Is All You Need" paper.
        'lang_src': 'en',
        'lang_tgt': 'it',
        'model_folder': 'weights',
        'model_basename': 'translation_model_',
        'preload': None,
        'tokenizer_file': 'tokenizer_{0}.json',
        'experiment_name': 'runs/translation_model',
        'encoder_layers': 6,
        'decoder_layers': 6,
        'p_drop': 0.1,
        'dff': 2048,
        'n_heads': 8
    }

config = get_config()

In [266]:
def get_all_sentences(ds, lang):
    for pair in ds:
        yield pair['translation'][lang]
        
tokenizer_src = None
tokenizer_tgt = None
train_ds = None
def get_ds(config):
    
    # Loading the train portion of the OpusBooks dataset.
    # The Language pairs will be defined in the 'config' dictionary we will build later
    ds_raw = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_tgt"]}', split = 'train') 
    
    # Building or loading tokenizer for both the source and target languages 
    tokenizer_src = build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = build_tokenizer(config, ds_raw, config['lang_tgt'])
    
    # Splitting the dataset for training and validation 
    train_ds_size = int(0.9 * len(ds_raw)) # 90% for training
    val_ds_size = len(ds_raw) - train_ds_size # 10% for validation
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size]) # Randomly splitting the dataset
                                    
    # Processing data with the BilingualDataset class, which we will define below
    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
                                    
    # Iterating over the entire dataset and printing the maximum length found in the sentences of both the source and target languages
    max_len_src = 0
    max_len_tgt = 0
    for pair in ds_raw:
        src_ids = tokenizer_src.encode(pair['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_src.encode(pair['translation'][config['lang_tgt']]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))
        
    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')
    
    # Creating dataloaders for the training and validadion sets
    # Dataloaders are used to iterate over the dataset in batches during training and validation
    train_dataloader = DataLoader(train_ds, batch_size = config['batch_size'], shuffle = True) # Batch size will be defined in the config dictionary
    val_dataloader = DataLoader(val_ds, batch_size = 1, shuffle = True)
    
    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt # Returning the DataLoader objects and tokenizers

def casual_mask(size):
        # Creating a square matrix of dimensions 'size x size' filled with ones
        mask = torch.triu(torch.ones(1, size, size), diagonal = 1).type(torch.int)
        return mask == 0

In [267]:
class BilingualDataset(Dataset):
    
    # This takes in the dataset contaning sentence pairs, the tokenizers for target and source languages, and the strings of source and target languages
    # 'seq_len' defines the sequence length for both languages
    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len) -> None:
        super().__init__()
        
        self.seq_len = seq_len
        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        
        # Defining special tokens by using the target language tokenizer
        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)

        
    # Total number of instances in the dataset (some pairs are larger than others)
    def __len__(self):
        return len(self.ds)
    
    # Using the index to retrive source and target texts
    def __getitem__(self, index: Any) -> Any:
        src_target_pair = self.ds[index]
        src_text = src_target_pair['translation'][self.src_lang]
        tgt_text = src_target_pair['translation'][self.tgt_lang]
        
        # Tokenizing source and target texts 
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
        
        # Computing how many padding tokens need to be added to the tokenized texts 
        # Source tokens
        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 # Subtracting the two '[EOS]' and '[SOS]' special tokens
        # Target tokens
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1 # Subtracting the '[SOS]' special token
        
        # If the texts exceed the 'seq_len' allowed, it will raise an error. This means that one of the sentences in the pair is too long to be processed
        # given the current sequence length limit (this will be defined in the config dictionary below)
        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError('Sentence is too long')
         
        # Building the encoder input tensor by combining several elements
        encoder_input = torch.cat(
            [
            self.sos_token, # inserting the '[SOS]' token
            torch.tensor(enc_input_tokens, dtype = torch.int64), # Inserting the tokenized source text
            self.eos_token, # Inserting the '[EOS]' token
            torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype = torch.int64) # Addind padding tokens
            ]
        )
        
        # Building the decoder input tensor by combining several elements
        decoder_input = torch.cat(
            [
                self.sos_token, # inserting the '[SOS]' token 
                torch.tensor(dec_input_tokens, dtype = torch.int64), # Inserting the tokenized target text
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype = torch.int64) # Addind padding tokens
            ]
        
        )
        
        # Creating a label tensor, the expected output for training the model
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype = torch.int64), # Inserting the tokenized target text
                self.eos_token, # Inserting the '[EOS]' token 
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype = torch.int64) # Adding padding tokens
                
            ]
        )
        
        # Ensuring that the length of each tensor above is equal to the defined 'seq_len'
        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len
        
        return {
            'encoder_input': encoder_input,
            'decoder_input': decoder_input, 
            'encoder_mask': (encoder_input != self.pad_token).unsqueeze(0).int(),
            'decoder_mask': (decoder_input != self.pad_token).unsqueeze(0).int() & (causal_mask(decoder_input.size(0)).squeeze(0)), 
            'label': label,
            'src_text': src_text,
            'tgt_text': tgt_text
        }    

def causal_mask(size):
        # Creating a square matrix of dimensions 'size x size' filled with ones
        mask = torch.triu(torch.ones(1, size, size), diagonal = 1).type(torch.int)
        return mask == 0

# Token Embedding

In [268]:
class TokenEmbeddings(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embeddings = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embeddings(x) * math.sqrt(self.d_model) # Normalizing the variance of the embedding
        

# Positional Encoding

<p style="
    margin-bottom: 5; 
    font-size: 22px;
    font-weight: 300;
    font-family: 'Helvetica Neue', sans-serif;
    color: #000000; 
  ">
    \begin{equation}
    \text{Even Indices } (2i): \quad \text{PE(pos, } 2i) = \sin\left(\frac{\text{pos}}{10000^{2i / d_{model}}}\right)
    \end{equation}
</p>

<p style="
    margin-bottom: 5; 
    font-size: 22px;
    font-weight: 300;
    font-family: 'Helvetica Neue', sans-serif;
    color: #000000; 
  ">
    \begin{equation}
    \text{Odd Indices } (2i + 1): \quad \text{PE(pos, } 2i + 1) = \cos\left(\frac{\text{pos}}{10000^{2i / d_{model}}}\right)
    \end{equation}
</p>

In [269]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        encodings = torch.zeros(seq_len, d_model)
        pos = torch.arange(0, seq_len, dtype = torch.float).unsqueeze(1) # [seq_len, 1]
        denominator = (torch.arange(0, d_model, 2).float() / d_model)
        denominator = torch.pow(denominator, 10000)

        encodings[:, 0::2] = torch.sin(pos/denominator)
        encodings[:, 1::2] = torch.cos(pos/denominator)
        encodings = encodings.unsqueeze(0) # add a batch dimension
        self.register_buffer('encodings', encodings) # Buffer is a tensor not considered as a model parameter


    def forward(self, x):
        return x + (self.encodings[:, :x.shape[1], :]).requires_grad_(False)
        

# Layer Norm

In [270]:
class LayerNorm(nn.Module):
    def __init__(self, eps: float = 1e-9):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1))
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * (x-mean) / (std+self.eps) + self.bias

# FFW

In [271]:
class FeedForwardBlock(nn.Module):
    
    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        # First linear transformation
        self.linear_1 = nn.Linear(d_model, d_ff) # W1 & b1
        self.dropout = nn.Dropout(dropout) # Dropout to prevent overfitting
        # Second linear transformation
        self.linear_2 = nn.Linear(d_ff, d_model) # W2 & b2
        
    def forward(self, x):
        # (Batch, seq_len, d_model) --> (batch, seq_len, d_ff) -->(batch, seq_len, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

# Multi Head Attention

<center>
    <img src = "https://i.imgur.com/JqJVrsj.png" width = 1556, height= 959>
<p style = "font-size: 16px;
            font-family: 'Georgia', serif;
            text-align: center;
            margin-top: 10px;"></p>
</center>

In [272]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads: int, d_model: int):
        super().__init__()
        self.n_heads = n_heads
        self.w_key = nn.Linear(d_model, d_model)
        self.w_query = nn.Linear(d_model, d_model)
        self.w_value = nn.Linear(d_model, d_model)
        self.w_out = nn.Linear(d_model, d_model)

    def attention(self, k, q, v, mask):
        d_k = q.shape[-1]
        affinities = (q @ k.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            affinities.masked_fill_(mask == 0, -1e9)
        affinities = affinities.softmax(dim=-1)
        value = affinities @ v
        return value
        

    def forward(self, q, k, v, mask):
        key = self.w_key(q)
        query = self.w_query(k)
        value = self.w_value(v)

        # split embedding dim for each heads
        new_d_model = config['d_model'] // self.n_heads
        k_chunks = torch.split(key, new_d_model, dim=-1)
        q_chunks = torch.split(query, new_d_model, dim=-1)
        v_chunks = torch.split(value, new_d_model, dim=-1)

        output_heads = []
        for i in range(self.n_heads):
            output_heads.append(self.attention(k_chunks[i], q_chunks[i], v_chunks[i], mask))

        concat_out = torch.cat(output_heads, dim=-1)
        return self.w_out(concat_out)
        

In [273]:
class ResidualConnection(nn.Module):
    def __init__(self):
        super().__init__()
        self.layernorm = LayerNorm()

    def forward(self, x, sub_layer):
        return x + sub_layer(self.layernorm(x))

# Encoder
<center>
    <img src = "https://www.researchgate.net/profile/Ehsan-Amjadian/publication/352239001/figure/fig1/AS:1033334390013952@1623377525434/Detailed-view-of-a-transformer-encoder-block-It-first-passes-the-input-through-an.jpg" width = 400, height= 400>
<p style = "font-size: 16px;
            font-family: 'Georgia', serif;
            text-align: center;
            margin-top: 10px;">Encoder block. Source: <a href = "https:///figure/Detailed-view-of-a-transformer-encoder-block-It-first-passes-the-input-through-an_fig1_352239001">researchgate.net</a>.</p>
</center>

In [274]:
class EncoderBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet_mha = ResidualConnection()
        self.resnet_ffw = ResidualConnection()
        self.mha = MultiHeadAttention(n_heads=config['n_heads'], d_model=config['d_model'])
        self.ffw = FeedForwardBlock(d_model=config['d_model'], d_ff=config['dff'], dropout=config['p_drop'])

    def forward(self, x, src_mask):
        x = self.resnet_mha(x, lambda x : self.mha(x, x, x, src_mask))
        x = self.resnet_ffw(x, lambda x : self.ffw(x))
        return x

In [275]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder_blocks = nn.ModuleList(
            [EncoderBlock() for _ in range(config['encoder_layers'])]
        )
        
    def forward(self, x, src_mask):
        for block in self.encoder_blocks:
            x = block(x, src_mask)
        return x

# Decoder

<center>
    <img src = "https://res.cloudinary.com/edlitera/image/upload/c_fill,f_auto/v1680629118/blog/gz5ccspg3yvq4eo6xhrr" width = 400, height= 400>
<p style = "font-size: 16px;
            font-family: 'Georgia', serif;
            text-align: center;
            margin-top: 10px;"></p>
</center>

In [276]:
class DecoderBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet_blocks = nn.ModuleList([ResidualConnection() for _ in range(3)])
        self.self_mha = MultiHeadAttention(n_heads=config['n_heads'], d_model=config['d_model'])
        self.cross_mha = MultiHeadAttention(n_heads=config['n_heads'], d_model=config['d_model'])
        self.ffw = FeedForwardBlock(d_model=config['d_model'], d_ff=config['dff'], dropout=config['p_drop'])

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.resnet_blocks[0](x, lambda x: self.self_mha(x, x, x, tgt_mask))
        x = self.resnet_blocks[1](x, lambda x: self.cross_mha(encoder_output, x, encoder_output, src_mask))
        x = self.resnet_blocks[2](x, lambda x: self.ffw(x))
        return x

In [277]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder_block = nn.ModuleList(
            [DecoderBlock() for _ in range(config['decoder_layers'])]
        )
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.decoder_block:
            x = layer(x, encoder_output, src_mask, tgt_mask)

        return x 

# Transformer

In [278]:
class TranslationTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder_embeddings = TokenEmbeddings(config['d_model'], tokenizer_src.get_vocab_size())
        self.encoder = Encoder()
        self.decoder_embeddings = TokenEmbeddings(config['d_model'], tokenizer_tgt.get_vocab_size())
        self.decoder = Decoder()
        self.positional_encodings = PositionalEncoding(config['d_model'], config['seq_len'])
        self.projection = nn.Linear(config['d_model'], tokenizer_tgt.get_vocab_size())

    def encode(self, encoder_inp, src_mask):
        encoder_embeddings = self.encoder_embeddings(encoder_inp)
        encoder_embeddings = self.positional_encodings(encoder_embeddings)
        encoder_output = self.encoder(encoder_embeddings, src_mask)
        return encoder_output

    def decode(self, encoder_output, decoder_inp, src_mask, tgt_mask):

        decoder_embeddings = self.decoder_embeddings(decoder_inp)
        decoder_embeddings = self.positional_encodings(decoder_embeddings)
        decoder_output = self.decoder(decoder_embeddings, encoder_output, src_mask, tgt_mask)
        output = torch.log_softmax(self.projection(decoder_output), dim = -1)
        return output

    def generate(self, encoder_inp, src_mask):
        encoder_output = self.encoder(encoder_inp, src_mask)

        sos_idx = tokenizer_tgt.token_to_id('[SOS]')
        eos_idx = tokenizer_tgt.token_to_id('[EOS]')
        decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(encoder_inp).to(device)

        while decoder_input.shape[1] < config['seq_len']:
            decoder_output = self.decode(encoder_output, decoder_inp, label, src_mask, causal_mask(decoder_inp.shape[1]).type_as(src_mask))
            output_token = torch.argmax(decoder_output[1, -1, :], dim=1) #greedy sampling
            decoder_input = torch.cat([decoder_input, torch.empty(1,1).fill_(output_token).type_as(encoder_inp).to(device)], dim=1)

            if output_token == eos_idx:
                break

        return decoder_input.squeeze(0)    

In [279]:
def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, writer, num_examples=2):
    model.eval() # Setting model to evaluation mode
    count = 0 # Initializing counter to keep track of how many examples have been processed
    
    console_width = 80 # Fixed witdh for printed messages
    
    # Creating evaluation loop
    with torch.no_grad(): # Ensuring that no gradients are computed during this process
        for batch in validation_ds:
            count += 1
            encoder_input = batch['encoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            
            # Ensuring that the batch_size of the validation set is 1
            assert encoder_input.size(0) ==  1, 'Batch size must be 1 for validation.'
            
            # Applying the 'greedy_decode' function to get the model's output for the source text of the input batch
            model_out = model.generate(encoder_input, encoder_mask)
            
            # Retrieving source and target texts from the batch
            source_text = batch['src_text'][0]
            target_text = batch['tgt_text'][0] # True translation 
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy()) # Decoded, human-readable model output
            
            # Printing results
            print_msg('-'*console_width)
            print_msg(f'SOURCE: {source_text}')
            print_msg(f'TARGET: {target_text}')
            print_msg(f'PREDICTED: {model_out_text}')
            
            # After two examples, we break the loop
            if count == num_examples:
                break

In [280]:
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
device = "cuda" if torch.cuda.is_available() else "cpu"

Max length of source sentence: 309
Max length of target sentence: 274


In [281]:
data = train_dataloader.dataset.__getitem__(100)
print(data['decoder_mask'].shape)
print(data['encoder_mask'].shape)

torch.Size([350, 350])
torch.Size([1, 350])


In [282]:
model = TranslationTransformer()
    
# Initialize the parameters
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [283]:
def train_model():
    writer = SummaryWriter(config['experiment_name'])
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps = 1e-9)
    loss_fn = nn.CrossEntropyLoss(ignore_index = tokenizer_src.token_to_id('[PAD]'), label_smoothing = 0.1).to(device)

    for epoch in range(config['num_epochs']):
        batch_iterator = tqdm(train_dataloader, desc = f'Processing epoch {epoch:02d}')
        for batch in batch_iterator:
            model.train()
            encoder_input = batch['encoder_input'].to(device)
            decoder_input = batch['decoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_mask = batch['decoder_mask'].to(device)

            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, decoder_input, encoder_mask, decoder_mask)

            label = batch['label'].to(device)
            
            loss = loss_fn(decoder_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            
            # Updating progress bar
            batch_iterator.set_postfix({f"loss": f"{loss.item():6.3f}"})
            
            writer.add_scalar('train loss', loss.item())
            writer.flush()
            
            # Performing backpropagation
            loss.backward()
            
            # Updating parameters based on the gradients
            optimizer.step()
            
            # Clearing the gradients to prepare for the next batch
            optimizer.zero_grad()

            run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), writer)
         
        # Saving model
        model_filename = get_weights_file_path(config, f'{epoch:02d}')
        # Writting current model state to the 'model_filename'
        torch.save({
            'epoch': epoch, # Current epoch
            'model_state_dict': model.state_dict(),# Current model state
            'optimizer_state_dict': optimizer.state_dict(), # Current optimizer state
            'global_step': global_step # Current global step 
        }, model_filename)    

In [None]:
train_model()

Processing epoch 00:   0%|          | 0/3638 [00:14<?, ?it/s, loss=nan]