In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

B: batch dimension  
T: sequence length  
D: feature dimension (hidden_size)  

In [3]:
torch.manual_seed(1337)
# B: batch dimension, T: sequence length, D: feature dimension (hidden size)
B,T,D = 4,8,2
x = torch.randn(B,T,D)
xbow = torch.zeros((B,T,D))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1,:] #(T,D)
        xbow[b,t,:] = torch.mean(xprev,dim=0) # reduce the time dimension

x.shape == xbow.shape

True

We average the previous and current tokens.  
![image](images/bow.png)

In [5]:
print(f'The first batch of x is \n{x[0]}')
print(f'The first batch of xbow is \n{xbow[0]}')

The first batch of x is 
tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])
The first batch of xbow is 
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


Use matrix multiplication to calculate the mean of the previous tokens.

In [6]:
a = torch.tril(torch.ones(3,3)) # lower traingular matrix
a = a / torch.sum(a,dim=1,keepdim=True) # sum up the rows
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print(f'a=\n{a}')
print(f'b=\n{b}')
print(f'c=\n{c}')


a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=
tensor([[8., 6.],
        [5., 2.],
        [4., 4.]])
c=
tensor([[8.0000, 6.0000],
        [6.5000, 4.0000],
        [5.6667, 4.0000]])


In [8]:
# a stands for attention matrix
# b stands for input matrix
# c stands for output matrix
attn = torch.tril(torch.ones(T,T))
attn = attn / torch.sum(attn,dim=1,keepdim=True) # sum up the rows
print(f'attn matrix: \n{attn}')
xbow2 = attn @ x #(T,T) @ (B,T,D) -> (B,T,D)
print(f'The first batch of xbow2 is \n{xbow2[0]}')

attn matrix: 
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
The first batch of x is 
tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])
The first batch of xbow2 is 
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
  

We can also use softmax to produce the attention weights.

In [9]:
tril = torch.tril(torch.ones(T,T))
attn = torch.zeros(T,T)
attn = attn.masked_fill(tril == 0, float('-inf')) # mask out the upper triangle (masked-attention)
attn = F.softmax(attn,dim=-1) # reduce the time dimension
xbow3 = attn @ x
print(f'attn matrix: \n{attn}')
print(f'The first batch of xbow3 is \n{xbow3[0]}')

attn matrix: 
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
The first batch of xbow3 is 
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


### Self attention

Previous attention matrix simply averages the previous tokens and current tokens.  
Can we modify the weight matrix to pay more attention to specific token?  
Yes! We can use self-attention machanism to do so.  

In [11]:
B,T,D = 1,8,32
x = torch.randn(B,T,D)

head_size = 16
key = nn.Linear(D, head_size)
query = nn.Linear(D, head_size)

k = key(x) # (B,T,D) -> (B,T,head_size)
q = query(x) # (B,T,D) -> (B,T,head_size)

# previous, we simply set weight matrix to be all zeros
attn = k @ q.transpose(2,1) # (B,T,head_size) @ (B,head_size,T) -> (B,T,T)
tril = torch.tril(torch.ones(T,T))
attn = attn.masked_fill(tril == 0, float('-inf')) # mask out the upper triangle (masked-attention)
attn = F.softmax(attn,dim=-1)
print(f'attention matrix: \n{attn}')

attention matrix: 
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.8203, 0.1797, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3286, 0.3266, 0.3448, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3526, 0.0940, 0.5211, 0.0323, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.8307, 0.0526, 0.0739, 0.0089, 0.0339, 0.0000, 0.0000, 0.0000],
         [0.1564, 0.0692, 0.2802, 0.2136, 0.1546, 0.1260, 0.0000, 0.0000],
         [0.2718, 0.0714, 0.4514, 0.1176, 0.0164, 0.0538, 0.0176, 0.0000],
         [0.0171, 0.0574, 0.1648, 0.3710, 0.0882, 0.0989, 0.1188, 0.0838]]],
       grad_fn=<SoftmaxBackward0>)


Previously, we do dot product between attention weights and the embedding of the tokens. 
Here, we perform dot product between the attention weights and "values" of the tokens.  The value is produced by passing the embedding through a linear layer.

In [12]:
value = nn.Linear(D,head_size)
v = value(x) # (B,T,D) -> (B,T,head_size)
o = attn @ v # (B,T,T) @ (B,T,head_size) -> (B,T,head_size)
o.shape # (B,T,head_size)

torch.Size([1, 8, 16])

In order to recover the dimension of the output, we pass the dot product through a linear layer called output layer. This layer is optional since nowadays the dimension of dot product between weight matrix and value matrix is designed to be the same as the input dimension.

In [68]:
out = nn.Linear(head_size,D)
o = out(o) # (B,T,head_size) -> (B,T,D)
o.shape # (B,T,head_size)
print(f'the input dimension is {x.shape} and the output dimension is {o.shape}')

the input dimension is torch.Size([1, 8, 32]) and the output dimension is torch.Size([1, 8, 32])


### Scaled product attention
We expect the variance to be close to 1, but after dot product, the variance will be close to head_size. We therefore divide the dot product by sqrt(head_size) to scale the variance back to 1.

In [19]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
print('The variance of k is', torch.var(k).item())
print('The variance of q is', torch.var(q).item())
attn = k @ q.transpose(2,1)
print('The variance of attn is', torch.var(attn).item())


The variance of k is 1.0172395706176758
The variance of q is 0.9533872604370117
The variance of attn is 18.86295509338379


In [17]:
attn = k @ q.transpose(2,1) / (head_size**0.5)
print('The variance of attn is', torch.var(attn).item())

The variance of attn is 1.0233604907989502
