In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
torch.manual_seed(1337)
B, T, C = 4,8,2
x = torch.randn(B,T,C)
x.shape


torch.Size([4, 8, 2])

In [3]:
# We want x[b,t] = mean_{i <= t} x[b, i]
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xbow[b,t] = torch.mean(x[b,:t+1], dim=0)

In [5]:
wei = torch.tril(torch.ones((T,T)))
wei = wei / torch.sum(wei, dim=1, keepdim=True)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) = (B, T, C)
torch.allclose(xbow, xbow2)

True

In [6]:
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [7]:
# version 3 using softmax
tril = torch.tril(torch.ones((T,T)))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x # (B, T, T) @ (B, T, C) = (B, T, C)
torch.allclose(xbow, xbow3)

True

In [9]:
# version 4: self attention!
torch.manual_seed(1337)
B, T, C = 4,8,32
x = torch.randn(B,T,C)

# single head attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, C) -> (B, T, head_size)
q = query(x) # (B, T, C) -> (B, T, head_size)
v = value(x) # (B, T, C) -> (B, T, head_size)


# dot product attention
wei = q @ k.transpose(-1,-2) # (B, T, head_size) @ (B, head_size, T) = (B, T, T) 
 

tril = torch.tril(torch.ones((T,T))) 
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)
out = wei @ v # (B, T, T) @ (B, T, head_size) = (B, T, head_size)
print(out.shape)



torch.Size([4, 8, 16])
