math trick inside self-attention

In [61]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [62]:
# consider the following toy example:
torch.manual_seed(1337)
B, T, C = 4, 8, 2   # batch, time, channels
x = torch.randn(B, T, C)
print(x.shape)

torch.Size([4, 8, 2])


In [63]:
# We want x[b,t] = mean_{i <= t} x[b, i]
xbow = torch.zeros((B,T,C))     # bag of words
for b in range(B):      # batch
    for t in range(T):  # time
        xprev = x[b, :t+1]   # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)   # averaging time

In [64]:
# version 2
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim = True)
xbow2 = wei @ x  # create a batch (B, T, T) @ (B, T, C) --> (B, T, C)   identical to xbow
torch.allclose(xbow, xbow2)

True

In [65]:
xbow[0], xbow2[0]       # first batch

(tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

In [66]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [67]:
xbow[0]      # average of the first x lines of x[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [68]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1,keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

# sum of first n elements
# doing averaging on matrix

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [69]:
# version 3: use softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))     # not 0 eles will be -inf
wei = F.softmax(wei, dim = -1) 
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

True

k: what I hold

q: what I am looking for

v: do not agregate all the values together. the vectors we agregate. "here's what I will communicate to you" 

Notes:

Attention is a communication mechanism. Can be seen as nodes in a directed graph looking at each other and aggregating 
with a weighted sum from all nodes that point to them, with data-dependent wrights

There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.

Each example across batch dimension is of course processed completelt independently and never talk to each other

"self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the 

4 seperate pools of 8

self-attention: q, k, v comes from the same x

cross-attention: q from x, k and v from different source

"scaled" attention: divide wei by 1/sqrt(head_size). This makes it so when input q, k are unit variance, wei will be unit variant too and softmax will stay diffuse and not saturate too much. 


In [70]:
# version 4: self attention
torch.manual_seed(1337)
B, T, C = 4, 8, 32  # batch, time, channels
x  = torch.randn(B, T, C)

# single Head self-attention
head_size = 16
key = nn.Linear(C, head_size, bias = False)
query = nn.Linear(C, head_size, bias = False)
value = nn.Linear(C, head_size, bias = False)
k = key(x)      # forwarding these modules on x  (B, T, 16) 
q = query(x)    # (B, T, 16)
wei = q @ k.transpose(-2, -1) * head_size**-0.5    # transpose the last 2 dimensions 
                                # (B, T, 16) @ (B, 16, T)  ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim = -1)
v = value(x)
out = wei @ v   # (B, T, head_size), single head
# out = wei @ x

out.shape

torch.Size([4, 8, 16])

In [71]:
# raw products
wei[0]      # (4, 8, 8)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3966, 0.6034, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3069, 0.2892, 0.4039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3233, 0.2175, 0.2443, 0.2149, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1479, 0.2034, 0.1663, 0.1455, 0.3369, 0.0000, 0.0000, 0.0000],
        [0.1259, 0.2490, 0.1324, 0.1062, 0.3141, 0.0724, 0.0000, 0.0000],
        [0.1598, 0.1990, 0.1140, 0.1125, 0.1418, 0.1669, 0.1061, 0.0000],
        [0.0845, 0.1197, 0.1078, 0.1537, 0.1086, 0.1146, 0.1558, 0.1553]],
       grad_fn=<SelectBackward0>)

In [72]:
# after softmax
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3966, 0.6034, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3069, 0.2892, 0.4039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3233, 0.2175, 0.2443, 0.2149, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1479, 0.2034, 0.1663, 0.1455, 0.3369, 0.0000, 0.0000, 0.0000],
        [0.1259, 0.2490, 0.1324, 0.1062, 0.3141, 0.0724, 0.0000, 0.0000],
        [0.1598, 0.1990, 0.1140, 0.1125, 0.1418, 0.1669, 0.1061, 0.0000],
        [0.0845, 0.1197, 0.1078, 0.1537, 0.1086, 0.1146, 0.1558, 0.1553]],
       grad_fn=<SelectBackward0>)

In [73]:
k.var()

tensor(0.3164, grad_fn=<VarBackward0>)

In [74]:
q.var()

tensor(0.3386, grad_fn=<VarBackward0>)

In [75]:
wei.var()

tensor(0.0287, grad_fn=<VarBackward0>)

In [76]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])

In [2]:
class BatchNorm1d:

    def __init__(self, dim, eps = 1e-5, momentum = 0.1):
        self.eps = eps
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)

    def __call__(self, x):
        # calculate the forward pass
        xmean = x.mean(1, keepdim = True)   # batch mean
        xvar = x.var(1, keepdim = True)     # batch variance
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)     # normalize to unit variance
        self.out = self.gamma * xhat + self.beta
        return self.out
    
    def parameters(self):
        return [self.gamma, self.beta]


In [3]:
import torch
torch.manual_seed(1337)
module = BatchNorm1d(100)
x = torch.randn(32, 100)    # batch size 32 of 100-dimensional vectors
x = module(x)
x.shape

torch.Size([32, 100])