## Part 3: Training a Model with Multi-Head Attention

In this chapter we built upon our derivation of a single attention head:

- We create an config class to store hyperparameters
- We create an Single Attention Head class
- We build an attention head into the model and adapt the `generate` and `forward` functions of the model
- We expand the single attention head model to a multi-head attention model

For storing and managing hyperparameters we create a config class so that all parameters easily accessible in one object.

In [1]:
class Config:
    
    def __init__(self, 
                 vocab_size: int = 65,
                 n_embd: int = 32,
                 block_size: int = 32,
                 num_heads: int = 1,
                 batch_size = 16
                 ):
        self.vocab_size = vocab_size
        self.n_embd = n_embd
        self.block_size = block_size
        self.num_heads = num_heads
        self.head_size = n_embd // num_heads
        self.batch_size = batch_size
        
config = Config()


From Part 2 implement our derived attention mechansims bundled in one attention head class.

In [2]:
# Let's modify our simple bigram model to a simple attention model 
# The model uses an attention head, which was derived in part 2

import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy
 
# Implementation of a single attention head   
    
class Head(nn.Module):
    """Single Attention Head"""
    def __init__(self, config):
        super().__init__()
        self.key_repr = nn.Linear(config.n_embd, config.head_size, bias=False)
        self.query_repr = nn.Linear(config.n_embd, config.head_size, bias=False)
        self.value_repr = nn.Linear(config.n_embd, config.head_size, bias=False)
        # Register buffer for attention mask
        self.register_buffer('tril', torch.tril(torch.ones(config.block_size, config.block_size)))
        
    def forward(self, x: torch.Tensor):
        B, T, C = x.shape
        
        # create learned representations
        k = self.key_repr(x)
        q = self.query_repr(x)
        v = self.value_repr(x)
        
        # compute attention scores ('affinities between tokens')
        W = q @ k.transpose(-2,-1)  # (B,T,C) @ (B,C,T) -> (B,T,T)
        W *= C ** -0.5  # scaling of the dot product to keep softmax from saturating too much
        W = W.masked_fill(self.tril[:T, :T] == 0, float('-inf'))  # (B,T,T) # TODO: explain [:T, :T]
        W = F.softmax(W, dim=-1) # (B,T,T)
        
        # compute weighted aggregation of values
        out = W @ v # (B,T,T) @ (B,T,C) -> (B,T,C)
        return out        

Now we now can augment our Bigram model and use an attention head to use the information of a the complete sequence to predict the next token instead of only the most recent one.

In [3]:
# For this we now also modify the embedding of the input:
# We use an additional positional encoding to add information about the temporal dimension of each token.
# This is useful because attention in itself is a position agnostic communication mechanism, where the position/time is not taken into account by nature.

# In addition to the self-attention head, we also add a final linear layer, which is called lm_head. 
# Here we project our output back to the dimension of the vocabulary so we can compute a probability distribution of all characters in the vocabulary.

class SimpleAttentionModel(nn.Module):
    
    def __init__(self, config: Config):
        super().__init__()
        self.token_embeddings_table = nn.Embedding(config.vocab_size, config.n_embd)
        self.position_embeddings_table = nn.Embedding(config.block_size, config.n_embd)
        self.sa_head = Head(config=config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
        self.block_size = config.block_size
        
    def forward(self, inputs: torch.tensor, targets: torch.tensor = None):
        B, T = inputs.shape
        
        # inputs and targets are both (B,T) tensors of integers
        tok_emb = self.token_embeddings_table(inputs)
        pos_emb = self.position_embeddings_table(torch.arange(T)) # (T,C)
        
        x = tok_emb + pos_emb # (B,T,C)
        x = self.sa_head(x)  # apply one head of self-attention (B,T,C)
        logits = self.lm_head(x)  # (B,T, vocab_size)
        
        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view (B*T)
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss
    
    def generate(self, inputs, max_new_tokens):
        # Generate new tokens one at a time, 
        # using now the full sequence of n = block_size tokens, weighted via self-attention,
        # to decide on the next generated token
        
        for _ in range(max_new_tokens):
            
            # make sure the inputs are at max block_size number of tokens
            # necessary because our position embedding can only encode position information for up to block_size tokens
            model_inputs = inputs[:, -self.block_size:]
            #print(f"Generate: inputs.shape = {inputs.shape}")
            
            logits, _ = self(model_inputs)  # shape: (B,T,C)
            # For generation, we only need the predictions from the last position
            probs = F.softmax(logits[:, -1, :], dim=-1)  # shape: (B,C)
            
            # Sample from the probability distribution to get the next token
            inputs_next = torch.multinomial(probs, num_samples=1)  # shape: (B,1)
            
            # Append the new token to our sequence
            inputs = torch.cat((inputs, inputs_next), dim=1)  # shape: (B,T+1)
        return inputs

### Training of single attention head model

Now if we train the model unlike before we will learn now the affinities between tokens, which should result in better model performance than before.

In [4]:
# This cell adds context form part 1 so that we can perform a similar training run as for the bigram model.

# Read dataset
with open('shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()
    
# How we encode text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Dataset encoding
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)

# Generating train and test split
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# Generate random batch from the training split
def get_batch(split):
    data = train_data if split == 'train' else val_data
    idxs = torch.randint(len(data) - config.block_size, size=(config.batch_size,))
    x = torch.stack([data[i:i + config.block_size] for i in idxs])
    y = torch.stack([data[i+1:i+config.block_size+1] for i in idxs])
    return x,y  


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65
torch.Size([1115394]) torch.int64


In [5]:
model = SimpleAttentionModel(config=config)

inputs = torch.zeros((1,1), dtype=torch.long)
decoded_output = decode(model.generate(inputs=inputs, max_new_tokens=500)[0].tolist())
print(decoded_output)


DVjnTECva?Hhb,UtXzUKPwsIzbzPC'qNi:HxgdvgQsvqTQLnaSoPVt'CVNQ?mUcQturygSPLL!$G!yUO;kaag.BABlZXhsnU&h?GzybTJxnCD,tKzTF.!Kf!L.sR?Wwwxrz-GDVyNFqkMXGojcZCKVTszL,L3SXo!X?SyzU:nfb-
;HiLMY&MSYIt&aKySluW
-Hghf??eN3NzIznZEzovK3
!&bilA,wjh'izmwh?EdOkKnFNE,jpCYjW.tt'yspzADffWHG-eSfjWk:3ybSmEJeXeWBvYwnbPrCmKWsYSZnRdKpSlOMxAU tAm SW.FVDjmTLGz?PVfK?dk
zhxNrgBjYgxN;QgHtH ;xrHAeTlwy!LMD$yNjK?
Pj3X$XTKZKTh pciO.jmVnJEwwAQgHk&XpUW3K&DjtfI,L
INwMezmnGWVJKMdoCdjSIKLrIgmA ltKNmHeG;&3pnh DFeFCMu,IPJ3-i:cZjtxiR.'kMJr3T:


In [6]:

# Now we train our simple attention head model!
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for steps in range(10000):

    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.371469497680664


We can see that, compared to our bigram model loss with rougly 2.57, we can achieve a smaller loss with the same amount of training steps!
This is a nice improvement; learning the affinities does bring additional value.
However, if we generate now text with our trained model it is still not readable Shakespeare, so there is room for improvement.

In [7]:
inputs = torch.zeros((1,1), dtype=torch.long)
decoded_output = decode(model.generate(inputs=inputs, max_new_tokens=500)[0].tolist())
print(len(decoded_output))
print(decoded_output)

501

Yo ma nllia ftitre he thache than beenig;
Bury thay at dat sens houthe weral fail thre omarth indrousr; woth thourk thitork, whofur ill gnd nethey thidefatond tho oerd: thug his amy tohus nof fl, the ad thitachis,
Theatolo, ot yofrr loren lod tsenco; suisllf GRIITOLeer! heixcofr fret.

KI our,
GRer, saly bthit, rorwe maud ne yit anve.

GRS: nd saighich hart ilsat om?
Coprph fimas den icut:
YYour ienid.

LARour, myou, ma thinerere; sher ove, I wod ble
I nd dy:
Cobe thalie, be;
PUS:
He, my, la men


### Expand to Multi-Head Attention

Instead of using only one attention head we can use multiple in parallel. The idea here is that each attention head "attends" to a different suspace of the input.
Then the output of all the attention heads is concatenated and the joined output should be more comprehensive than the output of a single attention-head.

In [8]:
# Next let's try using Multi-head attention. Here the input is now only passed through one attention head, but multiple ones in parallel.
# The ouput is then concatenated and has the same final dimension (B,T,C) as before.

class MultiHeadAttention(nn.Module):
    
    def __init__(self, config: Config):
        super().__init__()
        self.heads = nn.ModuleList([Head(config) for _ in range(config.num_heads)])
        
    def forward(self, input):
        # During the forwad pass the input is passed in parallel through all the heads and afterwards is concatenated
        # To make sure that the concatenated output has the correct dimension, each of the works in a subspace of head_size = n_embd // num_heads
        return torch.cat([h(input) for h in self.heads], dim=-1)

In [9]:
num_heads = 4

# In the multi-head attention model we replace the single attention module with the multi-head module.
# The residual model stays the same
class MultiHeadAttentionModel(SimpleAttentionModel):
    
    def __init__(self, config):
        super().__init__(config)
        self.sa_head = MultiHeadAttention(config)

In [10]:
multi_model = MultiHeadAttentionModel(config=config)

In [None]:
optimizer = torch.optim.AdamW(multi_model.parameters(), lr=1e-3)

for steps in range(10000):

    xb, yb = get_batch('train')
    logits, loss = multi_model(xb, yb)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.6412746906280518


In [12]:
inputs = torch.zeros((1,1), dtype=torch.long)
decoded_output = decode(multi_model.generate(inputs=inputs, max_new_tokens=500)[0].tolist())
print(len(decoded_output))
print(decoded_output)

# And with that we have derived the Multi-head self-attention mechanism from scratch!

501

mMinTortt ntowe wheour, l, cen tteise bois.
Car horl gheome an?
Tthisses.


Bod chou wauverhit tar:

O fs aave, imjattouquro wathe har avend
Arsend s d'swe bd chonthigawsbar h we lounounthr ghlingbu.

N p


VHowirrs lind te t'
LAthod CUSo tot ad w fe
TAnond;
OThimder an ollythe:
An e G y,
RHof, oluvit pe matord anowar vie, bor bamed wor:
EF: Gpun! us be m?
LEVICORCis erthoir,
SE: sev a gruns thouseapue helgth,
I.
S'le.
MLo, pludinaref 'Flellt pys sin Whet w d l, averyoom h towarlen eris tim! mon
