<a href="https://colab.research.google.com/github/Lmalviya/machineTranslationTask/blob/main/TransformerFromScratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## scop of the notebook:
  1. built vanila transformer from scretch
  2. train it for languge tranlation english to italian
  3. fot choosing the next token we build the greedy search and beam search


# Build Transformer from scretch

In [None]:
import torch
import torch.nn as nn

import math

### Word embedding

In [None]:
class WordEmbedding(nn.Module):
  def __init__(self, d_model: int, vocab_size: int):
    super().__init__()
    # d_model: it is the embedding vector size
    # vocab_size: number of words present into vocab

    self.d_model = d_model
    self.vocab_size = vocab_size
    self.embedding = nn.Embedding(self.vocab_size, self.d_model)

  def forward(self, x):
    return self.embedding(x)*math.sqrt(self.d_model)  ## mention in the paper "sqrt(self.d_model)"


### Postional Embedding

In [None]:
class PostionalEmbedding(nn.Module):
  def __init__(self, d_model: int, seq_len: int, dropout_p: float):
    super().__init__()
    # seq_len: number of tokens present in the input
    # dropout: used for regularization

    self.d_model = d_model
    self.seq_len = seq_len
    self.dropout = nn.Dropout(dropout_p)

    # create a matrix of shape (seq_len, d_model)
    pe = torch.zeros(self.seq_len, self.d_model) # postional embedding for each token

    # create vector of shape (seq_len, 1)
    position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float()*(-math.log(10000.0)/d_model) )

    # apply sin at even postion
    pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
    pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))

    # now we have to add the batch dim
    pe = pe.unsqueeze(0) # (1, seq_len, d_model)

    # now define the tensor into buffer
    # when you have tensor which is not learnable parameter
    # but you want to save when model is save, then you have to put it in register buffer
    self.register_buffer('pe', pe)

  def forward(self, x):
    x = x + (self.pe[:, :x.size(1), :].requires_grad_(False))
    return self.dropout(x)


### Layer normalization

In [None]:
class LayerNorm(nn.Module):
  def __init__(self, eps:float = 10**-6):
    super().__init__()
    # We also introduce two parameters, usually called gamma (multiplicative) and beta (additive)
    # that introduce some fluctuations in the data, because maybe having all values between 0 and 1
    # may be too restrictive for the network. The network will learn to tune these two parameters to
    # introduce fluctuations when necessary.

    self.eps = eps
    self.gamma = nn.Parameter(torch.ones(1))  # multiplicative
    self.beta = nn.Parameter(torch.zeros(1)) # Additative

  def forward(self, x):
    mean = x.mean(dim=-1, keepdim=True) #usually mean() remove the dim, to keep the dim we used the keepdim=True
    std = x.std(dim=-1, keepdim=True)
    return self.gamma*(x - mean)/(std + self.eps) + self.beta



### Feed forward block

In [None]:
class FeedForwardBlock(nn.Module):
  def __init__(self, d_model: int, d_ff: int, dropout_p: float =0.1):
    super().__init__()

    self.linear_1 = nn.Linear(d_model, d_ff) #W1 and B1
    self.linear_2 = nn.Linear(d_ff, d_model) #W2 and B2

    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(dropout_p)

  def forward(self, x):
    # input: (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model) output
    return self.linear_2(self.dropout(self.relu(self.linear_1(x))))


### Multi-head Attention

In [None]:
class MultiheadAttentionBlock(nn.Module):
  def __init__(self, d_model: int, head: int, dropout_p: float):
    super().__init__()
    # we take input and replecate into three vector each vector name as follows: key, query, value
    # we three matrix caled: key_mul, query_mul, val_mul
    # multiply each vectoer with its corresponding matrix and get output which we called key_hat_mat, query_hat_mat, value_hat_mat
    # each vecter key_hat_mat, query_hat_mat, value_hat_mat divide alog the d_model dim
    # means each head full access of the sequence but different part of each word

    self.head = head
    self.d_model = d_model
    self.attention_score = None

    # if condition fail print the message
    assert d_model % head == 0,  'd_model is not divisible by head'

    self.d_k = self.d_model // self.head
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)

    self.w_o = nn.Linear(d_model, d_model)
    self.dropout = nn.Dropout(dropout_p)

  @staticmethod
  def attention(query, key, value, mask, dropout: nn.Dropout):
    d_k = query.shape[-1]

    attention_score = (query @ key.transpose(-2, -1))/math.sqrt(d_k)
    if mask is not None:
      attention_score.masked_fill_(mask == 0, -1e9) # In the mask where mask[i][j] == 0, replace -1e9 in the attention_score

    attention_score = attention_score.softmax(dim = -1) # (batch, h, seq_len, seq_len)
    if dropout is not None:
      attention_score = dropout(attention_score)

    return (attention_score @ value), attention_score

  def forward(self, q, k, v, mask): #mask used to restrick some words to interect with other words
    query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
    key  = self.w_k(k)
    value = self.w_v(v)

    # divid for multi head
    # (batch, seq_len, d_model) --> (batch, seq_len, head, d_k) --> (batch, head, seq_len, d_k)
    # Transpos because we want each head see all the tokens
    query = query.view(query.shape[0], query.shape[1], self.head, self.d_k).transpose(1, 2)
    key = key.view(key.shape[0], key.shape[1], self.head, self.d_k).transpose(1, 2)
    value = value.view(value.shape[0], value.shape[1], self.head, self.d_k).transpose(1, 2)

    x, self.attention_score  = MultiheadAttentionBlock.attention(query, key, value, mask, self.dropout)

    # (batch, head, seq_len, d_k) --> (batch, seq_len, head, d_k) --> (batch, seq_len, d_model)
    x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.head*self.d_k) # contiguous() used, because we want contiguous memory allocation

    # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
    return self.w_o(x)


## Residual connection

In [None]:
class ResidualConnection(nn.Module):
  def __init__(self, dropout_p: float):
    super().__init__()
    self.dropout = nn.Dropout(dropout_p)
    self.norm = LayerNorm()

  def forward(self, x, sublayer):
    return x + self.dropout(sublayer(self.norm(x)))

## Encoder Block

In [None]:
class EncoderBlock(nn.Module):
  def __init__(self, self_attention_block: MultiheadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout_p: float):
    super().__init__()
    self.self_attention_block  = self_attention_block
    self.feed_forward_block = feed_forward_block
    self.residualConnection = nn.ModuleList([ResidualConnection(dropout_p) for _ in range(2)])

  def forward(self, x, src_mask):
    x = self.residualConnection[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
    x = self.residualConnection[1](x, self.feed_forward_block)
    return x


## Encoder

In [None]:
class Encoder(nn.Module):
  def __init__(self, layers: nn.ModuleList):
    super().__init__()
    self.layers = layers
    self.norm = LayerNorm()

  def forward(self, x, mask):
    for layer in self.layers:
      x = layer(x, mask)
    return self.norm(x)


## Decoder Block

In [None]:
class DecoderBlock(nn.Module):
  def __init__(self, self_attention_block: MultiheadAttentionBlock, cross_attention_block: MultiheadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout_p: float):
    super().__init__()
    self.self_attention_block = self_attention_block
    self.cross_attention_block = cross_attention_block
    self.feed_forward_block = feed_forward_block
    self.residualConnection = nn.ModuleList([ResidualConnection(dropout_p) for _ in range(3)])
    self.dropout = nn.Dropout(dropout_p)

  def forward(self, x, encoder_output, src_mask, tgt_mask):
    # src_mask: it is for source language
    # tgt_mask: it is for target language

    x = self.residualConnection[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
    x = self.residualConnection[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask)) # query comming from the decoder and key and value comming from the encoder
    x = self.residualConnection[2](x, self.feed_forward_block)
    return x


### Decoder block

In [None]:
class Decoder(nn.Module):
  def __init__(self, layers: nn.ModuleList):
    super().__init__()
    self.layers = layers
    self.norm = LayerNorm()

  def forward(self, x, encoder_output, src_mask, tgt_mask):
    for layer in self.layers:
      x = layer(x, encoder_output, src_mask, tgt_mask)
    return self.norm(x)

### Last decoder output layer (projection layer)

In [None]:
class ProjectionLayer(nn.Module):
  def __init__(self, d_model: int, vocab_size: int):
    super().__init__()
    self.proj = nn.Linear(d_model, vocab_size)

  def forward(self, x):
    # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
    return torch.log_softmax(self.proj(x), dim = -1) # log_softmax: first apply softmax and then apply log on it


### Transformer


In [None]:
class Transformer(nn.Module):
  def __init__(self, encoder: Encoder, decoder: Decoder, src_embedding: WordEmbedding, tgt_embedding: WordEmbedding, src_pos: PostionalEmbedding, tgt_pos: PostionalEmbedding, proj: ProjectionLayer):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.src_embed = src_embedding
    self.tgt_embed = tgt_embedding
    self.src_pos = src_pos
    self.tgt_pos = tgt_pos
    self.projectionLayer = proj

  def encode(self, src, src_mask):
    src = self.src_embed(src)
    src = self.src_pos(src)
    return self.encoder(src, src_mask)

  def decode(self, encoder_output, src_mask, tgt, tgt_mask):
    tgt = self.tgt_embed(tgt)
    tgt = self.tgt_pos(tgt)
    return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

  def projection(self, x):
    return self.projectionLayer(x)


## Build Tranformer

In [None]:
def buildTransformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int = 512, head: int= 8, d_ff: int= 2048, N: int=6, dropout_p: float = 0.1):
  # create source and target embedding
  src_embedding = WordEmbedding(d_model, src_vocab_size)
  tgt_embedding = WordEmbedding(d_model, tgt_vocab_size)

  # create postional embedding for source and target
  src_pos_embed = PostionalEmbedding(d_model, src_seq_len, dropout_p)
  tgt_pos_embed = PostionalEmbedding(d_model, tgt_seq_len, dropout_p)

  EncoderBlocks = []
  for _ in range(N):
    encoder_self_attention = MultiheadAttentionBlock(d_model, head, dropout_p)
    feed_forward = FeedForwardBlock(d_model, d_ff, dropout_p)
    encoder_block = EncoderBlock(encoder_self_attention, feed_forward, dropout_p)
    EncoderBlocks.append(encoder_block)

  DecoderBlocks = []
  for _ in range(N):
    decoder_self_attention = MultiheadAttentionBlock(d_model, head, dropout_p)
    decoder_cross_attention = MultiheadAttentionBlock(d_model, head, dropout_p)
    feed_forward = FeedForwardBlock(d_model, d_ff, dropout_p)
    decoder_block = DecoderBlock(decoder_self_attention, decoder_cross_attention, feed_forward, dropout_p)
    DecoderBlocks.append(decoder_block)

  # create encoder and decoder
  encoder = Encoder(nn.ModuleList(EncoderBlocks))
  decoder = Decoder(nn.ModuleList(DecoderBlocks))

  # projection layer
  projectionLayer = ProjectionLayer(d_model, tgt_vocab_size)

  # Transformer
  transformer = Transformer(encoder, decoder, src_embedding, tgt_embedding, src_pos_embed, tgt_pos_embed, projectionLayer)

  # initialize paramers so training faster
  for p in transformer.parameters():
    if p.dim() > 1:
      nn.init.xavier_uniform(p)

  return transformer

# Build Language Translation from English to Italian using Transformer

In [None]:
!pip3 install datasets --quiet
!pip3 install torchmetrics



In [None]:
import os
import sys
from typing import Any
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

from pathlib import Path #create absolut path using relative path
import warnings

from torch.utils.tensorboard import SummaryWriter
import torchmetrics

### Tokenizer

In [None]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]

def buildTokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

### Get dataset from HuggingFace

In [None]:
class BilingualDataset(nn.Module):
  def __init__(self, dataset, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
    super().__init__()
    self.dataset = dataset
    self.tokenizer_src = tokenizer_src
    self.tokenizer_tgt = tokenizer_tgt
    self.src_lang = src_lang
    self.tgt_lang = tgt_lang
    self.seq_len = seq_len

    self.sos_token = torch.tensor([self.tokenizer_tgt.token_to_id('[SOS]')], dtype=torch.int64)
    self.pad_token = torch.tensor([self.tokenizer_tgt.token_to_id('[PAD]')], dtype=torch.int64)
    self.eos_token = torch.tensor([self.tokenizer_tgt.token_to_id('[EOS]')], dtype=torch.int64)

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, index: Any):
    src_target_text = self.dataset[index]
    src_text = src_target_text['translation'][self.src_lang]
    tgt_text = src_target_text['translation'][self.tgt_lang]

    enc_input_tokens = self.tokenizer_src.encode(src_text).ids #convert each src word into id give as array of id's
    dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

    enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 # 2 becasue we add [SOS] and [EOS]
    dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) -  1 # 1 because we only add [SOS] to the decoder side

    if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
      raise ValueError('Sentence is too long')

    # input for encoder
    encoder_input = torch.cat([
            self.sos_token,
            torch.tensor(enc_input_tokens, dtype=torch.int64),
            self.eos_token,
            torch.tensor([self.pad_token]*enc_num_padding_tokens, dtype=torch.int64)
        ])

    # input for decoder
    decoder_input = torch.cat([
        self.sos_token,
        torch.tensor(dec_input_tokens, dtype=torch.int64),
        torch.tensor([self.pad_token]*dec_num_padding_tokens, dtype=torch.int64),
    ])

    # label at decoder side
    label = torch.cat([
        torch.tensor(dec_input_tokens, dtype=torch.int64),
        self.eos_token,
        torch.tensor([self.pad_token]*dec_num_padding_tokens, dtype=torch.int64)
    ])

    assert encoder_input.size(0) == self.seq_len
    assert decoder_input.size(0) == self.seq_len
    assert label.size(0) == self.seq_len

    return {
        'encoder_input': encoder_input, # (seq_len)
        'decoder_input': decoder_input, # (seq_len)
        'encoder_mask': (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), #(1, 1, seq_len)
        'decoder_mask': (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len)
        'label': label, # (seq_len)
        'src_text': src_text,
        'tgt_text': tgt_text,
    }


def causal_mask(size):
  # we expect value along the diagoan is high
  # if we don't want some postions to interact, we can set thier value -inf
  # before appling the softmax in this matrix and model will not learn those interaction
  # we will use this in the decoder

  # we create matrix in which every value below the diagonal is zero and all value of diagoanl is one
  mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)

  # but we want below the diagonal is 1 and above the diagonal is one
  return mask == 0



In [None]:
def get_data(config):
  ds_raw = load_dataset(config['datasource'], f"{config['lang_src']}-{config['lang_tgt']}", split='train')

  # build tokenizers
  tokenizer_src = buildTokenizer(config, ds_raw, config['lang_src'])
  tokenizer_tgt = buildTokenizer(config, ds_raw, config['lang_tgt'])

  #split dataset int train and validation set
  train_size = int(0.9*len(ds_raw))
  val_size = len(ds_raw) - train_size
  train_raw, val_raw = random_split(ds_raw, [train_size, val_size])

  train_ds = BilingualDataset(train_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
  val_ds = BilingualDataset(val_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

  max_len_src = 0
  max_len_tgt = 0

  for item in ds_raw:
    src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
    tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids

    max_len_src = max(max_len_src, len(src_ids))
    max_len_tgt = max(max_len_tgt, len(tgt_ids))

  print(f"Max Length of source sentence: {max_len_src}")
  print(f"Max Length of target sentence : {max_len_tgt}")

  train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
  val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

  return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

### Get the model

In [None]:
def get_model(config, vocab_src_len, vocab_tgt_len):
  model = buildTransformer(vocab_src_len, vocab_tgt_len, config['seq_len'], config['seq_len'], config['d_model'], config['attention_head'], config['hidden_layer_dim'], config['blocks'])
  return model

### Configrations

In [None]:
def get_config():
  return {
      "beam_size": 3,
      "batch_size": 8,
      "num_epochs": 1,
      "lr": 10**-4,
      "seq_len": 350,
      "d_model": 512,
      "attention_head": 2,
      "hidden_layer_dim": 1024, #d_ff size
      "blocks": 2, # number of encoder-decoder blocks in the transformer
      "lang_src": 'en',
      "lang_tgt": 'it',
      "datasource": 'opus_books',
      "model_folder": 'weights',
      "model_basename": 'tmodel_',
      "preload": None,
      "tokenizer_file": "tokenizer_{0}.json",
      "experiment_name": 'runs/tmodel'
  }

def get_weights_file_path(config, epoch: str):
  model_folder = f"{config['datasource']}_{config['model_folder']}"
  model_filename =  f"{config['model_basename']}{epoch}.pt"

  return str(Path('.')/model_folder/model_filename)

def latest_weights_file_path(config):
  model_folder = f"{config['datasource']}_{config['model_folder']}"
  model_name = f"{config['model_basename']}*"
  weights_file_list = list(Path(model_folder).glob(model_name)) # fetch all the file name which are present inside the "model_folder"
  if len(weights_file_list) == 0:  # weights are not exist
    return None
  else:
    weights_file_list.sort()
    return weights_file_list[-1]

### Train the model

In [None]:
def train_model(config):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  print(f"using device: {device}")
  if (device == 'cuda'):
    print(f"Device name: {torch.cuda.get_device_name(device.index)}")
    print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory/ 1024**3} GB")
  elif (device == 'mps'):
    print(f"Device name: <mps>")
  else:
    pass

  # create folder to store models weights
  Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)
  train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_data(config)
  model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

  # Tensorboard
  writer = SummaryWriter(config['experiment_name'])

  optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

  # if model crash in between training then again strat from that instance
  initial_epoch = 0
  global_step = 0
  preload = config['preload']
  model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload  else None

  if model_filename:
    model_filename = get_weights_file_path(config, config['preload'])
    print(f"preloading model: {model_filename}")
    state = torch.load(model_filename)
    initial_epoch = state['epoch'] + 1
    optimizer.load_state_dict(state['optimizer_state_dict'])
    global_step = state['global_step']
  else:
    print("No model to preload, starting from scratch")


  # smooth_lable: it means take 0.1 from highest probability and distribute to others so that model will less overfit
  loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

  # training loop
  for epoch in range(initial_epoch, config['num_epochs']):
    torch.cuda.empty_cache()
    model.train()
    batch_iterator = tqdm(train_dataloader, desc=f'Processing Epoch: {epoch:02d}')
    for batch in batch_iterator:

      encoder_input = batch['encoder_input'].to(device) # (batch, seq_len)
      decoder_input = batch['decoder_input'].to(device) # (batch, seq_len)
      encoder_mask = batch['encoder_mask'].to(device) # (batch, 1, 1, seq_len)
      decoder_mask = batch['decoder_mask'].to(device) # (batch, 1, seq_len, seq_len)

      # Run the tensor through the transformer
      encoder_output = model.encode(encoder_input, encoder_mask) # (batch, seq_len, d_model)
      decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (batch, seq_len, d_model)
      projection = model.projection(decoder_output) # (batch, seq_len, tgt_vocab_size)

      label = batch['label'].to(device) # (batch, seq_len)

      # (batch, seq_len, tgt_vocab_size) --> (batch*seq_len, tgt_vocab_size)
      loss = loss_fn(projection.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
      batch_iterator.set_postfix({ "loss": f"{loss.item():6.3f}" })

      # log the loss
      writer.add_scalar("train loss", loss.item(), global_step)
      writer.flush()

      # backpropogate the loss
      loss.backward()

      # update the weights
      optimizer.step()
      optimizer.zero_grad()

      global_step += 1

    run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)

    # save the model at the end of the every epochs
    model_file_path = get_weights_file_path(config, f'{epoch:02d}')

    # if you want to resume the training is recommanded to store the optimizer state also
    # because optimize also store weight for each weights
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'global_step': global_step
    }, model_file_path)



### Greedy Search for next token

In [None]:
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build mask for target
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # calculate output
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        # get next token
        prob = model.projection(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)

### Beam Search for next token

In [None]:

def beam_search_decode(model, beam_size, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_initial_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)

    # Create a candidate list
    candidates = [(decoder_initial_input, 1)]

    while True:

        # If a candidate has reached the maximum length, it means we have run the decoding for at least max_len iterations, so stop the search
        if any([cand.size(1) == max_len for cand, _ in candidates]):
            break

        # Create a new list of candidates
        new_candidates = []

        for candidate, score in candidates:

            # Do not expand candidates that have reached the eos token
            if candidate[0][-1].item() == eos_idx:
                continue

            # Build the candidate's mask
            candidate_mask = causal_mask(candidate.size(1)).type_as(source_mask).to(device)
            # calculate output
            out = model.decode(encoder_output, source_mask, candidate, candidate_mask)
            # get next token probabilities
            prob = model.project(out[:, -1])
            # get the top k candidates
            topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1)
            for i in range(beam_size):
                # for each of the top k candidates, get the token and its probability
                token = topk_idx[0][i].unsqueeze(0).unsqueeze(0)
                token_prob = topk_prob[0][i].item()
                # create a new candidate by appending the token to the current candidate
                new_candidate = torch.cat([candidate, token], dim=1)
                # We sum the log probabilities because the probabilities are in log space
                new_candidates.append((new_candidate, score + token_prob))

        # Sort the new candidates by their score
        candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)
        # Keep only the top k candidates
        candidates = candidates[:beam_size]

        # If all the candidates have reached the eos token, stop
        if all([cand[0][-1].item() == eos_idx for cand, _ in candidates]):
            break

    # Return the best candidate
    return candidates[0][0].squeeze()

### Validation

In [None]:
def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examble=2):
  model.eval()
  count = 0

  source_texts = []
  expected = []
  predicted_by_greedy = []
  predicted_by_beam = []


  with torch.no_grad():
    for batch in validation_ds:
      count += 1
      encoder_input = batch['encoder_input'].to(device) # (batch, seq_len)
      encoder_mask = batch['encoder_mask'].to(device) # (batch, 1, 1, seq_len)

      assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"

      model_output_greedy = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
      model_output_beam = beam_search_decode(model, config['beam_size'], encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)


      source_text = batch['src_text'][0]
      target_text = batch['tgt_text'][0]
      model_output_text_greedy  = tokenizer_tgt.decode(model_output_greedy.detach().cpu().numpy())
      model_output_text_beam = tokenizer_tgt.decode(model_output_beam.detach().cpu().numpy())

      source_texts.append(source_text)
      expected.append(target_text)
      predicted_by_greedy.append(model_output_text_greedy)
      predicted_by_beam.append(model_output_text_beam)

      # print to console
      print_msg('-'*80)
      print_msg(f"{f'SOURCE: ':>12}{source_text}")
      print_msg(f"{f'TARGET: ':>12}{target_text}")
      print_msg(f"{f'PREDICTED GREEDY: ':>12}{model_output_text_greedy}")
      print_msg(f"{f'PREDICTED BEAM: ':>12}{model_output_text_beam}")

      if count == num_examble:
        print_msg('-'*80)
        break

  if writer:
    # Evaluate the character error rate

    # Compute the char error rate
    metric = torchmetrics.CharErrorRate()
    cer_greedy = metric(predicted_by_greedy, expected)
    cer_beam = metric(predicted_by_beam, expected)
    writer.add_scalar('validation cer with greedy search', cer_greedy, global_step)
    writer.flush()

    writer.add_scalar('validation cer with beam search', cer_beam, global_step)
    writer.flush()

    # Compute the word error rate
    metric = torchmetrics.WordErrorRate()
    wer_greedy = metric(predicted_by_greedy, expected)
    wer_beam = metric(predicted_by_beam, expected)
    writer.add_scalar('validation wer with greedy search', wer_greedy, global_step)
    writer.flush()

    writer.add_scalar('validation wer with beam search', wer_beam, global_step)
    writer.flush()

    # Compute BLEU metric
    metric = torchmetrics.BLEUScore()
    bleu_greedy = metric(predicted_by_greedy, expected)
    bleu_beam = metric(predicted_by_beam, expected)

    writer.add_scalar('validation BLEU score with greedy search', bleu_greedy, global_step)
    writer.flush()

    writer.add_scalar('validation BLEU score with beam search', bleu_beam, global_step)
    writer.flush()

  # Refence
  # https://github.com/hkproj/pytorch-transformer
  # https://www.youtube.com/watch?v=ISNdQcPhsts


In [None]:
if __name__ == '__main__':
  warnings.filterwarnings("ignore")
  config = get_config()
  train_model(config)

using device: cuda
Max Length of source sentence: 309
Max Length of target sentence : 1
No model to preload, starting from scratch


Processing Epoch: 00:   3%|▎         | 113/3638 [00:09<05:10, 11.36it/s, loss=1.259]


KeyboardInterrupt: 