# Understanding the Transformer Model

The Transformer model is a groundbreaking architecture in AI, especially for Natural Language Processing (NLP) tasks like language translation, text generation, and more. It was introduced in 2017 by Vaswani et al. in a paper titled "Attention Is All You Need". This model was designed to handle sequential data (like sentences) in a much more efficient way than previous models like Recurrent Neural Networks (RNNs) and Long Short-Term Memory (LSTM) networks.

## Key Components of the Transformer

The transformer architecture has a few critical building blocks that set it apart:

### a. Self-Attention Mechanism

What is Attention? The self-attention mechanism helps the model figure out which words in a sentence are most important relative to other words. For example, in the sentence "The cat sat on the mat.", the model needs to understand that "sat" relates closely to both "cat" and "mat."

Why Self-Attention? Traditional RNNs processed words one-by-one (sequentially), which makes it harder to understand long-range dependencies. With self-attention, the model can look at all words in the sentence at the same time and weigh their relationships.

How Does It Work? Self-attention takes in three components for each word:
- Query (Q)
- Key (K)
- Value (V)

Each word in a sentence will compute attention scores by comparing its Query to all other words' Keys. These attention scores tell us how much focus should be placed on each word when generating a word’s output. The result is weighted by the Value of the word.

### b. Multi-Head Attention

Instead of having just one attention mechanism, the transformer model uses multiple attention heads. Each head learns a different aspect of the relationships between words in the sentence.

Why Multiple Heads? With multiple heads, the model can capture different patterns in the relationships. For instance, one head may focus on subject-verb relationships, while another might focus on noun-adjective pairings.

### c. Positional Encoding

Transformers process input data all at once (in parallel), so they don't inherently know the order of the words in a sentence.

To fix this, positional encoding is added to the input embeddings to represent the position of each word in the sequence. This encoding ensures the model understands the order of the words, which is crucial for meaning.

### d. Feed-Forward Neural Networks

After the attention layers, the transformer uses feed-forward layers (fully connected layers) to process the output from the attention mechanism and transform the data further.

These layers help learn complex transformations that improve the overall model’s performance.

## Why Transformers Are So Effective
### a. Parallelization

One of the key advantages of transformers is that they process all the words in a sentence simultaneously (in parallel). Unlike RNNs or LSTMs, which process one word at a time, transformers don’t have to wait for the previous word’s processing to be completed before moving on to the next one.
This makes transformers much faster to train and scalable to large datasets.

### b. Long-Range Dependencies

Transformers are better at capturing long-range dependencies in text. For instance, in a sentence like “The man who was standing at the door left the house,” the model can connect “man” with “left” even though there are words in between. Traditional models like RNNs struggle with such dependencies, especially when sentences are long.

### c. Flexibility in Tasks

Transformers can handle a variety of tasks like translation, summarization, question-answering, and text generation.
This flexibility is because the encoder-decoder architecture can be adapted. For instance, the decoder part can be used for generating text, while the encoder can be used for understanding and classifying text.

## How Transformers are Used in Real Life

Transformers have revolutionized the field of AI, and they're the backbone of many cutting-edge models:

- GPT (Generative Pre-trained Transformer): This is the model behind AI tools like ChatGPT. It uses a transformer decoder architecture to generate human-like text.
- BERT (Bidirectional Encoder Representations from Transformers): BERT uses the transformer encoder to understand context and perform tasks like sentiment analysis, named entity recognition, and more.
- T5 (Text-to-Text Transfer Transformer): T5 treats every NLP problem as a text-to-text problem. For example, translating text from one language to another is treated as "Translate English to French: [text]."

## Make my own Transformer Model from scratch

### Transformer Architecture: Encoder-Decoder Structure

The transformer model is typically divided into two parts:

- Encoder: This part processes the input data (e.g., a sentence in one language).
- Decoder: This part generates the output data (e.g., the translation of that sentence).

Each of these parts is made up of layers that are stacked on top of each other:
- Encoder Layers: Each encoder layer consists of two main components:
    - Multi-head self-attention
    - Feed-forward neural network
- Decoder Layers: The decoder is similar, but with an additional layer of masked multi-head self-attention (which ensures that the decoder cannot "cheat" by looking ahead at future words).

In [105]:
import torch
import torch.nn as nn
import math
from torch.utils.data import Dataset, random_split
from pathlib import Path
from datasets import load_dataset

data_folder = 'model/data'

In [106]:
class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)
    
    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

In [107]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq = seq
        self.dropout = nn.Dropout(dropout)
        
        # Create a matrix of shape (seq, d_model)
        pe = torch.zeros(seq, d_model)
        
        # Create a vector of shape (seq)
        position = torch.arange(0, seq, dtype=torch.float).unsqueeze(1) # (seq, 1)
        
        # Create a vector of shape (d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
        
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
        
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
        
        # Add a batch dimension to the positional encoding
        pe = pe.unsqueeze(0) # (1, seq, d_model)
        
        # Register the positional encoding as a buffer
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq, d_model)
        return self.dropout(x)

In [108]:
class LayerNormalization(nn.Module):

    def __init__(self, features: int, eps:float=10**-6) -> None:
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features)) # alpha (multiplicative) is a learnable parameter
        self.bias = nn.Parameter(torch.zeros(features)) # bias (additive) is a learnable parameter

    def forward(self, x):
        # x: (batch, seq, hidden_size)
        # Keep the dimension for broadcasting
        mean = x.mean(dim = -1, keepdim = True) # (batch, seq, 1)
        # Keep the dimension for broadcasting
        std = x.std(dim = -1, keepdim = True) # (batch, seq, 1)
        # eps is to prevent dividing by zero or when std is very small
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

In [109]:
class FeedForwardBlock(nn.Module):

    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2

    def forward(self, x):
        # (batch, seq, d_model) --> (batch, seq, d_ff) --> (batch, seq, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

In [110]:
class MultiHeadAttentionBlock(nn.Module):

    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model # Embedding vector size
        self.h = h # Number of heads
        # Make sure d_model is divisible by h
        assert d_model % h == 0, "d_model is not divisible by h"

        self.d_k = d_model // h # Dimension of vector seen by each head
        self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
        self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
        self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
        self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        # Just apply the formula from the paper
        # (batch, h, seq, d_k) --> (batch, h, seq, seq)
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            # Write a very low value (indicating -inf) to the positions where mask == 0
            attention_scores.masked_fill_(mask == 0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq, seq) # Apply softmax
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        # (batch, h, seq, seq) --> (batch, h, seq, d_k)
        # return attention scores which can be used for visualization
        return (attention_scores @ value), attention_scores

    def forward(self, q, k, v, mask):
        query = self.w_q(q) # (batch, seq, d_model) --> (batch, seq, d_model)
        key = self.w_k(k) # (batch, seq, d_model) --> (batch, seq, d_model)
        value = self.w_v(v) # (batch, seq, d_model) --> (batch, seq, d_model)

        # (batch, seq, d_model) --> (batch, seq, h, d_k) --> (batch, h, seq, d_k)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        # Calculate attention
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
        
        # Combine all the heads together
        # (batch, h, seq, d_k) --> (batch, seq, h, d_k) --> (batch, seq, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

        # Multiply by Wo
        # (batch, seq, d_model) --> (batch, seq, d_model)  
        return self.w_o(x)

In [111]:
class ResidualConnection(nn.Module):
    
        def __init__(self, features: int, dropout: float) -> None:
            super().__init__()
            self.dropout = nn.Dropout(dropout)
            self.norm = LayerNormalization(features)
    
        def forward(self, x, sublayer):
            return x + self.dropout(sublayer(self.norm(x)))

In [112]:
class EncoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

In [113]:
class Encoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [114]:
class DecoderBlock(nn.Module):

    def __init__(
        self, 
        features: int, 
        self_attention_block: MultiHeadAttentionBlock, 
        cross_attention_block: MultiHeadAttentionBlock, 
        feed_forward_block: FeedForwardBlock, 
        dropout: float
    ) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x

In [115]:
class Decoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

In [116]:
class ProjectionLayer(nn.Module):

    def __init__(self, d_model, vocab_size) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x) -> None:
        # (batch, seq, d_model) --> (batch, seq, vocab_size)
        return self.proj(x)

In [117]:
class Transformer(nn.Module):

    def __init__(
        self, 
        encoder: Encoder, 
        decoder: Decoder, 
        src_embed: InputEmbeddings, 
        tgt_embed: InputEmbeddings, 
        src_pos: PositionalEncoding, 
        tgt_pos: PositionalEncoding, 
        projection_layer: ProjectionLayer
    ) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src, src_mask):
        # (batch, seq, d_model)
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)
    
    def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
        # (batch, seq, d_model)
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
    
    def project(self, x):
        # (batch, seq, vocab_size)
        return self.projection_layer(x)

In [118]:
def build_transformer(
    src_vocab_size: int, 
    tgt_vocab_size: int, 
    src_seq: int, 
    tgt_seq: int, 
    d_model: int=512, 
    N: int=6, 
    h: int=8, 
    dropout: float=0.1, 
    d_ff: int=2048
) -> Transformer:
    # Create the embedding layers
    src_embed = InputEmbeddings(d_model, src_vocab_size)
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

    # Create the positional encoding layers
    src_pos = PositionalEncoding(d_model, src_seq, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq, dropout)
    
    # Create the encoder blocks
    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)

    # Create the decoder blocks
    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
        decoder_blocks.append(decoder_block)
    
    # Create the encoder and decoder
    encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))
    
    # Create the projection layer
    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
    
    # Create the transformer
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
    
    # Initialize the parameters
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    
    return transformer

### What is a Tokenizer?

A tokenizer is a tool or algorithm used to break down text into smaller units called tokens. These tokens can be words, subwords, or even characters, depending on the type of tokenizer used.

#### a. Why Tokenize?

Text Processing: Computers cannot directly understand raw text. Tokenization transforms text into smaller, more manageable parts (tokens) that AI models can process.

Building Blocks: Tokens are the building blocks for further processing, like embedding (turning tokens into numerical vectors) or training models.

#### b. Types of Tokenization

Word-level Tokenization:
- Breaks the text into individual words.
- Example: "I love pizza!" → ["I", "love", "pizza", "!"]
- This is simple but doesn't handle variations well (e.g., "running" and "ran" would be treated as separate words).

Character-level Tokenization:
- Breaks the text into individual characters.
- Example: "love" → ['l', 'o', 'v', 'e']
- This is useful for languages with complex word formations but results in more tokens.

Subword Tokenization: (***thats what we will use here***)
- Breaks words into smaller meaningful units, like prefixes or suffixes.
- Example: "unhappiness" → ["un", "happiness"]
- This is often used in Byte Pair Encoding (BPE) or WordPiece, and helps handle unknown or rare words.

<br>

> **What is BPE Tokenizer?**
> 
> Byte Pair Encoding (BPE) is a subword tokenization method that iteratively merges the most frequent pairs of bytes or characters in a corpus of text. It is commonly used to break down words into smaller, more frequent subword units, making it easier for models to handle rare or unseen words.
> 
> **Usage**
> 
> Handling Rare Words: BPE can break down rare or unknown words into subword units that the model has seen during training. This helps avoid the "out-of-vocabulary" problem.
>
> Balance: It balances between character-level and word-level tokenization, capturing frequent subword patterns without creating too many tokens (like character-level tokenization).
>
> Compression: BPE helps compress the vocabulary size and reduces memory usage for handling words in NLP tasks.

#### c. Why is Tokenization Important?

Model Input: AI models can only work with numbers. Tokenization converts text into tokens that can then be transformed into numerical vectors.

Handling Vocabulary: Tokenizers help manage vocabulary size, especially when dealing with complex or large text corpora.

Preprocessing: It's an essential part of preprocessing before passing text to models like BERT, GPT, etc.

--- 

> **Special tokens used below:**
> 
> [**SOS**] → Signals the start of a sentence (important for decoding).
> 
> [**EOS**] → Indicates the end of a sentence (used in labels).
> 
> [**PAD**] → Pads shorter sentences to match seq_len (ignored during computation).
>
> [**UNK**] → Unknown token = word or subword is not in the tokenizer's vocabulary.

In [119]:
class BPETokenizer:
    def __init__(self, num_merges=100):
        self.num_merges = num_merges
        self.vocab = {}
        self.merges = {}

        # Special tokens
        self.special_tokens = {
            "[SOS]": 0,
            "[EOS]": 1,
            "[PAD]": 2
        }

    def train(self, corpus):
        """
        Train the BPE tokenizer on a given text corpus.
        """
        word_freqs = {}
        for sentence in corpus:
            if sentence == None: 
                continue
            words = sentence.split()
            for word in words:
                word_freqs[word] = word_freqs.get(word, 0) + 1

        # Initialize vocabulary with words
        self.vocab = {word: idx + 3 for idx, word in enumerate(word_freqs.keys())}
        self.vocab.update(self.special_tokens)

        # Merge operations
        for _ in range(self.num_merges):
            pairs = {}
            for word, freq in word_freqs.items():
                symbols = word.split()
                for i in range(len(symbols) - 1):
                    pair = (symbols[i], symbols[i + 1])
                    pairs[pair] = pairs.get(pair, 0) + freq

            if not pairs:
                break

            best_pair = max(pairs, key=pairs.get)
            new_symbol = " ".join(best_pair)
            self.vocab[new_symbol] = len(self.vocab)
            self.merges[best_pair] = new_symbol

            new_word_freqs = {}
            for word, freq in word_freqs.items():
                symbols = word.split()
                i = 0
                while i < len(symbols) - 1:
                    pair = (symbols[i], symbols[i + 1])
                    if pair == best_pair:
                        symbols[i] = new_symbol
                        del symbols[i + 1]
                    i += 1
                new_word_freqs[" ".join(symbols)] = freq
            word_freqs = new_word_freqs

    def encode(self, text):
        """
        Encode text into token IDs.
        """
        words = text.split()
        while len(words) > 1:
            pairs = [(words[i], words[i + 1]) for i in range(len(words) - 1)]
            best_pair = None
            for pair in pairs:
                if pair in self.merges:
                    best_pair = pair
                    break
            if best_pair is None:
                break
            new_symbol = self.merges[best_pair]
            new_words = []
            i = 0
            while i < len(words):
                if i < len(words) - 1 and (words[i], words[i + 1]) == best_pair:
                    new_words.append(new_symbol)
                    i += 2
                else:
                    new_words.append(words[i])
                    i += 1
            words = new_words

        return [self.special_tokens["[SOS]"]] + [self.vocab[w] for w in words if w in self.vocab] + [self.special_tokens["[EOS]"]]

    def decode(self, token_ids):
        """
        Decode token IDs back into text.
        """
        inv_vocab = {v: k for k, v in self.vocab.items()}
        return " ".join([inv_vocab[token] for token in token_ids if token not in self.special_tokens.values()])

    def token_to_id(self, token):
        """
        Get the ID of a special token.
        """
        return self.special_tokens.get(token, self.vocab.get(token, -1))


In [120]:
import torch
import csv

class TranslationDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_raw, tokenizer_src, tokenizer_tgt, src_lang="fr", tgt_lang="en", seq_len=50):
        super().__init__()
        self.seq_len = seq_len
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

        self.sos_token = torch.tensor([0], dtype=torch.int64)  # Define manually since custom tokenizer has no ID mapping
        self.eos_token = torch.tensor([1], dtype=torch.int64)
        self.pad_token = torch.tensor([2], dtype=torch.int64)

        # Load dataset
        self.dataset = []
        for row in dataset_raw: 
            self.dataset.append({
                "src_text": row[src_lang],
                "tgt_text": row[tgt_lang]
            })

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        src_text = self.dataset[idx]['src_text']
        tgt_text = self.dataset[idx]['tgt_text']

        # Tokenize using custom BPE tokenizer
        enc_input_tokens = self.tokenizer_src.encode(src_text)
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text)

        # Ensure sequence length constraints
        enc_padding = self.seq_len - len(enc_input_tokens) - 2  # -2 for SOS and EOS
        dec_padding = self.seq_len - len(dec_input_tokens) - 1  # -1 for SOS

        if enc_padding < 0 or dec_padding < 0:
            raise ValueError("Sentence is too long for the specified sequence length.")

        # Prepare inputs with SOS/EOS/PAD
        encoder_input = torch.cat([
            self.sos_token,
            torch.tensor(enc_input_tokens, dtype=torch.int64),
            self.eos_token,
            torch.tensor([self.pad_token] * enc_padding, dtype=torch.int64),
        ])

        decoder_input = torch.cat([
            self.sos_token,
            torch.tensor(dec_input_tokens, dtype=torch.int64),
            torch.tensor([self.pad_token] * dec_padding, dtype=torch.int64),
        ])

        label = torch.cat([
            torch.tensor(dec_input_tokens, dtype=torch.int64),
            self.eos_token,
            torch.tensor([self.pad_token] * dec_padding, dtype=torch.int64),
        ])

        # Ensure shapes are correct
        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return {
            "encoder_input": encoder_input,
            "decoder_input": decoder_input,
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(),
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)),
            "label": label,
            "src_text": src_text,
            "tgt_text": tgt_text,
        }

def causal_mask(size):
    """
    Create a triangular mask for the decoder. 
    This ensures that the model only attends to previous tokens and cannot "see the future" during training.
    """
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

In [121]:
from pathlib import Path

def get_config():
    return {
        "batch_size": 8,
        "num_epochs": 20,
        "lr": 10**-4,
        "seq": 350,
        "d_model": 512,
        "datasource": 'FrancophonIA/french-to-english',
        "lang_src": "fr",
        "lang_tgt": "en",
        "model_folder": "runs",
        "model_basename": "tmodel_",
        "preload": "latest",
        "tokenizer_file": "model/data/tokenizers/tokenizer_{0}.json",
        "experiment_name": "runs/tmodel"
    }

def get_weights_file_path(config, epoch: str):
    model_folder = f"{data_folder}/{config['model_folder']}/{config['datasource']}"
    model_filename = f"{config['model_basename']}{epoch}.pt"
    return str(Path('.') / model_folder / model_filename)

# Find the latest weights file in the weights folder
def latest_weights_file_path(config):
    model_folder = f"{data_folder}/{config['model_folder']}/{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}*"
    weights_files = list(Path(model_folder).glob(model_filename))
    if len(weights_files) == 0:
        return None
    weights_files.sort()
    return str(weights_files[-1])

In [122]:
import json
from pathlib import Path

def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))

    if not tokenizer_path.exists():
        print(f"Training new tokenizer for {lang}...")
        
        # Initialize tokenizer with special tokens
        tokenizer = BPETokenizer(num_merges=100)
        tokenizer.special_tokens.update({
            "[UNK]": 3
        })

        # Train the tokenizer
        tokenizer.train(get_all_sentences(ds, lang))

        # Save tokenizer as JSON
        with open(tokenizer_path, "w", encoding="utf-8") as f:
            json.dump({
                "vocab": tokenizer.vocab,
                "merges": tokenizer.merges,
                "special_tokens": tokenizer.special_tokens
            }, f)
    else:
        print(f"Loading tokenizer from {tokenizer_path}...")
        
        # Load tokenizer from file
        with open(tokenizer_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        
        tokenizer = BPETokenizer()
        tokenizer.vocab = data["vocab"]
        tokenizer.merges = data["merges"]
        tokenizer.special_tokens = data["special_tokens"]

    return tokenizer

In [130]:
def get_ds(config):
    # It only has the train split, so we divide it overselves
    ds_raw = load_dataset(f"{config['datasource']}", "default", split='train')

    # Build tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

    # Keep 90% for training, 10% for validation
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    # train_ds = train_ds_raw.set_format(type="torch")
    # val_ds = val_ds_raw.set_format(type="torch")
    train_ds = TranslationDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq'])
    val_ds = TranslationDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq'])

    # Find the maximum length of each sentence in the source and target sentence
    max_len_src = 0
    max_len_tgt = 0

    for item in ds_raw:
        src_ids = tokenizer_src.encode(item[config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(item[config['lang_tgt']]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')

    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

In [124]:
def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq"], config['seq'], d_model=config['d_model'])
    return model

In [125]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item[lang]

In [126]:
def train_model(config):
    # Define the device
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    print("Using device:", device)
    if (device == 'cuda'):
        print(f"Device name: {torch.cuda.get_device_name(device.index)}")
        print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
    elif (device == 'mps'):
        print(f"Device name: <mps>")
    else:
        print("NOTE: If you have a GPU, consider using it for training.")
        print("      On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc")
        print("      On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu")
    device = torch.device(device)

    # Make sure the weights folder exists
    Path(f"{data_folder}/{config['model_folder']}/{config['datasource']}").mkdir(parents=True, exist_ok=True)

    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    # Tensorboard
    writer = SummaryWriter(config['experiment_name'])

    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

    # If the user specified a model to preload before training, load it
    initial_epoch = 0
    global_step = 0
    preload = config['preload']
    model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
    if model_filename:
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    else:
        print('No model to preload, starting from scratch')

    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

    for epoch in range(initial_epoch, config['num_epochs']):
        torch.cuda.empty_cache()
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
        for batch in batch_iterator:

            encoder_input = batch['encoder_input'].to(device) # (b, seq)
            decoder_input = batch['decoder_input'].to(device) # (B, seq)
            encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq)
            decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq, seq)

            # Run the tensors through the encoder, decoder and the projection layer
            encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq, d_model)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq, d_model)
            proj_output = model.project(decoder_output) # (B, seq, vocab_size)

            # Compare the output with the label
            label = batch['label'].to(device) # (B, seq)

            # Compute the loss using a simple cross entropy
            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

            # Log the loss
            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            # Backpropagate the loss
            loss.backward()

            # Update the weights
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            global_step += 1

        # Run validation at the end of every epoch
        run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq'], device, lambda msg: batch_iterator.write(msg), global_step, writer)

        # Save the model at the end of every epoch
        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)

In [None]:
config = get_config()
train_model(config)

Using device: mps
Device name: <mps>
Loading tokenizer from model/data/tokenizers/tokenizer_fr.json...
Loading tokenizer from model/data/tokenizers/tokenizer_en.json...
