## Part 2: Deriving Self-attention from Bigram Model Flaws

In Part 1 we concluded that a large limiting factor of bigram models is that the context on which next tokens are predicted itself can only be one token in size.
This means we can predict the next token only based on the directly preceding token, regardless of the context length. So, e.g., if the context offers 5 preceding tokens, we just ignore them and use the latest one. Now we will try to improve this by designing a mechanism that can take context into account within a context window of variable length.

To this end, let's reuse the dataset, the encoder and the `get_batch` function from part 1.

In [1]:
import torch

# Get dataset
!curl.exe --output shakespeare.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open('shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()
 
# Create encoder and encode dataset
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
data = torch.tensor(encode(text), dtype=torch.long)

# Generate train and test split
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# Set hpyerparameters
B,T,C = 4,3,2  # Batch, Block size (timesteps), Number of channels

def get_batch(split):
    data = train_data if split == 'train' else val_data
    idxs = torch.randint(len(data) - T, size=(T,))
    x = torch.stack([data[i:i + T] for i in idxs])
    y = torch.stack([data[i+1:i+T+1] for i in idxs])
    return x,y  

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 1089k  100 1089k    0     0  4930k      0 --:--:-- --:--:-- --:--:-- 4951k


### How to gather information from multiple tokens?

The easiest way to take into account information from multiple tokens is to average their information into one combined represenation.
The code below ilustrates this for a dummy sample from our shakespeare text; we first need to transform the input tokens into a vector representation, then we average them element-wise.

We keep the result of this naive averaging implementation to demonstrate in the next step that this can be implement much more efficiently.

In [2]:
import torch.nn as nn
torch.manual_seed(42)

def generate_embedding(x):
    fake_emb = nn.Linear(T,T*C)
    return fake_emb(x)
    
b_x, b_y = get_batch(train_data)    
x,y = b_x[0],b_y[0]
emb_x = generate_embedding(x.to(dtype=torch.float32)).view(T,C)
    
result_avg = []
for t in range(T):
    context = x[:t+1]
    emb_context = emb_x[:t+1]
    target = y[t]
    print(f"To predict: {target} we use this context: {context}")
    print(f"Encoded context: \n{emb_context}")
    avg = torch.mean(emb_context.to(dtype=torch.float32), 0)
    print(f"Averaged context: \n{avg} \n")
    result_avg.append(avg.detach().numpy())


To predict: 39 we use this context: tensor([44])
Encoded context: 
tensor([[23.1113, 21.4887]], grad_fn=<SliceBackward0>)
Averaged context: 
tensor([23.1113, 21.4887], grad_fn=<MeanBackward1>) 

To predict: 41 we use this context: tensor([44, 39])
Encoded context: 
tensor([[23.1113, 21.4887],
        [ 5.7481, 32.7760]], grad_fn=<SliceBackward0>)
Averaged context: 
tensor([14.4297, 27.1324], grad_fn=<MeanBackward1>) 

To predict: 43 we use this context: tensor([44, 39, 41])
Encoded context: 
tensor([[ 23.1113,  21.4887],
        [  5.7481,  32.7760],
        [ 17.0037, -17.1892]], grad_fn=<SliceBackward0>)
Averaged context: 
tensor([15.2877, 12.3585], grad_fn=<MeanBackward1>) 



As already pointed out above, this sample-wise averaging of the context can be computed very efficiently through a mathematical trick via matrix multiplication.
For this let's quickly recap standard matrix multiplication again below:

In [3]:
torch.manual_seed(42)
a = torch.ones(3,3)
b = torch.randint(0,10,(3,2)).float()

In [4]:
print(f"A:\n{a}\nB:\n{b}")
c = a @ b # @ = matrix multiplication (https://en.wikipedia.org/wiki/Matrix_multiplication)
print(f"C:\n{c}")

A:
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
B:
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
C:
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


In the inner foor loop of our naive implementation we summed up the token embeddings element wise and then divided them by the number of embedding vectors to obtain the average.
The loop itself dictated the number of tokem embeddings (number of preceding characters) for over which we are averaging.

Through an efficient trick we can summarize this inner foor loop in parallel by using a lower triangular matrix!

By using tril (lower triangular matrix) without any modification on a matrix of token embeddings we effictively perform the element-wise summation in our naive for loop without the averaging step.

In [5]:
a_triangular = torch.tril(a)
print(f"Tril: \n{a_triangular}")
print(f"Token embeddings:\n{b}")

c = a_triangular @ b
print(f"Row-wise summation:\n{c}")

Tril: 
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
Token embeddings:
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
Row-wise summation:
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


When we additionally introduce row-wise factors in the lower triangular matrix, our element-wise addition can become element-wise averaging!
The factors quantify **the weight** of each number in the elemtent-wise addition; when we choose the weight to be 1/n, where n = number of embedding vectors, we obtain the element-wise average of the embedding vectors.

In [6]:
a_triangular_average = a_triangular / torch.sum(a_triangular, 1, keepdim=True)
a_triangular_average

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [7]:
# And now we obtain a row-wise average
c = a_triangular_average @ b
c

tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])

### Efficient Matrix-like Implementation

Let's use our efficient formulation on our sample from the naive averaging foor-loop to verify that we indeed obtain the same result.

In [8]:
# Optimized implementation with matrix multiplication
W = torch.tril(torch.ones(size=(T,T)))
W = W / torch.sum(W, 1, keepdim=True)

out = W @ emb_x  # shape: (T,T) @ (T,C) --> (T,C)

torch.allclose(torch.tensor(result_avg), out, atol=1e-07)
# We obtain the same result, but much more efficient

  torch.allclose(torch.tensor(result_avg), out, atol=1e-07)


True

Very good!

Small note: We can derive the lower triangular averaging matrix W also by using the softmax function.

In [9]:
# Alternative Version: Using Softmax to create the percentages of contribution of each token
import torch.nn.functional as F

tril = torch.tril(torch.ones(size=(T,T)))
W = torch.zeros((T,T))
W = W.masked_fill(tril == 0, float('-inf'))
print(W)

# Now if we apply row-wise softmax we obtain the same weight matrix W as before
W = F.softmax(W, dim=-1)
print(W)


tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])


We achieved now that we can peform predictions about the next token **based on the information of all passed tokens** (inside a certain window size).
However, there is still a problem: our factors (the weights of the individual tokens during the averaging) are hardcoded to the exact same amount, i.e. uniform distribution.

However: we want this to be **data dependent**; information that we gather from the past naturally is not equally important for different future tokens.
Therefore, we can improve by **LEARNING** these factors (learn the amount of contribution each past token has for the prediction).

### Self-Attention

Now we are already at the heart of what makes self-attention so powerful. 
Through learning these factors we can find out the "affinities" between tokens and therefore can generate **content dependent representations** that more provide relevant information for deciding upon the next token.

Self-attention solves this problem of learning affinities by creating multiple learned representation of the input tokens.
1. Query: "What am I looking for?" (intuition)
2. Key: "What do I contain?" (intuition)

The affinities between the tokens are then computed by performing the dot product of the query and key token representations!

Concretely for one token: One query (one "What am I looking for" token representation) peforms a dot product with the key representation ("What do I contain/can offer" representation) of every other token.
Which quantifies the affinity of all other tokens to this token. (And then this mechanism is done for every token (every query)!)

### Implementation of a Single Self-Attention Head

With all what we have learned so far let's go ahead and implement a single self-attention head. 
As already discussed in self-attention multiple **learned representations** are used, concretely **key, query, value**.

In the following we will implement these representations and develop an intutation for why exactly these representations are used.

First let's talk about key and query. Here the input is passed through two linear layers.
These layers represent the learnable parameters that determine the representation of the input.
Mr. Karpathy described their intuition as follows:
- Key: "**What am I looking for**" representation of a token
- Query: "**What do I contain/can offer**" representation of a token

In [10]:
torch.manual_seed(1337)

B,T,C = 4,8,32 # batch, time, channels per token
inputs = torch.randn(B,T,C) # input of the attention head

head_size = 16 # channel dimension of the key, query representations of the input
key_repr = nn.Linear(C, head_size, bias=False)
query_repr = nn.Linear(C, head_size, bias=False)

# creating the key and query representations:
k = key_repr(inputs)
q = query_repr(inputs)

print(f"Key representation: {k.shape}")
print(f"Query representation: {q.shape}")

Key representation: torch.Size([4, 8, 16])
Query representation: torch.Size([4, 8, 16])


Together, the key and query representation compute the weight matrix W. So we have achieved that our affinties between tokens are now in a **learned** matrix representation that is **input dependent**!

In [11]:
# To compute our learned W we have to calculate the dot product of every query with every key, 
# which can again be effectively done with matrix multiplication!

W = q @ k.transpose(-2,-1)  # (B,T,16) @ (B,16,T) --> (B,T,T)
print(f"W.shape: {W.shape}")

# Now because we are still discussing masked self-attention (decoder style model) we have to perform the masking (as in the first part)
# The reason is the same as before: We don't want to use information from future tokens
tril = torch.tril(torch.ones(T,T))
W = W.masked_fill(tril == 0, float('-inf'))
W = F.softmax(W, dim=-1)

W[0]

W.shape: torch.Size([4, 8, 8])


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

After computing the affinities, we then use an additional learned representation, the **value** representation, to determine the output of self-attention.
The intuition Mr. Karpathy provided is the following:

"If you find me interesting (determined by the weight matrix W, computed through key and query representations), **here is what I will communicate to you!**"

In [12]:
value_repr = nn.Linear(C, head_size, bias=False)
v = value_repr(inputs)

result_self_attention = W @ v
print(f"Output.shape: {result_self_attention.shape}")
result_self_attention[0]

Output.shape: torch.Size([4, 8, 16])


tensor([[-0.1571,  0.8801,  0.1615, -0.7824, -0.1429,  0.7468,  0.1007, -0.5239,
         -0.8873,  0.1907,  0.1762, -0.5943, -0.4812, -0.4860,  0.2862,  0.5710],
        [ 0.6764, -0.5477, -0.2478,  0.3143, -0.1280, -0.2952, -0.4296, -0.1089,
         -0.0493,  0.7268,  0.7130, -0.1164,  0.3266,  0.3431, -0.0710,  1.2716],
        [ 0.4823, -0.1069, -0.4055,  0.1770,  0.1581, -0.1697,  0.0162,  0.0215,
         -0.2490, -0.3773,  0.2787,  0.1629, -0.2895, -0.0676, -0.1416,  1.2194],
        [ 0.1971,  0.2856, -0.1303, -0.2655,  0.0668,  0.1954,  0.0281, -0.2451,
         -0.4647,  0.0693,  0.1528, -0.2032, -0.2479, -0.1621,  0.1947,  0.7678],
        [ 0.2510,  0.7346,  0.5939,  0.2516,  0.2606,  0.7582,  0.5595,  0.3539,
         -0.5934, -1.0807, -0.3111, -0.2781, -0.9054,  0.1318, -0.1382,  0.6371],
        [ 0.3428,  0.4960,  0.4725,  0.3028,  0.1844,  0.5814,  0.3824,  0.2952,
         -0.4897, -0.7705, -0.1172, -0.2541, -0.6892,  0.1979, -0.1513,  0.7666],
        [ 0.1866, -0.0

And with that we have derived self-attention!

Here are some additional notes that I took directly from Mr. Karpathy's tutorial [here](https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing#scrollTo=M5CvobiQ0pLr):

- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `W` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, W will be unit variance too and Softmax will stay diffuse and not saturate too much.