In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F

B: batch dimension
T: sequence length
D: feature dimension

In [10]:
torch.manual_seed(1337)
B,T,D = 4,8,2
x = torch.randn(B,T,D)
x.shape

torch.Size([4, 8, 2])

We average the previous and current tokens.  
x[b,t] = mean(x[b,1....t]) for t in 1...T

In [11]:
xbow = torch.zeros((B,T,D))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1,:] #(T,D)
        xbow[b,t,:] = torch.mean(xprev,dim=0) # reduce the time dimension

In [12]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [13]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

Use matrix multiplication to calculate the mean of the previous tokens.

In [14]:
torch.manual_seed(1337)
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a,dim=1,keepdim=True) # sum up the rows
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('b=')
print(b)
print('c=')
print(c)


a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=
tensor([[5., 7.],
        [2., 0.],
        [5., 3.]])
c=
tensor([[5.0000, 7.0000],
        [3.5000, 3.5000],
        [4.0000, 3.3333]])


In [22]:
wei = torch.tril(torch.ones(T,T))
wei  = wei / torch.sum(wei,dim=1,keepdim=True) # sum up the rows
xbow2 = wei @ x #(T,T) @ (B,T,D) -> (B,T,D)

softmax

In [26]:
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros(T,T)
wei = wei.masked_fill(tril == 0, float('-inf')) # mask out the upper triangle (masked-attention)
wei = F.softmax(wei,dim=2) # reduce the time dimension
wei
# same way to produce the mask

True

# self attention

In [37]:
torch.manual_seed(1337)
B,T,D = 4,8,32
x = torch.randn(B,T,D)

head_size = 16
key = nn.Linear(D,head_size)
query = nn.Linear(D,head_size)

k = key(x) # (B,T,D) -> (B,T,head_size)
q = query(x) # (B,T,D) -> (B,T,head_size)

wei = k @ q.transpose(2,1) # (B,T,head_size) @ (B,head_size,T) -> (B,T,T)
tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril == 0, float('-inf')) # mask out the upper triangle (masked-attention)
wei = F.softmax(wei,dim=2)
wei.shape # (B,T,T)

torch.Size([4, 8, 8])

In [39]:
value = nn.Linear(D,head_size)
v = value(x) # (B,T,D) -> (B,T,head_size)
out = wei @ v # (B,T,T) @ (B,T,head_size) -> (B,T,head_size)
out.shape # (B,T,head_size)

torch.Size([4, 8, 16])

# Scaled product attention
since softmax

In [40]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
print('The variance of k is', torch.var(k).item())
print('The variance of q is', torch.var(q).item())
wei = k @ q.transpose(2,1)
print('The variance of wei is', torch.var(wei).item())


The variance of k is 1.046860933303833
The variance of q is 0.8975691199302673
The variance of wei is 17.88959312438965


In [41]:
wei = k @ q.transpose(2,1) / (head_size**0.5)
print('The variance of wei is', torch.var(wei).item())

The variance of wei is 1.118099570274353
