# Multi-Headed Self-Attention Transformer Decoder
![Multi-Head Attention](./images/multi-head-attention.png)

We will build an autoregressive multi-headed self-attention language model for predicting the next token in a sequence of text. The shakespeare_char dataset will be used for this demonstration, which can be found in the data folder.  

A multi-headed self-attention block combines multiple smaller single-headed self-attention blocks and concatenates the output of all of them. It can extract information from different representation subspaces. This enables it to capture more diverse and complex patterns in the input space, while single-headed self-attention can only focus on a single subspace at a time. 

We use a head size of the single-heads that is equal to the number of embedding dimensions divided by the number of heads. This ensures that the output dimensions of the multi-headed block is the same as a single-headed block, provided that 'num_heads' can divide 'n_embed' without any remainder. For further explanation on multi-headed attention see Andrej Karpathy's video [2].


### References:
- [1] [GPT colab notebook](https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing)
- [2] [Video: multi-headed self-attention](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=4919s)
- [3] [Attention Is All You Need](https://arxiv.org/abs/1706.03762)



In [1]:
import os
import requests
import numpy as np
import jax
from jax import random
import jax.numpy as jnp
import flax.linen as nn
import optax
from jax import value_and_grad

from helper_funcs import get_batch, generate, masked_fill, loss_fn

In [2]:
n_embed = 32 # Number of embedding dimensions
batch_size = 32 # How many independent sequences will we process in parallel?
block_size = 8 # What is the maximum context length for predictions?
num_heads = 4 # Number of heads in the multi-headed block

rng_key = jax.random.PRNGKey(128)

## Data Preparation

In [3]:
# download the tiny shakespeare dataset
input_file_path = os.path.join('./data/shakespeare_char/input.txt')
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

# get all the unique characters that occur in this text
chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    l = np.array(l)
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# create the train and test splits
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

# encode both to integers
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

train_ids = jnp.array(train_ids, dtype=jnp.uint16)
val_ids = jnp.array(val_ids, dtype=jnp.uint16)

length of dataset in characters: 1,115,394
all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65
train has 1,003,854 tokens
val has 111,540 tokens


In [4]:
print(decode(train_ids[:100]))

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


## Build the Attention Model

In [5]:
class Head(nn.Module):
    """
    A single-headed self-attention decoder block.
    Takes the combined token and position embedding as input,
    then calculates the key and query values.
    The key and query are multiplied to calculate the 
    attention scores/affinities. The future weights are
    then altered to have zero affinity, this ensures the 
    model can't "cheat". The input is then used to calculate
    the values, which are then aggregated by multiplying 
    them with the weights.
    """
    head_size: int

    @nn.compact
    def __call__(self, x):
        B,T,C = x.shape
        key = nn.Dense(self.head_size, use_bias=False)
        k = key(x) # (B,T,C)
        query = nn.Dense(self.head_size, use_bias=False)
        q = query(x) # (B,T,C)
        # compute attention scores ("affinities")
        weights =  q @ k.transpose((0, -1, -2)) * self.head_size**-0.5 # (B, T, C) @ (B, C, T) ---> (B, T, T)
        tril = jnp.tril(jnp.ones(shape=(T, T), dtype=bool))
        tril = jnp.repeat(tril[None, ...], repeats=B, axis=0)
        weights = masked_fill(tril, weights, -jnp.inf)
        weights = jax.nn.softmax(weights, axis=-1)
        # perform the weighted aggregation of the values
        value = nn.Dense(self.head_size, use_bias=False)
        v = value(x)
        out = weights @ v
        return out

In [12]:
class MultiHeadedAttention(nn.Module):
    """
    Combines multiple heads of scaled self-attention 
    in parallel, then concatenates the heads outputs.
    """
    num_heads: int
    head_size: int

    @nn.compact
    def __call__(self, x):
        # Create a list of num_heads heads
        heads = [Head(self.head_size) for _ in range(self.num_heads)]
        # Provide the same input for each head
        heads_out = [h(x) for h in heads]
        combined_logits = jnp.concatenate(heads_out, axis=-1)
        return combined_logits

In [13]:
class AttentionLanguageModel(nn.Module):
    """
    Multi-headed self-attention language model.
    Uses the previous tokens in the sequence to 
    determine the probabilities of the next token.
    Processes the combined position and token embedding
    through a multi-headed self-attention decoder block, 
    which is then processed through a dense layer to 
    aquire the token logits.
    The logits can then be processed through a softmax
    function to calculate the token probabilities.
    """
    vocab_size: int
    n_embed: int
    block_size: int
    num_heads: int
    
    @nn.compact
    def __call__(self, index_seq):
        B, T = index_seq.shape

        # Each token directly reads off the logits for the next token from a lookup table
        token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embed) 
        token_emb = token_embedding_table(index_seq) # (B, T, C)

        position_embedding_table = nn.Embed(num_embeddings=self.block_size, features=self.n_embed) 
        pos_emb = position_embedding_table(jnp.arange(T)) # (T, C)

        x = token_emb + pos_emb # (B, T, C)

        head_size = self.n_embed // self.num_heads
        sa_heads = MultiHeadedAttention(self.num_heads, head_size)
        x = sa_heads(x) # apply one head of self-attention (B, T, C)

        lm_head = nn.Dense(self.vocab_size)
        logits = lm_head(x) # (B, T, vocab_size)

        return logits

In [14]:
model = AttentionLanguageModel(vocab_size, n_embed, block_size, num_heads)
dummy_x = jnp.zeros(shape=(batch_size, block_size), dtype=jnp.uint16)
variables = model.init(rng_key, dummy_x)

In [15]:
out = model.apply(variables, dummy_x)
print(out.shape)

(32, 8, 65)


## Text Generation Pre-Training

In [16]:
index_seq = jnp.zeros(shape=(1,1), dtype=jnp.uint16)
max_new_tokens = 100

generated_indices = generate(variables, model.apply, index_seq, rng_key, vocab_size, 1, block_size, max_new_tokens)
generated_indices = list(np.array(generated_indices[0]))
print("Generated text: ")
print(decode(generated_indices))

Generated text: 

FeRkiTwf.'jtMwetQ
y;:zYFfVmFiPuyZaYrv,wzii Tgh,j.rE.'srlHn&QEy'sjb33e&?K3,FEgIeNBOm zv;vZlQlNm.mqbbM


## Train the Model

In [17]:
optimizer = optax.adamw(learning_rate=1e-2)
opt_state = optimizer.init(variables)

In [18]:
steps = 100

for step in range(steps):
    rng_key, subkey = jax.random.split(rng_key)
    xb, yb = get_batch(train_ids, subkey, batch_size, block_size)

    loss, grads = value_and_grad(loss_fn, argnums=(0))(
        variables, 
        model.apply,
        xb, 
        yb
    )
    updates, opt_state = optimizer.update(grads, opt_state, variables)
    variables = optax.apply_updates(variables, updates) 

    print(f"Epoch: {step}, Loss: {loss :.4f}")

Epoch: 0, Loss: 4.2002
Epoch: 1, Loss: 4.1314
Epoch: 2, Loss: 4.0487
Epoch: 3, Loss: 3.9995
Epoch: 4, Loss: 3.9011
Epoch: 5, Loss: 3.7680
Epoch: 6, Loss: 3.6813
Epoch: 7, Loss: 3.5244
Epoch: 8, Loss: 3.4426
Epoch: 9, Loss: 3.4472
Epoch: 10, Loss: 3.4264
Epoch: 11, Loss: 3.4981
Epoch: 12, Loss: 3.1786
Epoch: 13, Loss: 3.2236
Epoch: 14, Loss: 3.6526
Epoch: 15, Loss: 3.4811
Epoch: 16, Loss: 3.3162
Epoch: 17, Loss: 3.2343
Epoch: 18, Loss: 3.2692
Epoch: 19, Loss: 3.3367
Epoch: 20, Loss: 3.1827
Epoch: 21, Loss: 3.2085
Epoch: 22, Loss: 3.1524
Epoch: 23, Loss: 3.1159
Epoch: 24, Loss: 3.1055
Epoch: 25, Loss: 3.1955
Epoch: 26, Loss: 3.2321
Epoch: 27, Loss: 3.0995
Epoch: 28, Loss: 3.1132
Epoch: 29, Loss: 3.1211
Epoch: 30, Loss: 3.1240
Epoch: 31, Loss: 3.1205
Epoch: 32, Loss: 3.0943
Epoch: 33, Loss: 3.0640
Epoch: 34, Loss: 3.0053
Epoch: 35, Loss: 3.0887
Epoch: 36, Loss: 2.9509
Epoch: 37, Loss: 2.9637
Epoch: 38, Loss: 2.8833
Epoch: 39, Loss: 3.0016
Epoch: 40, Loss: 2.8503
Epoch: 41, Loss: 3.0264
Ep

## Text Generation Post-Training

In [19]:
index_seq = jnp.zeros(shape=(1,1), dtype=jnp.uint16)
max_new_tokens = 100

generated_indices = generate(variables, model.apply, index_seq, rng_key, vocab_size, 1, block_size, max_new_tokens)
generated_indices = list(np.array(generated_indices[0]))
print("Generated text: ")
print(decode(generated_indices))

Generated text: 

Tot hte; hef preer tusl ous blrlee.

N Gowrvee.


S yorc.

o
L:
Fe've,
:

Tetscshe Yfodot d muj inOt
