In [1]:
import torch
import torch.nn as nn 
import math


# Make A Transformer Model

## InputEmbedding

In [6]:
class InputEmbedding(nn.Module):
    
    def __init__(self, 
                 d_model:int,   # dimension model
                 vocab_size:int   # vocabulary size  
                 ) -> None:
        
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        
    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)
    

In [7]:
d_model =  512
vocab_size = 10000


InputEmbedding(
    d_model=d_model,
    vocab_size=vocab_size
)

InputEmbedding(
  (embedding): Embedding(10000, 512)
)

## Positional Encoding

In [285]:
class PositionalEncoding(nn.Module):
    
    def __init__(self,
                 d_model:int,
                 seq_len:int,
                 dropout:float
                 ) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        
        
        # initilize the positional encoding 
        pe = torch.zeros(seq_len, d_model)
        
        # positional 
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        
        # dominator part of positional encoding 
        dominator = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        # enven indices in positional encoding 
        pe[:, 0::2] = torch.sin(position * dominator)
        
        # cos indices in positional encoding 
        pe[:, 1::2] = torch.cos(position * dominator)
        
        # unsqueze the dimension 
        pe = pe.unsqueeze(0)
        
        # take minimum pattern in the positional encoding 
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
        return self.dropout(x)

In [286]:
PositionalEncoding(
    d_model=512,
    seq_len=1000,
    dropout=0.1
)

PositionalEncoding(
  (dropout): Dropout(p=0.1, inplace=False)
)

## MultiHeadAttention

In [287]:
class MultiHeadAttentionBlock(nn.Module):
    
    def __init__(self,
                 h:int,
                 d_model:int,
                 dropout:float) -> None:
        super().__init__()
        self.h = h 
        self.d_model = d_model
        
        
        # make sure the dimension of the d_model 
        assert d_model % h == 0, "d_model is not divisible by h" 
        
        # make sure d_model is dibsible by h 
        self.d_k = d_model // h 
        
        # make query, key, value and output of the word 
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)
        
    @staticmethod    
    def attention(query, key, value, mask, dropout:nn.Dropout):
        
        # query is the list dimension 
        d_k = query.shape[-1]
        
        # transform the dimension of the query and key 
        attention_score = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        
        # apply the mask in the attention score 
        if mask is not None:
            attention_score.masked_fill_(mask==0,
                                           -1e9)
            
        # apply the dropout in the attention score 
        if dropout is not None:
            attention_score = dropout(attention_score)
            return (attention_score @ value), attention_score
        
        
        
    def forward(self, q, k, v, mask):
        
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)
        
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
        
        
        # calculation of attention 
        x, self.attention_score = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
        
        # transpose the attention 
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.d_k * self.h)
        
        return self.w_o(x)        

In [288]:
MultiHeadAttentionBlock(
    h=8,
    d_model=512,
    dropout=0.1
)

MultiHeadAttentionBlock(
  (w_q): Linear(in_features=512, out_features=512, bias=False)
  (w_k): Linear(in_features=512, out_features=512, bias=False)
  (w_v): Linear(in_features=512, out_features=512, bias=False)
  (w_o): Linear(in_features=512, out_features=512, bias=False)
  (dropout): Dropout(p=0.1, inplace=False)
)

## LayerNormalization

In [289]:
class LayerNormalization(nn.Module):
    
    def __init__(self,
                 fetures:int,
                 eps:float=10**-6) -> None:
        super().__init__()
        self.eps = eps 
        self.alpha = nn.Parameter(torch.ones(fetures))
        self.bias = nn.Parameter(torch.zeros(fetures))
        
    def forward(self, x):
        
        mean = x.mean(dim = -1,
                      keepdim=True)
        
        
        std = x.std(dim=-1, keepdim=True)
        
        return self.alpha * (x - mean) / (std + self.eps) + self.bias
    

In [290]:
LayerNormalization(
    fetures=20
)

LayerNormalization()

## FeedForward

In [291]:
class FeedForwardBlock(nn.Module):
    
    def __init__(self,
                 d_model:int,
                 d_ff:int,
                 dropout:float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model,
                                  d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
        
        
    def forward(self, x):
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

In [292]:
FeedForwardBlock(
    d_model=512,
    d_ff=2048,
    dropout=0.1
)

FeedForwardBlock(
  (linear_1): Linear(in_features=512, out_features=2048, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear_2): Linear(in_features=2048, out_features=512, bias=True)
)

## Residual Connection

In [293]:
class ResidualConnection(nn.Module):
    
    def __init__(self,
                 features:int,
                 dropout:float
                 ) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(features)
        
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [294]:
ResidualConnection(
    features=2048,
    dropout=0.1
)

ResidualConnection(
  (dropout): Dropout(p=0.1, inplace=False)
  (norm): LayerNormalization()
)

## EncoderBlock

In [295]:
class EncoderBlock(nn.Module):
    
    def __init__(self,
                 features:int,
                 self_attention_block:MultiHeadAttentionBlock,
                 feed_forward_block:FeedForwardBlock,
                 dropout:float
                 ) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connection = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])
        
    def forward(self, x, src_mask):
        
        x = self.residual_connection[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connection[1](x, self.feed_forward_block)
        
        return x 
        
        

In [296]:
[ResidualConnection(features=2048, dropout=0.1) for i in range(2)]


[ResidualConnection(
   (dropout): Dropout(p=0.1, inplace=False)
   (norm): LayerNormalization()
 ),
 ResidualConnection(
   (dropout): Dropout(p=0.1, inplace=False)
   (norm): LayerNormalization()
 )]

## Encoder

In [320]:
class Encoder(nn.Module):
    
    def __init__(self,
                 features:int,
                 layers:nn.ModuleList
                 ) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)
        
        
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [321]:
Encoder(features=512,
        layers=3)

Encoder(
  (norm): LayerNormalization()
)

## DecoderBlock

In [322]:
class DecoderBlock(nn.Module):
    
    def __init__(self,
                 features:int,
                 self_attention_block:MultiHeadAttentionBlock,
                 cross_attention_block:MultiHeadAttentionBlock,
                 feed_forward_block:FeedForwardBlock,
                 dropout
                 ) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])
        
        
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        
        return x
        

In [323]:
class Decoder(nn.Module):
    
    def __init__(self,
                 features:int,
                 layers:nn.ModuleList
                 ) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)
        
        
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)
            

## ProjectionLayer

In [324]:
class ProjectionLayer(nn.Module):
    
    def __init__(self,
                 d_model,
                 vocab_size) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)
        
        
    def forward(self, x):
        return self.proj(x)

In [325]:

ProjectionLayer(d_model=512,
                vocab_size=2000)

ProjectionLayer(
  (proj): Linear(in_features=512, out_features=2000, bias=True)
)

## Transformer

In [326]:
class Transformer(nn.Module):
    
    def __init__(self,
                 encoder:Encoder,
                 decoder:Decoder,
                 src_embed:InputEmbedding,
                 tgt_embed:InputEmbedding,
                 src_pos:PositionalEncoding,
                 tgt_pos:PositionalEncoding,
                 projection_layer:ProjectionLayer
                 ) -> None:
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
        
        
    def encode(self, 
               src,
               src_mask):
        
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)
    
    def decode(self, 
                encoder_output:torch.Tensor,
                src_mask:torch.Tensor,
                tgt:torch.Tensor,
                tgt_mask:torch.Tensor
                 ):
        
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
    
    
    def project(self, x):
        return self.projection_layer(x)

## Build Transformer

In [327]:
def build_Transformer(
                    src_vocab_size:int,
                    tgt_vocab_size:int,
                    src_seq_len:int,
                    tgt_seq_len:int,
                    d_model:int=512,
                    N:int=6,
                    h:int=8,
                    dropout:float=0.1,
                    d_ff:int=2048
                ) -> Transformer:
    
    
    # create the embedding layers 
    src_embed = InputEmbedding(d_model, src_vocab_size)
    tgt_embed = InputEmbedding(d_model, tgt_vocab_size)
    
    # create a positional embedding 
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
    
    
    # create a encoder blocks 
    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiHeadAttentionBlock(h=h,
                                                               d_model=d_model,
                                                               dropout=dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(features=d_model,
                                     self_attention_block=encoder_self_attention_block,
                                     feed_forward_block=feed_forward_block,
                                     dropout=dropout)
        encoder_blocks.append(encoder_block)
        
        
    # create a decoder blocks 
    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiHeadAttentionBlock(h=h,
                                                               d_model=d_model,
                                                               dropout=dropout)
        
        decoder_cross_attention_block = MultiHeadAttentionBlock(h=h,
                                                                d_model=d_model,
                                                                dropout=dropout)
        
        feed_forward_block = FeedForwardBlock(d_model=d_model, 
                                              d_ff=d_ff,
                                              dropout=dropout)
        
        decoder_block = DecoderBlock(features=d_model,
                                     self_attention_block=decoder_self_attention_block,
                                     cross_attention_block=decoder_cross_attention_block,
                                     feed_forward_block=feed_forward_block,
                                     dropout=dropout)
        
        decoder_blocks.append(decoder_block)
        
    # create the encoder and decoder 
    encoder = Encoder(features=d_model,
                      layers=nn.ModuleList(encoder_blocks))
    
    decoder = Decoder(features=d_model,
                      layers=nn.ModuleList(decoder_blocks))
    
    
    # create a projection layer 
    projection_layer = ProjectionLayer(d_model=d_model,
                                       vocab_size=tgt_vocab_size)
    
    
    # create a transformer 
    transformer = Transformer(encoder=encoder,
                              decoder=decoder,
                              src_embed=src_embed,
                              tgt_embed=tgt_embed,
                              src_pos=src_pos,
                              tgt_pos=tgt_pos,
                              projection_layer=projection_layer
                              )
    
    # initilize the parameter using xavier uniform 
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
            
            
    return transformer
    
    
    

# Train The Model

In [328]:
from pathlib import Path

In [329]:
def get_config():
    return {
        "batch_size":8,
        "num_epochs":20,
        "lr":10**-4,
        "seq_len":350,
        "d_model":512,
        "datasource":"opus_books",
        "lang_src":"en",
        "lang_tgt":"it",
        "model_folder":"weights",
        "model_basename":"tmodel_",
        "preload":"latest",
        "tokenizer_file":"tokenizer_{0}.json",
        "experiment_name":"runs/tmodel"
    }
    
    
# Allow to find the Path where we need to save the weights

def get_weights_file_path(config, epoch:str):
    
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}{epoch}.pt"
    
    return str(Path('.') / model_folder / model_filename)
    

In [330]:
# find the latest weight file in the weights

def latest_weights_file_path(config):
    
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config["model_basename"]}*"
    weights_files = list(Path(model_folder).glob(model_filename))
    
    if len(weights_files) == 0:
        return None 
    
    weights_files.sort()
    return str(weights_files[-1])

In [331]:
numbers = [4, 2, 9, 1, 5, 6]
numbers.sort(reverse=True)
print(numbers)


[9, 6, 5, 4, 2, 1]


In [332]:
get_config, get_weights_file_path, latest_weights_file_path

(<function __main__.get_config()>,
 <function __main__.get_weights_file_path(config, epoch: str)>,
 <function __main__.latest_weights_file_path(config)>)

In [333]:
from torch.utils.data import Dataset

In [334]:
class BilingualDataset(Dataset):
    
    def __init__(self,
                 ds,
                 tokenizer_src,
                 tokenizer_tgt,
                 src_lang,
                 tgt_lang,
                 seq_len
                 ) -> None:
        super().__init__()
        
        self.ds = ds 
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.seq_len = seq_len
        
        
        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)
        
        
    # we also need to define the length method of this dataset 
    def __len__(self):
        return len(self.ds)
    
    
    @staticmethod
    def causal_mask(size):
        
        mask = torch.triu(
            torch.ones(1, size, size), diagonal=1
        ).type(torch.int)
        
        return mask == 0
    
    def __getitem__(self, idx):
        src_target_pair = self.ds[idx]
        src_text = src_target_pair['translation'][self.src_lang]
        tgt_text = src_target_pair['translation'][self.tgt_lang]
        
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
        
        
        # calculate how may padding toolkeys we need to add for the encoder and decoder size 
        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 
        
        # we will only add <s> and </s> only on the label 
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
        
        # make sure the number of padding tokens is not negative 
        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError("Sentence is too long")
        
        
        encoder_input = torch.cat([
            self.sos_token,
            torch.tensor(enc_input_tokens, dtype=torch.int64),
            self.eos_token,
            torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64)
        ],
            dim=0)
        
        
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
            ],
            dim=0
        )
        
        
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
            ],
            dim=0
        )
        
        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len
        
        
        return {
            "encoder_input":encoder_input,
            "decoder_input":decoder_input,
            "encoder_mask":(encoder_input != self.pad_token).unsqueeze(0).int(),
            "decoder_mask":(decoder_input != self.pad_token).unsqueeze(0).int() & self.causal_mask(decoder_input.size(0)),
            "label":label,
            "src_text":src_text,
            "tgt_text":tgt_text
        }
        
        
    
    
        
    

In [335]:
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path


In [336]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]

In [337]:
def get_or_build_tokenizer(config,
                           ds,
                           lang):
    
    
    
    
    
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        
        
        trainer = WordLevelTrainer(
            special_tokens = [
                "[UNK]",
                "[PAD]",
                "[SOS]",
                "[EOS]"
            ], min_frequency=2
        )
        
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
        
        
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
        
        
    return tokenizer
    

In [338]:
from torch.utils.data import Dataset, DataLoader, random_split
BilingualDataset, BilingualDataset.causal_mask

(__main__.BilingualDataset,
 <function __main__.BilingualDataset.causal_mask(size)>)

In [339]:
def get_ds(config):
    
    ds_raw = load_dataset(
        'opus_books',
        f"{config['lang_src']}-{config['lang_tgt']}",
        split='train'
    )
    
    # build tokenizers 
    tokenizer_src = get_or_build_tokenizer(config=config,
                                           ds=ds_raw,
                                           lang=config['lang_src']
                                           )
    
    
    
    tokenizer_tgt = get_or_build_tokenizer(config=config,
                                           ds=ds_raw,
                                           lang=config["lang_src"]
                                           )
    
    
    
    
    # keep 90% for training 10% for validation
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])
    
    
    train_ds = BilingualDataset(ds=train_ds_raw,
                                tokenizer_src=tokenizer_src,
                                tokenizer_tgt=tokenizer_tgt,
                                src_lang=config["lang_src"],
                                tgt_lang=config["lang_tgt"],
                                seq_len=config["seq_len"]
                                )
    
    
    val_ds = BilingualDataset(ds=val_ds_raw,
                              tokenizer_src=tokenizer_src,
                              tokenizer_tgt=tokenizer_tgt,
                              src_lang=config["lang_src"],
                              tgt_lang=config["lang_tgt"],
                              seq_len=config["seq_len"]
                              )
    
    
    # find the maximum length of each sentence in the source and target sentence 
    max_len_src = 0 
    max_len_tgt = 0 
    
    
    for item in ds_raw:
        src_ids = tokenizer_src.encode(item["translation"][config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(item["translation"][config["lang_tgt"]]).ids
        
        
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))
        
    print(f"Max length of source sentence: {max_len_src}")
    print(f"Max_length of target sentence: {max_len_tgt}")
    
    train_dataloader = DataLoader(
        train_ds,
        batch_size=config['batch_size'],
        shuffle=True
    )
    
    val_dataloader = DataLoader(
        val_ds,
        batch_size=1,
        shuffle=True
    )
    
    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt
        
        
    
    
    

In [340]:
build_Transformer

<function __main__.build_Transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int = 512, N: int = 6, h: int = 8, dropout: float = 0.1, d_ff: int = 2048) -> __main__.Transformer>

In [341]:
def get_model(config,
              vocab_src_len,
              vocab_tgt_len):
    
    model = build_Transformer(
        vocab_src_len,
        vocab_tgt_len,
        config["seq_len"],
        config["seq_len"],
        d_model=config["d_model"]
        
    )
    
    return model

## Train the Model

In [342]:
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import warnings

def train_model(config):
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(device)
    
    
    # make sure the weights folder exists 
    Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)
    
    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config=config)
    model = get_model(config=config,
                      vocab_src_len=tokenizer_src.get_vocab_size(),
                      vocab_tgt_len=tokenizer_tgt.get_vocab_size()
                      ).to(device)
    
    
    # Tensorboard 
    # config['experiment_name'] = 'my_experiment'
    print("config", config)

    writer = SummaryWriter(config['experiment_name'])
    # print("config", config)
    
    
    
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=config['lr'],
        eps = 1e-9
    )
    
    
    
    initial_epoch = 0
    global_step = 0
    preload = config['preload']
    model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
    
    if model_filename:
        print(f"Preload model {model_filename}")
        
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']

    else:
        print("No model to preload, starting from scratch")
        
        
    loss_fn = nn.CrossEntropyLoss(
        ignore_index=tokenizer_src.token_to_id('[PAD]'),
        label_smoothing=0.1
    ).to(device)
    
    
    
    for epoch in range(initial_epoch, config['num_epochs']):
        torch.cuda.empty_cache()
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
        
        for batch in batch_iterator:
            
            # finally we get the tensor 
            encoder_input = batch['encoder_input'].to(device)
            decoder_input = batch['decoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_mask = batch['decoder_mask'].to(device)
            
            
            # Run the tensors throught the encoder, decoder and the projection layer 
            encoder_output = model.encode(encoder_input,
                                          encoder_mask)
            
            decoder_output = model.decode(encoder_output,
                                          encoder_mask,
                                          decoder_input,
                                          decoder_mask)
            
            proj_output = model.project(decoder_output)
            
            
            # compare the output with the label 
            label = batch['label'].to(device)
            
            
            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
            
            
            # log the loss 
            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()
            
            # Backpropagate the loss 
            loss.backward()
            
            # update the weights 
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            
            global_step += 1
            
        # Run validation at the end of every epoch 
        run_validation(model, 
                       val_dataloader, 
                       tokenizer_src, 
                       tokenizer_tgt, 
                       config['seq_len'], 
                       device, 
                       lambda msg: batch_iterator.write(msg), 
                       global_step, 
                       writer)
        
        
        
        # save the model at the end of every epoch
        model_filename = get_weights_file_path(config=config,
                                               epoch=f"{epoch:02d}")
        
        
        # take a snapshort 
        torch.save({
            'epoch':epoch,
            'model_state_dict':model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)
        
        
        
if __name__ == "__main__":
    warnings.filterwarnings('ignore')
    config = get_config()
    train_model(config)
         
    

cpu


Max length of source sentence: 309
Max_length of target sentence: 274
config {'batch_size': 8, 'num_epochs': 20, 'lr': 0.0001, 'seq_len': 350, 'd_model': 512, 'datasource': 'opus_books', 'lang_src': 'en', 'lang_tgt': 'it', 'model_folder': 'weights', 'model_basename': 'tmodel_', 'preload': 'latest', 'tokenizer_file': 'tokenizer_{0}.json', 'experiment_name': 'runs/tmodel'}
No model to preload, starting from scratch


Processing Epoch 00:  37%|███▋      | 1344/3638 [4:05:01<6:49:41, 10.72s/it, loss=3.123]

## Validation Model

In [None]:
import os 


def greedy_decode(model,
                  source,
                  source_mask,
                  tokenizer_src,
                  tokenizer_tgt,
                  max_len,
                  device):
    
    
    
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')
    
    
    # precompute the encoder output and reuse if for every step 
    encoder_output = model.encoder(source,
                                   source_mask)
    
    
    # initilize the decoder input wht the sos token 
    decoder_input = torch.empty(
        1,
        1
    ).fill_(sos_idx).type_as(source).to(device)
    
    
    while True:
        if decoder_input.size(1) == max_len:
            break 
        
        decoder_mask = BilingualDataset.causal_mask(size=decoder_input.size(1)).type_as(source_mask).to(device)
        
        
        out = model.decode(encoder_output, 
                           source_mask,
                           decoder_input,
                           decoder_mask)
        
        
        # get next token, probability of the next token using the linear layer 
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        
        decoder_input = torch.cat([
            decoder_input, 
            torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)
        ], dim = 1)
        
        
        # next word is equal to end of the sentence token then stop the loop 
        if next_word == eos_idx:
            break
        
        
    return decoder_input.squeeze(0)

In [None]:
import torchmetrics

def run_validation(model,
                   validation_ds,
                   tokenizer_src,
                   tokenizer_tgt,
                   max_len,
                   device,
                   print_msg,
                   global_step,
                   writer,
                   num_examples=2
                   ):
    
    
    model.eval()
    count = 0
    
    source_texts = []
    expected = []
    pridicted = []
    
    try:
        
        # get the console window with 
        with os.popen('stty size', 'r') as console:
            _, console_width = console.read().split()
            console.width = int(console_width)
            
    
    except:
        
        # if we can't get the console width use 80 as default 
        console_width = 80 
        
    with torch.no_grad():
        
        for batch in validation_ds:
            
            count += 1
            encoder_input = batch['encoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            
            # check that the batch size is 1
            assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"
            
            # so as you remember where we want to inference the model we need to calculate the encoder output only onece 
            model_out = greedy_decode(
                model=model,
                source=encoder_input,
                source_mask=encoder_mask,
                tokenizer_src=tokenizer_src,
                tokenizer_tgt=tokenizer_tgt,
                max_len=max_len,
                device=device
            )
            
            
            # we wan to compare this model output with what we expected so with label so let's append 
            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            model_out_text = tokenizer_tgt.decoder(model_out.detach().cpu().numpy())
            
            # save into the prespective list 
            source_texts.append(source_text)
            expected.append(target_text)
            pridicted.append(model_out_text)
            
            
            
            print_msg('-'*console_width)    # can we print some bars 
            print_msg(f"{f'SOURCE: ':>12}{source_text}")  # source text message
            print_msg(f"{f'TARGET: ':>12}{target_text}")  # target text message
            print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")  # pridicted text message 
            
            
            # if we have already processed number of example then we just break 
            if count == num_examples:
                print_msg('-'*console_width)
                break
            
            
    if writer:
        
        # Evaluate the character error rate 
        # Compute the char error rate 
        metric = torchmetrics.CharErrorRate()
        cer = metric(pridicted, expected)
        writer.add_scalar('validation cer', cer, global_step)
        writer.flush()
        
        
        # compute the word error rate 
        metric = torchmetrics.WordErrorRate()
        wer = metric(pridicted, expected)
        writer.add_scalar('validation wer', wer, global_step)
        writer.flush()
        
        
        # compute the BLEU metric for the translation task 
        metric = torchmetrics.BLEUScore()
        bleu = metric(pridicted, expected)
        writer.add_scalar('validation BLEU', bleu, global_step)
        writer.flush()
        
    