In [None]:
import torch
import torch.nn as nn

import math

from torch.nn.utils.rnn import pad_sequence
from datasets import load_dataset

# Purpose
I am gonna develop a translation model based on attention (both encoding and decoding)

From **SCRATCH**!!!!**!!!**

English -> Italian

In [4]:
ds = load_dataset("Helsinki-NLP/opus-100", "en-it")

example = ds['train'][100]
print(example)

Generating test split: 100%|██████████| 2000/2000 [00:00<00:00, 156419.25 examples/s]
Generating train split: 100%|██████████| 1000000/1000000 [00:00<00:00, 2461701.71 examples/s]
Generating validation split: 100%|██████████| 2000/2000 [00:00<00:00, 681169.96 examples/s]

{'translation': {'en': "What's going on?", 'it': 'Che succede?'}}





Process dataset into training and val

In [5]:
# This function will process a "batch" of examples at once
def extract_translations(batch):
  return {
      'en_text': [t['en'] for t in batch['translation']],
      'it_text': [t['it'] for t in batch['translation']],
  }

# .map() will apply this function to the whole dataset very quickly
# batched=True is the key to making it fast
processed_ds_train = ds['train'].map(extract_translations, batched=True)

x_train = processed_ds_train['en_text']
y_train = processed_ds_train['it_text']

processed_ds_val = ds['validation'].map(extract_translations, batched=True)

x_val = processed_ds_val['en_text']
y_val = processed_ds_val['it_text']

Map: 100%|██████████| 1000000/1000000 [00:04<00:00, 236131.25 examples/s]
Map: 100%|██████████| 2000/2000 [00:00<00:00, 167340.42 examples/s]


In [6]:
print("English (x):", x_train[:5])
print("Italian (y):", y_train[:5])

English (x): ['- Thanks, buddy.', 'Say it.', 'Sodium triphosphate (sodium tripolyphosphates)', 'Surely, he is ardent in his love of wealth.', 'ANNEX I']
Italian (y): ['- Grazie, amico.', 'Dillo.', 'Trifosfato di sodio (tripolifosfato di sodio)', 'Invero è avido per amore delle ricchezze!', 'ALLEGATO I']


Character level tokenizer just like karpathy tutorials

In [7]:
all_text = "".join(x_train) + "".join(y_train)
chars = sorted(list(set(all_text)))

stoi = {c:i for i, c in enumerate(chars)}
stoi['<PAD>'] = len(stoi)

itos = {i: ch for ch, i in stoi.items()}

Config

In [8]:
vocab_size = len(stoi)
d_embd = 128
context_window = 128

Pick n samples for experimenting

In [9]:
sample_size = 100

In [10]:
# filter for sequences with max length = context window
max_len = context_window
filtered_pairs = [(en, it) for en, it in zip(x_train[:sample_size], y_train[:sample_size])
                  if len(en) <= max_len and len(it) <= max_len]

x_train_s, y_train_s = zip(*filtered_pairs) if filtered_pairs else ([], [])

x_train_s = list(x_train_s)
y_train_s = list(y_train_s)

x_encoded_s = [[stoi[char] for char in text] for text in x_train_s]
y_encoded_s = [[stoi[char] for char in text] for text in y_train_s]

x_tensors = [torch.tensor(seq) for seq in x_encoded_s]
y_tensors = [torch.tensor(seq) for seq in y_encoded_s]

x_t_s = pad_sequence(x_tensors, batch_first=True, padding_value=stoi['<PAD>'])
y_t_s = pad_sequence(y_tensors, batch_first=True, padding_value=stoi['<PAD>'])

print("Encoded English tensor (x):", x_t_s)
print("Encoded Italian tensor (y):", y_t_s)
print("Shape of x:", x_t_s.shape)
print("Shape of y:", y_t_s.shape)
print(f"Filtered from {sample_size}")

Encoded English tensor (x): tensor([[  13,    0,   52,  ..., 1290, 1290, 1290],
        [  51,   65,   89,  ..., 1290, 1290, 1290],
        [  51,   79,   68,  ..., 1290, 1290, 1290],
        ...,
        [  59,   39,   47,  ..., 1290, 1290, 1290],
        [  47,   72,   12,  ..., 1290, 1290, 1290],
        [  41,   83,    0,  ..., 1290, 1290, 1290]])
Encoded Italian tensor (y): tensor([[  13,    0,   39,  ..., 1290, 1290, 1290],
        [  36,   73,   76,  ..., 1290, 1290, 1290],
        [  52,   82,   73,  ..., 1290, 1290, 1290],
        ...,
        [  43,   65,   75,  ..., 1290, 1290, 1290],
        [  13,    0,   47,  ..., 1290, 1290, 1290],
        [  37,    7,    0,  ..., 1290, 1290, 1290]])
Shape of x: torch.Size([93, 69])
Shape of y: torch.Size([93, 94])
Filtered from 100


Embedding (Token and Positional)

In [None]:
class TokenEmbedding(nn.Module):
  def __init__(self, vocab_size, d_embd, padding_idx=None):
    super().__init__()
    # for translation task we should have padding_idx = stoi['<PAD>']
    self.embd = nn.Embedding(vocab_size, d_embd, padding_idx=padding_idx)

  def forward(self, x):
    return self.embd(x)

In [None]:
class PositionalEmbedding(nn.Module):
  def __init__(self, n_tokens, d_embd):
    super().__init__()
    self.embd = nn.Embedding(n_tokens, d_embd)

  def forward(self, x):
    _, T, _ = x.shape
    pos = torch.arange(T, device=x.device)
    return self.embd(pos)

Single Head (with optional masking)

In [None]:
class Head(nn.Module):
  def __init__(self, d_embd, head_size, dropout=0.1):
    super().__init__()
    self.query = nn.Linear(d_embd, head_size, bias=False)
    self.key = nn.Linear(d_embd, head_size, bias=False)
    self.value = nn.Linear(d_embd, head_size, bias=False)
    self.dropout = nn.Dropout(dropout)
    self.register_buffer('tril', torch.tril(torch.ones(context_window, context_window)))

  def forward(self, x, src=None, key_padding_mask=None, causal_mask=False):
    _, q_pos, _ = x.shape
    q = self.query(x)
    if src is not None:
      k = self.key(src)
      v = self.value(src)
    else:
      k = self.key(x)
      v = self.value(x)

    qk = (q @ k.transpose(-2, -1)) * (1 / math.sqrt(k.size(-1))) # (B, q_pos, k_pos)

    if causal_mask: # for self attention
      # Note: k_pos = q_pos
      qk = qk.masked_fill(self.tril[:q_pos, :q_pos] == 0, float('-inf')) # (B, q_pos, q_pos)

    if key_padding_mask: # for cross attention
      expanded_mask = key_padding_mask.unsqueeze(1) # (B, 1, k_pos)
      qk = qk.masked_fill(expanded_mask, float('-inf')) # (B, q_pos, k_pos)

    attn = torch.softmax(qk, dim=-1)
    attn = self.dropout(attn)
    out = attn @ v
    return out

Testing impl

In [None]:
tok_emb = TokenEmbedding(vocab_size, d_embd)
pos_emb = PositionalEmbedding(context_window, d_embd)

final_emb = tok_emb(x_t_s) + pos_emb(x_t_s)

print(tok_emb(x_t_s).shape)
print(pos_emb(x_t_s).shape)
print(final_emb.shape)

In [None]:
encoder_key_padding_mask = (x_t_s == stoi['<PAD>'])
decoder_key_padding_mask = (x_t_s == stoi['<PAD>'])


In [None]:
s_a = Head(d_embd, d_embd)
c_a = Head(d_embd, d_embd, key_padding_mask=(x_t_s == stoi['<PAD>']), causal_mask=False)

print(s_a(final_emb).shape)
print(c_a(final_emb, src=x_t_s).shape)
