<a href="https://colab.research.google.com/github/MaazMikail/Vanilla-Transformer-Notebook/blob/main/Vanilla_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import math

In [2]:
class InputEmbeddings(nn.Module):
  def __init__(self, d_model: int, vocab_size: int):
    super().__init__()
    self.d_model = d_model
    self.vocab_size = vocab_size
    self.emb = nn.Embedding(vocab_size, d_model)

  def forward(self, x):
    return self.emb(x) + torch.sqrt(torch.tensor(self.d_model))

In [3]:
embed = InputEmbeddings(512, 4096)
res = embed(torch.tensor([143, 891, 1000, 482, 18, 12]))
res.shape

torch.Size([6, 512])

In [4]:
class PositionalEncoding(nn.Module):

  def __init__(self, d_model: int, seq_len: int, dropout: float):
    super().__init__()
    self.d_model = d_model
    self.seq_len = seq_len
    self.dropout = nn.Dropout()

    self.pe = torch.zeros(self.seq_len, self.d_model)

    position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0,self.d_model,2).float() * (-math.log(10000.0)/self.d_model))

    self.pe[:, 0::2] = torch.sin(position * div_term)
    self.pe[:, 1::2] = torch.cos(position * div_term)

    self.pe = self.pe.unsqueeze(0)
    self.pe.requires_grad = False


    self.register_buffer('pos', self.pe)


  def forward(self, x):


    x = x + self.pe

    return self.dropout(x)



In [5]:
pos = PositionalEncoding(512, 6,0)
#res = res.unsqueeze(0)
print(res.shape)

torch.Size([6, 512])


In [6]:
embs_pos = pos(res)
embs_pos.shape

torch.Size([1, 6, 512])

In [7]:
class LayerNorm(nn.Module):

  def __init__(self, eps: float = 1e-6):
    super().__init__()
    self.eps = eps
    self.alpha = nn.Parameter(torch.ones(1)) # mult
    self.bias = nn.Parameter(torch.ones(1)) # added

  def forward(self, x):
    mean = x.mean(dim=-1, keepdim=True)
    std = x.std(dim = -1, keepdim=True)
    return self.alpha * (x- mean) / (std + self.eps) + self.bias



In [8]:
ln = LayerNorm()
emb_norm = ln(embs_pos)
emb_norm.shape

torch.Size([1, 6, 512])

In [9]:
emb_norm[:, :, -1].mean()

tensor(1.0178, grad_fn=<MeanBackward0>)

In [10]:
embs_pos[:, :, -1].mean()

tensor(23.7750, grad_fn=<MeanBackward0>)

In [11]:
class FeedForward(nn.Module):

  def __init__(self, d_model: int, d_ff:int, dropout: float):
    super().__init__()
    self.ff_1 = nn.Linear(d_model, d_ff)
    self.ff_2 = nn.Linear(d_ff, d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    return self.ff_2(self.dropout(nn.functional.relu(self.ff_1(x))))

In [12]:
ff = FeedForward(512, 2048, 0)
ff_out = ff(embs_pos)
ff_out.shape

torch.Size([1, 6, 512])

In [13]:
class MultiHeadAttentionBlock(nn.Module):

  def __init__(self, d_model: int, h: int, dropout: float):
    super().__init__()
    self.d_model = d_model
    self.h = h
    self.dropout = nn.Dropout()
    assert self.d_model % h == 0, "d_model is not divisible by h"
    self.d_k = d_model // h

    self.w_q, self.w_k, self.w_v = [nn.Linear(d_model, d_model) for _ in range(3)]

    self.w_o = nn.Linear(d_model, d_model)

  @staticmethod
  def attention(query, key, value, mask, dropout: nn.Dropout):

    d_k = query.shape[-1]
    attention_scores = (query @ key.transpose(-2,-1)) / math.sqrt(d_k)
    if mask is not None:
      attention_scores.masked_fill_(mask==0, -1e9)
    attention_scores = attention_scores.softmax(dim=-1)
    if dropout is not None:
      attention_scores = dropout(attention_scores)

    return (attention_scores @ value), attention_scores

  def forward(self, q, k, v, mask):
    query = self.w_q(q)
    key = self.w_k(k)
    value = self.w_v(v)

    query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1,2)
    key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1,2)
    value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1,2)

    x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)

    x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h*self.d_k)

    return self.w_o(x)


In [14]:
mah = MultiHeadAttentionBlock(512, 8, 0)
mah.w_q.weight.shape, mah.w_q.bias.shape

(torch.Size([512, 512]), torch.Size([512]))

In [15]:
mah_res = mah(embs_pos, embs_pos, embs_pos, None)
mah_res.shape

torch.Size([1, 6, 512])

In [16]:
x, scores = MultiHeadAttentionBlock.attention(mah_res, mah_res, mah_res, None, None)

In [78]:
scores.shape

torch.Size([1, 8, 6, 6])

In [91]:
x.shape

torch.Size([1, 8, 6, 64])

In [21]:
class ResidualConnection(nn.Module):

  def __init__(self, dropout: float):
    super().__init__()
    self.dropout = nn.Dropout(dropout)
    self.norm = LayerNorm()

  def forward(self, x, sublayer):
    return x + self.dropout(sublayer(self.norm(x)))

In [22]:
rc = ResidualConnection(0)
rc

ResidualConnection(
  (dropout): Dropout(p=0, inplace=False)
  (norm): LayerNorm()
)

In [23]:
class EncoderBlock(nn.Module):

  def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForward, dropout: float):
    super().__init__()
    self.attention_block = self_attention_block
    self.feed_forward_block = feed_forward_block
    self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])


  def forward(self, x, src_mask):
    x = self.residual_connections[0](x, lambda x: self.attention_block(x,x,x,src_mask))
    x = self.residual_connections[1](x, self.feed_forward_block)

    return x


In [24]:
en = EncoderBlock(mah, ff, 0)

In [25]:
all_result = en(embs_pos, None)

In [26]:
all_result.shape

torch.Size([1, 6, 512])

In [27]:
class Encoder(nn.Module):

  def __init__(self, layers: nn.ModuleList):
    super().__init__()
    self.layers = layers
    self.norm = LayerNorm()

  def forward(self, x, mask):
    for layer in self.layers:
      x = layer(x, mask)
    return self.norm(x)

In [28]:
encoder = Encoder(nn.ModuleList([EncoderBlock(mah, ff, 0) for _ in range(5)]))
encoder

Encoder(
  (layers): ModuleList(
    (0-4): 5 x EncoderBlock(
      (attention_block): MultiHeadAttentionBlock(
        (dropout): Dropout(p=0.5, inplace=False)
        (w_q): Linear(in_features=512, out_features=512, bias=True)
        (w_k): Linear(in_features=512, out_features=512, bias=True)
        (w_v): Linear(in_features=512, out_features=512, bias=True)
        (w_o): Linear(in_features=512, out_features=512, bias=True)
      )
      (feed_forward_block): FeedForward(
        (ff_1): Linear(in_features=512, out_features=2048, bias=True)
        (ff_2): Linear(in_features=2048, out_features=512, bias=True)
        (dropout): Dropout(p=0, inplace=False)
      )
      (residual_connections): ModuleList(
        (0-1): 2 x ResidualConnection(
          (dropout): Dropout(p=0, inplace=False)
          (norm): LayerNorm()
        )
      )
    )
  )
  (norm): LayerNorm()
)

In [29]:
encoder_result = encoder(embs_pos, None)
encoder_result.shape

torch.Size([1, 6, 512])

In [30]:
class DecoderBlock(nn.Module):

  def __init__(self, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward: FeedForward, dropout: float):
    super().__init__()
    self.self_attention_block = self_attention_block
    self.cross_attention_block = cross_attention_block
    self.feed_forward_block = feed_forward
    self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])

  def forward(self, x, encoder_output, src_mask, tgt_mask):
    x = self.residual_connections[0](x, lambda x: self.self_attention_block(x,x,x,tgt_mask))
    x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
    x = self.residual_connections[2](x, self.feed_forward_block)

    return x

In [31]:
dec_block = DecoderBlock(mah, mah, ff, 0)
dec_block

DecoderBlock(
  (self_attention_block): MultiHeadAttentionBlock(
    (dropout): Dropout(p=0.5, inplace=False)
    (w_q): Linear(in_features=512, out_features=512, bias=True)
    (w_k): Linear(in_features=512, out_features=512, bias=True)
    (w_v): Linear(in_features=512, out_features=512, bias=True)
    (w_o): Linear(in_features=512, out_features=512, bias=True)
  )
  (cross_attention_block): MultiHeadAttentionBlock(
    (dropout): Dropout(p=0.5, inplace=False)
    (w_q): Linear(in_features=512, out_features=512, bias=True)
    (w_k): Linear(in_features=512, out_features=512, bias=True)
    (w_v): Linear(in_features=512, out_features=512, bias=True)
    (w_o): Linear(in_features=512, out_features=512, bias=True)
  )
  (feed_forward_block): FeedForward(
    (ff_1): Linear(in_features=512, out_features=2048, bias=True)
    (ff_2): Linear(in_features=2048, out_features=512, bias=True)
    (dropout): Dropout(p=0, inplace=False)
  )
  (residual_connections): ModuleList(
    (0-2): 3 x Resi

In [32]:
class Decoder(nn.Module):

  def __init__(self, layers: nn.ModuleList):
    super().__init__()
    self.layers = layers
    self.norm = LayerNorm()

  def forward(self, x, encoder_output, src_mask, tgt_mask):
    for layer in self.layers:
      x = layer(x, encoder_output, src_mask, tgt_mask)
    return self.norm(x)

In [33]:
dec = Decoder(nn.ModuleList([DecoderBlock(mah,mah, ff, 0) for _ in range(5)]))
dec

Decoder(
  (layers): ModuleList(
    (0-4): 5 x DecoderBlock(
      (self_attention_block): MultiHeadAttentionBlock(
        (dropout): Dropout(p=0.5, inplace=False)
        (w_q): Linear(in_features=512, out_features=512, bias=True)
        (w_k): Linear(in_features=512, out_features=512, bias=True)
        (w_v): Linear(in_features=512, out_features=512, bias=True)
        (w_o): Linear(in_features=512, out_features=512, bias=True)
      )
      (cross_attention_block): MultiHeadAttentionBlock(
        (dropout): Dropout(p=0.5, inplace=False)
        (w_q): Linear(in_features=512, out_features=512, bias=True)
        (w_k): Linear(in_features=512, out_features=512, bias=True)
        (w_v): Linear(in_features=512, out_features=512, bias=True)
        (w_o): Linear(in_features=512, out_features=512, bias=True)
      )
      (feed_forward_block): FeedForward(
        (ff_1): Linear(in_features=512, out_features=2048, bias=True)
        (ff_2): Linear(in_features=2048, out_features=512,

In [34]:
decoder_output = dec(embs_pos, encoder_result, None, None)
decoder_output.shape

torch.Size([1, 6, 512])

In [35]:
class ProjectionLayer(nn.Module):

  def __init__(self, d_model:int, vocab_size: int):
    super().__init__()
    self.proj = nn.Linear(d_model, vocab_size)

  def forward(self, x):
    x = self.proj(x)

    return torch.log_softmax(x, dim=-1)


In [36]:
proj = ProjectionLayer(512, 4096)
proj = proj(decoder_output)
proj.shape

torch.Size([1, 6, 4096])

In [37]:
class TransformerBlock(nn.Module):

  def __init__(self, encoder: Encoder, decoder: Decoder, src_emb: InputEmbeddings, tgt_emb: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, proj: ProjectionLayer):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.src_emb = src_emb
    self.tgt_emb = tgt_emb
    self.src_pos = src_pos
    self.tgt_pos = tgt_pos
    self.proj = proj

  def encode(self, src, src_mask):
    src = self.src_embed(src)
    src = self.src_pos(src)
    return self.encoder(src, src_mask)

  def decode(self, encoder_output, src_mask, tgt, tgt_mask):
    tgt = self.tgt_embed(tgt)
    tgt = self.tgt_pos(tgt)
    return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

  def project(self, x):
    return self.proj(x)




In [38]:
def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len:int, d_ff: int = 2048, d_model: int = 512, N: int = 6, h: int = 8, dropout: float = 0.0):

  # embed layer
  src_embed = InputEmbeddings(d_model, src_vocab_size)
  tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

  # pos embs
  src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
  tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)


  # encoder
  encoder_blocks = []
  for _ in range(N):
    encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
    feed_forward_block = FeedForward(d_model, d_ff, dropout)
    encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
    encoder_blocks.append(encoder_block)

  # decoder
  decoder_blocks = []
  decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
  decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
  feed_forward_block = FeedForward(d_model, d_ff, dropout)
  decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
  decoder_blocks.append(decoder_block)

  # combine
  encoder = Encoder(nn.ModuleList(encoder_blocks))
  decoder = Decoder(nn.ModuleList(decoder_blocks))


  # projection
  projection_layer = ProjectionLayer(d_model, tgt_vocab_size)

  # transformer

  transformer = TransformerBlock(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)

  for p in transformer.parameters():
    if p.dim()>1:
      nn.init.xavier_uniform_(p)


  return transformer


In [39]:
tf = build_transformer(4096, 4096, 6, 6)

In [40]:
tf

TransformerBlock(
  (encoder): Encoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderBlock(
        (attention_block): MultiHeadAttentionBlock(
          (dropout): Dropout(p=0.5, inplace=False)
          (w_q): Linear(in_features=512, out_features=512, bias=True)
          (w_k): Linear(in_features=512, out_features=512, bias=True)
          (w_v): Linear(in_features=512, out_features=512, bias=True)
          (w_o): Linear(in_features=512, out_features=512, bias=True)
        )
        (feed_forward_block): FeedForward(
          (ff_1): Linear(in_features=512, out_features=2048, bias=True)
          (ff_2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (residual_connections): ModuleList(
          (0-1): 2 x ResidualConnection(
            (dropout): Dropout(p=0.0, inplace=False)
            (norm): LayerNorm()
          )
        )
      )
    )
    (norm): LayerNorm()
  )
  (decoder): Decoder(
    (l

In [36]:
# Get named learnable parameters
learnable_named_params = [(name, param) for name, param in tf.named_parameters() if param.requires_grad]

for name, param in learnable_named_params:
    print(f"Parameter name: {name}, value: {param.shape}")


Parameter name: encoder.layers.0.attention_block.w_q.weight, value: torch.Size([512, 512])
Parameter name: encoder.layers.0.attention_block.w_q.bias, value: torch.Size([512])
Parameter name: encoder.layers.0.attention_block.w_k.weight, value: torch.Size([512, 512])
Parameter name: encoder.layers.0.attention_block.w_k.bias, value: torch.Size([512])
Parameter name: encoder.layers.0.attention_block.w_v.weight, value: torch.Size([512, 512])
Parameter name: encoder.layers.0.attention_block.w_v.bias, value: torch.Size([512])
Parameter name: encoder.layers.0.attention_block.w_o.weight, value: torch.Size([512, 512])
Parameter name: encoder.layers.0.attention_block.w_o.bias, value: torch.Size([512])
Parameter name: encoder.layers.0.feed_forward_block.ff_1.weight, value: torch.Size([2048, 512])
Parameter name: encoder.layers.0.feed_forward_block.ff_1.bias, value: torch.Size([2048])
Parameter name: encoder.layers.0.feed_forward_block.ff_2.weight, value: torch.Size([512, 2048])
Parameter name: enc

In [41]:
# Calculate the total number of learnable parameters
learnable_params = [param for param in tf.parameters() if param.requires_grad]
total_learnable_params = sum(p.numel() for p in learnable_params)
print(f"Total Learnable params : {(total_learnable_params / 1000000):.1f} Million")


Total Learnable params : 29.4 Million


#TRAINING

In [42]:
!pip install datasets
!pip install transformers

Collecting datasets
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-2.21.0-py3-none-any.whl (527 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.3/527.3 kB[0m [31m25.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)
[2

In [43]:
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path

from torch.utils.data import Dataset, DataLoader, random_split

In [44]:
def get_all_sentences(ds, lang):
  for item in ds:
    yield item['translation'][lang]


def get_or_build_tokenizer(config, ds, lang):
  tokenizer_path = Path(config['tokenizer_file'].format(lang))
  if not Path.exists(tokenizer_path):
    tokenizer = Tokenizer(WordLevel(unk_token='[UNK]'))
    tokenizer.pre_tokenizer = Whitespace()
    trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[EOS]"], min_frequency=2)
    tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
    tokenizer.save(str(tokenizer_path))

  else:
    tokenizer = Tokenizer.from_file(str(tokenizer_path))

  return tokenizer

In [45]:
class BilingialDataset(Dataset):
  def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
    super().__init__()

    self.ds = ds
    self.tokenizer_src = tokenizer_src
    self.tokenizer_tgt = tokenizer_tgt
    self.src_lang = src_lang
    self.tgt_lang = tgt_lang


    self.sos_token = torch.Tensor([tokenizer_src.token_to_id(['[SOS]'])], dtype=torch.int64)
    self.eos_token = torch.Tensor([tokenizer_src.token_to_id(['[EOS]'])], dtype=torch.int64)
    self.pad_token = torch.Tensor([tokenizer_src.token_to_id(['[PAD]'])], dtype=torch.int64)


  def __len__(self):
    return len(self.ds)

  def __getitem__(self, index: int):
    src_target_pair = self.ds[index]
    src_text = src_target_pair['translation'][self.src_lang]
    tgt_text = src_target_pair['translation'][self.tgt_lang]

    enc_input_tokens =self.tokenizer_src.encode(src_text).ids
    dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

    enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
    dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

    if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
      raise ValueError("Sentence is too long")

    encoder_input = torch.cat(

                              [self.sos_token,
                               torch.tensor(enc_input_tokens, dtype=torch.int64),
                               self.eos_token,
                               torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64)]
    )

    decoder_input = torch.cat(

                              [self.sos_token,
                               torch.tensor(dec_input_tokens, dtype=torch.int64),
                               torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)]
    )

    label = torch.cat(
        [
            torch.tensor(dec_input_tokens, dtype=torch.int64),
            self.eos_token,
            torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
        ]
    )


    assert encoder_input.size(0) == self.seq_len
    assert decoder_input.size(0) == self.seq_len
    assert label.size(0) == self.seq_len


    return {"encoder_input": encoder_input,
            "decoder_input": decoder_input,
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(),
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).unsuqeeze(0).int() & casual_mask(decoder_input.size(0)),
            "label":label,
            "src_text": src_text,
            "tgt_text": tgt_text}






In [46]:
def casual_mask(size):
  mask = torch.triu(torch.ones(1, size, size),diagonal=1).type(torch.int)
  return mask == 0

In [47]:
def get_ds(config):
  ds_raw = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_tgt"]}', split='train')

  tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
  tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

  train_ds_size = int(0.9 * len(ds_raw))
  val_ds_size = len(ds_raw) - train_ds_size
  train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

  train_ds = BilingialDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
  val_ds = BilingialDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

  max_len_src = 0
  max_len_tgt = 0

  for item in ds_raw:
    src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
    tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
    max_len_src = max(max_len_src, len(src_ids))
    max_len_tgt = max(max_len_tgt, len(tgt_ids))


  print("Max length of source: ", max_len_src)
  print("Max length of tgt: ", max_len_tgt)


  train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
  val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)


  return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt







In [48]:
def get_model(config, vocab_src_len, vocab_tgt_len):
  model = build_transformer(vocab_src_len, vocab_tgt_len, config['seq_len'], config['seq_len'], config['d_model'])

  return model

In [49]:
from pathlib import Path

def get_config():
    return {
        "batch_size": 8,
        "num_epochs": 20,
        "lr": 10**-4,
        "seq_len": 350,
        "d_model": 512,
        "datasource": 'opus_books',
        "lang_src": "en",
        "lang_tgt": "it",
        "model_folder": "weights",
        "model_basename": "tmodel_",
        "preload": "latest",
        "tokenizer_file": "tokenizer_{0}.json",
        "experiment_name": "runs/tmodel"
    }

def get_weights_file_path(config, epoch: str):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}{epoch}.pt"
    return str(Path('.') / model_folder / model_filename)

# Find the latest weights file in the weights folder
def latest_weights_file_path(config):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}*"
    weights_files = list(Path(model_folder).glob(model_filename))
    if len(weights_files) == 0:
        return None
    weights_files.sort()
    return str(weights_files[-1])

In [51]:
from torch.utils.tensorboard import SummaryWriter
import tqdm

In [52]:
def train_model(config):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  print("Using device: ", device)

  Path(config['model_folder']).mkdir(parents=True, exist_ok=True)

  train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
  model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
  writer = SummaryWriter(config['experiment_name'])


  optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

  initial_epoch = 0
  global_step = 0
  if config['preload']:
    model_filename = get_weights_file_path(config, config['preload'])
    print(f"pre-loading model: {model_filename}")
    state = torch.load(model_filename)
    initial_epoch = state["epoch"] + 1
    optimizer.load_state_dict(state['optimizer_state_dict'])
    global_step = state['global_step']


  loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

  for epoch in range(initial_epoch, config['num_epochs']):
    model.train()
    batch_iterator = tqdm(train_dataloader, desc=f'Processing epoch: {epoch:02d}')

    for batch in batch_iterator:
      encoder_input = batch['encoder_input'].to(device)
      decoder_input = batch['decoder_input'].to(device)
      encoder_mask = batch['encoder_mask'].to(device)
      decoder_mask = batch['decoder_mask'].to(device)

      encoder_output = model.encode(encoder_input, encoder_mask)
      decoder_output = model.decode(encoder_output, decoder_input, decoder_mask)
      proj_output = model.project(decoder_output)

      label = batch['label'].to(device)
      loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))

      batch_iterator.set_postfix({f"loss": f"{loss.item:6.3f}"})
      writer.add_scaler("train loss", loss.item(), global_step)
      writer.flush()

      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      global_step+=1

    model_filename = get_weights_file_path(config, f'{epoch:02d}')

    torch.save({

                'epoch':epoch,
                'model_state_dict':model.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),
                'global_step': global_step}, model_filename

    )