In [58]:
import torch
import torch.nn as nn
import torch.nn.functional as F

B: batch dimension  
T: sequence length  
D: feature dimension (hidden_size)  

In [59]:
torch.manual_seed(1337)
B,T,D = 4,8,2
x = torch.randn(B,T,D)


We average the previous and current tokens.  
![image](images/bow.png)

In [60]:
xbow = torch.zeros((B,T,D))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1,:] #(T,D)
        xbow[b,t,:] = torch.mean(xprev,dim=0) # reduce the time dimension

In [61]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [62]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

Use matrix multiplication to calculate the mean of the previous tokens.

In [63]:
a = torch.tril(torch.ones(3,3)) # lower traingular matrix
a = a / torch.sum(a,dim=1,keepdim=True) # sum up the rows
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('b=')
print(b)
print('c=')
print(c)


a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=
tensor([[8., 6.],
        [5., 2.],
        [4., 4.]])
c=
tensor([[8.0000, 6.0000],
        [6.5000, 4.0000],
        [5.6667, 4.0000]])


In [64]:
# a stands for weight matrix
# b stands for input matrix
# c stands for output matrix
wei = torch.tril(torch.ones(T,T))
wei  = wei / torch.sum(wei,dim=1,keepdim=True) # sum up the rows
print(f'weight matrix: \n{wei}')
xbow2 = wei @ x #(T,T) @ (B,T,D) -> (B,T,D)
xbow2[0]


weight matrix: 
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

We can also use softmax to produce the attention weights.

In [65]:
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros(T,T)
wei = wei.masked_fill(tril == 0, float('-inf')) # mask out the upper triangle (masked-attention)
wei = F.softmax(wei,dim=-1) # reduce the time dimension
print(f'weight matrix: \n{wei}')
xbow3 = wei @ x
xbow3[0]


weight matrix: 
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

### Self attention

Previous weight matrix simply averages the previous tokens and current tokens.  
Can we modify the weight matrix to pay more attention to specific token?  
Yes, we can use self-attention machanism to do so.  

In [66]:
B,T,D = 1,8,32
x = torch.randn(B,T,D)

head_size = 16
key = nn.Linear(D, head_size)
query = nn.Linear(D, head_size)

k = key(x) # (B,T,D) -> (B,T,head_size)
q = query(x) # (B,T,D) -> (B,T,head_size)

# previous, we simply set weight matrix to be all zeros
wei = k @ q.transpose(2,1) # (B,T,head_size) @ (B,head_size,T) -> (B,T,T)
tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril == 0, float('-inf')) # mask out the upper triangle (masked-attention)
wei = F.softmax(wei,dim=-1)
print(f'weight matrix: \n{wei}')

weight matrix: 
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.8568, 0.1432, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.9035, 0.0319, 0.0646, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0794, 0.7826, 0.0262, 0.1117, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2599, 0.0619, 0.0666, 0.5537, 0.0579, 0.0000, 0.0000, 0.0000],
         [0.1648, 0.0910, 0.0811, 0.1171, 0.1276, 0.4185, 0.0000, 0.0000],
         [0.5038, 0.0824, 0.0079, 0.2029, 0.0508, 0.0840, 0.0683, 0.0000],
         [0.1233, 0.1467, 0.1079, 0.1274, 0.0938, 0.3208, 0.0458, 0.0343]]],
       grad_fn=<SoftmaxBackward0>)


Previously, We do dot product between attention weights and the embedding of the tokens. 
Here, we perform dot product between the attention weights and "values" of the tokens.  The value is produced by passing the embedding through a linear layer.

In [67]:
value = nn.Linear(D,head_size)
v = value(x) # (B,T,D) -> (B,T,head_size)
o = wei @ v # (B,T,T) @ (B,T,head_size) -> (B,T,head_size)
o.shape # (B,T,head_size)

torch.Size([1, 8, 16])

In order to recover the dimension of the output, we pass the dot product through a linear layer called output layer.

In [68]:
out = nn.Linear(head_size,D)
o = out(o) # (B,T,head_size) -> (B,T,D)
o.shape # (B,T,head_size)
print(f'the input dimension is {x.shape} and the output dimension is {o.shape}')

the input dimension is torch.Size([1, 8, 32]) and the output dimension is torch.Size([1, 8, 32])


# Scaled product attention
since softmax

In [69]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
print('The variance of k is', torch.var(k).item())
print('The variance of q is', torch.var(q).item())
wei = k @ q.transpose(2,1)
print('The variance of wei is', torch.var(wei).item())


The variance of k is 1.1766988039016724
The variance of q is 0.8625323176383972
The variance of wei is 16.994678497314453


In [70]:
wei = k @ q.transpose(2,1) / (head_size**0.5)
print('The variance of wei is', torch.var(wei).item())

The variance of wei is 1.0621674060821533
