# Attention is All You Need

This notebook is an implementation of a transformer model introduced in the paper "Attention is all you need" [1]. The model is trained on a small dataset of Pink Floyd lyrics [2].

### References

1. A. Vaswani et al., “Attention Is All You Need.” arXiv, Dec. 05, 2017. doi: 10.48550/arXiv.1706.03762.
2. J. Robson, "Pink Floyd Lyrics", retrieved from [url](https://www.kaggle.com/datasets/joaorobson/pink-floyd-lyrics/code).
3. R. Sennrich, B. Haddow, A. Birch, "Neural Machine Translation of Rare WOrds with Subword Units", 2016. doi: 10.48550/arXiv.1508.07909

## Dataset and Tokenization

We start by preprocessing the dataset and training a tokenizer.

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [2]:
from typing import List

In [3]:
data = pd.read_csv("./assets/data/pink_floyd_lyrics.csv")
data.head()

Unnamed: 0,album,song_title,year,lyrics
0,The Piper at the Gates of Dawn,Astronomy Domine,1967-08-05,"""Moon in both [houses]...""...Scorpio, [Arabian..."
1,The Piper at the Gates of Dawn,Lucifer Sam,1967-08-05,"Lucifer Sam, siam cat\nAlways sitting by your ..."
2,The Piper at the Gates of Dawn,Matilda Mother,1967-08-05,There was a king who ruled the land\nHis Majes...
3,The Piper at the Gates of Dawn,Flaming,1967-08-05,Alone in the clouds all blue\nLying on an eide...
4,The Piper at the Gates of Dawn,Pow R. Toc H.,1967-08-05,TCH TCH\nAHH (AHH)\nTCH TCH\nAHH AHH\nDoi doi\...


In [4]:
data = data[~data["album"].isin(["The Piper at the Gates of Dawn", "A Saucerful of Secrets"])]

### Dataset Cleaning

This version of the dataset is quite noisy and contains unformatted lyrics (see e.g. [Pink Floyd dataset of Huggingface](https://huggingface.co/datasets/huggingartists/pink-floyd) that is in a better state). Nevertheless, to demonstrate somewhat of a realistic data preprocessing flow, we will stick to this particular version.

In [5]:
df = data.drop(columns=["album", "song_title", "year"])
df = df.dropna()

df = df.replace("\((.*?)\),? ?", "", regex=True)   # remove round brackets and content
df = df.replace("\[(.*?)\],? ?", "", regex=True)   # remove round brackets and content
df = df.replace("[\"“”…]", "", regex=True)         # remove "
df = df.replace("\.{3,}", "...", regex=True)       # replace multiple dots with three dots
df = df.replace("(\*.*\*)", "", regex=True)        # remove sound effects between *
df = df.replace("[\:\-\.\!\?]", " ", regex=True)   # remove :, -, ., !, ?
df = df.replace("\\\\ n", "\n", regex=True)        # remove ill-formatted newlines
df = df.replace("\\\\", "", regex=True)            # remove \
df = df.replace("(\\n)+", "\\n", regex=True) # remove multiple newlines
df = df.replace(" +", " ", regex=True)             # remove multiple spaces
df = df.replace("\n ", "\n", regex=True)           # remove leading spaces after newline

df["lyrics"] = df["lyrics"].str.lower()            # lowercase
df["lyrics"] = df["lyrics"].str.strip("-. ")       # remove leading and trailing spaces
df["lyrics"] = df["lyrics"].str.replace("\\n", " ", regex=True)

lyrics = [l for l in df.lyrics]
lyrics = "\n".join(lyrics)                      # add EOS token between songs

with open("./assets/data/pink_floyd_lyrics.txt", "w") as f: f.write(lyrics)

### Tokenization

Next up, we tokenize the obtained sentences. Following the original paper, we will utilize byte-pair encoding [3].

In [6]:
from tokenizers import Tokenizer, normalizers
from tokenizers.models import BPE
from tokenizers.normalizers import NFD, Lowercase, Strip, StripAccents
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import BpeTrainer

In [7]:
tokenizer = Tokenizer(BPE())
tokenizer.normalizer = normalizers.Sequence([NFD(), StripAccents(), Lowercase(), Strip()])
tokenizer.pre_tokenizer = Whitespace()

In [8]:
trainer = BpeTrainer(special_tokens=["[BOS]"], show_progress=False)
tokenizer.train_from_iterator([lyrics], trainer=trainer)
print(f"Vocabulary Size: {tokenizer.get_vocab_size()}")

Vocabulary Size: 4102


## Training

In this section, we train a decoder-only transformer model to predict the next word using our dataset.

In [55]:
import torch
import torch.nn.functional as F

In [56]:
from microai.models.transformer import TransformerConfig, Transformer

In [135]:
config = TransformerConfig(
    vocab_size=tokenizer.get_vocab_size(),
    d_model=64,
    num_heads=8,
    context_size=128,
    dropout=0.2,
    decoder_layers=4,
)

In [136]:
model = Transformer(config)

In [137]:
lr = 3e-4
batch_size = 32
epochs = 2000
eval_freq = 50
weight_decay = 1e-2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [138]:
params = {k: v for k, v in model.named_parameters() if v.requires_grad}

params_decay = [v for _, v in params.items() if v.dim() >= 2]
params_no_decay = [v for _, v in params.items() if v.dim() < 2]

optimizer = torch.optim.Adam([
    {"params": params_decay, "weight_decay": weight_decay},
    {"params": params_no_decay, "weight_decay": 0.0}
], lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.1)

In [110]:
def get_batch(tokens: List[List[int]], batch_size: int):
    def _pad(sequence: List[int], size: int):
        return [tokenizer.token_to_id("[BOS]")] * (size - len(sequence)) + sequence
    
    ids = torch.randperm(len(tokens))[:batch_size].tolist()
    input = [tokens[i] for i in range(len(tokens)) if i in ids]
    max_size = max([len(i) for i in input])
    input = [_pad(i, max_size) for i in input]
    input = torch.tensor(input)

    return input[:, :-1], input[:, 1:]

In [88]:
@torch.no_grad()
def estimate_loss(tokens: List[List[int]], batch_size: int = 5, num_batches: int = 25):
    losses = []
    
    for _ in range(num_batches):
        x, y = get_batch(tokens, batch_size=batch_size)
        x, y = x.to(device), y.to(device)

        y_pred = model(x)
        loss = F.cross_entropy(y_pred.view((-1, y_pred.size(-1))), y.view(-1))
        losses.append(loss.item())

    return np.mean(losses) 

In [20]:
def tokenize_data(data: pd.DataFrame, context_size: int):
    items = []

    for _, row in data.iterrows():
        tokens = tokenizer.encode(row["lyrics"]).ids
        for batch in range(len(tokens) // context_size + 1):
            item_tokens = tokens[batch * context_size: (batch + 1) * context_size]
            items.append(item_tokens)

    return items

In [21]:
tokens = tokenize_data(df, config.context_size)

train_chunk = 0.9
train_ids = torch.randperm(len(tokens))[:int(len(tokens) * train_chunk)].tolist()

train_tokens = [tokens[i] for i in range(len(tokens)) if i in train_ids]
test_tokens = [tokens[i] for i in range(len(tokens)) if i not in train_ids]

In [139]:
for epoch in range(1, epochs + 1):
    x, y = get_batch(train_tokens, batch_size=batch_size)
    x, y = x.to(device), y.to(device)

    y_pred = model(x)
    loss = F.cross_entropy(y_pred.view((-1, y_pred.size(-1))), y.view(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()

    if epoch % eval_freq == 0 or epoch == 1:
        train_loss, test_loss = estimate_loss(train_tokens), estimate_loss(test_tokens)
        print(f"Epoch: {epoch}, Train Loss: {train_loss:.3f}, Test Loss: {test_loss:.3f}")

Epoch: 1, Train Loss: 10.188, Test Loss: 9.912
Epoch: 50, Train Loss: 5.298, Test Loss: 4.907
Epoch: 100, Train Loss: 5.022, Test Loss: 4.531
Epoch: 150, Train Loss: 4.546, Test Loss: 4.525
Epoch: 200, Train Loss: 4.376, Test Loss: 3.997
Epoch: 250, Train Loss: 4.408, Test Loss: 4.140
Epoch: 300, Train Loss: 3.971, Test Loss: 3.803
Epoch: 350, Train Loss: 3.975, Test Loss: 3.813
Epoch: 400, Train Loss: 4.006, Test Loss: 3.732
Epoch: 450, Train Loss: 3.625, Test Loss: 3.910
Epoch: 500, Train Loss: 3.893, Test Loss: 3.877
Epoch: 550, Train Loss: 3.605, Test Loss: 3.954
Epoch: 600, Train Loss: 3.722, Test Loss: 3.903
Epoch: 650, Train Loss: 3.586, Test Loss: 3.800
Epoch: 700, Train Loss: 3.834, Test Loss: 3.810
Epoch: 750, Train Loss: 3.367, Test Loss: 3.959
Epoch: 800, Train Loss: 3.513, Test Loss: 3.693
Epoch: 850, Train Loss: 3.435, Test Loss: 4.001
Epoch: 900, Train Loss: 3.452, Test Loss: 3.803
Epoch: 950, Train Loss: 3.655, Test Loss: 3.967
Epoch: 1000, Train Loss: 3.514, Test Loss:

In [23]:
def generate(model: Transformer, prompt: str, context_size: int = 8, max_length: int = 1000):
    context = torch.tensor(tokenizer.encode(prompt).ids, device=device)
    model.eval()

    while True:    
        logits = model(context[-context_size:].unsqueeze(0))
        probs = F.softmax(logits, dim=-1)
        token = torch.multinomial(probs[:, -1, :].flatten(), num_samples=1).item()
        if context.size(0) >= max_length:
            break
        context = torch.cat((context, torch.tensor([token], device=device)), dim=0)

    model.train()
    return tokenizer.decode(context.tolist())

In [24]:
print(generate(model, "shine on ", context_size=config.context_size, max_length=15))
print(generate(model, "time ", context_size=config.context_size, max_length=15))
print(generate(model, "money ", context_size=config.context_size, max_length=15))

shine on ground to make the weak in the animals become the lived now at
time has the same we lie out is who for all anced ’ s are
money you want and high is in you feel narrow hey you ’ ll to
