### Self Attention's Math Trick

In [1]:
import torch
from torch.nn import functional as F
import torch.nn as nn

B,T,C = 4,8,2
x = torch.rand(B, T, C)
print(x.shape)
print(x)

torch.Size([4, 8, 2])
tensor([[[0.7667, 0.5314],
         [0.9172, 0.2774],
         [0.3465, 0.1333],
         [0.7567, 0.7931],
         [0.0519, 0.1533],
         [0.1922, 0.9974],
         [0.3706, 0.7383],
         [0.5901, 0.1120]],

        [[0.4926, 0.9296],
         [0.4528, 0.9448],
         [0.4835, 0.5699],
         [0.6518, 0.5521],
         [0.2763, 0.4441],
         [0.1384, 0.8170],
         [0.1880, 0.5782],
         [0.1035, 0.7034]],

        [[0.8070, 0.4398],
         [0.9748, 0.6560],
         [0.8835, 0.9323],
         [0.0752, 0.4822],
         [0.4767, 0.9107],
         [0.2862, 0.2678],
         [0.3134, 0.9763],
         [0.2066, 0.9792]],

        [[0.9492, 0.9015],
         [0.0651, 0.0087],
         [0.4753, 0.3830],
         [0.4324, 0.9958],
         [0.9457, 0.4595],
         [0.8539, 0.7081],
         [0.3116, 0.0152],
         [0.8553, 0.4381]]])


In [2]:
xbow = torch.zeros((B, T, C))
for batch in range(B):
    for token in range(T):
        token_sub_sequence = x[batch, :token+1]
        xbow[batch, token] = torch.mean(token_sub_sequence, 0)
xbow

tensor([[[0.7667, 0.5314],
         [0.8419, 0.4044],
         [0.6768, 0.3140],
         [0.6968, 0.4338],
         [0.5678, 0.3777],
         [0.5052, 0.4810],
         [0.4860, 0.5177],
         [0.4990, 0.4670]],

        [[0.4926, 0.9296],
         [0.4727, 0.9372],
         [0.4763, 0.8148],
         [0.5202, 0.7491],
         [0.4714, 0.6881],
         [0.4159, 0.7096],
         [0.3833, 0.6908],
         [0.3484, 0.6924]],

        [[0.8070, 0.4398],
         [0.8909, 0.5479],
         [0.8884, 0.6760],
         [0.6851, 0.6276],
         [0.6434, 0.6842],
         [0.5839, 0.6148],
         [0.5452, 0.6664],
         [0.5029, 0.7055]],

        [[0.9492, 0.9015],
         [0.5072, 0.4551],
         [0.4965, 0.4310],
         [0.4805, 0.5722],
         [0.5735, 0.5497],
         [0.6203, 0.5761],
         [0.5762, 0.4960],
         [0.6111, 0.4887]]])

In [3]:
## VECTORIZING THE AVERAGE OPERATION
weights = torch.tril(torch.ones(T, T))
weights = weights / torch.sum(weights, 1, keepdim=True)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [4]:
# weights is (T, T) and x is (B, T, C)
# pytorch recognizes that and creates a batch dimension so that its
# (B, T, T) @ (B, T, C)
weights @ x

tensor([[[0.7667, 0.5314],
         [0.8419, 0.4044],
         [0.6768, 0.3140],
         [0.6968, 0.4338],
         [0.5678, 0.3777],
         [0.5052, 0.4810],
         [0.4860, 0.5177],
         [0.4990, 0.4670]],

        [[0.4926, 0.9296],
         [0.4727, 0.9372],
         [0.4763, 0.8148],
         [0.5202, 0.7491],
         [0.4714, 0.6881],
         [0.4159, 0.7096],
         [0.3833, 0.6908],
         [0.3484, 0.6924]],

        [[0.8070, 0.4398],
         [0.8909, 0.5479],
         [0.8884, 0.6760],
         [0.6851, 0.6276],
         [0.6434, 0.6842],
         [0.5839, 0.6148],
         [0.5452, 0.6664],
         [0.5029, 0.7055]],

        [[0.9492, 0.9015],
         [0.5072, 0.4551],
         [0.4965, 0.4310],
         [0.4805, 0.5722],
         [0.5735, 0.5497],
         [0.6203, 0.5761],
         [0.5762, 0.4960],
         [0.6111, 0.4887]]])

In [5]:
# checks if values are equivalent
torch.allclose(xbow, weights @ x)

True

In [6]:
## USING SOFTMAX
# we are building intution for how we go about encoding contextual information into each token by the previous tokens only
# now instead of just taking the average of the token subsequence (inclusive) to be the current token, we can use softmax to give us a normalized decimal contribution of each previous token to the current token
tril = torch.tril(torch.ones(T, T))
print(tril)
# these are set to 0 right now BUT
# these weights will be the attention pattern that we learn in the future
# this is after the embeddings of all tokens have been mapped to the key query space in the columns and rows respectively and the dot products of these respective columns' and rows' key/query vectors have been computed
# each (tr, tc) in the weights matrix is the (W_k @ T_n) @ (W_q @ T_n)
# key_r = W_k @ T_n
# query_c = W_q @ T_n
# now that we understand the weight matrix we will apply masking so that future token information doesnt get encoded into any token
# the form the weights matrix is in, in this notebook is the transpose of the form it is in, in the 3b1b video
weights = torch.zeros((T, T))
# setting the upper triangle of the weights matrix (which can be seen as 0s in tril) to -inf because after applying softmax the -inf's go to 0 and the rest of the values add up to 1 and we have a valid probability distribution
weights = torch.masked_fill(weights, tril == 0, float("-inf"))
print(weights)
# -1 dim means we do it across the last dimension which is the row dimension
# we can also just use 1 as the dim to specify the row dimension (for a 2x2 weights matrix like we have here) but -1 is generalizable
weights = F.softmax(weights, dim=-1)
print(weights)
softmaxBowWOW = weights @ x
softmaxBowWOW 

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000,

tensor([[[0.7667, 0.5314],
         [0.8419, 0.4044],
         [0.6768, 0.3140],
         [0.6968, 0.4338],
         [0.5678, 0.3777],
         [0.5052, 0.4810],
         [0.4860, 0.5177],
         [0.4990, 0.4670]],

        [[0.4926, 0.9296],
         [0.4727, 0.9372],
         [0.4763, 0.8148],
         [0.5202, 0.7491],
         [0.4714, 0.6881],
         [0.4159, 0.7096],
         [0.3833, 0.6908],
         [0.3484, 0.6924]],

        [[0.8070, 0.4398],
         [0.8909, 0.5479],
         [0.8884, 0.6760],
         [0.6851, 0.6276],
         [0.6434, 0.6842],
         [0.5839, 0.6148],
         [0.5452, 0.6664],
         [0.5029, 0.7055]],

        [[0.9492, 0.9015],
         [0.5072, 0.4551],
         [0.4965, 0.4310],
         [0.4805, 0.5722],
         [0.5735, 0.5497],
         [0.6203, 0.5761],
         [0.5762, 0.4960],
         [0.6111, 0.4887]]])

In [21]:
# IMPLEMENTING A SINGLE HEAD OF ATTENTION - one query and one key
B,T,C = 4,8,32
x = torch.randn(B,T,C)

# SINGLE HEAD
# dimensions of key-query space
head_size = 16
# key and query Linear layers map from the 32 length vector embeddings to the key-query space vectors
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
# here we map the token embeddings to the key-query space
K = key(x) # (B, T, head_size)
Q = query(x) # (B, T, head_size)
# now we need to create the attention pattern matrix which is a T, T matrix for each group of tokens in a batch
# tranpose method for K: (B, T, head_size) => (B, head_size, T)
# we do this tranpose because currently for any batch for Q we have 4 groups of 8, length 16 (head_size) vectors 
# the attention pattern is a square matrix and to create that square matrix attention pattern for each element of the batch we need to tranpose K
# we need calculate the dot product between pairs of key query vectors for one token
# if we do this correctly we should get a scalar value for each key query vector pair for each token for each token
# and for that reason we need to tranpose the K matrix in such a way that we get 4 attention patterns (because thats 1 (T, T) matrix for each item in the batch)
# the reason we multiply by the square root of head size is because the variance of the weights matrix is very high after the key query dot product even if the key and query matrices' varainces are low
# and if we run softmax on the rows with high variance, softmax will converge to one hot vectors
# this is because say for example if we have large magnitude negative values and large magnitude positive values, by themselves they aren't too far apart, but softmax enhances that disparity sharpens it and the large magnitude positive values will converege on one while low values will converge on 0
# for this reason we divide by sqrt(head_size) for normalization to control this variance
# variance drops much lower after applying this term roughly to the same as the original key and query matrices
weights = Q @ K.transpose(-2, -1) * (head_size ** -0.5) # (B, T, T)

tril = torch.tril(torch.ones(T, T))

# weights = torch.zeros((T, T))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)

V = value(x)
out = weights @ V # (B, T, head_size) which is (4, 8, 16) here

'''
out = weights @ x
print(out.shape)
print(out)
'''

torch.Size([4, 8, 32])
tensor([[[-1.1131,  0.5709, -0.1162,  ...,  0.7472,  0.6243, -0.2236],
         [-0.9526,  1.1547, -0.2072,  ...,  0.3372,  0.4379, -0.6110],
         [-0.6856,  0.3799,  0.0517,  ...,  0.1590,  0.2242,  0.0108],
         ...,
         [-1.5999,  0.7860, -0.8684,  ...,  1.2739,  1.7156, -0.3900],
         [-1.2997,  0.9643, -0.6678,  ...,  0.8460,  1.1088, -0.3689],
         [-0.8024,  0.9467, -0.0853,  ...,  0.2087,  0.0361, -0.0566]],

        [[ 1.4529,  0.2413,  0.2447,  ...,  0.9274, -2.4062, -0.2953],
         [ 0.8658, -0.1639,  0.0378,  ...,  0.5807, -1.5448, -0.1430],
         [ 0.4775, -0.6827,  0.0276,  ...,  0.1490, -0.1697,  0.1547],
         ...,
         [ 1.7115, -0.6649,  0.6615,  ...,  0.7575,  0.4393,  0.5986],
         [ 0.0145,  0.2786,  1.2535,  ...,  0.0973, -0.5672,  0.1354],
         [ 0.4402,  0.0713,  0.6413,  ...,  0.0179, -0.4803, -0.1015]],

        [[ 0.6063,  0.9585, -0.3751,  ..., -1.4635,  0.1501,  0.0207],
         [-0.0700,  0.