In [1]:
import torch
import torch.nn as nn
import math

In [2]:
class InputEmbeddings(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super().__init__()
        self.embed_size = embed_size
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, embed_size)

    def forward(self, x):
      return self.embedding(x) * math.sqrt(self.embed_size)


In [3]:
class PositionalEncoding(nn.Module):

  def __init__(self, embed_size, max_len, dropout):
    super().__init__()
    self.embed_size = embed_size
    self.max_len = max_len
    self.dropout = nn.Dropout(dropout)
    pe = torch.zeros(max_len, embed_size)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)

    self.register_buffer('pe', pe)

  def forward(self, x):
    x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)

    return self.dropout(x)

In [4]:
class LayerNormalization(nn.Module):

  def __init__(self, eps = 10**-6):
    super().__init__()
    self.eps = eps
    self.alpha = nn.Parameter(torch.ones(1))
    self.bias = nn.Parameter(torch.zeros(1))
  def forward(self, x):
    mean = x.mean(dim = -1, keepdim = True)
    std = x.std(dim = -1, keepdim = True)

    return self.alpha * (x - mean) / (std + self.eps) + self.bias

In [5]:
class FeedForwardBlock(nn.Module):

  def __init__(self, d_model, d_ff, dropout):
    super().__init__()
    self.linear_1 = nn.Linear(d_model, d_ff)
    self.dropout = nn.Dropout(dropout)
    self.linear_2 = nn.Linear(d_ff, d_model)

  def forward(self, x):
    return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

In [6]:
class MultiHeadAttention(nn.Module):

  def __init__(self, d_model, h, dropout):
    super().__init__()
    self.d_model = d_model
    self.h = h
    assert d_model % h == 0, "d_model is not divisible by h"
    self.d_k = d_model // h
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)
    self.w_o = nn.Linear(d_model, d_model)
    self.dropout = nn.Dropout(dropout)

  @staticmethod
  def attention(query, key, value, mask, dropout):
    d_k = query.shape[-1]

    attention_scores = (query @ key.transpose(-2,-1)) / math.sqrt(d_k)
    if mask is not None:
      attention_scores.masked_fill_(mask == 0, -1e9)
    attention_scores = attention_scores.softmax(dim = -1)
    if dropout is not None:
      attention_scores = dropout(attention_scores)
    return (attention_scores @ value), attention_scores

  def forward(self, q, k , v , mask):
    query = self.w_q(q)
    key = self.w_k(k)
    value = self.w_v(v)

    query = query.view(query.shape[0],query.shape[1],self.h,self.d_k).transpose(1,2)
    key = key.view(key.shape[0],key.shape[1],self.h,self.d_k).transpose(1,2)
    value = value.view(value.shape[0],value.shape[1],self.h,self.d_k).transpose(1,2)

    x, attention_scores = MultiHeadAttention.attention(query, key, value, mask, self.dropout)

    x = x.transpose(1,2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
    return self.w_o(x)

In [7]:
class ResidualLayer(nn.Module):
  def __init__(self, dropout):
    super().__init__()
    self.dropout = nn.Dropout(dropout)
    self.norm = LayerNormalization()

  def forward(self, x, sublayer):
    return x + self.dropout(sublayer(self.norm(x)))

In [8]:
class EncoderBlock(nn.Module):
  def __init__(self, self_attention, feed_forward, dropout):
    super().__init__()
    self.self_attention = self_attention
    self.feed_forward = feed_forward
    self.residual_layers = nn.ModuleList([ResidualLayer(dropout) for _ in range(2)])

  def forward(self, x, src_mask):
    x = self.residual_layers[0](x, lambda x: self.self_attention(x, x, x, src_mask))
    x = self.residual_layers[1](x, self.feed_forward)

    return x

In [9]:
class Encoder(nn.Module):
  def __init__(self, layers):
    super().__init__()
    self.layers = nn.ModuleList(layers)
    self.norm = LayerNormalization()

  def forward(self, x, mask):
    for layer in self.layers:
      x = layer(x, mask)
    return self.norm(x)

In [10]:
class DecoderBlock(nn.Module):

  def __init__(self, self_attention, cross_attention, feed_forward, dropout):
    super().__init__()
    self.self_attention = self_attention
    self.cross_attention = cross_attention
    self.feed_forward = feed_forward
    self.residual_layers = nn.ModuleList([ResidualLayer(dropout) for _ in range(3)])

  def forward(self,x, encoder_output, src_mask, tgt_mask):
    x = self.residual_layers[0](x, lambda x: self.self_attention(x, x, x, tgt_mask))
    x = self.residual_layers[1](x, lambda x: self.cross_attention(x, encoder_output, encoder_output, src_mask))
    x = self.residual_layers[2](x, self.feed_forward)
    return x

In [11]:
class Decoder(nn.Module):
  def __init__(self, layers):
    super().__init__()
    self.layers = nn.ModuleList(layers)
    self.norm = LayerNormalization()

  def forward(self, x, encoder_output, src_mask, tgt_mask):
    for layer in self.layers:
      x = layer(x, encoder_output, src_mask, tgt_mask)
    return self.norm(x)

In [12]:
class ProjectionLayer(nn.Module):

  def __init__(self, d_model, vocab_size):
    super().__init__()
    self.proj = nn.Linear(d_model, vocab_size)

  def forward(self, x):
    return torch.log_softmax(self.proj(x), dim = -1)

In [13]:
class Transformers(nn.Module):

  def __init__(self, encoder, decoder, src_embedding, tgt_embedding, src_pos, tgt_pos, projection):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.src_embedding = src_embedding
    self.tgt_embedding = tgt_embedding
    self.src_pos = src_pos
    self.tgt_pos = tgt_pos
    self.projection = projection

  def encode(self, src, src_mask):
    src = self.src_embedding(src)
    src = self.src_pos(src)

    return self.encoder(src, src_mask)

  def decode(self, encoder_output, src_mask, tgt, tgt_mask):
    tgt = self.tgt_embedding(tgt)
    tgt = self.tgt_pos(tgt)

    return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

  def project(self, x):
    return self.projection(x)

In [14]:
def build_transformer(src_vocab_size, tgt_vocab_size, src_seq_len, tgt_seq_len, d_model = 512, N = 6, num_heads = 8, ff_dim = 2048, dropout = 0.1):
  src_embed = InputEmbeddings(src_vocab_size, d_model)
  tgt_embed = InputEmbeddings(tgt_vocab_size, d_model)

  src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
  tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)

  encoder_blocks = []

  for _ in range(N):
    encoder_self_attention = MultiHeadAttention(d_model, num_heads, dropout)
    feed_forward = FeedForwardBlock(d_model, ff_dim, dropout)
    encoder_blocks.append(EncoderBlock(encoder_self_attention, feed_forward, dropout))

    decoder_blocks = []

    for _ in range(N):
      decoder_self_attention = MultiHeadAttention(d_model, num_heads, dropout)
      decoder_cross_attention = MultiHeadAttention(d_model, num_heads, dropout)
      feed_forward = FeedForwardBlock(d_model, ff_dim, dropout)
      decoder_blocks.append(DecoderBlock(decoder_self_attention, decoder_cross_attention, feed_forward, dropout))

    encoder = Encoder(nn.ModuleList(encoder_blocks))
    decoder = Decoder(nn.ModuleList(decoder_blocks))

    projection = ProjectionLayer(d_model, tgt_vocab_size)

    transformer = Transformers(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection)

    for p in transformer.parameters():
      if p.dim() > 1:
        nn.init.xavier_uniform_(p)

    return transformer

#Downstream Task: Translation

In [15]:
!pip install datasets tokenizers

Collecting datasets
  Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.17-py310-none-any.whl.metadata (7.2 kB)
INFO: pip is looking at multiple versions of multiprocess to determine which version is compatible with other requirements. This could take a while.
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.0.1-py3-none-any.whl (471 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m471.6/471.6 kB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-an

In [16]:
from torch.utils.data import Dataset, DataLoader, random_split
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path

In [17]:
def causal_mask(size):
  mask = torch.triu(torch.ones(1,size,size), diagonal = 1).type(torch.int)
  return mask == 0

In [18]:
class BilingualDataset(Dataset):

  def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
    super().__init__()
    self.seq_len = seq_len
    self.ds = ds
    self.tokenizer_src = tokenizer_src
    self.tokenizer_tgt = tokenizer_tgt
    self.src_lang = src_lang
    self.tgt_lang = tgt_lang
    self.sos_token = torch.tensor([tokenizer_src.token_to_id("[SOS]")], dtype = torch.int64)
    self.eos_token = torch.tensor([tokenizer_src.token_to_id("[EOS]")], dtype = torch.int64)
    self.pad_token = torch.tensor([tokenizer_src.token_to_id("[PAD]")], dtype = torch.int64)

  def __len__(self):
    return len(self.ds)

  def __getitem__(self, idx):
    src_target_pair = self.ds[idx]
    src_text = src_target_pair['translation'][self.src_lang]
    tgt_text = src_target_pair['translation'][self.tgt_lang]

    enc_input_tokens = self.tokenizer_src.encode(src_text).ids
    dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

    enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
    dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

    if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
      raise ValueError("Sentence is too long")

    encoder_input = torch.cat(
        [
            self.sos_token,
            torch.Tensor(enc_input_tokens),
            self.eos_token,
            torch.Tensor([self.pad_token] * enc_num_padding_tokens),
        ]
    )

    decoder_input = torch.cat(
        [
            self.sos_token,
            torch.Tensor(dec_input_tokens),
            torch.Tensor([self.pad_token] * dec_num_padding_tokens),
        ]
    )

    label = torch.cat(
        [
            torch.Tensor(dec_input_tokens),
            self.eos_token,
            torch.Tensor([self.pad_token] * dec_num_padding_tokens),
        ]
    )

    assert encoder_input.size(0) == self.seq_len
    assert decoder_input.size(0) == self.seq_len
    assert label.size(0) == self.seq_len

    return {
        'encoder_input': encoder_input,
        'decoder_input': decoder_input,
        'encoder_mask': (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(),
        'decoder_mask': (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)),
        'label': label,
        'src_text': src_text,
        'tgt_text': tgt_text,
    }


In [19]:
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
  sos_idx = tokenizer_tgt.token_to_id('[SOS]')
  eos_idx = tokenizer_tgt.token_to_id('[EOS]')

  encoder_output = model.encode(source, source_mask)
  decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(source).to(device)
  while True:
    if decoder_input.size(1) == max_len:
      break

    decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

    out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

    prob = model.project(out[:,-1])
    _, next_word = torch.max(prob, dim = 1)
    decoder_input = torch.cat(
        [decoder_input, torch.empty(1,1).type_as(source).fill_(next_word.item()).to(device)], dim = 1
    )

    if next_word == eos_idx:
      break

  return decoder_input.squeeze(0)

In [20]:
def run_val(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, num_examples = 2):
  model.eval()
  count = 0

  source_texts = []
  expected = []
  predicted = []

  console_width = 80
  with torch.no_grad():
    for batch in validation_ds:
      count += 1
      encoder_input = batch['encoder_input'].to(device).long()
      encoder_mask = batch['encoder_mask'].to(device)
      decoder_input = batch['decoder_input'].to(device).long()
      decoder_mask = batch['decoder_mask'].to(device)
      label = batch['label'].to(device).long()

      model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
      source_text = batch['src_text'][0]
      target_text = batch['tgt_text'][0]
      model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

      source_texts.append(source_text)
      expected.append(target_text)
      predicted.append(model_out_text)

      print_msg("-"*console_width)
      print(f"I: {source_text}")
      print(f"T: {target_text}")
      print(f"O: {model_out_text}\n")

      if count == num_examples:
        break


In [21]:
def get_all_sentences(ds, lang):
  for item in ds:
    yield item['translation'][lang]

In [22]:
def get_or_build_tokenizer(config, ds, lang):
  tokenizer_path = Path(config['tokenizer_file'].format(lang))
  if not Path.exists(tokenizer_path):
    tokenizer = Tokenizer(WordLevel(unk_token = "[UNK]"))
    tokenizer.pre_tokenizer = Whitespace()
    trainer = WordLevelTrainer(special_tokens = ["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency = 2)
    tokenizer.train_from_iterator(get_all_sentences(ds,lang), trainer = trainer)
    tokenizer.save(str(tokenizer_path))
  else:
    tokenizer = Tokenizer.from_file(str(tokenizer_path))
  return tokenizer

In [23]:
def get_ds(config):
  ds_raw = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_tgt"]}', split = 'train')
  tokenizer_src  = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
  tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

  train_ds_size = int(.9 * len(ds_raw))
  val_ds_size = len(ds_raw) - train_ds_size
  train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

  train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
  val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

  max_len_src = 0
  max_len_tgt = 0

  for item in ds_raw:
    src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
    tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
    max_len_src = max(max_len_src, len(src_ids))
    max_len_tgt = max(max_len_tgt, len(tgt_ids))

  print(f"Max length of source sentence: {max_len_src}")
  print(f"Max length of target sentence: {max_len_tgt}")

  train_dataloader = DataLoader(train_ds, batch_size = config['batch_size'], shuffle = True)
  val_dataloader = DataLoader(val_ds, batch_size = 1, shuffle = True)

  return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

In [24]:
def get_model(config, vocab_size_src, vocab_size_tgt):
  model = build_transformer(vocab_size_src, vocab_size_tgt, config['seq_len'], config['seq_len'], d_model = config['d_model'])
  return model

In [25]:
def get_config():
  return {
      "batch_size": 8,
      "num_epochs": 20,
      "lr": 10**-4,
      "seq_len": 350,
      "d_model": 512,
      "lang_src": "en",
      "lang_tgt": "it",
      "model_folder": "weights",
      "model_basename": "tmodel_",
      "preload": None,
      "tokenizer_file": "tokenizer_{0}.json",
      "experiment_name": "runs/tmodel"
  }

In [26]:
def get_weights_file_path(config, epoch):
  model_folder = config['model_folder']
  model_basename = config['model_basename']
  model_filename = f"{model_basename}{epoch}.pt"
  return str(Path('.') / model_folder / model_filename)

In [27]:
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

In [31]:
def train_model(config):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  print(f"Using device {device}")

  Path(config['model_folder']).mkdir(parents = True, exist_ok = True)
  train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)

  model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

  writer = SummaryWriter(config['experiment_name'])

  optimizer = torch.optim.Adam(model.parameters(), lr = config['lr'], eps = 1e-9)

  initial_epoch = 0
  global_step = 0
  if config['preload']:
    model_filename = get_weights_file_path(config, config['preload'])
    print(f"Preloading model {model_filename}")
    state = torch.load(model_filename)
    model.load_state_dict(state['model_state_dict'])
    initial_epoch = state['epoch'] + 1
    optimizer.load_state_dict(state['optimizer_state_dict'])
    global_step = state['global_step']

  loss_fn = nn.CrossEntropyLoss(ignore_index = tokenizer_src.token_to_id('[PAD]'), label_smoothing = 0.1).to(device)

  for epoch in range(initial_epoch, config['num_epochs']):

    batch_iterator = tqdm(train_dataloader, desc = f"Processing Epoch {epoch:02d}")
    for batch in batch_iterator:
      model.train()
      encoder_input = batch['encoder_input'].to(device).long()
      decoder_input = batch['decoder_input'].to(device).long()
      encoder_mask = batch['encoder_mask'].to(device)
      decoder_mask = batch['decoder_mask'].to(device)

      encoder_output = model.encode(encoder_input, encoder_mask)
      decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
      proj_output = model.project(decoder_output)

      label = batch['label'].to(device)
      loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1).long())
      batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

      writer.add_scalar('train loss', loss.item(), global_step)
      writer.flush()
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()
      # run_val(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg))
      global_step += 1

    model_filename = get_weights_file_path(config, f"{epoch:02d}")
    torch.save({
      'epoch': epoch,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'global_step': global_step
    }, model_filename)

In [None]:
if __name__ == '__main__':
  config = get_config()
  train_model(config)

Using device cuda
Max length of source sentence: 309
Max length of target sentence: 274


Processing Epoch 00:  52%|█████▏    | 1881/3638 [09:30<08:52,  3.30it/s, loss=6.316]

In [None]:
run_val(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg))
