## Part 3: Training a Model with Multi-Head Attention

In part 2 we derived the code and intutition behind self-attention for a single attention head.
In this last chapter we built upon this to do the following:

- Augment the bigram model to use a single Attention Head
- Expand the single attention head model to a multi-head attention model

But first, for storing and managing hyperparameters let's create a dummy config class so that all parameters are easily accessible in one object.

In [1]:
class Config:
    
    def __init__(self, 
                 vocab_size: int = 65,
                 n_embd: int = 32,
                 block_size: int = 32,
                 num_heads: int = 1,
                 batch_size = 16
                 ):
        self.vocab_size = vocab_size
        self.n_embd = n_embd
        self.block_size = block_size
        self.num_heads = num_heads
        self.head_size = n_embd // num_heads
        self.batch_size = batch_size
        
config = Config()


### Transforming the Bigram Model

Now let's take our simple bigram model form the first part and transform into a more capable model that can utilize the full context by using a self-attention head.

For this we implement the following components in the following:
- A class for a single attention head, which contains the logic derived in part 2
- The modified Bigram Model class with:
    - **Self-attention head**
    - **Positional token embedding** (to introduce information about temporal dimension)
    - **Language modeling head** (additional final linear layer to project our output back to the dimension of the vocabulary in order to compute a probability distribution of all characters in the vocabulary). 

In [2]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy
    
class Head(nn.Module):
    """Single Attention Head"""
    def __init__(self, config):
        super().__init__()
        self.key_repr = nn.Linear(config.n_embd, config.head_size, bias=False)
        self.query_repr = nn.Linear(config.n_embd, config.head_size, bias=False)
        self.value_repr = nn.Linear(config.n_embd, config.head_size, bias=False)
        # Register buffer for attention mask
        self.register_buffer('tril', torch.tril(torch.ones(config.block_size, config.block_size)))
        
    def forward(self, x: torch.Tensor):
        B, T, C = x.shape
        
        # create learned representations
        k = self.key_repr(x)
        q = self.query_repr(x)
        v = self.value_repr(x)
        
        # compute attention scores ('affinities between tokens')
        W = q @ k.transpose(-2,-1)  # (B,T,C) @ (B,C,T) -> (B,T,T)
        W *= C ** -0.5  # scaling of the dot product to keep softmax from saturating too much
        W = W.masked_fill(self.tril[:T, :T] == 0, float('-inf'))  # (B,T,T), we use :T here to make sure we never exceed the context window
        W = F.softmax(W, dim=-1) # (B,T,T)
        
        # compute weighted aggregation of values
        out = W @ v # (B,T,T) @ (B,T,C) -> (B,T,C)
        return out        

In [3]:
torch.manual_seed(1337)

class SimpleAttentionModel(nn.Module):
    
    def __init__(self, config: Config):
        super().__init__()
        self.token_embeddings_table = nn.Embedding(config.vocab_size, config.n_embd)
        self.position_embeddings_table = nn.Embedding(config.block_size, config.n_embd)
        self.sa_head = Head(config=config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
        self.config = config
        
    def forward(self, inputs: torch.tensor, targets: torch.tensor = None):
        """ Forward pass of the model where we compute the raw preferences for what to be the next characters. """
        B, T = inputs.shape
        
        # inputs and targets are both (B,T) tensors of integers
        tok_emb = self.token_embeddings_table(inputs)
        pos_emb = self.position_embeddings_table(torch.arange(T)) # (T,C)
        
        x = tok_emb + pos_emb # (B,T,C)
        x = self.sa_head(x)  # apply one head of self-attention (B,T,C)
        logits = self.lm_head(x)  # (B,T, vocab_size)
        
        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view (B*T)
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss
    
    def generate(self, inputs, max_new_tokens):
        """ Autoregressive langauge generation function, where based on a given start input we run a forward pass, 
        select the most probable next token, append it, and repeat the process max_new_tokens times. """       
        for _ in range(max_new_tokens):
            # make sure the inputs are at max block_size number of tokens
            # necessary because our position embedding can only encode position information for up to block_size tokens
            model_inputs = inputs[:, -self.config.block_size:]
            #print(f"Generate: inputs.shape = {inputs.shape}")
            logits, _ = self(model_inputs)  # shape: (B,T,C)
            # For generation, we only need the predictions from the last position
            probs = F.softmax(logits[:, -1, :], dim=-1)  # shape: (B,C)
            # Sample from the probability distribution to get the next token
            inputs_next = torch.multinomial(probs, num_samples=1)  # shape: (B,1)
            # Append the new token to our sequence
            inputs = torch.cat((inputs, inputs_next), dim=1)  # shape: (B,T+1)
        return inputs

### Training the Single Attention Head Model

Now we can train our more capable model. Unlike before we will now **learn the affinities between tokens**, which should result in better model performance than before.

For this one last time, let's reuse the dataset, encoder/decoder and `get_batch` function form part 1.

In [4]:
import torch

# Get dataset
!curl.exe --output shakespeare.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open('shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()
    
# Create encoder and encode datasetchars = sorted(list(set(text)))
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
data = torch.tensor(encode(text), dtype=torch.long)

# Generating train and test split
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data = train_data if split == 'train' else val_data
    idxs = torch.randint(len(data) - config.block_size, size=(config.batch_size,))
    x = torch.stack([data[i:i + config.block_size] for i in idxs])
    y = torch.stack([data[i+1:i+config.block_size+1] for i in idxs])
    return x,y  

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 1089k  100 1089k    0     0  4437k      0 --:--:-- --:--:-- --:--:-- 4464k


Now let's create our model and as before check what the untrained output looks like.

In [5]:
model = SimpleAttentionModel(config=config)

inputs = torch.zeros((1,1), dtype=torch.long)
decoded_output = decode(model.generate(inputs=inputs, max_new_tokens=500)[0].tolist())
print(decoded_output)


j?vVOEQvMF-wXabyhCkpFww
.wIYwDcqAw-HXygXR$3OkWzxvSOH Zp;QZJg-BHm!kzlxQnSTrto cajDG3PmYXaZis
Jpz$nyZc!muYICgZ C3QBI:OVynfvMdMMgGk;RuebBuvK,:avxSvauFL:3RSPKafUQyfNkYHgDgCLU.Abfq'3h3tzCEw?$:mbA?W&!rFvYQb.c3O&b BPh;YiKYyT
hJnhh3QhK ZibtgnwNup?enzRuYwiLEKBPXz$VC'qBbQ3&!e.bAF:WdRKrkTlk
WdFMJqmbhDr!YCD Gzsys:zKRj .Dsdt tTgO'bov$po$raxDmx;e$3sCXqCs bj;I.-qWbeFV,:anA.-xbo;mCVtXxTEeaYCdO-h3:qDk?BH
FZjrcTbVpwTLN?rLFzXdV$k$'E-Tap!hH BhtuexSSS3U
Qui!G3nZ3mFKaDllY:JMSlr.
HiGxz
WeSrzE,m?3TfNBQBMSx?KDGt
RqRc
&l


As expected, it is pure gibberish, so now let's train it and compare the loss to the loss of the bigram model.

In [6]:
from tqdm import tqdm

# Now we train our simple attention head model!
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for steps in tqdm(range(10000)): # increase number of steps for good results...

    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

100%|██████████| 10000/10000 [00:21<00:00, 473.28it/s]

2.3612353801727295





We can see that, compared to our bigram model loss with rougly 2.57, we can achieve a smaller loss with the same amount of training steps!
This is a nice improvement; learning the affinities does bring additional value.
Of course we performed not enough training to get perfect shakespeare, but fragments of language can now already be observed:

In [7]:
inputs = torch.zeros((1,1), dtype=torch.long)
decoded_output = decode(model.generate(inputs=inputs, max_new_tokens=500)[0].tolist())
print(len(decoded_output))
print(decoded_output)

501

AUCNIO:
FRMO:
Sate't tlan thak as tere my tof tinoer,
NORome at,
Whabmaref at umpacus than befer
VIOPu?
Shal
Tif than followre boly a thebr's; tour thy ipst, to whald:
At mongl-
t can
dos dors, Rant,
R ARWICHery hagh ronget an
Whig th astt, mse wrure thath go cow thize igt hlfy, boum owreel the; omonoot; the do bout foleas.

Mat mind,
Yo foagem io ome be'de's whanwo yon soso rime hsan's the, wo ous LTo CHES:
Towe so merd-d adr ance Bane, whyu me st the may, thes the I! wis, ron thout to don, who


### Expand to Multi-Head Attention

Instead of using only one attention head, we can also use multiple ones in parallel. The idea here is that each attention head "attends" to a different suspace of the input and then the output of all the attention heads is concatenated.

For this let's create a new attention class and also a new model class.

In [8]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, config: Config):
        super().__init__()
        self.heads = nn.ModuleList([Head(config) for _ in range(config.num_heads)])
        
    def forward(self, input):
        # During the forwad pass the input is passed in parallel through all the heads and afterwards is concatenated
        # To make sure that the concatenated output has the correct dimension, each of the works in a subspace of head_size = n_embd // num_heads
        return torch.cat([h(input) for h in self.heads], dim=-1)

In [9]:
torch.manual_seed(1337)
num_heads = 4

# In the multi-head attention model we replace the single attention module with the multi-head module.
# The residual model stays the same
class MultiHeadAttentionModel(SimpleAttentionModel):
    
    def __init__(self, config):
        super().__init__(config)
        self.sa_head = MultiHeadAttention(config)

Now for completeness we can again train a multi-head attention model for some steps and check the output.

In [10]:
multi_model = MultiHeadAttentionModel(config=config)
optimizer = torch.optim.AdamW(multi_model.parameters(), lr=1e-3)

for steps in tqdm(range(10000)): # increase number of steps for good results...

    xb, yb = get_batch('train')
    logits, loss = multi_model(xb, yb)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

100%|██████████| 10000/10000 [00:21<00:00, 466.28it/s]

2.2932310104370117





In [11]:
inputs = torch.zeros((1,1), dtype=torch.long)
decoded_output = decode(multi_model.generate(inputs=inputs, max_new_tokens=500)[0].tolist())
print(len(decoded_output))
print(decoded_output)

501

KINIel ak!
Thame om fitho ak cperon mimed founs wirgtha pt of me ar
Inke clon Lthinud ikes, wartan ber, n
O:
jowarkins anthedes ot folloth nouke ouss:
od ttlanere hatred icavou:
Gomas isthead:
What thinrdi fon thy lloe meil st the' bupers'ld, wheath wito the hel thilldovino the, ido theat lod theave:
He, eve some'l d'd.

Thme wron;
An:
Wis gerle wieadurd,
SNA didats ibert
Gene Byo rot sono us!!
Ound wed in le, sellt oret hato, mive fort sist st whier ndose fo it thath, wilitul rspicears buigit i


And with this we have derived Multi-Head Self-Attention from scratch! 