### Attention 1: Average of previous tokens


In [26]:
import torch

torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [27]:
"""
Method 1: nested loop over batch and token
say for batch=0 and token=5
It will sum t=0 to t=5 then get the mean on dim 0
This will be bag of words at batch=0
"""
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)

print(x[0])
print('\n')
print(xbow[0])

"""
so x[0][0][0] = 0.1808 and xbow[0][0][0] = -0.3596
(0.1808 - 0.3596)/2 = -0.0894 = xbow[0][1][0]
"""

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])


tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


'\nso x[0][0][0] = 0.1808 and xbow[0][0][0] = -0.3596\n(0.1808 - 0.3596)/2 = -0.0894 = xbow[0][1][0]\n'

In [21]:
"""
Method 2: Matrix multiplication
Ones matrix @ matrix = sums up every row
or output[0][0] = input[0][0] + input[1][0] + input[2][0] + ... (upto input[T][0]) 

this means the first token has attended all the tokens which we dont want. Each token should only be able to see behind. So we use lower triangular matrix
This results in every token can attend upto that token

output[0][0] = input[0][0]
output[1][0] = input[0][0] + input[1][0]
...
"""

torch.manual_seed(1337)
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3, 2)).float()
print(f"print a=\n{a}")
print(f"\nprint b=\n{b}")
c = a@b
print(f"\nmatmul with ones matrix:\n{c}")
l = torch.tril(torch.ones(3, 3))
print(f"\nLower triangular matrix:\n{l}")
d = l@b
print(f"\nmatmul with lower triangular matrix:\n{d}")

print a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

print b=
tensor([[5., 7.],
        [2., 0.],
        [5., 3.]])

matmul with ones matrix:
tensor([[12., 10.],
        [12., 10.],
        [12., 10.]])

Lower triangular matrix:
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

matmul with lower triangular matrix:
tensor([[ 5.,  7.],
        [ 7.,  7.],
        [12., 10.]])


In [24]:
"""
Since we are using mean of all the previous tokens as attention
What we can do is find the mean of the lower triangual matrix before matmul
"""
l = torch.tril(torch.ones(3, 3))
l = l/torch.sum(l, 1, keepdim=True)
print(f"mean over dim=1 for lower triangular matrix:\n{l}")
print(f"\nprint b=\n{b}")
d = l@b
print(f"\nmean of previous tokens:\n{d}")

mean over dim=1 for lower triangular matrix:
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

print b=
tensor([[5., 7.],
        [2., 0.],
        [5., 3.]])

mean of previous tokens:
tensor([[5.0000, 7.0000],
        [3.5000, 3.5000],
        [4.0000, 3.3333]])


In [None]:
# Lets match it with out original xbow value with data x
torch.manual_seed(1337)
mask = torch.tril(torch.ones(T, T))
wei = mask/mask.sum(1, keepdim=True) # (T, T) torch add the Bth dim when doing matmul
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ---> (B, T, C)

In [35]:
"""
Method 3: using -inf and softmax instead of doing mean in the lower triangular matrix
why?
- This is the best method since we can use a weights matrix
- instead of average we can use probablilites over the weights
"""
tril = torch.tril(torch.ones((T, T)))
print(f"Tril:\n{tril}")
wei = torch.zeros((T, T)) # We are currently intializing with zeros but later it will have proper weights
wei = wei.masked_fill(tril==0, float('-inf'))
print(f"Weights Mask:\n{wei}")

# Now we use softmax to convert the weights matrix into probs.
# Since we have initialized the weights as Zeros matrix we will get the average over the rows
wei = torch.nn.functional.softmax(wei, dim=-1)
print(f"Probs:\n{wei}")

xbow3 = wei @ x
torch.allclose(xbow2, xbow3)

Tril:
tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
Weights Mask:
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
Probs:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.25

True

### Attention 2: Weighted aggregation of past tokens or Self-Attention 
We need to gather information from the past instead of just using average of previous tokens.

**Self Attention**  
Every token will emit: *Query and Key Vectors*  
Query -> "What am I looking for?"  
Key ->  "What do I contain?"  

Now to get the afinities between the tokens now is **Query @ Key = wei**  
We also get a *value* vector to aggregate with the *wei*

**Notes:**  
- Attention does not have a notion of space. They are simple set of vectors. Tha's why we need positional encoding
- The batches never "talk" to each other. Each batch is processed in isolation
- If you need "encoder" block you have to use ones matrix instead of `tril` since, all the tokens need to "talk" to each other. For "deocder" only/ autoregressive models we need the lower trainagular matrix
- "Self-Attention" is when keys, queries and values comes from the same input. While "Cross-Attention" is when the queries comes from a different input
- "Scaled" attention additionally divides `wei` by 1/sqrt(head_size) to normalize the output. without it the variance would be on the order of head_size. In my case I have 16 heads so the variance would be ~16

In [43]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn((B, T, C))

# Single head of self-attention
head_size = 16
query = nn.Linear(C, head_size, bias=False)
key = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
v = value(x) # (B, T, 16)

wei = q @ k.transpose(-2, -1) # (B, T, 16) --(transpose)--> (B, 16, T); wei = (B, T, 16) @ (B, 16, T) --(dot)--> (B, T, T)
wei = wei * head_size**-0.5
print(f"Weights[0]:\n{wei[0]}")

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)
print(f"Weights[0] probs:\n{wei[0]}")
out = wei @ v

out.shape

Weights[0]:
tensor([[-0.4407, -0.8334, -0.2557,  0.1959, -0.3142, -0.0782,  0.2719, -0.4511],
        [-0.3253, -0.4139, -0.3152, -0.2004,  0.0047,  0.6038,  0.4913, -0.1031],
        [ 0.1413,  0.0260,  0.0191, -0.0842, -0.1970, -0.0276, -0.0655, -0.2077],
        [ 0.5404,  0.8446, -0.0953, -0.2124, -0.3301, -0.2483, -0.0789,  0.1475],
        [-0.2668, -0.5456, -0.2461, -0.1401,  0.5091,  0.8362,  0.1523, -0.1997],
        [ 0.4908,  0.2604, -0.3576, -0.2925,  0.2160, -0.6307,  0.3154, -0.1464],
        [ 0.2691, -0.0139,  0.0187, -0.3232,  0.0930,  0.3547, -0.1371,  0.1608],
        [-0.1132,  0.0732, -0.2387, -0.2565,  0.2314,  0.3049,  0.2012,  0.1576]],
       grad_fn=<SelectBackward0>)
Weights[0] probs:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5221, 0.4779, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3602, 0.3210, 0.3188, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2980, 0.4039, 0.1578, 0.1404, 0.0000, 0.0000, 

torch.Size([4, 8, 16])