In [1]:
# How can we improve the simple Bigram model?
# Currently, we are predicting the next token only based on the token which is diretly preceeding.

# How can we make predictions based on all past tokens?

import torch, numpy
import torch.nn.functional as F
import torch.nn as nn
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)

In [2]:
x.shape

torch.Size([4, 8, 2])

In [3]:
# Short insertion about attention in decoder-like models:
# For attention in decoder like models we have to take into account that when we perform attention we do masked self-attention.

# Why is that?
# When generating new tokens from a given input sequence the model can only use the past tokens (the input) as reference for deciding what the next token should be
# So we have to mimic this behavior during training; concretely this means that for predicting a token at timestep t the model is only allowed to "attend" to the tokens from position 0 up to t - 1

# As demonstrated in this file when we train a decoder we can turn one sample of k tokens into k-1 samples by using all the subsequences, so 
# (input: t_0, label: t_1), (input: t_0, t_1), label: t_2), ... 

# => Information only flows from previous tokens to the current timestamp, and we cannot get any information from the future, as we try to predict it

In [4]:
# How can you make tokens "attend" to each other?

# The easiest way to do so is performing an average over a set of tokens, which means we get the averaged information from all the tokens
x_avg = torch.zeros_like(x)

# batch dimension
for b in range(B):
    # time dimension of one sample
    for t in range(T):
        x_prev = x[b, :t+1]
        x_avg[b,t] = torch.mean(x_prev,0)
               
# This code snippet essentially did that for every time dimension in every sample in every batch 

In [5]:
# There is a mathematical trick how to make the above computation very efficient through matrix multiplication

torch.manual_seed(42)
a = torch.ones(3,3)
b = torch.randint(0,10,(3,2)).float()

In [6]:
print(f"A:\n{a}\nB:\n{b}")
c = a @ b # @ = matrix multiplication (https://en.wikipedia.org/wiki/Matrix_multiplication)
print(f"C:\n{c}")

A:
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
B:
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
C:
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


In [7]:
# Instead of standard matrix multiplication what one can do is perform row-wise summation of the second matrix B. For this we use torch.tril which gives us a lower triangular matrix.
a_triangular = torch.tril(a)

c = a_triangular @ b
print(f"C:\n{c}")

# Imagine B being our sample of tokens (2 in this case) and what happend below is that in matrix C we performed row wise addition of matrix B.

print(f"B:\n{b}")

C:
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])
B:
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])


In [8]:
# We can spin this further and instead of peforming just row wise addition we perform row wise averaging!
# For that we manipulate the lower triangular matrix:

a_triangular_average = a_triangular / torch.sum(a_triangular, 1, keepdim=True)
a_triangular_average

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [9]:
# And now we obtain a row-wise average
c = a_triangular_average @ b
c

tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])

In [10]:
# We use this principle to simplify our double for-loop attention mechanism from the beginning:

In [11]:
x_avg = torch.zeros_like(x)

# Naive implementation of attention

# batch dimension
for b in range(B):
    # time dimension of one sample
    for t in range(T):
        x_prev = x[b, :t+1]
        x_avg[b,t] = torch.mean(x_prev,0)


# Optimized implementation with matrix multiplication

W = torch.tril(torch.ones(size=(T,T)))
W = W / torch.sum(W, 1, keepdim=True)

out = W @ x  # shape: (B,T,T) @ (B,T,C) | The shape of W is automatically broadcasted to batch dimension

print(x_avg[0])
print(out[0])
torch.allclose(x_avg, out, atol=1e-07)

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


True

In [12]:
# Alternative Version: Using Softmax to create the percentages of contribution of each token

tril = torch.tril(torch.ones(size=(T,T)))
W = torch.zeros((T,T))
W = W.masked_fill(tril == 0, float('-inf'))
W

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [13]:
# Now if we apply row-wise softmax we obtain the same weight matrix W as before
W = F.softmax(W, dim=-1)
W

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [14]:
# We achieved now that we can peform predictions about the next token based on the information of all passed tokens (inside a certain window size).
# However this can still be improved:

# The problem with this dummy interpretation of attention is that the affinities between the tokens are hardcoded.
# Each token contributes the exact same amount to predicting the next token (uniform distribution).

# However: we want this to be data dependent; information that we gather from the past naturally is not equally important for different future tokens.
# Therefore, we can improve by LEARNING these affinities between tokens
# (learn the amount of contribution each past token has for the prediction).

In [15]:
# Self-attention solves this problem of learning affinities by creating multiple learned representation of the input tokens.
# 1. Query: "What am I looking for?" (intuition)
# 2. Key: "What do I contain?" (intuition)

# The affinities between the tokens are then computed by performing the dot product of the query and key token representations!

# Concretely for one token: One query (one "What am I looking for" token representation) peforms a dot product with the key representation ("What do I contain/can offer" representation) of every other token.
# Which quantifies the affinity of all other tokens to this token. (And then this mechanism is done for every token (every query)!)

In [16]:

# Sketch of a single attention head
torch.manual_seed(1337)

B,T,C = 4,8,32 # batch, time, channels per token
inputs = torch.randn(B,T,C) # input of the attention head

head_size = 16 # channel dimension of the key, query representations of the input
key_repr = nn.Linear(C, head_size, bias=False)
query_repr = nn.Linear(C, head_size, bias=False)

# creating the key and query representations:
k = key_repr(inputs)
q = query_repr(inputs)

print(f"Key representation: {k.shape}")
print(f"Query representation: {q.shape}")

Key representation: torch.Size([4, 8, 16])
Query representation: torch.Size([4, 8, 16])


In [17]:
# For computing our learned W we have to compute now the dot product of every query with every key, 
# which can again be effectively done with (you guessed it) matrix multiplication!

W = q @ k.transpose(-2,-1)  # (B,T,16) @ (B,16,T) --> (B,T,T)
print(f"W.shape: {W.shape}")

# Now because we are still discussing masked self-attention (decoder style model) we have to perform the masking again
# The reason is the same as before: We don't want to use information from future tokens
tril = torch.tril(torch.ones(T,T))
W = W.masked_fill(tril == 0, float('-inf'))
W = F.softmax(W, dim=-1)

W[0]

W.shape: torch.Size([4, 8, 8])


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

In [18]:
# After computing the affinities there is an additional learned representation: the value
value_repr = nn.Linear(C, head_size, bias=False)
v = value_repr(inputs)

# This representation is multiplied with the computed affinities with the following intuition:
# "If you find me interesting (determined by the weight matrix W, computed through key and query representations), 
# here is what I will communicate to you!"

output = W @ v
print(f"Output.shape: {output.shape}")
output[0]

Output.shape: torch.Size([4, 8, 16])


tensor([[-0.1571,  0.8801,  0.1615, -0.7824, -0.1429,  0.7468,  0.1007, -0.5239,
         -0.8873,  0.1907,  0.1762, -0.5943, -0.4812, -0.4860,  0.2862,  0.5710],
        [ 0.6764, -0.5477, -0.2478,  0.3143, -0.1280, -0.2952, -0.4296, -0.1089,
         -0.0493,  0.7268,  0.7130, -0.1164,  0.3266,  0.3431, -0.0710,  1.2716],
        [ 0.4823, -0.1069, -0.4055,  0.1770,  0.1581, -0.1697,  0.0162,  0.0215,
         -0.2490, -0.3773,  0.2787,  0.1629, -0.2895, -0.0676, -0.1416,  1.2194],
        [ 0.1971,  0.2856, -0.1303, -0.2655,  0.0668,  0.1954,  0.0281, -0.2451,
         -0.4647,  0.0693,  0.1528, -0.2032, -0.2479, -0.1621,  0.1947,  0.7678],
        [ 0.2510,  0.7346,  0.5939,  0.2516,  0.2606,  0.7582,  0.5595,  0.3539,
         -0.5934, -1.0807, -0.3111, -0.2781, -0.9054,  0.1318, -0.1382,  0.6371],
        [ 0.3428,  0.4960,  0.4725,  0.3028,  0.1844,  0.5814,  0.3824,  0.2952,
         -0.4897, -0.7705, -0.1172, -0.2541, -0.6892,  0.1979, -0.1513,  0.7666],
        [ 0.1866, -0.0

Notes (directly taken from [here](https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing#scrollTo=M5CvobiQ0pLr)):
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `W` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, W will be unit variance too and Softmax will stay diffuse and not saturate too much.