## Part 2: Building self-attention from Bigram Model flaws

In this chapter we want to extend on our bigram model. We found as significant limitation that we predict the next token only based on the directly preceding token, regardless of the context length. So, e.g., if the context offers 5 preceding tokens, we just ignore them and use the latest one. Now we will try to improve this by designing a mechanism that can take into account inside a context window of specifiable length.

For example: we are given the tokens `["h", "e", "l", "l"]` and instead of using just `"l"` to predict the next token we use all 4 tokens to then predict `"o"`.

For this let's first import necessary libraries and code snippets from the first part so we don't need to rewrite them.

In [1]:
# Imports
import torch

# Dataset processing
!curl.exe --output shakespeare.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open('shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()
    
chars = sorted(list(set(text)))
vocab_size = len(chars)

stoi = { ch:i for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
data = torch.tensor(encode(text), dtype=torch.long)

# Generate train and test split
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# Set hpyerparameters
B,T,C = 4,8,2  # Batch, Block size (timesteps), Number of channels

def get_batch(split):
    data = train_data if split == 'train' else val_data
    idxs = torch.randint(len(data) - T, size=(T,))
    x = torch.stack([data[i:i + T] for i in idxs])
    y = torch.stack([data[i+1:i+T+1] for i in idxs])
    return x,y  

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 1089k  100 1089k    0     0  4964k      0 --:--:-- --:--:-- --:--:-- 4996k


### Making tokens to attend to each other

The easiest way to make tokens (vectors) "attend" to each other is by performing an average over a set of tokens (so to average information from all tokens into one represenation).
The code below ilustrates this for a dummy sample from our shakespeare text.

In [2]:
import torch.nn as nn

def generate_embedding(x):
    torch.manual_seed(42)
    fake_emb = nn.Linear(T,T*2)
    return fake_emb(x)
    
b_x, b_y = get_batch(train_data)    
x,y = b_x[0],b_y[0]
emb_x = generate_embedding(x.to(dtype=torch.float32)).view(T,C)
    
avg_context = []
for t in range(T):
    context = x[:t+1]
    emb_context = emb_x[:t+1]
    target = y[t]
    print(f"To predict: {target} we use this context: {context}")
    print(f"Encoded context: \n{emb_context}")
    avg = torch.mean(emb_context.to(dtype=torch.float32), 0)
    print(f"Averaged context: \n{avg} \n")
    avg_context.append(avg.detach().numpy())


To predict: 50 we use this context: tensor([1])
Encoded context: 
tensor([[32.1431, 18.6732]], grad_fn=<SliceBackward0>)
Averaged context: 
tensor([32.1431, 18.6732], grad_fn=<MeanBackward1>) 

To predict: 43 we use this context: tensor([ 1, 50])
Encoded context: 
tensor([[ 32.1431,  18.6732],
        [ -1.3044, -47.4114]], grad_fn=<SliceBackward0>)
Averaged context: 
tensor([ 15.4194, -14.3691], grad_fn=<MeanBackward1>) 

To predict: 39 we use this context: tensor([ 1, 50, 43])
Encoded context: 
tensor([[ 32.1431,  18.6732],
        [ -1.3044, -47.4114],
        [ 16.9987,  22.9851]], grad_fn=<SliceBackward0>)
Averaged context: 
tensor([15.9458, -1.9177], grad_fn=<MeanBackward1>) 

To predict: 60 we use this context: tensor([ 1, 50, 43, 39])
Encoded context: 
tensor([[ 32.1431,  18.6732],
        [ -1.3044, -47.4114],
        [ 16.9987,  22.9851],
        [-28.0273, -12.7631]], grad_fn=<SliceBackward0>)
Averaged context: 
tensor([ 4.9525, -4.6290], grad_fn=<MeanBackward1>) 

To predic

This sample-wise averaging of the context can be computed very efficiently through a mathematical trick via matrix multiplication.
Standard matrix multiplication is portrayed below:

In [3]:
torch.manual_seed(42)
a = torch.ones(3,3)
b = torch.randint(0,10,(3,2)).float()

In [4]:
print(f"A:\n{a}\nB:\n{b}")
c = a @ b # @ = matrix multiplication (https://en.wikipedia.org/wiki/Matrix_multiplication)
print(f"C:\n{c}")

A:
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
B:
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
C:
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


Through an efficient trick we can effectively compute the inner foor loop (iterating through T) in parallel by using a lower triangular matrix.
In other words, we perform row-wise summation of the the second matrix B, when A is lower triangular:

In [5]:
a_triangular = torch.tril(a)

c = a_triangular @ b
print(f"C:\n{c}")

# Imagine B being our sample of tokens (2 in this case) and what happend below is that in matrix C we performed row wise addition of matrix B.

print(f"B:\n{b}")

C:
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])
B:
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])


We can spin this further by additionally introducing row-wise factors, which quantifies **the weight** of each number in the row wise addition, leading to **row-wise averaging** instead of only addition.
For that we manipulate the lower triangular matrix:

In [6]:
a_triangular_average = a_triangular / torch.sum(a_triangular, 1, keepdim=True)
a_triangular_average

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [7]:
# And now we obtain a row-wise average
c = a_triangular_average @ b
c

tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])

### Efficient Matrix-like implementation of Attention
We use this principle to simplify our double for-loop from the beginning, and implement our first matrix multiplication optimized attention mechanism:

In [8]:
# Optimized implementation with matrix multiplication
W = torch.tril(torch.ones(size=(T,T)))
W = W / torch.sum(W, 1, keepdim=True)

out = W @ emb_x  # shape: (T,T) @ (T,C) --> (T,C)

torch.allclose(torch.tensor(avg_context), out, atol=1e-07)
# We obtain the same result, but much more efficient

  torch.allclose(torch.tensor(avg_context), out, atol=1e-07)


True

In [9]:
# Alternative Version: Using Softmax to create the percentages of contribution of each token
import torch.nn.functional as F

tril = torch.tril(torch.ones(size=(T,T)))
W = torch.zeros((T,T))
W = W.masked_fill(tril == 0, float('-inf'))
print(W)
# Now if we apply row-wise softmax we obtain the same weight matrix W as before
W = F.softmax(W, dim=-1)
print(W)


tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


We achieved now that we can peform predictions about the next token based on the information of all passed tokens (inside a certain window size).
However, there is still a problem: our affinities (the weights of the individual tokens during the averaging) are hardcoded (to the exact same amount, i.e. uniform distribution).

However: we want this to be **data dependent**; information that we gather from the past naturally is not equally important for different future tokens.
Therefore, we can improve by **LEARNING** these affinities between tokens (learn the amount of contribution each past token has for the prediction).

### Self-Attention: learning affinities

Self-attention solves this problem of learning affinities by creating multiple learned representation of the input tokens.
1. Query: "What am I looking for?" (intuition)
2. Key: "What do I contain?" (intuition)

The affinities between the tokens are then computed by performing the dot product of the query and key token representations!

Concretely for one token: One query (one "What am I looking for" token representation) peforms a dot product with the key representation ("What do I contain/can offer" representation) of every other token.
Which quantifies the affinity of all other tokens to this token. (And then this mechanism is done for every token (every query)!)

### Implementation single Self-Attention Head

In [10]:
torch.manual_seed(1337)

B,T,C = 4,8,32 # batch, time, channels per token
inputs = torch.randn(B,T,C) # input of the attention head

head_size = 16 # channel dimension of the key, query representations of the input
key_repr = nn.Linear(C, head_size, bias=False)
query_repr = nn.Linear(C, head_size, bias=False)

# creating the key and query representations:
k = key_repr(inputs)
q = query_repr(inputs)

print(f"Key representation: {k.shape}")
print(f"Query representation: {q.shape}")

Key representation: torch.Size([4, 8, 16])
Query representation: torch.Size([4, 8, 16])


In [11]:
# For computing our learned W we have to compute now the dot product of every query with every key, 
# which can again be effectively done with (you guessed it) matrix multiplication!

W = q @ k.transpose(-2,-1)  # (B,T,16) @ (B,16,T) --> (B,T,T)
print(f"W.shape: {W.shape}")

# Now because we are still discussing masked self-attention (decoder style model) we have to perform the masking (as in the first part)
# The reason is the same as before: We don't want to use information from future tokens
tril = torch.tril(torch.ones(T,T))
W = W.masked_fill(tril == 0, float('-inf'))
W = F.softmax(W, dim=-1)

W[0]

W.shape: torch.Size([4, 8, 8])


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

In [12]:
# After computing the affinities there is an additional learned representation: the value
value_repr = nn.Linear(C, head_size, bias=False)
v = value_repr(inputs)

# This representation is multiplied with the computed affinities with the following intuition:
# "If you find me interesting (determined by the weight matrix W, computed through key and query representations), 
# here is what I will communicate to you!"

output = W @ v
print(f"Output.shape: {output.shape}")
output[0]

Output.shape: torch.Size([4, 8, 16])


tensor([[-0.1571,  0.8801,  0.1615, -0.7824, -0.1429,  0.7468,  0.1007, -0.5239,
         -0.8873,  0.1907,  0.1762, -0.5943, -0.4812, -0.4860,  0.2862,  0.5710],
        [ 0.6764, -0.5477, -0.2478,  0.3143, -0.1280, -0.2952, -0.4296, -0.1089,
         -0.0493,  0.7268,  0.7130, -0.1164,  0.3266,  0.3431, -0.0710,  1.2716],
        [ 0.4823, -0.1069, -0.4055,  0.1770,  0.1581, -0.1697,  0.0162,  0.0215,
         -0.2490, -0.3773,  0.2787,  0.1629, -0.2895, -0.0676, -0.1416,  1.2194],
        [ 0.1971,  0.2856, -0.1303, -0.2655,  0.0668,  0.1954,  0.0281, -0.2451,
         -0.4647,  0.0693,  0.1528, -0.2032, -0.2479, -0.1621,  0.1947,  0.7678],
        [ 0.2510,  0.7346,  0.5939,  0.2516,  0.2606,  0.7582,  0.5595,  0.3539,
         -0.5934, -1.0807, -0.3111, -0.2781, -0.9054,  0.1318, -0.1382,  0.6371],
        [ 0.3428,  0.4960,  0.4725,  0.3028,  0.1844,  0.5814,  0.3824,  0.2952,
         -0.4897, -0.7705, -0.1172, -0.2541, -0.6892,  0.1979, -0.1513,  0.7666],
        [ 0.1866, -0.0

Notes (directly taken from [here](https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing#scrollTo=M5CvobiQ0pLr)):
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `W` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, W will be unit variance too and Softmax will stay diffuse and not saturate too much.