# Scaled Self-Attention Transformer Decoder

We will build an autoregressive single-headed self-attention language model for predicting the next token in a sequence of text. The shakespeare_char dataset will be used for this demonstration, which can be found in the data folder.  

In the Bigram model the raw token embedding was used to perform the prediction. In this case we will combine the token embedding with a positional embedding, then process the combined embedding through a single-headed self-attention block, which is then processed through a dense layer to get the logits used for prediction. The self-attention head mechanism is shown in the 'Self-Attention Mathematical Trick' section, further explanation can is also provided in Andrej Karpathy's video: [THE CRUX OF THE VIDEO: version 4: self-attention](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=3720s).


### References:
- [GPT colab notebook](https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing)
- [Video: "scaled" self-attention](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=4616s)
- [Attention Is All You Need](https://arxiv.org/abs/1706.03762)



In [1]:
import os
import requests
import numpy as np
import jax
from jax import random
import jax.numpy as jnp
import flax.linen as nn
import optax
from jax import value_and_grad

from helper_funcs import get_batch, generate, masked_fill, loss_fn

In [2]:
n_embed = 32
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

rng_key = jax.random.PRNGKey(128)

## Data Preparation

In [3]:
# download the tiny shakespeare dataset
input_file_path = os.path.join('./data/shakespeare_char/input.txt')
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

# get all the unique characters that occur in this text
chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    l = np.array(l)
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# create the train and test splits
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

# encode both to integers
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

train_ids = jnp.array(train_ids, dtype=jnp.uint16)
val_ids = jnp.array(val_ids, dtype=jnp.uint16)

length of dataset in characters: 1,115,394
all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65
train has 1,003,854 tokens
val has 111,540 tokens


In [4]:
print(decode(train_ids[:100]))

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


## Self-Attention Mathematical Trick

In [5]:
# toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
a = jnp.tril(jnp.ones(shape=(3, 3)))
a = a / jnp.sum(a, axis=1, keepdims=1)
b = random.randint(key=rng_key, shape=(3, 2), minval=0, maxval=10)
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
[[1.         0.         0.        ]
 [0.5        0.5        0.        ]
 [0.33333334 0.33333334 0.33333334]]
--
b=
[[5 7]
 [3 2]
 [9 9]]
--
c=
[[5.        7.       ]
 [4.        4.5      ]
 [5.666667  6.0000005]]


In [6]:
B,T,C = 4,8,2 # batch, time, channels
x = random.normal(key=rng_key, shape=(B, T, C))
x.shape

(4, 8, 2)

In [7]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = np.zeros(shape=(B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = np.mean(xprev, axis=0)

In [8]:
# version 2: using matrix multiply for a weighted aggregation
weights = jnp.tril(jnp.ones(shape=(T, T)))
weights = weights / jnp.sum(weights, axis=1, keepdims=1)
xbow2 = weights @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
jnp.allclose(xbow, xbow2)

Array(True, dtype=bool)

In [9]:
# version 3: use Softmax
tril = jnp.tril(jnp.ones(shape=(T, T), dtype=bool))
weights = jnp.zeros(shape=(T, T))
weights = masked_fill(tril, weights, -jnp.inf)
weights = jax.nn.softmax(weights, axis=-1)
xbow3 = weights @ x
jnp.allclose(xbow, xbow3)

Array(True, dtype=bool)

In [10]:
# version 4: self-attention!
B,T,C = 4,8,32 # batch, time, channels
x = random.normal(key=rng_key, shape=(B, T, C))

# let's see a single Head perform self-attention
head_size = 16
key = nn.Dense(head_size, use_bias=False)
query = nn.Dense(head_size, use_bias=False)
value = nn.Dense(head_size, use_bias=False)

rng_key, subkey = jax.random.split(rng_key)
key_variables = key.init(subkey, x)
rng_key, subkey = jax.random.split(rng_key)
query_variables = query.init(subkey, x)
rng_key, subkey = jax.random.split(rng_key)
value_variables = value.init(subkey, x)

k = key.apply(key_variables, x)   # (B, T, 16)
q = query.apply(query_variables, x) # (B, T, 16)
weights =  q @ k.transpose((0, -1, -2)) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = jnp.tril(jnp.ones(shape=(T, T), dtype=bool))
tril = jnp.repeat(tril[None, ...], repeats=B, axis=0)
weights = masked_fill(tril, weights, -jnp.inf)
weights = jax.nn.softmax(weights, axis=-1)

v = value.apply(value_variables, x)
out = weights @ v

out.shape


(4, 8, 16)

In [11]:
weights[0]

Array([[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [8.3365250e-01, 1.6634753e-01, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [2.7693647e-01, 7.2302037e-01, 4.3186730e-05, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [1.3745368e-04, 5.7609759e-02, 2.2170721e-03, 9.4003570e-01,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [9.7506523e-01, 2.3148456e-03, 7.4237287e-03, 9.0312017e-03,
        6.1650318e-03, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [1.5689149e-01, 2.0382376e-02, 2.6865533e-01, 2.7895829e-01,
        9.8608084e-02, 1.7650440e-01, 0.0000000e+00, 0.0000000e+00],
       [9.9750141e-06, 3.4491016e-04, 3.6039501e-02, 1.0690327e-01,
        8.4371493e-07, 2.5878797e-05, 8.5667557e-01, 0.0000000e+00],
       [1.2620859e-01, 6.4960361e-02, 1.4

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additionally divides `weights` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

In [12]:
rng_key, subkey = jax.random.split(rng_key)
k = random.normal(key=subkey, shape=(B, T, C))
rng_key, subkey = jax.random.split(rng_key)
q = random.normal(key=subkey, shape=(B, T, C))

weights =  q @ k.transpose((0, -1, -2)) * C**-0.5

In [13]:
k.var()

Array(1.0384574, dtype=float32)

In [14]:
q.var()

Array(1.0208337, dtype=float32)

In [15]:
weights.var()

Array(0.96584, dtype=float32)

In [16]:
jax.nn.softmax(jnp.array([0.1, -0.2, 0.3, -0.2, 0.5]), axis=-1)

Array([0.19249782, 0.1426059 , 0.23511738, 0.1426059 , 0.287173  ],      dtype=float32)

In [17]:
jax.nn.softmax(jnp.array([0.1, -0.2, 0.3, -0.2, 0.5])*8, axis=-1) # gets too peaky, converges to one-hot

Array([0.03260834, 0.00295816, 0.1615102 , 0.00295816, 0.79996514],      dtype=float32)

In [18]:
# French to English translation example:

# <--------- ENCODE ------------------><--------------- DECODE ----------------->
# les réseaux de neurones sont géniaux! <START> neural networks are awesome!<END>

## Build the Attention Model

In [19]:
class Head(nn.Module):
    """
    A single-headed self-attention decoder block.
    Takes the combined token and position embedding as input,
    then calculates the key and query values.
    The key and query are multiplied to calculate the 
    attention scores/affinities. The future weights are
    then altered to have zero affinity, this ensures the 
    model can't "cheat". The input is then used to calculate
    the values, which are then aggregated by multiplying 
    them with the weights.
    The aggregated values are the desired "attention scores".
    """
    head_size: int

    @nn.compact
    def __call__(self, x):
        B,T,C = x.shape
        key = nn.Dense(self.head_size, use_bias=False)
        k = key(x) # (B,T,C)
        query = nn.Dense(self.head_size, use_bias=False)
        q = query(x) # (B,T,C)
        # compute attention scores ("affinities")
        weights =  q @ k.transpose((0, -1, -2)) * self.head_size**-0.5 # (B, T, C) @ (B, C, T) ---> (B, T, T)
        tril = jnp.tril(jnp.ones(shape=(T, T), dtype=bool))
        tril = jnp.repeat(tril[None, ...], repeats=B, axis=0)
        weights = masked_fill(tril, weights, -jnp.inf)
        weights = jax.nn.softmax(weights, axis=-1)
        # perform the weighted aggregation of the values
        value = nn.Dense(self.head_size, use_bias=False)
        v = value(x)
        out = weights @ v
        return out

In [20]:
class AttentionLanguageModel(nn.Module):
    """
    Single-headed self-attention language model.
    Uses the previous token in the sequence to 
    determine the probabilities of the next token.
    Processes the combined position and token embedding
    through a self-attention decoder block, which is then
    processed through a dense layer to aquire the token logits.
    The logits can then be processed through a softmax
    function to calculate the token probabilities.
    """
    vocab_size: int
    n_embed: int
    block_size: int
    head_size: int
    
    @nn.compact
    def __call__(self, index_seq):
        B, T = index_seq.shape

        # Each token directly reads off the logits for the next token from a lookup table
        token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embed) 
        token_emb = token_embedding_table(index_seq) # (B, T, C)

        position_embedding_table = nn.Embed(num_embeddings=self.block_size, features=self.n_embed) 
        pos_emb = position_embedding_table(jnp.arange(T)) # (T, C)

        x = token_emb + pos_emb # (B, T, C)

        sa_head = Head(self.head_size)
        x = sa_head(x) # apply one head of self-attention (B, T, C)

        lm_head = nn.Dense(self.vocab_size)
        logits = lm_head(x) # (B, T, vocab_size)

        return logits

In [21]:
model = AttentionLanguageModel(vocab_size, n_embed, block_size, head_size)
dummy_x = jnp.zeros(shape=(batch_size, block_size), dtype=jnp.uint16)
variables = model.init(rng_key, dummy_x)

In [22]:
out = model.apply(variables, dummy_x)
print(out.shape)

(32, 8, 65)


## Text Generation Pre-Training

In [23]:
index_seq = jnp.zeros(shape=(1,1), dtype=jnp.uint16)
max_new_tokens = 100

generated_indices = generate(variables, model.apply, index_seq, rng_key, vocab_size, 1, block_size, max_new_tokens)
generated_indices = list(np.array(generated_indices[0]))
print("Generated text: ")
print(decode(generated_indices))

Generated text: 

Tvg.'jtMvetQ
y;:zYEfUmEhOuyYaXqu,wzhi Sfh,i3qD-'rqjGm&PDy'sja33d&?J3,EEgIdMBOm zu;vZlPkMm.lqbbLmqhFJ


## Train the Model

In [24]:
optimizer = optax.adamw(learning_rate=1e-2)
opt_state = optimizer.init(variables)

In [25]:
steps = 100

for step in range(steps):
    rng_key, subkey = jax.random.split(rng_key)
    xb, yb = get_batch(train_ids, subkey, batch_size, block_size)

    loss, grads = value_and_grad(loss_fn, argnums=(0))(
        variables, 
        model.apply,
        xb, 
        yb
    )
    updates, opt_state = optimizer.update(grads, opt_state, variables)
    variables = optax.apply_updates(variables, updates) 

    print(f"Epoch: {step}, Loss: {loss :.4f}")

Epoch: 0, Loss: 4.1797
Epoch: 1, Loss: 4.1357
Epoch: 2, Loss: 4.1036
Epoch: 3, Loss: 4.0445
Epoch: 4, Loss: 3.9720
Epoch: 5, Loss: 3.9320
Epoch: 6, Loss: 3.8340
Epoch: 7, Loss: 3.7325
Epoch: 8, Loss: 3.6364
Epoch: 9, Loss: 3.7192
Epoch: 10, Loss: 3.6024
Epoch: 11, Loss: 3.5051
Epoch: 12, Loss: 3.3887
Epoch: 13, Loss: 3.3672
Epoch: 14, Loss: 3.4806
Epoch: 15, Loss: 3.3368
Epoch: 16, Loss: 3.3332
Epoch: 17, Loss: 3.2800
Epoch: 18, Loss: 3.1807
Epoch: 19, Loss: 3.2574
Epoch: 20, Loss: 3.2909
Epoch: 21, Loss: 3.3127
Epoch: 22, Loss: 3.2338
Epoch: 23, Loss: 3.2703
Epoch: 24, Loss: 3.2303
Epoch: 25, Loss: 3.2596
Epoch: 26, Loss: 3.2458
Epoch: 27, Loss: 3.2228
Epoch: 28, Loss: 3.2018
Epoch: 29, Loss: 3.1379
Epoch: 30, Loss: 3.2039
Epoch: 31, Loss: 3.0974
Epoch: 32, Loss: 3.1259
Epoch: 33, Loss: 3.0729
Epoch: 34, Loss: 3.1007
Epoch: 35, Loss: 3.0092
Epoch: 36, Loss: 3.1513
Epoch: 37, Loss: 3.0775
Epoch: 38, Loss: 3.0723
Epoch: 39, Loss: 3.0594
Epoch: 40, Loss: 3.3393
Epoch: 41, Loss: 3.2729
Ep

## Text Generation Post-Training

In [26]:
index_seq = jnp.zeros(shape=(1,1), dtype=jnp.uint16)
max_new_tokens = 100

generated_indices = generate(variables, model.apply, index_seq, rng_key, vocab_size, 1, block_size, max_new_tokens)
generated_indices = list(np.array(generated_indices[0]))
print("Generated text: ")
print(decode(generated_indices))

Generated text: 

ag, g hart,
T:
llo nth anotd, I pe uusurkde,
TOurnm, by
J
O faul?
A bthsurthe.
Coetilc mun he ur, ht
