# Attention is All You Need

This notebook is an implementation of a transformer model introduced in the paper "Attention is all you need" [1]. The model is trained on a small dataset of Pink Floyd lyrics [2]. The first version consists of decoder-only model, trained on a language modeling task. The second version consists of full encoder-decoder style architecture, where the encoder processes song titles, and the decoder is trained to predict the lyrics.

### References

1. A. Vaswani et al., “Attention Is All You Need.” arXiv, Dec. 05, 2017. doi: 10.48550/arXiv.1706.03762.
2. J. Robson, "Pink Floyd Lyrics", retrieved from [url](https://www.kaggle.com/datasets/joaorobson/pink-floyd-lyrics/code).
3. R. Sennrich, B. Haddow, A. Birch, "Neural Machine Translation of Rare WOrds with Subword Units", 2016. doi: 10.48550/arXiv.1508.07909

## Dataset and Tokenization

We start by preprocessing the dataset and training a tokenizer.

In [1]:
import numpy as np
import pandas as pd

In [12]:
from typing import List, Tuple

In [3]:
data = pd.read_csv("./assets/data/pink_floyd_lyrics.csv")
data.head()

Unnamed: 0,album,song_title,year,lyrics
0,The Piper at the Gates of Dawn,Astronomy Domine,1967-08-05,"""Moon in both [houses]...""...Scorpio, [Arabian..."
1,The Piper at the Gates of Dawn,Lucifer Sam,1967-08-05,"Lucifer Sam, siam cat\nAlways sitting by your ..."
2,The Piper at the Gates of Dawn,Matilda Mother,1967-08-05,There was a king who ruled the land\nHis Majes...
3,The Piper at the Gates of Dawn,Flaming,1967-08-05,Alone in the clouds all blue\nLying on an eide...
4,The Piper at the Gates of Dawn,Pow R. Toc H.,1967-08-05,TCH TCH\nAHH (AHH)\nTCH TCH\nAHH AHH\nDoi doi\...


We additionally exclude two albums involving Syd Barrett, in order to obtain a more coherent corpus that resembles more of a later style of Pink Floyd.

In [4]:
data = data[~data["album"].isin(["The Piper at the Gates of Dawn", "A Saucerful of Secrets"])]

### Dataset Cleaning

This version of the dataset is quite noisy and contains lots of unformatted lyrics (see e.g. [Pink Floyd dataset of Huggingface](https://huggingface.co/datasets/huggingartists/pink-floyd) for a more cleaned up version). To compensate for this, we next perform some data preparation and cleaning.

In [171]:
df = data.drop(columns=["album", "year"])
df = df.dropna()

df = df.replace("\((.*?)\),? ?", "", regex=True)   # remove round brackets and content
df = df.replace("\[(.*?)\],? ?", "", regex=True)   # remove round brackets and content
df = df.replace("[\"“”…]", "", regex=True)         # remove "
df = df.replace("\.{3,}", "...", regex=True)       # replace multiple dots with three dots
df = df.replace("(\*.*\*)", "", regex=True)        # remove sound effects between *
df = df.replace("[\:\-\.\!\?]", " ", regex=True)   # remove :, -, ., !, ?
df = df.replace("\\\\ n", "\n", regex=True)        # remove ill-formatted newlines
df = df.replace("\\\\", "", regex=True)            # remove \
df = df.replace("(\\n)+", "\\n", regex=True)       # remove multiple newlines
df = df.replace(" +", " ", regex=True)             # remove multiple spaces
df = df.replace("\n ", "\n", regex=True)           # remove leading spaces after newline

df["lyrics"] = df["lyrics"].str.lower()            # lowercase
df["lyrics"] = df["lyrics"].str.strip("-. ")       # remove leading and trailing spaces
df["lyrics"] = df["lyrics"].str.replace("\\n", " ", regex=True)

df["song_title"] = df["song_title"].str.lower() 
df["song_title"] = df["song_title"].str.strip("-. ")

lyrics = [l for l in df.lyrics]
lyrics = "[BOS]".join(lyrics)                      # add BOS token between songs
titles = "[BOS]".join([t for t in df.song_title])

with open("./assets/data/pink_floyd_lyrics.txt", "w") as f: f.write(lyrics)

### Tokenization

Next, we tokenize the obtained lyrics and titles. Following the original paper, we will utilize byte-pair encoding [3]. We additionally introduce two special tokens. First, the `[BOS]` token indicates the beginning of each song. Second, the `[PAD]` token is used when padding batch during training.

In [172]:
from tokenizers import Tokenizer, normalizers
from tokenizers.models import BPE
from tokenizers.normalizers import NFD, Lowercase, Strip, StripAccents
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import BpeTrainer

In [173]:
tokenizer = Tokenizer(BPE())
tokenizer.normalizer = normalizers.Sequence([NFD(), StripAccents(), Lowercase(), Strip()])
tokenizer.pre_tokenizer = Whitespace()

In [174]:
trainer = BpeTrainer(special_tokens=["[BOS]", "[PAD]"], show_progress=False)
tokenizer.train_from_iterator([lyrics + titles], trainer=trainer)
print(f"Vocabulary Size: {tokenizer.get_vocab_size()}")

Vocabulary Size: 4155


## Training

With a trained tokenizer, we turn our attention to the model. The definition of transformer is found in the `microai.models.transformer` module. Because we are working with a very small dataset, we also scale down the size of the model accordingly.

### Training Setup

In this section, we prepare the model for training and evaluation pipelines.

In [9]:
import torch
import torch.nn.functional as F

In [10]:
from microai.models.transformer import TransformerConfig, Transformer

All model-related parameters are encapsulated in a `TransformerConfig` class.

In [28]:
config = TransformerConfig(
    style="decoder",
    vocab_size=tokenizer.get_vocab_size(),
    d_model=64,
    num_heads=8,
    context_size=128,
    dropout=0.2,
    decoder_layers=4,
)

Likewise, the model itself can be instantiated from the created config. We additionally define few of the parameters used during training.

In [29]:
lr = 3e-4
batch_size = 32
epochs = 2000
eval_freq = 50
weight_decay = 1e-2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Transformer(config).to(device)

The weight decay is applied to linear layers, but not biases and other 1D params (e.g., in layer normalization).

In [None]:
params = {k: v for k, v in model.named_parameters() if v.requires_grad}

params_decay = [v for _, v in params.items() if v.dim() >= 2]
params_no_decay = [v for _, v in params.items() if v.dim() < 2]

optimizer = torch.optim.Adam([
    {"params": params_decay, "weight_decay": weight_decay},
    {"params": params_no_decay, "weight_decay": 0.0}
], lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.1)

### Utility Functions

Here, we group few of the utility functions used during training.

In [13]:
def get_batch(tokens: List[List[int]], batch_size: int):
    def _pad(sequence: List[int], size: int):
        return [tokenizer.token_to_id("[BOS]")] * (size - len(sequence)) + sequence
    
    ids = torch.randperm(len(tokens))[:batch_size].tolist()
    input = [tokens[i] for i in range(len(tokens)) if i in ids]
    max_size = max([len(i) for i in input])
    input = [_pad(i, max_size) for i in input]
    input = torch.tensor(input, device=device)

    return input[:, :-1], input[:, 1:]

In [14]:
@torch.no_grad()
def estimate_loss(tokens: List[List[int]], batch_size: int = 5, num_batches: int = 25):
    losses = []
    
    for _ in range(num_batches):
        x, y = get_batch(tokens, batch_size=batch_size)

        y_pred = model(x)
        loss = F.cross_entropy(y_pred.view((-1, y_pred.size(-1))), y.view(-1))
        losses.append(loss.item())

    return np.mean(losses) 

In [15]:
def tokenize_data(data: pd.DataFrame, context_size: int, encode_song_title: bool = False):
    items = []

    for _, row in data.iterrows():
        tokens = tokenizer.encode(row["lyrics"]).ids
        title_tokens = tokenizer.encode(row["song_title"]).ids

        for batch in range(len(tokens) // context_size + 1):
            item_tokens = tokens[batch * context_size: (batch + 1) * context_size]
            item = (title_tokens, item_tokens) if encode_song_title else item_tokens
            items.append(item)

    return items

### Tokens

The training portion of the dataset is comprised of $90\%$ of available tokens, where $10\%$ is left for testing. 

In [35]:
tokens = tokenize_data(df, config.context_size)

train_chunk = 0.9
train_ids = torch.randperm(len(tokens))[:int(len(tokens) * train_chunk)].tolist()

train_tokens = [tokens[i] for i in range(len(tokens)) if i in train_ids]
test_tokens = [tokens[i] for i in range(len(tokens)) if i not in train_ids]

### Training

We train the mode for 2000 epochs, evaluating the train/test loss every 50 epochs using a held-out test set.

In [36]:
for epoch in range(1, epochs + 1):
    x, y = get_batch(train_tokens, batch_size=batch_size)
    x, y = x.to(device), y.to(device)

    y_pred = model(x)
    loss = F.cross_entropy(y_pred.view((-1, y_pred.size(-1))), y.view(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()

    if epoch % eval_freq == 0 or epoch == 1:
        train_loss, test_loss = estimate_loss(train_tokens), estimate_loss(test_tokens)
        print(f"Epoch: {epoch}, Train Loss: {train_loss:.3f}, Test Loss: {test_loss:.3f}")

Epoch: 1, Train Loss: 10.188, Test Loss: 9.912
Epoch: 50, Train Loss: 5.298, Test Loss: 4.907
Epoch: 100, Train Loss: 5.022, Test Loss: 4.531
Epoch: 150, Train Loss: 4.546, Test Loss: 4.525
Epoch: 200, Train Loss: 4.376, Test Loss: 3.997
Epoch: 250, Train Loss: 4.408, Test Loss: 4.140
Epoch: 300, Train Loss: 3.971, Test Loss: 3.803
Epoch: 350, Train Loss: 3.975, Test Loss: 3.813
Epoch: 400, Train Loss: 4.006, Test Loss: 3.732
Epoch: 450, Train Loss: 3.625, Test Loss: 3.910
Epoch: 500, Train Loss: 3.893, Test Loss: 3.877
Epoch: 550, Train Loss: 3.605, Test Loss: 3.954
Epoch: 600, Train Loss: 3.722, Test Loss: 3.903
Epoch: 650, Train Loss: 3.586, Test Loss: 3.800
Epoch: 700, Train Loss: 3.834, Test Loss: 3.810
Epoch: 750, Train Loss: 3.367, Test Loss: 3.959
Epoch: 800, Train Loss: 3.513, Test Loss: 3.693
Epoch: 850, Train Loss: 3.435, Test Loss: 4.001
Epoch: 900, Train Loss: 3.452, Test Loss: 3.803
Epoch: 950, Train Loss: 3.655, Test Loss: 3.967
Epoch: 1000, Train Loss: 3.514, Test Loss:

## Evaluation

To evaluate the model, we prompt it with few words and ask it to continue the sequence, up to a predefined maximum length. The results are far from perfect, but the model has learned basic relationships between words and does produce "Pink Floyd like" sentences. 

In [38]:
def generate(model: Transformer, prompt: str, context_size: int = 8, max_length: int = 1000):
    context = torch.tensor(tokenizer.encode(prompt).ids, device=device)
    model.eval()

    while True:    
        logits = model(context[-context_size:].unsqueeze(0))
        probs = F.softmax(logits, dim=-1)
        token = torch.multinomial(probs[:, -1, :].flatten(), num_samples=1).item()
        if context.size(0) >= max_length:
            break
        context = torch.cat((context, torch.tensor([token], device=device)), dim=0)

    model.train()
    return tokenizer.decode(context.tolist())

In [24]:
print(generate(model, "shine on ", context_size=config.context_size, max_length=15))
print(generate(model, "time ", context_size=config.context_size, max_length=15))
print(generate(model, "money ", context_size=config.context_size, max_length=15))

shine on ground to make the weak in the animals become the lived now at
time has the same we lie out is who for all anced ’ s are
money you want and high is in you feel narrow hey you ’ ll to


## Encoder-Decoder Structure

To provide a proof-of-concept that involves a full encoder-decoder structure of the transformer, we will continue to do language modeling, but now conditioned on the title of a song, which will be passed through the encoding layers of the model. This is a very crude task to train a model on, but it will serve a purpose of demonstrating the full transformer architecture.

### Utility Functions

Like before, we group few of the utility functions used throughout training. The only difference to previous version is the incorporation of song title into the model's input.

In [275]:
def get_batch_with_titles(tokens: List[Tuple[List[int], List[int]]], batch_size: int):
    def _pad(sequence: List[int], size: int):
        return [tokenizer.token_to_id("[PAD]")] * (size - len(sequence)) + sequence
    
    ids = torch.randperm(len(tokens))[:batch_size].tolist()

    # extract and pad lyrics
    lyrics = [tokens[i][1] for i in range(len(tokens)) if i in ids]
    max_size = max([len(i) for i in lyrics])
    lyrics = [_pad(i, max_size) for i in lyrics]

    # extract and pad titles
    titles = [tokens[i][0] for i in range(len(tokens)) if i in ids]
    titles = [_pad(i, max_size - 1) for i in titles]

    lyrics = torch.tensor(lyrics, device=device)
    titles = torch.tensor(titles, device=device)
    return titles, lyrics[:, :-1], lyrics[:, 1:]

In [276]:
@torch.no_grad()
def estimate_loss_with_titles(tokens: List[Tuple[List[int], List[int]]], batch_size: int = 5, num_batches: int = 25):
    losses = []
    
    for _ in range(num_batches):
        titles, lyrics, targets = get_batch_with_titles(tokens, batch_size=batch_size)

        y_pred = model((titles, lyrics))
        loss = F.cross_entropy(y_pred.view((-1, y_pred.size(-1))), targets.flatten())
        losses.append(loss.item())

    return np.mean(losses) 

### Training

Next, we initialize and train the full transformer model.

In [327]:
config = TransformerConfig(
    style="encoder-decoder",
    vocab_size=tokenizer.get_vocab_size(),
    d_model=64,
    num_heads=2,
    context_size=128,
    dropout=0.1,
    decoder_layers=2,
)

In [328]:
tokens = tokenize_data(df, config.context_size, encode_song_title=True)

train_chunk = 0.9
train_ids = torch.randperm(len(tokens))[:int(len(tokens) * train_chunk)].tolist()

train_tokens = [tokens[i] for i in range(len(tokens)) if i in train_ids]
test_tokens = [tokens[i] for i in range(len(tokens)) if i not in train_ids]

In [329]:
lr = 3e-4
batch_size = 32
epochs = 2000
eval_freq = 50
weight_decay = 1e-2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Transformer(config).to(device)

In [330]:
params = {k: v for k, v in model.named_parameters() if v.requires_grad}

params_decay = [v for _, v in params.items() if v.dim() >= 2]
params_no_decay = [v for _, v in params.items() if v.dim() < 2]

optimizer = torch.optim.Adam([
    {"params": params_decay, "weight_decay": weight_decay},
    {"params": params_no_decay, "weight_decay": 0.0}
], lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.1)

In [331]:
for epoch in range(1, epochs + 1):
    titles, lyrics, targets = get_batch_with_titles(train_tokens, batch_size=batch_size)

    y_pred = model((titles, lyrics))
    loss = F.cross_entropy(y_pred.view((-1, y_pred.size(-1))), targets.flatten())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()

    if epoch % eval_freq == 0 or epoch == 1:
        train_loss, test_loss = estimate_loss_with_titles(train_tokens), estimate_loss_with_titles(test_tokens)
        print(f"Epoch: {epoch}, Train Loss: {train_loss:.3f}, Test Loss: {test_loss:.3f}")

Epoch: 1, Train Loss: 8.757, Test Loss: 8.629
Epoch: 50, Train Loss: 5.182, Test Loss: 4.173
Epoch: 100, Train Loss: 4.440, Test Loss: 4.131
Epoch: 150, Train Loss: 4.546, Test Loss: 4.155
Epoch: 200, Train Loss: 4.240, Test Loss: 3.441
Epoch: 250, Train Loss: 4.371, Test Loss: 3.543
Epoch: 300, Train Loss: 3.858, Test Loss: 3.774
Epoch: 350, Train Loss: 3.776, Test Loss: 3.707
Epoch: 400, Train Loss: 3.962, Test Loss: 3.251
Epoch: 450, Train Loss: 3.790, Test Loss: 3.554
Epoch: 500, Train Loss: 3.648, Test Loss: 3.443
Epoch: 550, Train Loss: 3.725, Test Loss: 3.408
Epoch: 600, Train Loss: 3.812, Test Loss: 3.545
Epoch: 650, Train Loss: 3.725, Test Loss: 3.586
Epoch: 700, Train Loss: 3.642, Test Loss: 3.700
Epoch: 750, Train Loss: 3.745, Test Loss: 3.229
Epoch: 800, Train Loss: 3.643, Test Loss: 3.934
Epoch: 850, Train Loss: 3.681, Test Loss: 3.469
Epoch: 900, Train Loss: 3.675, Test Loss: 3.214
Epoch: 950, Train Loss: 3.653, Test Loss: 3.553
Epoch: 1000, Train Loss: 3.683, Test Loss: 

## Evaluation

Like before, we ask the model to generate a sentence, but this time conditioned on a song title. 

In [282]:
def generate_with_titles(model: Transformer, title: str, context_size: int = 8, max_length: int = 1000):
    pad_token = tokenizer.token_to_id("[PAD]")
    title_tokens = torch.tensor(tokenizer.encode(title).ids, device=device)

    context = torch.tensor(tokenizer.encode("[BOS]").ids, device=device)
    model = model.eval()

    for _ in range(max_length):    
        max_size = min(max(len(title_tokens), len(context)), context_size)
        
        context = F.pad(context, pad=(max_size - len(context), 0), mode="constant", value=pad_token)
        context = context[-max_size:].unsqueeze(0)

        title = F.pad(title_tokens, pad=(max_size - len(title_tokens), 0), mode="constant", value=pad_token)
        title = title[-max_size:].unsqueeze(0)
        
        logits = model((title, context[:, -context_size:]))
        probs = F.softmax(logits, dim=-1)
        token = torch.multinomial(probs[:, -1, :].flatten(), num_samples=1).item()

        context = torch.cat((context.flatten(), torch.tensor([token], device=device)), dim=0)

    model.train()
    return tokenizer.decode(context.flatten().tolist())

In [232]:
print(generate_with_titles(model, "about money", context_size=config.context_size, max_length=15))
print(generate_with_titles(model, "about life", context_size=config.context_size, max_length=15))
print(generate_with_titles(model, "about war", context_size=config.context_size, max_length=15))

filled drops a call to me have the silver gone , has on the sorted
ya gone right shine ll sensation be gloom man tell you have about and labyrin
there can time be how black look cleared slight went of ’ s eins strayed
