In [49]:
import torch
from torch import nn

In [59]:
torch.manual_seed(1337)

# batch, time, channels
# batch, input tokens, "some information per each input token"
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)

# we now want to somehow couple these 8 tokens, so that they can 
# share information between each other (without lookahead)

# the easiest way to do that is to just average the previous tokens
# of course it's also the most lossy way

xbow = torch.zeros(B, T, C)
for b in range(B):
    for t in range(T):
        # we pick batch, and for each token in this batch
        # gather all tokens leading up to it and token itself
        xprev = x[b, :t+1]
        xbow[b, t] = xprev.mean(0)

print("Naive for-loop average")
print(xbow[0])

# this can also be sped up using matmuls
# we basically seed the multiplicative accumulators
# and multiply the inputs with mean-making-matrix
tril = torch.tril(torch.ones((8, 8)))
tril = tril / tril.sum(1, keepdim=True)
xbow2 = tril @ x

print("\nMatmul average")
print(xbow2[0])

# the normalization of tril elements can also be done
# using softmax, while filling 0 elements as -inf
tril = torch.tril(torch.ones((8, 8)))
mat = torch.zeros(8, 8)
mat = mat.masked_fill(tril == 0, float("-inf"))
mat = mat.softmax(1)
xbow3 = mat @ x

print("\nMatmul softmax average")
print(xbow2[0])

Naive for-loop average
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

Matmul average
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

Matmul softmax average
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


In [75]:
torch.manual_seed(1337)

# now that we know how to gather information, we can properly
# get into how self-attention works
# batch, tokens, embeddings
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# the idea is to build query and keys vectors
# their dot product will build the average matrix
# that we had before

# we add linear layers that we will pass x through
# so that the vectors can learn to behave like
# queries and keys

head_size = 16
key   = nn.Linear(C, head_size, bias=False) # B, T, C -> B, T, 16
query = nn.Linear(C, head_size, bias=False) # B, T, C -> B, T, 16
value = nn.Linear(C, head_size, bias=False) # B, T, C -> B, T, 16

k = key(x)
q = query(x)

# we now want a dot product to give us an attentiom matrix
# as before, we want this matrix to be (B, T, T)
# as both vectors are (B, T, 16), we need to transpose
# so: (B, T, 16) x (B, 16, T) -> (B, T, T)
# torch transpose is not ordered, only asks what dims

mat = k @ q.transpose(-2, -1)

tril = torch.tril(torch.ones(T, T))
mat = mat.masked_fill(tril == 0, float("-inf"))
mat = mat.softmax(-1)

# softmax(qk) @ v 

v = value(x)
att = mat @ v # [B, T, T] x [B, T, 16]

# each row shows how interested the token equal to row id 
# was with each of tokens up to its position
# ie. first row shows that first token was maximally interested
# in token 1, becasue it cant see any future tokens, etc

mat[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5877, 0.4123, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4457, 0.2810, 0.2733, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2220, 0.7496, 0.0175, 0.0109, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0379, 0.0124, 0.0412, 0.0630, 0.8454, 0.0000, 0.0000, 0.0000],
        [0.5497, 0.2187, 0.0185, 0.0239, 0.1831, 0.0062, 0.0000, 0.0000],
        [0.2576, 0.0830, 0.0946, 0.0241, 0.1273, 0.3627, 0.0507, 0.0000],
        [0.0499, 0.1052, 0.0302, 0.0281, 0.1980, 0.2657, 0.1755, 0.1474]],
       grad_fn=<SelectBackward0>)

In [None]:
# Q: Difference between encoder and decoder block
# the main difference between a decoder and encoder blocks is 
# allowing a lookahead, by removing the tril masking
# in an encoder all tokens see each other

# Q: Self-attention vs cross-attention
# self-attention means that q, k, v come from same source x
# q, k can come from encoder blocks, while x comes
# from decoder blocks -- then it is called cross attention

# Q: Scaling by the sqrt(head_size)
# scaling is again used to keep the variance in check

# Q: Multihead attention
# usually the head size is scaled by some x, and then additional x
# attention heads are created, then you concatenate the outputs
# and receive same shaped outputs, but heads can learn different 
# things