<img src="images/00-image.png" alt="encoder" class="bg-primary" width="100%">


[Image Reference](https://www.planetware.com/tourist-attractions-/potsdam-d-br-pt.htm)

<h1><center> Attention Explained <center></h1>

Vision Transformer (ViT) paper: [Paper Reference](https://arxiv.org/abs/2010.11929)

In [1]:
import torch
import numpy
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms

## Self Attention

In [2]:
# we need to find a way to parameterize each token so that we can rank them based on thier importance

In [3]:
torch.manual_seed(42)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [5]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

## Version 1 : Using simple mathematical function

In [14]:
x_bop = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        x_previous = x[b, :t+1]
        x_bop[b, t] = torch.mean(x_previous, 0) # bop: bag of pixel (average schemes)

In [15]:
x[0, :1+1]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152]])

In [16]:
x_bop[0, 1]

tensor([-0.0894, -0.4926])

In [17]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [18]:
x_bop[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

### Replicating above by combining Matmul and Trill

In [19]:
# Explaining concept of matmul - example
a = torch.ones(3,3)
print(a)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])


In [20]:
b = torch.randint(0, 10, (3, 2)).float()
print(b)

tensor([[8., 6.],
        [5., 2.],
        [4., 4.]])


In [22]:
c = a@b
print(c)

tensor([[17., 12.],
        [17., 12.],
        [17., 12.]])


In [24]:
# Explaining concept of Tril - example
torch.tril(torch.ones(8, 8))

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [25]:
# Combining tril and matmul
torch.manual_seed(42)
a = torch.tril(torch.ones(8, 8))
b = torch.randint(0, 10, (8, 2)).float()
c = a@b

print(a)
print(b)
print(c)

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.],
        [0., 4.],
        [0., 3.],
        [8., 4.],
        [0., 4.],
        [1., 2.]])
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.],
        [14., 20.],
        [14., 23.],
        [22., 27.],
        [22., 31.],
        [23., 33.]])


In [26]:
# average and using tril : basically make everything we multiply with it average themselves
torch.manual_seed(42)
a = torch.tril(torch.ones(8,8))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (8, 2)).float()
c = a@b

print(a)
print(b)
print(c)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.],
        [0., 4.],
        [0., 3.],
        [8., 4.],
        [0., 4.],
        [1., 2.]])
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333],
        [3.5000, 5.0000],
        [2.8000, 4.6000],
        [3.6667, 4.5000],
        [3.1429, 4.4286],
        [2.8750, 4.1250]])


### Introducing and Establishing the WEIGHT

In [27]:
# Now appling it
tril = torch.tril(torch.ones(T,T))

In [28]:
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [29]:
weight = tril / torch.sum(tril, 1, keepdim=True)
weight

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [33]:
x_bop2 = weight @ x

In [34]:
torch.allclose(x_bop2[0], x_bop[0])

True

In [37]:
x_bop[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [38]:
x_bop2[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

- Now we have found a way to parameterize them such that we have our weight (average) separated

#

## Version 2 : Using softmax - another method but better

- The weight is designed differently using softmax
- Though here softmax would do the same thing as above by evenly distributing the weight aggregate
- However, that is the case where we use 0 and 1, hence when we use weight that is data dependent (QKV in softmax), softmax would be better then

In [43]:
tril = torch.tril(torch.ones(T,T)) # because 8 tokens must result into 8X8
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [44]:
weight = torch.zeros(T,T)

In [45]:
weight = weight.masked_fill(tril == 0, float('-inf'))

In [46]:
weight

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [47]:
weight = torch.softmax(weight, dim=-1)
weight

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [48]:
x_bop3 = weight @ x

In [50]:
x_bop3[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [51]:
torch.allclose(x_bop2[0], x_bop3[0])

True

#

## Version 3: Final with self attention (Putting everything together)

- Here, the weight will be data dependent, thus making softmax very useful 
- The idea is, unlIke previous version, we do not want the values to be context average, otherwise it suggests that all token are equally important

Note below: In terms of image, we are no longer talking about 2D image, thus, arrangement of Channel, Width, Height no longer matter, everything is now 1D. Hence,
- B = batch, in image = channel
- T = time/Sentence, in image = flattend R/G/B 
- C = Depth dimension, can be any value you wish e.g RGB 

In [78]:
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)

- Note: the QKV is for the weight initialization, hence must come out in block_size/context lenght size
- The K and Q are the same values, but by using the transpose, each and every token can multiply its Key to all Query of the others
- Then we can estimate its affinity

In [79]:
head_size = 3 # 16 # this is the 
key = nn.Linear(in_features=C, out_features=head_size, bias=False)
query = nn.Linear(in_features=C, out_features=head_size, bias=False)
value = nn.Linear(in_features=C, out_features=head_size, bias=False)
k = key(x) # shape = (B, T, 3)
q = query(x) # shape = (B, T, 3)

In [80]:
k.shape

torch.Size([4, 8, 3])

In [81]:
k.transpose(-2, -1).shape

torch.Size([4, 3, 8])

In [82]:
weight = q@k.transpose(-2, -1)

In [83]:
weight.shape

torch.Size([4, 8, 8])

In [84]:
weight

tensor([[[ 1.8611e-01,  1.1800e-01,  1.5668e-01,  1.2040e-01, -4.4128e-02,
          -1.7572e-01,  6.6347e-02,  1.3252e-02],
         [ 2.1909e-01,  6.5659e-02,  3.5945e-01,  1.8460e-01, -1.3056e-01,
          -2.4246e-01,  6.7986e-02, -6.7991e-02],
         [-8.4862e-02,  1.2121e-01, -4.8961e-01, -1.5732e-01,  2.0796e-01,
           1.6520e-01, -6.0769e-03,  1.9369e-01],
         [ 6.1241e-02,  8.1695e-02, -5.0868e-02,  1.4532e-02,  3.1487e-02,
          -3.6983e-02,  2.7754e-02,  5.3281e-02],
         [ 6.4368e-02, -3.7803e-02,  2.4202e-01,  8.7650e-02, -9.9636e-02,
          -9.8991e-02,  1.2087e-02, -8.5132e-02],
         [-1.2658e-01, -1.1586e-01, -2.1478e-02, -6.1049e-02, -8.2046e-03,
           1.0220e-01, -5.0043e-02, -4.9649e-02],
         [ 8.0312e-02,  4.0801e-02,  9.1786e-02,  5.7878e-02, -2.9902e-02,
          -8.0747e-02,  2.7233e-02, -5.8286e-03],
         [ 1.2862e-01, -2.0452e-03,  3.0800e-01,  1.3213e-01, -1.2021e-01,
          -1.6207e-01,  3.4304e-02, -8.6237e-02]],

- The above has made our initial weight to be data driven
- Now we need to truncate using trill below

In [85]:
tril = torch.tril(torch.ones(T,T))
weight = weight.masked_fill(tril == 0, float('-inf')) # Comment out to obtain bi-directional effect
weight[0]

tensor([[ 0.1861,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.2191,  0.0657,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0849,  0.1212, -0.4896,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.0612,  0.0817, -0.0509,  0.0145,    -inf,    -inf,    -inf,    -inf],
        [ 0.0644, -0.0378,  0.2420,  0.0876, -0.0996,    -inf,    -inf,    -inf],
        [-0.1266, -0.1159, -0.0215, -0.0610, -0.0082,  0.1022,    -inf,    -inf],
        [ 0.0803,  0.0408,  0.0918,  0.0579, -0.0299, -0.0807,  0.0272,    -inf],
        [ 0.1286, -0.0020,  0.3080,  0.1321, -0.1202, -0.1621,  0.0343, -0.0862]],
       grad_fn=<SelectBackward0>)

In [86]:
weight = torch.softmax(weight, dim=-1)
weight[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5383, 0.4617, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3453, 0.4243, 0.2304, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2585, 0.2638, 0.2311, 0.2467, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2012, 0.1817, 0.2403, 0.2060, 0.1708, 0.0000, 0.0000, 0.0000],
        [0.1522, 0.1538, 0.1690, 0.1625, 0.1713, 0.1913, 0.0000, 0.0000],
        [0.1505, 0.1446, 0.1522, 0.1471, 0.1348, 0.1281, 0.1427, 0.0000],
        [0.1366, 0.1199, 0.1634, 0.1371, 0.1065, 0.1021, 0.1243, 0.1102]],
       grad_fn=<SelectBackward0>)

In [65]:
weight.shape

torch.Size([4, 8, 8])

In [66]:
v = value(x) # we aggregate the values not the exact token, it is also learnable
v.shape

torch.Size([4, 8, 3])

In [67]:
x_bop3 = weight @ v

In [68]:
x_bop3.shape

torch.Size([4, 8, 3])

## Puthing all together

In [87]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, embed_dim, head_size, dropout):
        super().__init__()
        self.key = nn.Linear(embed_dim, head_size, bias=False)
        self.query = nn.Linear(embed_dim, head_size, bias=False)
        self.value = nn.Linear(embed_dim, head_size, bias=False)
        #self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value):
        # input of size (batch, latent-space, feature map)
        # output of size (batch, latent-space, head size)
        
        B,T,C = key.shape
        
        key = self.key(key)   # (B,T,hs)
        query = self.query(query) # (B,T,hs)
        
        # compute attention scores ("affinities")
        weight = query @ key.transpose(-2,-1) # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        
        weight = weight * key.shape[-1]**-0.5 # Scale Factor

        weight = weight.masked_fill(torch.tril(torch.ones(T,T)) == 0, float('-inf'))# (B, T, T) # Comment out to obtain bi-directional effect
        
        weight = F.softmax(weight, dim=-1) # (B, T, T)
        
        weight = self.dropout(weight)
        
        # perform the weighted aggregation of the values
        value = self.value(value) # (B,T,hs)
        out = weight @ value # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        
        return out