In [10]:
import torch
import torch.nn.functional as F
from torch import nn

In [2]:
torch.manual_seed(1337)

B, T, C = 4, 8, 2
x = torch.randn(B, T, C)

In [3]:
# we want x[b, t] = mean(i <= t) x[b, i]
# note that with this approach there is information loss (for example positional information)
# also this way to compute is inefficient.
xbow = torch.zeros(B, T, C) # b  
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)

In [4]:
# better way to compute 
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (T, T) @ (B, T, C)

In [5]:
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))

# Setting to -inf means that token from the past can't communicate with token of the future. 
wei = wei.masked_fill(tril == 0, float("-inf"))

wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x

In [22]:
# self-attention
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# What we are doing here is checking the affinity between a set of keys and a set of queries.
# wei[i, j] is the affinity between the query i and key j
k = key(x) # B, T, head_size
q = query(x)
wei = q @ k.transpose(-2, -1) # / head_size**1/2   # B, T, T 

tril = torch.tril(torch.ones((T, T)))

# Now we block the comunication from the future to the past, if I'm looking at the 5 word, we need the past context but future tokens can't
# influence past tokens. The graph (in this specific case of language model) is a direct graph of past tokens pointing to future tokens and not
# viceversa.
wei = wei.masked_fill(tril == 0, float("-inf"))

# Normalize with a good distribution
wei = F.softmax(wei, dim=-1)

# v is the "what x wants to communicate and in what way"
v = value(x)
out = wei @ v


# So basically we have 3 informations: 
# what we are looking for (query) 
# what we are offering (key)
# in what way we encode the information (value)

# 1) Attention is a communication mechanism that can be applied to direct graphs.
# 2) Attention has no notion of position, this must be encoded in other ways.
# 3) There is no communication between batch dimension.
# 4) Attention is not limited to be applied on parts of the input; there are cases where full communication could be done (the graph is complete).
# (look also encoder vs decoder)
# 5) The attention mechanism can be implemented in different ways: self-attention (the source is the same for k, q, v), 
# cross-attention (different sources for k, q, v), etc.
# 6) The scaled version softmax(q @ k.transpose / head_size**1/2) is done to mantain the variance of the distribution and avoid the softmax
# converge to onehot vector. Think about what softmax does to value and see what happens when there are extreme values (we get higher probability).

In [23]:
out.shape

torch.Size([4, 8, 16])

In [21]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4029, 0.5971, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3030, 0.6489, 0.0481, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1680, 0.5172, 0.1575, 0.1572, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3082, 0.0732, 0.5796, 0.0154, 0.0236, 0.0000, 0.0000, 0.0000],
        [0.0453, 0.5189, 0.0089, 0.0113, 0.0045, 0.4111, 0.0000, 0.0000],
        [0.1629, 0.1323, 0.0992, 0.1143, 0.3090, 0.0110, 0.1714, 0.0000],
        [0.0207, 0.0220, 0.0072, 0.0171, 0.0152, 0.0180, 0.0511, 0.8487]],
       grad_fn=<SelectBackward0>)