In [19]:
import torch
import torch.nn.functional as F

# The mathematical trick in self-attention

In [2]:
# consider the following toy example:
torch.manual_seed(1337)
B, T, C = 4, 8, 2 # batch size, sequence length, and number of classes
x = torch.randn(B, T, C) # input sequence
x.shape

torch.Size([4, 8, 2])

In [4]:
# version 1
# We want x[b,t] = mean_{i<=t} x[b,i] for all b,t
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C)
        xbow[b,t] = torch.mean(xprev, dim=0) # (C,)

In [13]:
# version 2
wei = torch.tril(torch.ones((T, T))) # (T, T)
wei = wei / wei.sum(dim=1, keepdim=True) # (T, T)
xbow2 = wei @ x # (T, C) @ (B, T, C) -> (B, T, C)
torch.allclose(xbow, xbow2)

True

In [22]:
# version 3: use Softmax
tril = torch.tril(torch.ones((T, T))) # (T, T)
wei = torch.zeros((T, T)) # (T, T)
wei = wei.masked_fill(tril == 0, float('-inf')) # (T, T)
wei = F.softmax(wei, dim=-1) # (T, T)
xbow3 = wei @ x # (T, C) @ (B, T, C) -> (B, T, C)

In [12]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
a = a / a.sum(dim=1, keepdim=True)
b = torch.randint(0, 10, (3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [27]:
import torch.nn as nn

# version 4: self-attention!
torch.manual_seed(1337)
B, T, C = 4, 8, 32 # batch size, sequence length, and emb dim
x = torch.randn(B, T, C) # input sequence

# let's implement a single Head of self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, head_size)
q = query(x) # (B, T, head_size)

wei = q @ k.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)



tril = torch.tril(torch.ones((T, T))) # (T, T)
wei = torch.zeros((T, T)) # (T, T)
wei = wei.masked_fill(tril == 0, float('-inf')) # (T, T)
wei = F.softmax(wei, dim=-1) # (T, T)

v = value(x) # (B, T, head_size)

out = wei @ v

out.shape

torch.Size([4, 8, 16])

In [28]:
k.var()

tensor(0.3164, grad_fn=<VarBackward0>)

In [29]:
wei.var()

tensor(0.0273)