<a href="https://colab.research.google.com/github/SumayT9/transformer_implementations/blob/main/attn_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
os.chdir("drive/MyDrive/projects")

#Implementing Attention Is All You Need by translating portuguese to english


In [None]:
import logging
import time

import numpy as np
import matplotlib.pyplot as plt

import torch
MAX_SEQ_LEN = 128

###Dataset Prep

In [None]:
import tensorflow_datasets as tfds

examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en',
                               with_info=True,
                               as_supervised=True)

train_examples, val_examples = examples['train'], examples['validation']

In [None]:
def train_dataloader(batch_size):
    return tfds.as_numpy(train_examples.batch(batch_size))

def val_dataloader(batch_size):
    return tfds.as_numpy(val_examples.batch(batch_size))

In [None]:
dataloader = train_dataloader(1)
for pt_examples, en_examples in dataloader:
    print(pt_examples)
    print(en_examples)
    break

[b'e quando melhoramos a procura , tiramos a \xc3\xbanica vantagem da impress\xc3\xa3o , que \xc3\xa9 a serendipidade .']
[b'and when you improve searchability , you actually take away the one advantage of print , which is serendipity .']


###Downloading tokenizers

In [None]:
from transformers import AutoTokenizer

tokenizer_en = AutoTokenizer.from_pretrained("bert-base-uncased", bos_token="<BOS>")
tokenizer_pt = AutoTokenizer.from_pretrained("pierreguillou/gpt2-small-portuguese", bos_token="<BOS>")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
def tokenize_pt(text):
    return tokenizer_pt.encode(text, padding="max_length", max_length=MAX_SEQ_LEN)

def tokenize_en(text):
    return tokenizer_en.encode(text, padding="max_length", max_length=MAX_SEQ_LEN)

In [None]:
dataloader = train_dataloader(1)
count = 0
for pt_examples, en_examples in dataloader:
    print(tokenizer_en.batch_decode(tokenize_en(en_examples[0].decode('utf-8'))))
    print(tokenizer_pt.batch_decode(tokenize_pt(pt_examples[0].decode('utf-8'))))
    count += 1
    if count > 20:
      break

['[CLS]', 'and', 'when', 'you', 'improve', 'search', '##ability', ',', 'you', 'actually', 'take', 'away', 'the', 'one', 'advantage', 'of', 'print', ',', 'which', 'is', 'ser', '##end', '##ip', '##ity', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[

###Model assembly

Inputs: batch size x seq len

passed thru Embedding layer (size 512), outputs of this will be b x seq_len x 512

positional encoding added

passed thru multi-head attention (takes in the embedding, generates N query, key, value matrices, and does scaled dot product attention for all of em)

passed thru feedforward

In [None]:
import torch
import torch.nn as nn
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len):
        super().__init__()
        self.embed_dim = embed_dim
        self.max_len = max_len
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.d1 = nn.Dropout(p=0.1)
        self.register_buffer('pe', pe, persistent=False)


    def forward(self, x):
        x = x + self.d1(self.pe[:, :x.size(1)])
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, n_heads):
        # Dims
        super().__init__()
        assert embed_dim % n_heads == 0
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads
        # Nets
        self.w_q = nn.Linear(self.embed_dim, self.embed_dim)
        self.w_k = nn.Linear(self.embed_dim, self.embed_dim)
        self.w_v = nn.Linear(self.embed_dim, self.embed_dim)


    def expand_mask(self, mask):
        assert mask.ndim >= 2, "Mask must be at least 2-dimensional with seq_length x seq_length"
        if mask.ndim == 3:
            mask = mask.unsqueeze(1)
        while mask.ndim < 4:
            mask = mask.unsqueeze(0)
        return mask

    def split(self, vec):
        # vec shape: b x seq_len x embed_dim
        b, seq_len, embed_dim = vec.shape
        vec = vec.reshape(b, seq_len, self.n_heads, self.head_dim)
        vec = vec.permute(0, 2, 1, 3) # b, n_heads, seq_len, head_dim
        return vec

    def forward(self, q=None, k=None, v=None, mask=None):
        # qkv shape: b x seq_len x embed_dim
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        q, k, v = self.split(q), self.split(k), self.split(v)
        b, n_heads, seq_len, head_dim = q.shape
        logits = q @ k.transpose(-2, -1)
        scaled = logits / np.sqrt(head_dim)
        if mask is not None:
            mask = self.expand_mask(mask)
            scaled = scaled.masked_fill(mask==0, -1e9)
        attn = torch.softmax(scaled, dim=-1)
        attn = attn @ v
        attn = attn.permute(0, 2, 1, 3) # b x seq_len x num_heads x head_dim
        attn = attn.reshape(b, seq_len, self.embed_dim)
        return attn


class EncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.self_attn = MultiHeadAttention(embed_dim=embed_dim, n_heads=num_heads)
        self.ff = nn.Sequential(
            torch.nn.Linear(embed_dim, 4*embed_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(4*embed_dim, embed_dim)
        )
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.d1 = nn.Dropout(p=0.1)
        self.d2 = nn.Dropout(p=0.1)

    def forward(self, x, mask=None):
        x = x + self.ln1(self.d1(self.self_attn(q=x, k=x, v=x, mask=mask)))
        x = x + self.ln2(self.d2(self.ff(x)))
        return x


class Encoder(nn.Module):
    def __init__(self, embed_dim, num_heads, num_layers):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.embedding = nn.Embedding(tokenizer_pt.vocab_size, embed_dim)
        self.pos_enc = PositionalEncoding(embed_dim, MAX_SEQ_LEN)
        self.layers = nn.ModuleList([EncoderBlock(embed_dim, num_heads) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        x = self.embedding(x)
        x = self.pos_enc(x)
        for layer in self.layers:
            x = layer(x, mask=mask)
        return x

class DecoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.masked_self_attn = MultiHeadAttention(embed_dim=embed_dim, n_heads=num_heads)
        self.enc_dec_attn = MultiHeadAttention(embed_dim=embed_dim, n_heads=num_heads)
        self.ff = nn.Sequential(
            torch.nn.Linear(embed_dim, 4*embed_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(4*embed_dim, embed_dim)
        )
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.ln3 = nn.LayerNorm(embed_dim)
        self.d1 = nn.Dropout(p=0.1)
        self.d2 = nn.Dropout(p=0.1)
        self.d3 = nn.Dropout(p=0.1)

    def create_target_mask(self, target_input):
        b, seq_len, embed_dim = target_input.shape
        mask = (torch.tril(torch.ones((seq_len, seq_len))) - torch.diag(torch.ones(seq_len))).expand(1, seq_len, seq_len).to(device)
        return mask

    def forward(self, encoder_output, target_input, mask=None):
        target_mask = self.create_target_mask(target_input)
        if mask is not None:
            target_mask = target_mask * mask
        x = target_input + self.ln1(self.d1(self.masked_self_attn(q=target_input, k=target_input, v=target_input, mask=target_mask)))
        x = x + self.ln2(self.d2(self.enc_dec_attn(q=x, k=encoder_output, v=encoder_output)))
        x = x + self.ln3(self.d3(self.ff(x)))
        return x


class Decoder(nn.Module):
    def __init__(self, embed_dim, num_heads, num_layers):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.embedding = nn.Embedding(tokenizer_en.vocab_size, embed_dim)
        self.pos_enc = PositionalEncoding(embed_dim, MAX_SEQ_LEN)
        self.layers = nn.ModuleList([DecoderBlock(embed_dim, num_heads) for _ in range(num_layers)])

    def forward(self, x, encoder_output, mask=None):
        x = self.embedding(x)
        x = self.pos_enc(x)
        for layer in self.layers:
            x = layer(encoder_output=encoder_output, target_input=x, mask=mask)
        return x

class Transformer(nn.Module):
    def __init__(self, embed_dim, num_heads, num_layers):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.in_pad_val = tokenizer_pt.pad_token_id
        self.out_pad_val = tokenizer_en.pad_token_id
        self.encoder = Encoder(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers)
        self.decoder = Decoder(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers)
        self.out_linear = nn.Linear(embed_dim, tokenizer_en.vocab_size)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input_seq, target_seq):
        input_mask = (input_seq != self.in_pad_val).unsqueeze(1).unsqueeze(1)
        encoder_output = self.encoder(input_seq, mask=input_mask)
        target_mask = (target_seq != self.out_pad_val).unsqueeze(1).unsqueeze(1)
        decoder_output = self.decoder(x=target_seq, encoder_output=encoder_output, mask=target_mask)
        logits = self.out_linear(decoder_output)
        return logits


In [None]:
def translate(pt_text, model):
    pt_tokens = np.array(tokenize_pt(pt_text))
    pt_examples = torch.from_numpy(pt_tokens).unsqueeze(0).to(device)  # Add batch dimension

    en_tokens = torch.tensor([101] + [0] * (MAX_SEQ_LEN - 1), dtype=torch.long).unsqueeze(0).to(device)  # Initialize target sequence with batch dimension

    for i in range(1, MAX_SEQ_LEN):  # Start from 1 since the first token is [CLS]
        out_logits = model(pt_examples, en_tokens[:, :i])
        softmax_logits = torch.softmax(out_logits, dim=-1)
        token_ids = torch.argmax(softmax_logits, dim=-1)
        next_token = token_ids[0, i-1].item()  # Get the next token ID
        en_tokens[0, i] = next_token
        if next_token == tokenizer_en.eos_token_id:
            break

    en_tokens = en_tokens[0].cpu().numpy()
    return tokenizer_en.decode(en_tokens)


In [None]:
def validate(model):
  dl = val_dataloader(1)
  count = 0
  for pt_examples, en_examples in dl:
    pt_tokens = np.array([np.array(tokenize_pt(pt.decode('utf-8')))[:MAX_SEQ_LEN] for pt in pt_examples])
    en_tokens = np.array([np.array(tokenize_en(en.decode('utf-8')))[:MAX_SEQ_LEN] for en in en_examples])
    pt_examples = torch.from_numpy(pt_tokens).to(device)
    en_examples = torch.from_numpy(en_tokens).to(device)
    out_logits = model(pt_examples, en_examples)
    softmax_logits = torch.softmax(out_logits, dim=-1)
    token_ids = torch.argmax(softmax_logits, dim=-1)
    print(tokenizer_en.batch_decode(en_tokens))
    print(tokenizer_en.batch_decode(token_ids))
    count += 1
    if count > 5:
      break


In [None]:
class CustomScheduler():
    def __init__(self, optimizer, d_model, n_warmup_steps):
        self.optimizer = optimizer
        self.d_model = d_model
        self.n_warmup_steps = n_warmup_steps
        self.n_steps = 0

    def get_lr(self):
        d_model = self.d_model
        n_warmup_steps = self.n_warmup_steps
        n_steps = self.n_steps
        return (d_model ** -0.5) * min(n_steps ** -0.5, n_steps * n_warmup_steps ** -1.5)

    def zero_grad(self):
        self.optimizer.zero_grad()

    def step(self):
        self.n_steps += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

        self.optimizer.step()

In [None]:
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Transformer(128, 8, 4).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
scheduler = CustomScheduler(optimizer, 128, 4000)

criterion = torch.nn.CrossEntropyLoss(ignore_index=tokenizer_en.pad_token_id, reduction="sum")

dataloader = train_dataloader(64)
for epoch in range(100):
  model.train()
  epoch_loss = 0
  for i, batch in enumerate(dataloader):
    pt_examples, en_examples = batch
    pt_tokens = np.array([np.array(tokenize_pt(pt.decode('utf-8')))[:MAX_SEQ_LEN] for pt in pt_examples])
    en_tokens = np.array([np.array(tokenize_en(en.decode('utf-8')))[:MAX_SEQ_LEN] for en in en_examples])
    pt_examples = torch.from_numpy(pt_tokens).to(device)
    en_examples = torch.from_numpy(en_tokens).to(device)
    en_labels = torch.from_numpy(np.append(en_tokens, np.zeros((en_tokens.shape[0], 1), dtype=int), axis=1)[:, 1:]).to(device)

    scheduler.zero_grad()
    out_logits = model(pt_examples, en_examples)
    out_logits = out_logits.view(-1, out_logits.size(-1))  # (batch_size * seq_len, vocab_size)

    # Calculate loss
    # print(en_labels[0])
    en_labels = en_labels.view(-1)
    loss = criterion(out_logits, en_labels)

    # Backpropagate and update weights
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 10.0)
    scheduler.step()
    epoch_loss += loss.item()
  if epoch % 20 == 0:
    print(f"Epoch {epoch+1} loss: {epoch_loss}")
    with torch.no_grad():
      model.eval()
      validate(model)
      print("______________")
      print(translate("A grande e forte máquina estava muito azul e lenta.", model))
      print("--------------")
    torch.save(model.state_dict(), f"model.pt")

Epoch 1 loss: 7193931.670654297
['[CLS] did they eat fish and chips? [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]']
['and? are???? [SEP]????????????????????????????????????????????????????????????????????????????????? [SEP] [SEP]?????????????????????????????????????']
['[CLS] i was always worried about being caught and sent

In [None]:
print(translate("A grande e forte máquina estava muito azul e lenta.", model))

[CLS] and and strong and hotels, quite black and slow. [SEP] was too. [SEP] and dangerous. [SEP] does was even and done. [SEP], right? and very, and it's loud. [SEP] traditionally difference. [SEP]quicy. [SEP] course. [SEP] a great. [SEP]x. [SEP] much know. [SEP]'' [SEP] correctly. [SEP] duration. [SEP] awful can be yeah. [SEP]'[SEP] much - year. [SEP] far scene? [SEP] much so [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP]'[SEP]x. [SEP] [SEP] [SEP] [SEP] naturally. [SEP]king. [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] can be done. [SEP] [SEP] [SEP]
