In [1]:
import os; os.chdir('..')
import torch
import torch.nn as nn
from torch.nn import functional as F

from utils import *
from boring_utils.utils import init_graph, set_seed, get_device, cprint, tprint

set_seed(42)

# Math Trick of Self-Attention

At current time stemp, the current token can only communicate with the past.

In [2]:
B, T, C = 4, 8, 2  # batch size, time steps, channels
x = torch.rand(B, T, C)
x.shape

torch.Size([4, 8, 2])

## Version 1: direct implementation

In [3]:
xbow = torch.zeros((B, T, C))  # bag of words

for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]  # (t, C)
        xbow[b, t] = torch.mean(xprev, dim=0)  # (C,) 


In [4]:
x[0]

tensor([[0.8823, 0.9150],
        [0.3829, 0.9593],
        [0.3904, 0.6009],
        [0.2566, 0.7936],
        [0.9408, 0.1332],
        [0.9346, 0.5936],
        [0.8694, 0.5677],
        [0.7411, 0.4294]])

In [5]:
# xbow[0, a] is the previous average of x[0, :a+1]
# (0.8823 + 0.3829) / 2 = xbow[0, 1][0]
# (0.8823 + 0.3829 + 0.6706) / 3 = xbow[0, 2][0]
# (0.8823 + 0.3829 + 0.6706 + 0.1948) / 4 = xbow[0, 3][0]
xbow[0]

tensor([[0.8823, 0.9150],
        [0.6326, 0.9372],
        [0.5519, 0.8251],
        [0.4780, 0.8172],
        [0.5706, 0.6804],
        [0.6313, 0.6659],
        [0.6653, 0.6519],
        [0.6748, 0.6241]])

In [6]:
a = 3
x[0, :a+1].sum(dim=0) / (a + 1) == xbow[0, a]

tensor([True, True])

## Version 2: `torch.tril`

In [7]:
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b

print(a)
print('=' * 10)
print(b)
print('=' * 10)
print(c)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
tensor([[0., 1.],
        [3., 0.],
        [1., 1.]])
tensor([[4., 2.],
        [4., 2.],
        [4., 2.]])


This is a clever way to implement xbow!

In [8]:
a = torch.tril(torch.ones(3, 3))
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b

print(a)
print('=' * 10)
print(b)
print('=' * 10)
print(c)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[7., 9.],
        [4., 3.],
        [8., 9.]])
tensor([[ 7.,  9.],
        [11., 12.],
        [19., 21.]])


In [9]:
# let's make row sum of `a` == 1

a = torch.tril(torch.ones(3, 3))
a = a / a.sum(dim=1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b

print(a)
print('=' * 10)
print(b)
print('=' * 10)
print(c)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[3., 7.],
        [8., 1.],
        [4., 1.]])
tensor([[3.0000, 7.0000],
        [5.5000, 4.0000],
        [5.0000, 3.0000]])


In [10]:
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(dim=1, keepdim=True)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [11]:
xbow2 = wei @ x  # (B, T, T) @ (B, T, C) -> (B, T, C)

# https://pytorch.org/docs/master/generated/torch.allclose.html#torch.allclose
torch.allclose(xbow, xbow2)

True

## Version 3: softmax

In [12]:
tril = torch.tril(torch.ones(T, T))
# print('tril')
# print(tril)
# print('=' * 10)

wei = torch.zeros((T, T))
print('wei zero')
print(wei)
print('=' * 10)

wei = wei.masked_fill(tril == 0, float("-inf"))
print('wei masked')
print(wei)
print('=' * 10)

wei zero
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
wei masked
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])


In [13]:
wei = F.softmax(wei, dim=-1)
print('wei softmax')
print(wei)

xbow3 = wei @ x

torch.allclose(xbow, xbow3)

wei softmax
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

# Add self-attention

- https://en.wikipedia.org/wiki/Attention_(machine_learning)
- https://youtu.be/XhWdv7ghmQQ?t=306

Note:
- 01:11:38 note 1: attention as communication (DAG)
- 01:12:46 note 2: attention has no notion of space, operates over sets
- 01:13:40 note 3: there is no communication across batch dimension
- 01:14:14 note 4: encoder blocks vs. decoder blocks
- 01:15:39 note 5: attention vs. self-attention vs. cross-attention
- 01:16:56 note 6: "scaled" self-attention. why divide by sqrt(head_size)

increase channel to 32

## Not scaled

In [21]:
B, T, C = 4, 8, 32  # batch size, time steps, channels
x = torch.rand(B, T, C)
# print(x.shape)

tril = torch.tril(torch.ones(T, T))

# single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# The default initialization method for nn.Linear is Kaiming uniform
# print(key.weight[0])

# this will break the attention
# nn.init.ones_(key.weight)
# nn.init.ones_(query.weight)
# nn.init.ones_(value.weight)

k = key(x)  # (B, T, head_size) with x as input
q = query(x)  # (B, T, head_size) with x as input

# (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
wei = q @ k.transpose(-2, -1)  # (B, T, T)
print('wei with key and query')
print(wei.shape)
print(wei[0])
print('=' * 10)

# disable this to make nodes communicate with each other fully
wei = wei.masked_fill(tril == 0, float("-inf"))
print('wei masked')
print(wei[0])

wei with key and query
torch.Size([4, 8, 8])
tensor([[-0.4756, -0.4828, -0.6072, -0.5931, -0.3054, -0.4079, -0.5894, -0.2347],
        [-0.5114, -0.3956, -0.3761, -0.4249, -0.2592, -0.4078, -0.5005, -0.2828],
        [-0.6586, -0.4215, -0.7198, -1.0157, -0.5398, -0.6188, -0.6399, -0.3925],
        [-0.5370, -0.5140, -0.6704, -0.5307, -0.3981, -0.8240, -0.7268, -0.4975],
        [-0.4817, -0.3755, -0.1522, -0.3692, -0.1520, -0.2697, -0.3343,  0.1223],
        [-0.6250, -0.6073, -0.5404, -0.7432, -0.5004, -0.5817, -0.4660, -0.1888],
        [-0.7550, -0.7983, -0.4413, -0.4475, -0.6948, -0.7523, -0.6267, -0.2331],
        [-0.8731, -0.7870, -0.5414, -0.7434, -0.7863, -0.9081, -0.6232, -0.3012]],
       grad_fn=<SelectBackward0>)
wei masked
tensor([[-0.4756,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.5114, -0.3956,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.6586, -0.4215, -0.7198,    -inf,    -inf,    -inf,    -inf,    -inf],
       

In [22]:
wei = F.softmax(wei, dim=-1)
print('wei softmax')
print(wei[0])

v = value(x)  # (B, T, head_size) with x as input
out = wei @ v
print(out.shape)  # (B, T, head_size)

wei softmax
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4711, 0.5289, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3117, 0.3951, 0.2932, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2561, 0.2621, 0.2241, 0.2577, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1663, 0.1850, 0.2313, 0.1861, 0.2313, 0.0000, 0.0000, 0.0000],
        [0.1620, 0.1649, 0.1763, 0.1440, 0.1835, 0.1692, 0.0000, 0.0000],
        [0.1268, 0.1214, 0.1735, 0.1724, 0.1346, 0.1271, 0.1441, 0.0000],
        [0.1027, 0.1120, 0.1432, 0.1170, 0.1121, 0.0992, 0.1319, 0.1820]],
       grad_fn=<SelectBackward0>)
torch.Size([4, 8, 16])


In [23]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4711, 0.5289, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3117, 0.3951, 0.2932, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2561, 0.2621, 0.2241, 0.2577, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1663, 0.1850, 0.2313, 0.1861, 0.2313, 0.0000, 0.0000, 0.0000],
        [0.1620, 0.1649, 0.1763, 0.1440, 0.1835, 0.1692, 0.0000, 0.0000],
        [0.1268, 0.1214, 0.1735, 0.1724, 0.1346, 0.1271, 0.1441, 0.0000],
        [0.1027, 0.1120, 0.1432, 0.1170, 0.1121, 0.0992, 0.1319, 0.1820]],
       grad_fn=<SelectBackward0>)

## Scaled

Why scaled is important: we need to feed this into softmax

In [24]:
tmp = torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])

print(tmp.var())
print(torch.softmax(tmp, dim=-1))

# this is more sharpen
print((tmp * 8).var())
print(torch.softmax(tmp * 8, dim=-1))

tensor(0.0950)
tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
tensor(6.0800)
tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])


In [25]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
wei = q @ k.transpose(-2, -1)

print(k.var())
print(q.var())
print(wei.var())  # head_size is 16, so the variance is 16 times larger

tensor(1.0338)
tensor(0.9856)
tensor(16.7034)


In [26]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
wei = q @ k.transpose(-2, -1) * head_size ** -0.5

print(k.var())
print(q.var())
print(wei.var())

tensor(1.0058)
tensor(1.0581)
tensor(1.1634)
