# Transformers in Pytorch

The idea of this notebook is to explain how transformers are coded in pytorch. We will take as reference the original [Attention is all you need](https://arxiv.org/abs/1706.03762) paper and this [video](https://www.youtube.com/watch?v=ISNdQcPhsts). The transformer we are going to build is rather simple and it will translate sentences from English to Spanish.

First we will build the transformer component by component and then we will work on the training loop and inference.

## The Transformer

In order to build the transformer, we will have to build all the inner components first. The base components of the transformer are:
- Input Embeddings
- Positional Encoding
- Layer Normalization
- Feed Forward Block
- Multi Head Attention Block

Then we have the encoder and the decoder, both composed of many encoder and decoder blocks. And finally a projection layer.

![Transformer Architecture](assets/transformer-network.png)


In [1]:
import torch
import torch.nn as nn
import math

### Input Embeddings
This layer will assign a vector to each of the tokens of the input sequence. This vectors are learned during training and represent the "meaning" of the token (or word). 

A `nn.Module` with this functionality already exists in PyTorch, but we will build a module on top in order to make reference to it.

In [2]:
class InputEmbedding (nn.Module) :
    
    def __init__(self, d_model: int, vocab_size: int) -> None :
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        
    def forward (self, x) :
        return self.embedding(x) * math.sqrt(self.d_model) # as specified in the original paper

### Positional Encoding

The Positional Encodings adds some vectors to the embeddings in order to encode the position of the token in the sentence (e.g. first, second, ...). There are many ways to archive this, but here we will use the vectors proposed in the original paper, calculated with the following functions:

![Positional Encodings Functions](assets/positional-encoding-functions.png)

Where $pos$ is the position of the token in the sentence and $i$ is the dimension.

In [3]:
class PositionalEncoding(nn.Module) :
    
    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None :
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        
        pe = torch.zeros(seq_len, d_model) # (seq_len, d_model)
        
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(1000) / d_model)) # more numerically stable
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0) # (1, seq_len, d_model)
        
        self.register_buffer("pe", pe) # Save it to the state file, but not as a parameter
        
    def forward (self, x) :
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
        return self.dropout(x)

### Layer Normalization

This component normalizes each input (each vector corresponding to a token) so its values have mean 0 and variance 1. Then it scales the values with a parameter $\alpha$ and shifts them with a parameter $\beta$.

The propose of this block is to stabilize and accelerate the training of the model as inputs of the next block will be on a specified range.

In [4]:
class LayerNormalization(nn.Module) :
    
    def __init__(self, eps: float = 10**-6) -> None :
        super().__init__()
        self. eps = eps # numerical stability
        
        self.alpha = nn.Parameter(torch.ones(1)) # Multiplied 
        self.beta = nn.Parameter(torch.ones(1)) # Added
        
    def forward (self, x) :
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * (x - mean) / (std + self.eps) + self.beta

### Feed Forward Block

This block is a simple fully-connected two layer neural network. 

In [5]:
class FeedForwardBlock (nn.Module) :
    
    def __init__ (self, d_model: int, d_ff: int, dropout: float) -> None :
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
        
    def forward (self, x) :
        # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

### Multi-Head Attention Block

This block is the game changer in transformers, its objective is to update the embeddings by giving them some context of the other tokens in the sentence. We use multiple heads in order to focus on different parts of the embeddings in each one, allowing to process different traits and aspects of each word.

This implementation will cover both the self-attention, masked-attention and cross-attention, as it changes only the input values and the use of a mask, that will come handy in all three cases.

In [6]:
class MultiHeadAttentionBlock(nn.Module) :
    
    def __init__(self, d_model: int, h: int, dropout: float) -> None :
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model % h == 0, "d_model is not divisible by h" 
        
        self.d_k = d_model // h
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        
        self.w_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout) :
        d_k = query.shape[-1]
        
        # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            attention_scores.masked_fill_(mask == 0, -1e9)
            
        attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len
        
        if dropout is not None:
            attention_scores = dropout(attention_scores)
            
        return (attention_scores @ value)
    
    def forward (self, q, k, v, mask) :
        query = self.w_q(q) 
        key = self.w_k(k)
        value = self.w_v(v)
        
        # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
        
        x = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
        
        # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
        
        return self.w_o(x)

### Residual Connection

This last block will be used to handle the residual connection that appears in the diagram, also applying Layer Normalization. This way it gets more compact than simply writing a complex forward function in the encoder and decoder blocks.

In [7]:
class ResidualConnection(nn.Module) :
    def __init__(self, dropout: float) -> None :
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization()
        
    def forward(self, x, sublayer) :
        return self.dropout(self.norm(x + sublayer(x)))

### Encoder Block

An encoder block processes the encoder inputs and generates new embeddings that will be later processed by either another encoder block or the decoder.

Normally many encoder blocks are present inside the encoder.

In [8]:
class EncoderBlock (nn.Module) :
    
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None :
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ ResidualConnection(dropout) for _ in range(2) ])
        
    def forward(self, x, src_mask) :
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

### Encoder

Now an encoder module that contains a list of Encoder Blocks.

In [9]:
class Encoder (nn.Module) :
    def __init__ (self, layers: nn.ModuleList) -> None :
        super().__init__()
        self.layers = layers
        
        
    def forward(self, x, mask) :
        for layer in self.layers :
            x = layer(x, mask)
            
        return x

### Decoder Block

The Decoder Block generates new embeddings for the decoder input tokens based both on masked self attention and cross-attention with the encoder output (attention with the tokens of the encoder input).

Normally more than one decoder block is present in the decoder.

In [10]:
class DecoderBlock (nn.Module) :
    
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None :
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ ResidualConnection(dropout) for _ in range(3) ])
        
    def forward (self, x, encoder_output, src_mask, tgt_mask) :
        x = self.residual_connections[0](x, lambda x : self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        
        return x

### Decoder

Now the decoder module that contains all the decoder blocks.

In [11]:
class Decoder (nn.Module) :
    
    def __init__ (self, layers: nn.ModuleList) -> None :
        super().__init__()
        self.layers = layers
        
    def forward (self, x, encoder_output, src_mask, tgt_mask) :
        for layer in self.layers :
            x = layer(x, encoder_output, src_mask, tgt_mask)
            
        return x

### Projection Layer

This layer will project the embeddings from the decoder output and transform them into probabilities of picking a specific token in each position. It consist of a linear layer followed by a softmax.

In [12]:
class ProjectionLayer (nn.Module) :
    
    def __init__(self, d_model: int, vocab_size: int) -> None :
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)
        
    def forward(self, x) :
        # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
        return torch.log_softmax(self.proj(x), dim=-1)

### Transformer
Finally everything comes together in the transformer module. We won't build a forward method as we want to be able to run each part separately.

In [13]:
class Transformer (nn.Module) :
    
    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbedding, tgt_embed: InputEmbedding, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None :
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
        
    def encode (self, src, src_mask) :
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)
    
    def decode (self, encoder_output, src_mask, tgt, tgt_mask) :
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
    
    def project (self, x) :
        return self.projection_layer(x)

In [14]:
def build_transformer (src_vocab_size: int, tgt_vocab_size: int,
                       src_seq_len: int, tgt_seq_len: int,
                       d_model: int = 512, N: int = 6, h: int = 8, d_ff: int = 2048,
                       dropout: float = 0.1) -> None :
    
    # Create the embedding layers
    src_embed = InputEmbedding(d_model, src_vocab_size)
    tgt_embed = InputEmbedding(d_model, tgt_vocab_size)
    
    # Create the positional encoding layers (redundant as we just need one)
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
    
    # Create the encoder blocks
    encoder_blocks = []
    for _ in range(N) :
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)
        
    # Create the decoder blocks
    decoder_blocks = []
    for _ in range(N) :
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        decocer_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(decoder_self_attention_block, decocer_cross_attention_block, feed_forward_block, dropout)
        decoder_blocks.append(decoder_block)
        
    # Create the encoder and the decoder
    encoder = Encoder(nn.ModuleList(encoder_blocks))
    decoder = Decoder(nn.ModuleList(decoder_blocks))
    
    # Create the projection layer
    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
    
    # Create the transformer
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
    
    # Initialize the parameters (Make training faster)
    for p in transformer.parameters() :
        if p.dim() > 1 :
            nn.init.xavier_uniform_(p)
            
    return transformer

## Training

### Tokenizer

To be able to use our model we first need to define a tokenizer, it is a function that will split our text into tokens that appear in our vocabulary. In order to build such vocabulary we will need to train the tokenizer in our dataset.

We will use HuggingFace Tokenizers library to build our tokenizer. We will use a word level tokenizer.

In [15]:
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from  tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

from pathlib import Path

In [16]:
def get_all_sentences (ds, lang) :
    for item in ds :
        yield item["translation"][lang]

def get_or_build_tokenizer (config, ds, lang) :
    tokenizer_path = Path(config["tokenizer_file"].format(lang))
    
    if not Path.exists(tokenizer_path) :
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer)
        tokenizer.save(str(tokenizer_path))
    else :
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
        
    return tokenizer

### Dataset

We are going to use the [Opus Books](https://huggingface.co/datasets/Helsinki-NLP/opus_books), that contains quotes from books in many languages. Our translation model will be from english to spanish. 

To load the dataset we will use HuggingFace's datasets library, and will use the Dataset and DataLoader classes from Pytorch.

In [17]:
from torch.utils.data import Dataset, DataLoader, random_split
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


First we define the class for our dataset, we expect it to transform our original dataset into tensors we can feed the model, and also to give us the masks to use in training.

In [18]:
class BilingualDataset (Dataset) :
    
    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len) -> None:
        super().__init__()
        
        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.seq_len = seq_len
        
        self.sos_token = torch.tensor([tokenizer_src.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_src.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_src.token_to_id("[PAD]")], dtype=torch.int64)
        
    def __len__ (self) :
        return len(self.ds)
    
    def __getitem__(self, index):
        src_target_pair = self.ds[index]
        src_text = src_target_pair["translation"][self.src_lang]
        tgt_text = src_target_pair["translation"][self.tgt_lang]
        
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids 
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids 
        
        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
        
        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0 :
            raise ValueError("Sentence is too long")
        
        # Add SOS, EOS and PAD to source text
        encoder_input = torch.cat([
            self.sos_token,
            torch.tensor(enc_input_tokens, dtype=torch.int64),
            self.eos_token,
            torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64)
        ])
        
        # Add SOS and PAD to the decoder input
        decoder_input = torch.cat([
            self.sos_token,
            torch.tensor(dec_input_tokens, dtype=torch.int64),
            torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
        ])
        
        # Add EOS to the label (What we expect as output from the decoder)
        label = torch.cat([
            torch.tensor(dec_input_tokens, dtype=torch.int64),
            self.eos_token,
            torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
        ])
        
        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len
        
        return {
            "encoder_input": encoder_input, # (seq_len)
            "decoder_input": decoder_input, # (seq_len)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, 1, seq_len) & (1, seq_len, seq_len)
            "label": label,
            "src_text": src_text,
            "tgt_text": tgt_text
        }
        
def causal_mask (size) :
    mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
    return mask == 0

Now we will build the function that will do the following:
- Load the dataset from HuggingFace
- Create or load the tokenizers
- Split between train and validation (using random_split from torch)
- Build the dataset and dataloader objects

In [19]:
def get_ds (config) :
    
    # Load the raw dataset
    ds_raw = load_dataset("opus_books", f"{config['lang_src']}-{config['lang_tgt']}", split="train")
    
    # Build the tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config["lang_src"])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config["lang_tgt"])
    
    # Split between training and validation sets
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])
    
    # Create dataset objects
    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config["lang_src"], config["lang_tgt"], config["seq_len"])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config["lang_src"], config["lang_tgt"], config["seq_len"])
    
    # Build the dataloaders
    train_dataloader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)
    
    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

### Configuration

We have been using the configuration parameter for a while, now its time to define it. Feel free to change the configurations to make it work on your machine.

In [20]:
def get_config() :
    return {
        "batch_size": 4,
        "num_epochs": 20,
        "lr": 10**-4,
        "seq_len": 350,
        "d_model": 512,
        "lang_src": "en",
        "lang_tgt": "es",
        "model_folder": "tmodel_",
        "preload": None,
        "tokenizer_file": "tokenizer_{0}.json",
        "experiment_name": "runs/tmodel"
    }
    
def get_weights_file_path(config, epoch: str) :
    model_folder = config["model_folder"]
    model_basename = config["model_basename"]
    model_filename = f"{model_basename}{epoch}.pt"
    return str(Path(".") / model_folder / model_filename)

### Greedy Decoding

(First see the [Training Loop](#training-loop)) Here we will define the inference function of the model, needed to run validation during training.

We will run the encoder only once, then for every new token on the decoder we will take the previous output with the encoder output to produce the next token. The process ends either with the \[EOS\] token or by arriving to max length.

In [21]:
def greedy_decode (model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device) :
    sos_idx = tokenizer_src.token_to_id("[SOS]")
    eos_idx = tokenizer_src.token_to_id("[EOS]")
    
    # Precompute the encoder output and reuse it for every token
    encoder_output = model.encode(source, source_mask)
    
    # Initialize the decoder input with the SOS token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    
    while True :
        if decoder_input.size(1) == max_len :
            break
        
        # Build the mask for the target (decoder input)
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
        
        # Calculate the output of the decoder
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
        
        # Get the next token
        prob = model.project(out[:, -1])
        
        # Select the token with the max probability (greedy search)
        _, next_word = torch.max(prob, dim=1)
        
        decoder_input = torch.cat([
            decoder_input,
            torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)
        ], dim=1)
        
        if next_word == eos_idx :
            break
        
    return decoder_input.squeeze(0)

### Validation

This validation loop consist on running inference through some of the sentences of the validation set and qualitatively asses how good the translations are. 

In [22]:
def run_validation (model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2) :
    model.eval()
    count = 0
    
    # Size of the control window (default value)
    console_width = 80
    
    with torch.no_grad() :
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device)
            encoder_mask = batch["encoder_mask"].to(device)
            
            assert encoder_input.size(0) == 1, "Batch size must be one for validation"
            
            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
            
            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
            
            # Print to the console
            print_msg("-"*console_width)
            print_msg(f"SOURCE: {source_text}")
            print_msg(f"TARGET: {target_text}")
            print_msg(f"PREDICTED: {model_out_text}")
            
            if count == num_examples :
                break

### Training Loop

Now we will create the function that will train our model. We will be using CUDA as an accelerator if it is available to speed up training

In [23]:
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm import tqdm

In [26]:
def get_model(config, vocab_src_len, vocab_tgt_len) :
    model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config["seq_len"], config["d_model"])
    return model

def train_model (config) :
    # Define the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    Path(config["model_folder"]).mkdir(parents=True, exist_ok=True)
    
    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()) # TODO: find out if .to(device) is needed
    
    # Tensorboard
    writer = SummaryWriter(config["experiment_name"])
    
    optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], eps=1e-9)
    
    initial_epoch = 0
    global_step = 0
    
    if config["preload"] :
        model_filename = get_weights_file_path(config, config["preload"])
        print(f"Preloading model {model_filename}")
        state = torch.load(model_filename)
        initial_epoch = state["epoch"] + 1
        optimizer.load_state_dict(state["optimizer_state_dict"])
        global_step = state["global_step"]
        
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1).to(device)
    
    for epoch in range(initial_epoch, config["num_epochs"]) :
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing epoch {epoch:02d}")
        for batch in batch_iterator :
            encoder_input = batch["encoder_input"].to(device) # (batch, seq_len)
            decoder_input = batch["decoder_input"].to(device) # (batch, seq_len)
            encoder_mask = batch["encoder_mask"].to(device) # (batch, 1, 1, seq_len) 
            decoder_mask = batch["decoder_mask"].to(device) # (batch, 1, seq_len, seq_len)
            
            # Run the tensors through the transformer
            encoder_output = model.encode(encoder_input, encoder_mask) # (batch, seq_len, d_model)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (batch, seq_len, d_model)
            projection_output = model.project(decoder_output) # (batch, seq_len, tgt_vocab_size)
            
            label = batch["label"].to(device) # (batch, seq_len)
            
            # (batch, seq_len, tgt_vocab_size) --> (B * seq_len, tgt_vocab_size)
            loss = loss_fn(projection_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
            
            # Log the loss
            writer.add_scalar("train_loss", loss.item(), global_step)
            
            # Backpropagate the loss
            loss.backward()
            
            # Update the weights
            optimizer.step()
            optimizer.zero_grad()
            
            global_step += 1
            
        # Save the model at the end of every epoch
        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "global_step": global_step
        }, model_filename)
        
        run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config["seq_len"], device, lambda msg: batch_iterator.msg(msg), global_step, writer)
            

Now we can train the model:

In [None]:
config = get_config()
train_model(config)