In [11]:
import torch
import torch.nn as nn

torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# this is the naive way
# attention is calculated by dot product of query and key
# query is the information from previous time steps
# key is the information from current time step

# single head of self attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x)  # (B, T, head_size)
q = query(x)  # (B, T, head_size)

# dot product of query and key
# (B, T, 16) @ (B, 16, T) = (B, T, T)
wei = q @ k.transpose(-2, -1) * head_size ** -0.5

tril = torch.tril(torch.ones(T, T))
# these will have actual weight as future iterations have interests in past in varying degrees
# wei = torch.zeros((T, T))
# we dont really care about future
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = torch.softmax(wei, dim=1)
v = value(x)  # (B, T, head_size)
xbow3 = wei @ v


torch.Size([4, 8, 2])

In [12]:
# version 1
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]  # t, C
        xbow[b, t] = torch.mean(xprev, 0)


In [13]:
# version 2
wei = torch.tril(torch.ones(T, T))
wei = wei / torch.sum(wei, 1, keepdim=True)
xbow2 = wei @ x

torch.allclose(xbow, xbow2)


True

In [14]:
# version 3
tril = torch.tril(torch.ones(T, T))
# these will have actual weight as future iterations have interests in past in varying degrees
wei = torch.zeros((T, T))
# we dont really care about future
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = torch.softmax(wei, dim=1)
xbow3 = wei @ x

torch.allclose(xbow, xbow3)


True

In [15]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = torch.matmul(a, b)
a, b, c


(tensor([[1.0000, 0.0000, 0.0000],
         [0.5000, 0.5000, 0.0000],
         [0.3333, 0.3333, 0.3333]]),
 tensor([[2., 7.],
         [6., 4.],
         [6., 5.]]),
 tensor([[2.0000, 7.0000],
         [4.0000, 5.5000],
         [4.6667, 5.3333]]))